Coverage for csrf_protection.py: 30.49%
64 statements
« prev ^ index » next coverage.py v7.10.6, created at 2026-04-13 14:26 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2026-04-13 14:26 +0000
1"""
2CSRF protection middleware for the sports league backend.
3"""
5import logging
6import os
7import secrets
9from fastapi import Request, Response
10from fastapi.responses import JSONResponse
12logger = logging.getLogger(__name__)
14# Configuration
15CSRF_SECRET_KEY = os.getenv("CSRF_SECRET_KEY", secrets.token_urlsafe(32))
16CSRF_TOKEN_LIFETIME = 3600 # 1 hour
17CSRF_HEADER_NAME = "X-CSRF-Token"
18CSRF_COOKIE_NAME = "csrf_token"
20# Methods that require CSRF protection
21PROTECTED_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
23# Paths that are exempt from CSRF protection
24CSRF_EXEMPT_PATHS = {
25 "/api/auth/login", # Login needs to work without existing token
26 "/api/auth/signup", # Signup needs to work without existing token
27 "/api/auth/refresh", # Token refresh endpoint
28 "/docs", # API documentation
29 "/openapi.json", # OpenAPI spec
30 "/redoc", # ReDoc documentation
31}
34class CSRFProtection:
35 """CSRF Protection middleware for FastAPI."""
37 def __init__(self):
38 self.secret_key = CSRF_SECRET_KEY
40 def generate_csrf_token(self) -> str:
41 """Generate a new CSRF token."""
42 token = secrets.token_urlsafe(32)
43 return token
45 def verify_csrf_token(self, token: str, cookie_token: str) -> bool:
46 """Verify that the provided token matches the cookie token."""
47 if not token or not cookie_token:
48 return False
50 # Constant-time comparison to prevent timing attacks
51 return secrets.compare_digest(token, cookie_token)
53 def get_csrf_token_from_request(self, request: Request) -> str | None:
54 """Extract CSRF token from request headers."""
55 return request.headers.get(CSRF_HEADER_NAME)
57 def get_csrf_cookie_from_request(self, request: Request) -> str | None:
58 """Extract CSRF token from cookies."""
59 return request.cookies.get(CSRF_COOKIE_NAME)
61 def is_exempt(self, path: str) -> bool:
62 """Check if the path is exempt from CSRF protection."""
63 # Check exact matches
64 if path in CSRF_EXEMPT_PATHS:
65 return True
67 # Check prefixes (for paths with parameters)
68 return any(path.startswith(exempt_path) for exempt_path in CSRF_EXEMPT_PATHS)
70 def set_csrf_cookie(self, response: Response, token: str):
71 """Set CSRF token cookie in response."""
72 response.set_cookie(
73 key=CSRF_COOKIE_NAME,
74 value=token,
75 max_age=CSRF_TOKEN_LIFETIME,
76 httponly=True,
77 secure=os.getenv("ENVIRONMENT", "development") == "production",
78 samesite="strict",
79 )
82# Global CSRF protection instance
83csrf_protection = CSRFProtection()
86async def csrf_middleware(request: Request, call_next):
87 """CSRF protection middleware."""
88 # Skip CSRF check for safe methods
89 if request.method not in PROTECTED_METHODS:
90 response = await call_next(request)
92 # For GET requests, ensure a CSRF token is set
93 if request.method == "GET":
94 cookie_token = csrf_protection.get_csrf_cookie_from_request(request)
95 if not cookie_token:
96 token = csrf_protection.generate_csrf_token()
97 response = await call_next(request)
98 csrf_protection.set_csrf_cookie(response, token)
99 return response
101 return response
103 # Skip CSRF check for exempt paths
104 if csrf_protection.is_exempt(request.url.path):
105 return await call_next(request)
107 # Get tokens
108 header_token = csrf_protection.get_csrf_token_from_request(request)
109 cookie_token = csrf_protection.get_csrf_cookie_from_request(request)
111 # Verify CSRF token
112 if not csrf_protection.verify_csrf_token(header_token, cookie_token):
113 logger.warning(f"CSRF token validation failed for {request.url.path}")
114 return JSONResponse(status_code=403, content={"detail": "CSRF token validation failed"})
116 # Process request
117 response = await call_next(request)
119 # Refresh CSRF token cookie
120 if cookie_token:
121 csrf_protection.set_csrf_cookie(response, cookie_token)
123 return response
126def get_csrf_token(request: Request) -> str:
127 """Get or generate CSRF token for the current request."""
128 token = csrf_protection.get_csrf_cookie_from_request(request)
129 if not token:
130 token = csrf_protection.generate_csrf_token()
131 return token
134# Dependency for endpoints that need to provide CSRF token
135async def provide_csrf_token(request: Request, response: Response) -> dict:
136 """Provide CSRF token in response."""
137 token = get_csrf_token(request)
138 csrf_protection.set_csrf_cookie(response, token)
139 return {"csrf_token": token}