# middleware/decorators.py """ FastAPI decorators for cross-cutting concerns. This module provides classes and functions for: - Rate limiting decorators for endpoint protection - Request throttling and abuse prevention - Consistent error handling for rate limit violations """ import asyncio from functools import wraps from starlette.requests import Request from app.exceptions.base import RateLimitException from middleware.cloudflare import get_real_client_ip from middleware.rate_limiter import RateLimiter # Initialize rate limiter instance rate_limiter = RateLimiter() def _find_request(*args, **kwargs) -> Request | None: """Extract a Request object from function args/kwargs.""" # Check kwargs first (FastAPI usually passes request= as keyword) for val in kwargs.values(): if isinstance(val, Request): return val # Check positional args (e.g. self, request, ...) for val in args: if isinstance(val, Request): return val return None def rate_limit(max_requests: int = 100, window_seconds: int = 3600): """Rate limiting decorator for FastAPI endpoints. Works with both sync and async endpoint functions. Extracts the real client IP from the Request object for per-client limiting. """ def decorator(func): if asyncio.iscoroutinefunction(func): @wraps(func) async def async_wrapper(*args, **kwargs): request = _find_request(*args, **kwargs) client_id = ( get_real_client_ip(request) if request else "anonymous" ) if not rate_limiter.allow_request( client_id, max_requests, window_seconds ): raise RateLimitException( message="Rate limit exceeded", retry_after=window_seconds, ) return await func(*args, **kwargs) return async_wrapper @wraps(func) def sync_wrapper(*args, **kwargs): request = _find_request(*args, **kwargs) client_id = ( get_real_client_ip(request) if request else "anonymous" ) if not rate_limiter.allow_request( client_id, max_requests, window_seconds ): raise RateLimitException( message="Rate limit exceeded", retry_after=window_seconds, ) return func(*args, **kwargs) return sync_wrapper return decorator