Skip to content

Commit 404df7d

Browse files
committed
middleware.py
1 parent 840ab22 commit 404df7d

6 files changed

Lines changed: 372 additions & 2 deletions

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,3 @@ cython_debug/
205205
marimo/_static/
206206
marimo/_lsp/
207207
__marimo__/
208-
.pre-commit-config.yaml

.pre-commit-config.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.6.0
4+
hooks:
5+
- id: trailing-whitespace
6+
- id: end-of-file-fixer
7+
- id: check-yaml
8+
- id: check-toml
9+
- id: check-merge-conflict
10+
11+
- repo: https://github.com/psf/black
12+
rev: 24.10.0
13+
hooks:
14+
- id: black

fastapi_observer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .config import ObserverConfig
22
from .logger import build_logger, log_event
3+
from .middleware import ObserverMiddleware
34
from .models import LogEvent
45

5-
__all__ = ["ObserverConfig", "LogEvent", "build_logger", "log_event"]
6+
__all__ = ["ObserverConfig", "ObserverMiddleware", "LogEvent", "build_logger", "log_event"]

fastapi_observer/middleware.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import logging
5+
import time
6+
import uuid
7+
from typing import Any, Callable
8+
9+
from .config import ObserverConfig
10+
from .logger import build_logger, log_event
11+
from .models import LogEvent
12+
13+
try:
14+
from starlette.middleware.base import BaseHTTPMiddleware
15+
from starlette.requests import Request
16+
from starlette.responses import Response
17+
except ModuleNotFoundError as exc: # pragma: no cover - guarded runtime fallback
18+
_IMPORT_ERROR = exc
19+
else:
20+
_IMPORT_ERROR = None
21+
22+
23+
def _redact_sensitive(value: Any, redact_fields: set[str]) -> Any:
24+
if isinstance(value, dict):
25+
redacted: dict[str, Any] = {}
26+
for key, nested_value in value.items():
27+
if str(key).lower() in redact_fields:
28+
redacted[key] = "***"
29+
else:
30+
redacted[key] = _redact_sensitive(nested_value, redact_fields)
31+
return redacted
32+
if isinstance(value, list):
33+
return [_redact_sensitive(item, redact_fields) for item in value]
34+
return value
35+
36+
37+
def _redact_headers(headers: dict[str, str], redact_headers: set[str]) -> dict[str, str]:
38+
result: dict[str, str] = {}
39+
for key, value in headers.items():
40+
if key.lower() in redact_headers:
41+
result[key] = "***"
42+
else:
43+
result[key] = value
44+
return result
45+
46+
47+
def _to_log_level(status_code: int) -> str:
48+
if status_code >= 500:
49+
return "ERROR"
50+
if status_code >= 400:
51+
return "WARNING"
52+
return "INFO"
53+
54+
55+
if _IMPORT_ERROR is None:
56+
57+
class ObserverMiddleware(BaseHTTPMiddleware):
58+
def __init__(
59+
self,
60+
app: Any,
61+
*,
62+
config: ObserverConfig | None = None,
63+
logger: logging.Logger | None = None,
64+
) -> None:
65+
super().__init__(app)
66+
self.config = config or ObserverConfig()
67+
self.logger = logger or build_logger(self.config)
68+
69+
async def dispatch(
70+
self,
71+
request: Request,
72+
call_next: Callable[[Request], Any],
73+
) -> Response:
74+
if not self.config.enabled:
75+
return await call_next(request)
76+
77+
path = request.url.path
78+
method = request.method
79+
if not self.config.should_log_method(method) or not self.config.should_log_path(path):
80+
return await call_next(request)
81+
82+
correlation_id = self._resolve_correlation_id(request)
83+
request_body_for_log: Any | None = None
84+
request_for_next = request
85+
if self.config.log_request_body:
86+
body = await request.body()
87+
request_body_for_log = self._format_body_for_log(body)
88+
request_for_next = Request(request.scope, receive=_build_body_replay_receive(body))
89+
90+
start = time.perf_counter()
91+
try:
92+
response = await call_next(request_for_next)
93+
except Exception as exc:
94+
duration_ms = round((time.perf_counter() - start) * 1000, 3)
95+
log_event(
96+
self.logger,
97+
self.config,
98+
LogEvent(
99+
level="ERROR",
100+
message="HTTP request failed",
101+
method=method,
102+
path=path,
103+
status_code=500,
104+
duration_ms=duration_ms,
105+
correlation_id=correlation_id,
106+
error=str(exc),
107+
metadata=self._build_metadata(
108+
request,
109+
request_body=request_body_for_log,
110+
),
111+
),
112+
)
113+
raise
114+
115+
duration_ms = round((time.perf_counter() - start) * 1000, 3)
116+
response_body_for_log = (
117+
self._extract_response_body_for_log(response)
118+
if self.config.log_response_body
119+
else None
120+
)
121+
122+
log_event(
123+
self.logger,
124+
self.config,
125+
LogEvent(
126+
level=_to_log_level(response.status_code),
127+
message="HTTP request completed",
128+
method=method,
129+
path=path,
130+
status_code=response.status_code,
131+
duration_ms=duration_ms,
132+
correlation_id=correlation_id,
133+
metadata=self._build_metadata(
134+
request,
135+
request_body=request_body_for_log,
136+
response_body=response_body_for_log,
137+
),
138+
),
139+
)
140+
141+
if correlation_id:
142+
response.headers.setdefault(self.config.correlation_id_header, correlation_id)
143+
return response
144+
145+
def _resolve_correlation_id(self, request: Request) -> str | None:
146+
value = request.headers.get(self.config.correlation_id_header)
147+
if value:
148+
request.state.correlation_id = value
149+
return value
150+
if self.config.generate_correlation_id:
151+
generated = str(uuid.uuid4())
152+
request.state.correlation_id = generated
153+
return generated
154+
return None
155+
156+
def _build_metadata(
157+
self,
158+
request: Request,
159+
*,
160+
request_body: Any | None = None,
161+
response_body: Any | None = None,
162+
) -> dict[str, Any]:
163+
metadata: dict[str, Any] = {
164+
"query_params": dict(request.query_params),
165+
}
166+
if request.client:
167+
metadata["client"] = {
168+
"host": request.client.host,
169+
"port": request.client.port,
170+
}
171+
if self.config.log_headers:
172+
metadata["headers"] = _redact_headers(
173+
dict(request.headers),
174+
self.config.redact_headers,
175+
)
176+
if request_body is not None:
177+
metadata["request_body"] = request_body
178+
if response_body is not None:
179+
metadata["response_body"] = response_body
180+
return metadata
181+
182+
def _extract_response_body_for_log(self, response: Response) -> Any | None:
183+
body = getattr(response, "body", None)
184+
if body is None:
185+
return None
186+
if isinstance(body, bytes):
187+
return self._format_body_for_log(body)
188+
if isinstance(body, str):
189+
return self._format_body_for_log(body.encode("utf-8"))
190+
return None
191+
192+
def _format_body_for_log(self, body: bytes) -> Any:
193+
body_size = len(body)
194+
body_for_parse = body[: self.config.max_body_bytes]
195+
parsed: Any
196+
try:
197+
parsed = json.loads(body_for_parse.decode("utf-8"))
198+
except (UnicodeDecodeError, json.JSONDecodeError):
199+
parsed = body_for_parse.decode("utf-8", errors="replace")
200+
201+
redacted = _redact_sensitive(parsed, self.config.redact_fields)
202+
if body_size <= self.config.max_body_bytes:
203+
return redacted
204+
return {
205+
"truncated": True,
206+
"original_size": body_size,
207+
"content": redacted,
208+
}
209+
210+
211+
def _build_body_replay_receive(body: bytes) -> Callable[[], Any]:
212+
sent = False
213+
214+
async def receive() -> dict[str, Any]:
215+
nonlocal sent
216+
if not sent:
217+
sent = True
218+
return {"type": "http.request", "body": body, "more_body": False}
219+
return {"type": "http.request", "body": b"", "more_body": False}
220+
221+
return receive
222+
223+
224+
else:
225+
226+
class ObserverMiddleware: # pragma: no cover - import-time guard only
227+
def __init__(self, *_args: Any, **_kwargs: Any) -> None:
228+
raise RuntimeError(
229+
"ObserverMiddleware requires FastAPI/Starlette. "
230+
"Install optional runtime dependencies first."
231+
) from _IMPORT_ERROR
232+
233+
234+
__all__ = ["ObserverMiddleware"]

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ requires-python = ">=3.10,<3.15"
1111
license = { file = "LICENSE" }
1212
authors = [{ name = "fastapi-observer maintainers" }]
1313
dependencies = [
14+
"fastapi>=0.100,<1.0",
1415
"pydantic>=2.0,<3.0",
1516
]
1617

