Skip to content

Commit 6182318

Browse files
MohanLakshcrivetimihai
authored andcommitted
[BUG FIX][DB]: Restore transaction control to get_db() for middleware sessions (#3731) (#3813)
* fix(db): restore transaction control to get_db() for middleware sessions PR #3600 introduced a transaction management violation where ObservabilityMiddleware commits the shared database session instead of get_db(), breaking the established contract where get_db() controls transaction boundaries. This creates data integrity risks where failed validations can be committed to the database. This fix restores the correct behavior: - Middleware manages session lifecycle (create/close) - get_db() manages transactions (commit/rollback) Changes: - Remove commit logic from ObservabilityMiddleware (observability_middleware.py:210-216) - Add commit/rollback handling to get_db() for middleware sessions (main.py:3137-3164) - Update get_db() docstring to document transaction control responsibility - Update 2 existing tests to reflect new behavior - Add 7 comprehensive tests for transaction semantics Security implications: - Fixes data integrity bug where invalid data could be committed - Maintains proper transaction isolation per request - Preserves connection invalidation on broken connections - No impact on auth/RBAC (middleware runs before route handlers) Trade-offs: - Observability data (traces/spans) is rolled back on errors (acceptable - best-effort tracing) Closes #3731 Signed-off-by: Mohan Lakshmaiah <mohan.economist@gmail.com> * test(db): add coverage for double-failure edge case in get_db() Signed-off-by: Mohan Lakshmaiah <mohan.economist@gmail.com> * fix(tests): clean up lint violations in transaction control tests Remove unused AsyncMock import and unused variable assignments flagged by ruff (F401, F841). Apply isort/black formatting. Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> --------- Signed-off-by: Mohan Lakshmaiah <mohan.economist@gmail.com> Signed-off-by: Mihai Criveti <crivetimihai@gmail.com> Co-authored-by: Mihai Criveti <crivetimihai@gmail.com>
1 parent 10224d8 commit 6182318

4 files changed

Lines changed: 403 additions & 17 deletions

File tree

mcpgateway/main.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3092,8 +3092,17 @@ def get_db(request: Request = None):
30923092
30933093
When observability is enabled, this reuses the session created by
30943094
ObservabilityMiddleware (stored in request.state.db) to avoid duplicate
3095-
session creation. When observability is disabled or the
3096-
middleware hasn't created a session, this creates its own session.
3095+
session creation. When observability is disabled or the middleware hasn't
3096+
created a session, this creates its own session.
3097+
3098+
**Transaction Control**: This function ALWAYS controls transaction boundaries
3099+
(commit/rollback) regardless of whether it creates the session or reuses one
3100+
from middleware. This ensures predictable transaction semantics for route
3101+
handlers and maintains data integrity.
3102+
3103+
**Session Lifecycle**: Middleware manages session lifecycle (create/close)
3104+
while this function manages transactions (commit/rollback). This separation
3105+
of concerns prevents the transaction management violation described in #3731.
30973106
30983107
Commits the transaction on successful completion to avoid implicit rollbacks
30993108
for read-only operations. Rolls back explicitly on exception.
@@ -3114,7 +3123,10 @@ def get_db(request: Request = None):
31143123
Exception: Re-raises any exception after rolling back the transaction.
31153124
31163125
Ensures:
3117-
The database session is closed after the request completes, even in the case of an exception.
3126+
- Transaction is committed on success (for both owned and reused sessions)
3127+
- Transaction is rolled back on error (for both owned and reused sessions)
3128+
- Session is closed only if created by this function (not if reused from middleware)
3129+
- Broken connections are invalidated to prevent pool corruption
31183130
31193131
Examples:
31203132
>>> # Test that get_db returns a generator
@@ -3138,9 +3150,32 @@ def get_db(request: Request = None):
31383150
db = request.state.db
31393151
if db is not None:
31403152
logger.debug(f"[GET_DB] Reusing session from middleware: {id(db)}")
3141-
# Yield the middleware's session without closing it
3142-
# The middleware will handle commit/rollback/close
3143-
yield db
3153+
# Yield the middleware's session. We control transactions, middleware controls lifecycle.
3154+
try:
3155+
yield db
3156+
# Commit on successful completion (only if transaction still active)
3157+
# The transaction can become inactive if an exception occurred during
3158+
# async context manager cleanup (e.g., CancelledError during MCP session teardown).
3159+
if db.is_active:
3160+
db.commit()
3161+
except Exception:
3162+
try:
3163+
# Always call rollback() in exception handler.
3164+
# rollback() is safe to call even when is_active=False - it succeeds and
3165+
# restores the session to a usable state. When is_active=False (e.g., after
3166+
# IntegrityError), rollback() is actually REQUIRED to clear the failed state.
3167+
# Skipping rollback when is_active=False would leave the session unusable.
3168+
db.rollback()
3169+
except Exception:
3170+
# Connection is broken - invalidate to remove from pool
3171+
# This handles cases like PgBouncer query_wait_timeout where
3172+
# the connection is dead and rollback itself fails
3173+
try:
3174+
db.invalidate()
3175+
except Exception:
3176+
pass # nosec B110 - Best effort cleanup on connection failure
3177+
raise
3178+
# Don't close - middleware owns the session lifecycle
31443179
return
31453180

31463181
# Fallback: Create our own session (observability disabled or middleware didn't create one)

mcpgateway/middleware/observability_middleware.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,10 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
207207
except Exception as end_trace_error:
208208
logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}")
209209

