Coverage for csrf_protection.py: 30.49%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2026-04-15 13:02 +0000

1""" 

2CSRF protection middleware for the sports league backend. 

3""" 

4 

5import logging 

6import os 

7import secrets 

8 

9from fastapi import Request, Response 

10from fastapi.responses import JSONResponse 

11 

12logger = logging.getLogger(__name__) 

13 

14# Configuration 

15CSRF_SECRET_KEY = os.getenv("CSRF_SECRET_KEY", secrets.token_urlsafe(32)) 

16CSRF_TOKEN_LIFETIME = 3600 # 1 hour 

17CSRF_HEADER_NAME = "X-CSRF-Token" 

18CSRF_COOKIE_NAME = "csrf_token" 

19 

20# Methods that require CSRF protection 

21PROTECTED_METHODS = {"POST", "PUT", "DELETE", "PATCH"} 

22 

23# Paths that are exempt from CSRF protection 

24CSRF_EXEMPT_PATHS = { 

25 "/api/auth/login", # Login needs to work without existing token 

26 "/api/auth/signup", # Signup needs to work without existing token 

27 "/api/auth/refresh", # Token refresh endpoint 

28 "/docs", # API documentation 

29 "/openapi.json", # OpenAPI spec 

30 "/redoc", # ReDoc documentation 

31} 

32 

33 

34class CSRFProtection: 

35 """CSRF Protection middleware for FastAPI.""" 

36 

37 def __init__(self): 

38 self.secret_key = CSRF_SECRET_KEY 

39 

40 def generate_csrf_token(self) -> str: 

41 """Generate a new CSRF token.""" 

42 token = secrets.token_urlsafe(32) 

43 return token 

44 

45 def verify_csrf_token(self, token: str, cookie_token: str) -> bool: 

46 """Verify that the provided token matches the cookie token.""" 

47 if not token or not cookie_token: 

48 return False 

49 

50 # Constant-time comparison to prevent timing attacks 

51 return secrets.compare_digest(token, cookie_token) 

52 

53 def get_csrf_token_from_request(self, request: Request) -> str | None: 

54 """Extract CSRF token from request headers.""" 

55 return request.headers.get(CSRF_HEADER_NAME) 

56 

57 def get_csrf_cookie_from_request(self, request: Request) -> str | None: 

58 """Extract CSRF token from cookies.""" 

59 return request.cookies.get(CSRF_COOKIE_NAME) 

60 

61 def is_exempt(self, path: str) -> bool: 

62 """Check if the path is exempt from CSRF protection.""" 

63 # Check exact matches 

64 if path in CSRF_EXEMPT_PATHS: 

65 return True 

66 

67 # Check prefixes (for paths with parameters) 

68 return any(path.startswith(exempt_path) for exempt_path in CSRF_EXEMPT_PATHS) 

69 

70 def set_csrf_cookie(self, response: Response, token: str): 

71 """Set CSRF token cookie in response.""" 

72 response.set_cookie( 

73 key=CSRF_COOKIE_NAME, 

74 value=token, 

75 max_age=CSRF_TOKEN_LIFETIME, 

76 httponly=True, 

77 secure=os.getenv("ENVIRONMENT", "development") == "production", 

78 samesite="strict", 

79 ) 

80 

81 

82# Global CSRF protection instance 

83csrf_protection = CSRFProtection() 

84 

85 

86async def csrf_middleware(request: Request, call_next): 

87 """CSRF protection middleware.""" 

88 # Skip CSRF check for safe methods 

89 if request.method not in PROTECTED_METHODS: 

90 response = await call_next(request) 

91 

92 # For GET requests, ensure a CSRF token is set 

93 if request.method == "GET": 

94 cookie_token = csrf_protection.get_csrf_cookie_from_request(request) 

95 if not cookie_token: 

96 token = csrf_protection.generate_csrf_token() 

97 response = await call_next(request) 

98 csrf_protection.set_csrf_cookie(response, token) 

99 return response 

100 

101 return response 

102 

103 # Skip CSRF check for exempt paths 

104 if csrf_protection.is_exempt(request.url.path): 

105 return await call_next(request) 

106 

107 # Get tokens 

108 header_token = csrf_protection.get_csrf_token_from_request(request) 

109 cookie_token = csrf_protection.get_csrf_cookie_from_request(request) 

110 

111 # Verify CSRF token 

112 if not csrf_protection.verify_csrf_token(header_token, cookie_token): 

113 logger.warning(f"CSRF token validation failed for {request.url.path}") 

114 return JSONResponse(status_code=403, content={"detail": "CSRF token validation failed"}) 

115 

116 # Process request 

117 response = await call_next(request) 

118 

119 # Refresh CSRF token cookie 

120 if cookie_token: 

121 csrf_protection.set_csrf_cookie(response, cookie_token) 

122 

123 return response 

124 

125 

126def get_csrf_token(request: Request) -> str: 

127 """Get or generate CSRF token for the current request.""" 

128 token = csrf_protection.get_csrf_cookie_from_request(request) 

129 if not token: 

130 token = csrf_protection.generate_csrf_token() 

131 return token 

132 

133 

134# Dependency for endpoints that need to provide CSRF token 

135async def provide_csrf_token(request: Request, response: Response) -> dict: 

136 """Provide CSRF token in response.""" 

137 token = get_csrf_token(request) 

138 csrf_protection.set_csrf_cookie(response, token) 

139 return {"csrf_token": token}