Coverage for rate_limiter.py: 0.00%

52 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2026-04-13 14:11 +0000

1""" 

2Rate limiting middleware for the sports league backend. 

3""" 

4 

5import logging 

6import os 

7 

8import redis 

9from fastapi import Request 

10from slowapi import Limiter, _rate_limit_exceeded_handler 

11from slowapi.errors import RateLimitExceeded 

12from slowapi.middleware import SlowAPIMiddleware 

13from slowapi.util import get_remote_address 

14 

15logger = logging.getLogger(__name__) 

16 

17# Redis configuration for distributed rate limiting 

18REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379") 

19USE_REDIS = os.getenv("USE_REDIS_RATE_LIMIT", "false").lower() == "true" 

20 

21# Create Redis client if enabled 

22redis_client = None 

23if USE_REDIS: 

24 try: 

25 redis_client = redis.from_url(REDIS_URL, decode_responses=True) 

26 redis_client.ping() 

27 logger.info("Redis connected for rate limiting") 

28 except Exception as e: 

29 logger.warning(f"Redis connection failed, falling back to in-memory: {e}") 

30 redis_client = None 

31 

32 

33# Custom key function that includes user ID for authenticated requests 

34def get_rate_limit_key(request: Request) -> str: 

35 """Get rate limit key based on IP and user ID if authenticated.""" 

36 # Get IP address 

37 ip = get_remote_address(request) 

38 

39 # Try to get user ID from request state (set by auth middleware) 

40 user_id = getattr(request.state, "user_id", None) 

41 

42 if user_id: 

43 return f"{ip}:{user_id}" 

44 return ip 

45 

46 

47# Create limiter instance 

48limiter = Limiter( 

49 key_func=get_rate_limit_key, 

50 default_limits=["200 per hour", "50 per minute"], # Global limits 

51 storage_uri=REDIS_URL if redis_client else None, 

52 headers_enabled=True, # Include rate limit headers in responses 

53) 

54 

55# Rate limit configurations for different endpoint categories 

56RATE_LIMITS = { 

57 # Authentication endpoints - stricter limits 

58 "auth": {"login": "5 per minute", "signup": "3 per hour", "password_reset": "3 per hour"}, 

59 # Public read endpoints - generous limits 

60 "public": {"default": "100 per minute", "standings": "30 per minute", "games": "30 per minute"}, 

61 # Authenticated write endpoints - moderate limits 

62 "authenticated": { 

63 "default": "30 per minute", 

64 "create_game": "10 per minute", 

65 "update_game": "20 per minute", 

66 }, 

67 # Admin endpoints - relaxed limits 

68 "admin": {"default": "100 per minute"}, 

69} 

70 

71 

72def get_endpoint_limit(path: str, method: str, user_role: str | None = None) -> str: 

73 """Determine rate limit based on endpoint and user role.""" 

74 

75 # Auth endpoints 

76 if path.startswith("/api/auth/"): 

77 if "login" in path: 

78 return RATE_LIMITS["auth"]["login"] 

79 elif "signup" in path: 

80 return RATE_LIMITS["auth"]["signup"] 

81 elif "password" in path: 

82 return RATE_LIMITS["auth"]["password_reset"] 

83 

84 # Admin users get higher limits 

85 if user_role == "admin": 

86 return RATE_LIMITS["admin"]["default"] 

87 

88 # Write operations (POST, PUT, DELETE) 

89 if method in ["POST", "PUT", "DELETE"]: 

90 if "games" in path: 

91 if method == "POST": 

92 return RATE_LIMITS["authenticated"]["create_game"] 

93 else: 

94 return RATE_LIMITS["authenticated"]["update_game"] 

95 return RATE_LIMITS["authenticated"]["default"] 

96 

97 # Public read operations 

98 if "standings" in path: 

99 return RATE_LIMITS["public"]["standings"] 

100 elif "games" in path: 

101 return RATE_LIMITS["public"]["games"] 

102 

103 return RATE_LIMITS["public"]["default"] 

104 

105 

106def create_rate_limit_middleware(app): 

107 """Create and configure rate limiting middleware.""" 

108 

109 # Add error handler 

110 app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) 

111 

112 # Add middleware 

113 app.add_middleware(SlowAPIMiddleware) 

114 

115 return limiter 

116 

117 

118# Decorator for custom rate limits on specific endpoints 

119def rate_limit(limit: str): 

120 """Decorator to apply custom rate limit to an endpoint.""" 

121 return limiter.limit(limit)