Coverage for middleware/trace_middleware.py: 90.62%
32 statements
« prev ^ index » next coverage.py v7.10.6, created at 2026-04-13 00:07 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2026-04-13 00:07 +0000
1"""
2Trace Middleware - Request Tracing for Distributed Logging
4Extracts session_id and request_id from HTTP headers and binds them
5to structlog context for all subsequent logging in the request.
7Headers:
8- X-Session-ID: Frontend session identifier (mt-sess-{uuid})
9- X-Request-ID: Per-request identifier (mt-req-{uuid})
11If headers are not provided, generates new identifiers.
13Note: This middleware only handles context binding, not request logging.
14Request lifecycle logging should be done at the application level where
15the actual business logic resides, so log callsites show meaningful filenames.
16"""
18import uuid
19from contextvars import ContextVar
21import structlog
22from starlette.middleware.base import BaseHTTPMiddleware
23from starlette.requests import Request
24from starlette.responses import Response
26# Context variables for trace IDs - accessible throughout the request lifecycle
27_session_id: ContextVar[str | None] = ContextVar("session_id", default=None)
28_request_id: ContextVar[str | None] = ContextVar("request_id", default=None)
31def generate_request_id() -> str:
32 """Generate a new request ID (8 hex chars for easy copy/paste)."""
33 return f"mt-req-{uuid.uuid4().hex[:8]}"
36def generate_session_id() -> str:
37 """Generate a new session ID (8 hex chars, used when frontend doesn't provide one)."""
38 return f"mt-sess-{uuid.uuid4().hex[:8]}"
41def get_trace_context() -> tuple[str | None, str | None]:
42 """
43 Get current trace context (session_id, request_id).
45 Returns:
46 Tuple of (session_id, request_id) from current context
47 """
48 return _session_id.get(), _request_id.get()
51def get_session_id() -> str | None:
52 """Get current session ID from context."""
53 return _session_id.get()
56def get_request_id() -> str | None:
57 """Get current request ID from context."""
58 return _request_id.get()
61class TraceMiddleware(BaseHTTPMiddleware):
62 """
63 Middleware that extracts trace IDs from request headers and binds
64 them to structlog context for distributed tracing.
66 This middleware only handles context binding - no logging is done here
67 so that log callsites show the actual business logic location.
68 """
70 async def dispatch(self, request: Request, call_next) -> Response:
71 """Process request with trace context."""
72 # Extract or generate trace IDs
73 session_id = request.headers.get("X-Session-ID") or generate_session_id()
74 request_id = request.headers.get("X-Request-ID") or generate_request_id()
76 # Store in context variables
77 session_token = _session_id.set(session_id)
78 request_token = _request_id.set(request_id)
80 # Bind to structlog context for all logging in this request
81 structlog.contextvars.bind_contextvars(
82 session_id=session_id,
83 request_id=request_id,
84 )
86 try:
87 response = await call_next(request)
89 # Add trace IDs to response headers for debugging
90 response.headers["X-Request-ID"] = request_id
92 return response
94 finally:
95 # Reset context variables
96 _session_id.reset(session_token)
97 _request_id.reset(request_token)
99 # Clear structlog context
100 structlog.contextvars.unbind_contextvars("session_id", "request_id")