210-
# Commit the shared session (used by both observability and route handler)
211-
# Note: Some route handlers may have already committed. The is_active check
212-
# ensures we only commit if the transaction is still open. Services that
213-
# explicitly commit will have already closed their transaction.
214-
# Only commit if the transaction is still active AND has uncommitted changes
215-
if db.is_active and db.in_transaction():
216-
db.commit()
210+
# NOTE: Transaction control delegated to get_db()
211+
# Middleware only manages session lifecycle (create/close), not transactions.
212+
# get_db() will commit on success or rollback on error to maintain
213+
# predictable transaction semantics for route handlers (Issue #3731).
217214

218215
return response
219216

tests/unit/mcpgateway/middleware/test_observability_middleware.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,9 @@ async def mock_call_next(request):
288288
# Verify only one session was created
289289
assert len(session_instances) == 1, f"Expected 1 session, but {len(session_instances)} were created"
290290

291-
# Verify session was committed and closed
292-
session_instances[0].commit.assert_called_once()
291+
# Verify session was closed (lifecycle management)
292+
# Note: Middleware no longer commits - transaction control delegated to get_db() (Issue #3731)
293+
session_instances[0].commit.assert_not_called()
293294
session_instances[0].close.assert_called_once()
294295

295296
# Verify response
@@ -341,6 +342,7 @@ async def test_get_db_reuses_middleware_session():
341342
# Create a mock request with a session in state
342343
mock_request = MagicMock(spec=Request)
343344
mock_session = MagicMock()
345+
mock_session.is_active = True # Required for commit check
344346
mock_request.state.db = mock_session
345347

346348
# Call get_db with the request
@@ -350,14 +352,17 @@ async def test_get_db_reuses_middleware_session():
350352
# Verify it returns the same session
351353
assert db is mock_session, "get_db should return the middleware's session"
352354

353-
# Verify the session is NOT closed (middleware will handle that)
355+
# Complete the generator (simulating successful request)
354356
try:
355357
next(db_generator)
356358
except StopIteration:
357359
pass
358360

361+
# Verify get_db() commits the middleware session (Issue #3731 fix)
362+
# Transaction control is now delegated to get_db(), not middleware
363+
mock_session.commit.assert_called_once()
364+
# Verify the session is NOT closed (middleware will handle that)
359365
mock_session.close.assert_not_called()
360-
mock_session.commit.assert_not_called()
361366

362367

363368
@pytest.mark.asyncio

0 commit comments

Comments
 (0)