Skip to content

Commit a1f3a9d

Browse files
committed
test: add async tests for classification loop and error handling in ServerClassificationService
Signed-off-by: Lang-Akshay <akshay.shinde26@ibm.com>
1 parent a698569 commit a1f3a9d

1 file changed

Lines changed: 139 additions & 0 deletions

File tree

tests/unit/mcpgateway/services/test_server_classification_service.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,145 @@ async def test_publish_classification_handles_redis_error(self):
12031203
# Should not raise exception, just log error
12041204
await service._publish_classification_to_redis(result)
12051205

1206+
@pytest.mark.asyncio
1207+
async def test_classification_loop_as_leader_calls_perform(self):
1208+
"""Test classification loop calls _perform_classification when leader (lines 147-148)."""
1209+
mock_redis = AsyncMock()
1210+
service = ServerClassificationService(redis_client=mock_redis)
1211+
service._running = True
1212+
1213+
perform_called = asyncio.Event()
1214+
1215+
async def mock_perform():
1216+
perform_called.set()
1217+
service._running = False # Stop after first call
1218+
1219+
with patch.object(service, "_try_acquire_leader_lock", AsyncMock(return_value=True)):
1220+
with patch.object(service, "_perform_classification", side_effect=mock_perform):
1221+
with patch("mcpgateway.services.server_classification_service.settings") as mock_settings:
1222+
mock_settings.gateway_auto_refresh_interval = 0.05
1223+
await asyncio.wait_for(service._run_classification_loop(), timeout=2.0)
1224+
1225+
assert perform_called.is_set()
1226+
1227+
@pytest.mark.asyncio
1228+
async def test_classification_loop_cancelled_error(self):
1229+
"""Test classification loop handles CancelledError cleanly (lines 155-156)."""
1230+
mock_redis = AsyncMock()
1231+
service = ServerClassificationService(redis_client=mock_redis)
1232+
service._running = True
1233+
1234+
with patch.object(service, "_try_acquire_leader_lock", AsyncMock(return_value=False)):
1235+
with patch("mcpgateway.services.server_classification_service.settings") as mock_settings:
1236+
mock_settings.gateway_auto_refresh_interval = 10 # Long sleep so we can cancel
1237+
1238+
task = asyncio.create_task(service._run_classification_loop())
1239+
await asyncio.sleep(0.05)
1240+
task.cancel()
1241+
try:
1242+
await task
1243+
except asyncio.CancelledError:
1244+
pass # CancelledError propagates from asyncio.sleep, not the break
1245+
1246+
@pytest.mark.asyncio
1247+
async def test_perform_classification_pool_not_initialized(self):
1248+
"""Test _perform_classification returns early when pool not initialized (lines 188-190)."""
1249+
service = ServerClassificationService(redis_client=None)
1250+
1251+
with patch("mcpgateway.services.mcp_session_pool.get_mcp_session_pool", side_effect=RuntimeError("pool not initialized")):
1252+
# Should return early without error
1253+
await service._perform_classification()
1254+
1255+
@pytest.mark.asyncio
1256+
async def test_perform_classification_logs_underutilized_reason(self):
1257+
"""Test _perform_classification logs underutilized_reason when present (line 210)."""
1258+
mock_redis = AsyncMock()
1259+
mock_pipeline = AsyncMock()
1260+
mock_pipeline.__aenter__ = AsyncMock(return_value=mock_pipeline)
1261+
mock_pipeline.__aexit__ = AsyncMock(return_value=False)
1262+
mock_pipeline.execute = AsyncMock(return_value=None)
1263+
mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
1264+
1265+
service = ServerClassificationService(redis_client=mock_redis)
1266+
1267+
# Use 20 URLs so hot_cap=4, but only 1 has session activity → underutilized
1268+
all_urls = [f"http://server{i}:8080" for i in range(20)]
1269+
mock_pool = MagicMock()
1270+
pool_key = ("anonymous", all_urls[0], "hash123", TransportType.STREAMABLE_HTTP, None)
1271+
mock_queue = MagicMock()
1272+
active_session = MagicMock()
1273+
active_session.last_used = time.time()
1274+
active_session.use_count = 5
1275+
mock_queue._queue = deque([active_session])
1276+
mock_pool._pools = {pool_key: mock_queue}
1277+
1278+
with patch("mcpgateway.services.server_classification_service.settings") as mock_settings:
1279+
mock_settings.hot_cold_classification_enabled = True
1280+
mock_settings.hot_server_check_interval = 300
1281+
mock_settings.cold_server_check_interval = 900
1282+
mock_settings.gateway_auto_refresh_interval = 60
1283+
1284+
with patch.object(service, "_get_all_gateway_urls", AsyncMock(return_value=all_urls)):
1285+
with patch("mcpgateway.services.mcp_session_pool.get_mcp_session_pool", return_value=mock_pool):
1286+
await service._perform_classification()
1287+
1288+
mock_redis.pipeline.assert_called()
1289+
1290+
@pytest.mark.asyncio
1291+
async def test_perform_classification_full_happy_path(self):
1292+
"""Test _perform_classification runs classification and publishes to Redis (lines 199-213)."""
1293+
mock_redis = AsyncMock()
1294+
mock_pipeline = AsyncMock()
1295+
mock_pipeline.__aenter__ = AsyncMock(return_value=mock_pipeline)
1296+
mock_pipeline.__aexit__ = AsyncMock(return_value=False)
1297+
mock_pipeline.execute = AsyncMock(return_value=None)
1298+
mock_redis.pipeline = MagicMock(return_value=mock_pipeline)
1299+
1300+
service = ServerClassificationService(redis_client=mock_redis)
1301+
1302+
mock_pool = MagicMock()
1303+
mock_pool._pools = {} # Empty pool — all servers will be cold
1304+
1305+
with patch("mcpgateway.services.server_classification_service.settings") as mock_settings:
1306+
mock_settings.hot_cold_classification_enabled = True
1307+
mock_settings.hot_server_check_interval = 300
1308+
mock_settings.cold_server_check_interval = 900
1309+
mock_settings.gateway_auto_refresh_interval = 60
1310+
1311+
with patch.object(service, "_get_all_gateway_urls", AsyncMock(return_value=["http://test:8080"])):
1312+
with patch("mcpgateway.services.mcp_session_pool.get_mcp_session_pool", return_value=mock_pool):
1313+
await service._perform_classification()
1314+
1315+
# Redis pipeline should have been called to publish classification
1316+
mock_redis.pipeline.assert_called()
1317+
1318+
@pytest.mark.asyncio
1319+
async def test_perform_classification_exception_path(self):
1320+
"""Test _perform_classification catches and logs unexpected exceptions (line 212-213)."""
1321+
service = ServerClassificationService(redis_client=None)
1322+
1323+
with patch.object(service, "_get_all_gateway_urls", AsyncMock(side_effect=RuntimeError("unexpected"))):
1324+
with patch("mcpgateway.services.mcp_session_pool.get_mcp_session_pool", return_value=MagicMock()):
1325+
# Should not raise — exception is caught inside _perform_classification
1326+
await service._perform_classification()
1327+
1328+
@pytest.mark.asyncio
1329+
async def test_classify_servers_queue_missing_queue_attr(self):
1330+
"""Test _classify_servers_from_pool handles queue without _queue attribute (lines 254-255)."""
1331+
mock_pool = MagicMock()
1332+
pool_key = ("anonymous", "http://test:8080", "hash123", TransportType.STREAMABLE_HTTP, None)
1333+
1334+
# Queue that does NOT have _queue attribute
1335+
mock_queue = MagicMock(spec=[]) # spec=[] means no attributes allowed
1336+
mock_pool._pools = {pool_key: mock_queue}
1337+
1338+
service = ServerClassificationService(redis_client=None)
1339+
result = service._classify_servers_from_pool(mock_pool, ["http://test:8080"])
1340+
1341+
# Should complete without error, server goes cold (no session data)
1342+
assert result is not None
1343+
assert "http://test:8080" in result.cold_servers
1344+
12061345
@pytest.mark.asyncio
12071346
async def test_get_all_gateway_urls_handles_db_error(self):
12081347
"""Test _get_all_gateway_urls handles database errors."""

0 commit comments

Comments
 (0)