Skip to content

Commit 6caa5e6

Browse files
committed
Add unit tests for ServerClassificationService
- Implement tests for initialization, classification logic, leader election, polling decisions, and Redis state management. - Cover various scenarios including hot/cold classification, tie-breaking logic, and service lifecycle management. - Ensure comprehensive testing of the server classification algorithm and its integration with the GatewayService. Signed-off-by: Lang-Akshay <akshay.shinde26@ibm.com>
1 parent 9e67dfe commit 6caa5e6

4 files changed

Lines changed: 1379 additions & 15 deletions

File tree

mcpgateway/services/gateway_service.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3351,10 +3351,14 @@ async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Opt
33513351

33523352
# Hot/cold classification: Check if this server should be health-checked now
33533353
if self._classification_service:
3354-
should_check = await self._classification_service.should_poll_server(gateway_url, "health")
3355-
if not should_check:
3356-
logger.debug(f"Skipping health check for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification")
3357-
return
3354+
try:
3355+
should_check = await self._classification_service.should_poll_server(gateway_url, "health")
3356+
if not should_check:
3357+
logger.debug(f"Skipping health check for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification")
3358+
return
3359+
except Exception as e:
3360+
# Fail open: proceed with health check if classification check fails
3361+
logger.warning(f"Classification check failed for {gateway_name}, proceeding with health check (fail-open): {e}")
33583362

33593363
# Create span for individual gateway health check
33603364
with create_span(
@@ -3556,9 +3560,14 @@ def get_httpx_client_factory(
35563560
if settings.auto_refresh_servers:
35573561
# Hot/cold classification: Check if this server should have tools refreshed now
35583562
if self._classification_service:
3559-
should_auto_refresh = await self._classification_service.should_poll_server(gateway_url, "tools")
3560-
if not should_auto_refresh:
3561-
logger.debug(f"Skipping auto-refresh for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification")
3563+
try:
3564+
should_auto_refresh = await self._classification_service.should_poll_server(gateway_url, "tools")
3565+
if not should_auto_refresh:
3566+
logger.debug(f"Skipping auto-refresh for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification")
3567+
except Exception as e:
3568+
# Fail open: proceed with auto-refresh if classification check fails
3569+
logger.warning(f"Classification check failed for {gateway_name}, proceeding with auto-refresh (fail-open): {e}")
3570+
should_auto_refresh = True
35623571
else:
35633572
should_auto_refresh = True
35643573

mcpgateway/services/server_classification_service.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,15 @@ def _classify_servers_from_pool(self, pool: MCPSessionPool, all_gateway_urls: Li
227227
Returns:
228228
ClassificationResult with hot/cold servers and metadata
229229
"""
230-
N = len(all_gateway_urls)
231-
hot_cap = floor(0.20 * N)
230+
total_servers = len(all_gateway_urls)
231+
hot_cap = floor(0.20 * total_servers)
232232

233233
# Step 3: Extract server usage from pooled sessions
234234
server_metrics: Dict[str, ServerUsageMetrics] = {}
235235

236236
# Iterate over pool._pools (Dict[PoolKey, Queue[PooledSession]])
237237
# PoolKey = (user_identity, url, identity_hash, transport_type, gateway_id)
238-
for pool_key, session_queue in pool._pools.items():
238+
for pool_key, session_queue in pool._pools.items(): # pylint: disable=protected-access
239239
url = pool_key[1] # Extract server URL from pool key
240240

241241
if url not in server_metrics:
@@ -244,7 +244,7 @@ def _classify_servers_from_pool(self, pool: MCPSessionPool, all_gateway_urls: Li
244244
# Process each pooled session in the queue
245245
try:
246246
# Access queue items (asyncio.Queue has internal _queue deque)
247-
sessions_list = list(session_queue._queue) if hasattr(session_queue, "_queue") else []
247+
sessions_list = list(session_queue._queue) if hasattr(session_queue, "_queue") else [] # pylint: disable=protected-access
248248

249249
for session in sessions_list:
250250
# PooledSession has: last_used, use_count
@@ -258,7 +258,7 @@ def _classify_servers_from_pool(self, pool: MCPSessionPool, all_gateway_urls: Li
258258
continue
259259

260260
# Count active sessions from _active dict
261-
for pool_key, active_set in pool._active.items():
261+
for pool_key, active_set in pool._active.items(): # pylint: disable=protected-access
262262
url = pool_key[1]
263263
if url in server_metrics:
264264
server_metrics[url].active_session_count += len(active_set)
@@ -293,7 +293,7 @@ def _classify_servers_from_pool(self, pool: MCPSessionPool, all_gateway_urls: Li
293293
return ClassificationResult(
294294
hot_servers=hot_servers,
295295
cold_servers=cold_servers,
296-
metadata=ClassificationMetadata(N=N, hot_cap=hot_cap, hot_actual=hot_actual, eligible_count=eligible_count, timestamp=time.time(), underutilized_reason=underutilized_reason),
296+
metadata=ClassificationMetadata(N=total_servers, hot_cap=hot_cap, hot_actual=hot_actual, eligible_count=eligible_count, timestamp=time.time(), underutilized_reason=underutilized_reason),
297297
)
298298

299299
async def _get_all_gateway_urls(self) -> List[str]:
@@ -306,11 +306,11 @@ async def _get_all_gateway_urls(self) -> List[str]:
306306
from sqlalchemy import select
307307

308308
# First-Party
309-
from mcpgateway.db import DbGateway, SessionLocal
309+
from mcpgateway.db import Gateway, SessionLocal
310310

311311
try:
312312
with SessionLocal() as db:
313-
result = db.execute(select(DbGateway.url).where(DbGateway.enabled == True)) # noqa: E712
313+
result = db.execute(select(Gateway.url).where(Gateway.enabled is True))
314314
urls = [row[0] for row in result]
315315
return urls
316316
except Exception as e:

0 commit comments

Comments
 (0)