Coverage for middleware/trace_middleware.py: 90.62%

32 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2026-04-15 12:24 +0000

1""" 

2Trace Middleware - Request Tracing for Distributed Logging 

3 

4Extracts session_id and request_id from HTTP headers and binds them 

5to structlog context for all subsequent logging in the request. 

6 

7Headers: 

8- X-Session-ID: Frontend session identifier (mt-sess-{uuid}) 

9- X-Request-ID: Per-request identifier (mt-req-{uuid}) 

10 

11If headers are not provided, generates new identifiers. 

12 

13Note: This middleware only handles context binding, not request logging. 

14Request lifecycle logging should be done at the application level where 

15the actual business logic resides, so log callsites show meaningful filenames. 

16""" 

17 

18import uuid 

19from contextvars import ContextVar 

20 

21import structlog 

22from starlette.middleware.base import BaseHTTPMiddleware 

23from starlette.requests import Request 

24from starlette.responses import Response 

25 

26# Context variables for trace IDs - accessible throughout the request lifecycle 

27_session_id: ContextVar[str | None] = ContextVar("session_id", default=None) 

28_request_id: ContextVar[str | None] = ContextVar("request_id", default=None) 

29 

30 

31def generate_request_id() -> str: 

32 """Generate a new request ID (8 hex chars for easy copy/paste).""" 

33 return f"mt-req-{uuid.uuid4().hex[:8]}" 

34 

35 

36def generate_session_id() -> str: 

37 """Generate a new session ID (8 hex chars, used when frontend doesn't provide one).""" 

38 return f"mt-sess-{uuid.uuid4().hex[:8]}" 

39 

40 

41def get_trace_context() -> tuple[str | None, str | None]: 

42 """ 

43 Get current trace context (session_id, request_id). 

44 

45 Returns: 

46 Tuple of (session_id, request_id) from current context 

47 """ 

48 return _session_id.get(), _request_id.get() 

49 

50 

51def get_session_id() -> str | None: 

52 """Get current session ID from context.""" 

53 return _session_id.get() 

54 

55 

56def get_request_id() -> str | None: 

57 """Get current request ID from context.""" 

58 return _request_id.get() 

59 

60 

61class TraceMiddleware(BaseHTTPMiddleware): 

62 """ 

63 Middleware that extracts trace IDs from request headers and binds 

64 them to structlog context for distributed tracing. 

65 

66 This middleware only handles context binding - no logging is done here 

67 so that log callsites show the actual business logic location. 

68 """ 

69 

70 async def dispatch(self, request: Request, call_next) -> Response: 

71 """Process request with trace context.""" 

72 # Extract or generate trace IDs 

73 session_id = request.headers.get("X-Session-ID") or generate_session_id() 

74 request_id = request.headers.get("X-Request-ID") or generate_request_id() 

75 

76 # Store in context variables 

77 session_token = _session_id.set(session_id) 

78 request_token = _request_id.set(request_id) 

79 

80 # Bind to structlog context for all logging in this request 

81 structlog.contextvars.bind_contextvars( 

82 session_id=session_id, 

83 request_id=request_id, 

84 ) 

85 

86 try: 

87 response = await call_next(request) 

88 

89 # Add trace IDs to response headers for debugging 

90 response.headers["X-Request-ID"] = request_id 

91 

92 return response 

93 

94 finally: 

95 # Reset context variables 

96 _session_id.reset(session_token) 

97 _request_id.reset(request_token) 

98 

99 # Clear structlog context 

100 structlog.contextvars.unbind_contextvars("session_id", "request_id")