1718
[project.optional-dependencies]
1819
test = [
20+
"httpx>=0.24",
1921
"pytest>=8.0",
2022
"pytest-cov>=5.0",
2123
]

tests/test_middleware.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import logging
2+
3+
import pytest
4+
5+
fastapi = pytest.importorskip("fastapi")
6+
testclient = pytest.importorskip("fastapi.testclient")
7+
8+
from fastapi_observer.config import ObserverConfig
9+
from fastapi_observer.middleware import ObserverMiddleware
10+
11+
FastAPI = fastapi.FastAPI
12+
TestClient = testclient.TestClient
13+
14+
15+
class InMemoryHandler(logging.Handler):
16+
def __init__(self) -> None:
17+
super().__init__()
18+
self.records: list[logging.LogRecord] = []
19+
20+
def emit(self, record: logging.LogRecord) -> None:
21+
self.records.append(record)
22+
23+
24+
def _build_memory_logger(name: str) -> tuple[logging.Logger, InMemoryHandler]:
25+
handler = InMemoryHandler()
26+
logger = logging.getLogger(name)
27+
logger.handlers.clear()
28+
logger.setLevel(logging.INFO)
29+
logger.propagate = False
30+
logger.addHandler(handler)
31+
return logger, handler
32+
33+
34+
def test_middleware_logs_completed_request_and_sets_correlation_id():
35+
app = FastAPI()
36+
37+
@app.get("/items")
38+
async def read_items():
39+
return {"ok": True}
40+
41+
logger, memory = _build_memory_logger("fastapi_observer.test.middleware.success")
42+
config = ObserverConfig(log_headers=True)
43+
app.add_middleware(ObserverMiddleware, config=config, logger=logger)
44+
45+
client = TestClient(app)
46+
response = client.get("/items?limit=10")
47+
48+
assert response.status_code == 200
49+
assert config.correlation_id_header in response.headers
50+
assert len(memory.records) == 1
51+
52+
event = memory.records[0].event
53+
assert event["message"] == "HTTP request completed"
54+
assert event["path"] == "/items"
55+
assert event["status_code"] == 200
56+
assert event["correlation_id"] == response.headers[config.correlation_id_header]
57+
assert event["metadata"]["query_params"] == {"limit": "10"}
58+
59+
60+
def test_middleware_skips_excluded_path():
61+
app = FastAPI()
62+
63+
@app.get("/health")
64+
async def health():
65+
return {"ok": True}
66+
67+
logger, memory = _build_memory_logger("fastapi_observer.test.middleware.exclude")
68+
config = ObserverConfig()
69+
app.add_middleware(ObserverMiddleware, config=config, logger=logger)
70+
71+
client = TestClient(app)
72+
response = client.get("/health")
73+
74+
assert response.status_code == 200
75+
assert memory.records == []
76+
77+
78+
def test_middleware_logs_request_and_response_body_with_redaction():
79+
app = FastAPI()
80+
81+
@app.post("/echo")
82+
async def echo(payload: dict):
83+
return payload
84+
85+
logger, memory = _build_memory_logger("fastapi_observer.test.middleware.body")
86+
config = ObserverConfig(log_request_body=True, log_response_body=True)
87+
app.add_middleware(ObserverMiddleware, config=config, logger=logger)
88+
89+
client = TestClient(app)
90+
response = client.post("/echo", json={"password": "secret", "name": "alice"})
91+
92+
assert response.status_code == 200
93+
assert len(memory.records) == 1
94+
95+
event = memory.records[0].event
96+
assert event["metadata"]["request_body"]["password"] == "***"
97+
assert event["metadata"]["response_body"]["password"] == "***"
98+
assert event["metadata"]["request_body"]["name"] == "alice"
99+
100+
101+
def test_middleware_logs_exceptions():
102+
app = FastAPI()
103+
104+
@app.get("/boom")
105+
async def boom():
106+
raise RuntimeError("something failed")
107+
108+
logger, memory = _build_memory_logger("fastapi_observer.test.middleware.error")
109+
config = ObserverConfig()
110+
app.add_middleware(ObserverMiddleware, config=config, logger=logger)
111+
112+
client = TestClient(app, raise_server_exceptions=False)
113+
response = client.get("/boom")
114+
115+
assert response.status_code == 500
116+
assert len(memory.records) == 1
117+
event = memory.records[0].event
118+
assert event["message"] == "HTTP request failed"
119+
assert event["status_code"] == 500
120+
assert "something failed" in event["error"]

0 commit comments

Comments
 (0)