Some checks failed
- Fix rate limiter to extract real client IP and handle sync/async endpoints - Rate-limit public enrollment (10/min) and program info (30/min) endpoints - Add 409 Conflict to non-retryable status codes in retry decorator - Cache private key in get_save_url() to avoid re-reading JSON per call - Make update_class() return bool success status with error-level logging - Move Google Wallet config from core to loyalty module config - Document time.sleep() safety in retry decorator (threadpool execution) - Add per-card retry (1 retry, 2s delay) to wallet sync task - Add logo URL reachability check (HEAD request) to validate_config() - Add 26 comprehensive unit tests for GoogleWalletService Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
# middleware/decorators.py
|
|
"""
|
|
FastAPI decorators for cross-cutting concerns.
|
|
|
|
This module provides classes and functions for:
|
|
- Rate limiting decorators for endpoint protection
|
|
- Request throttling and abuse prevention
|
|
- Consistent error handling for rate limit violations
|
|
"""
|
|
|
|
import asyncio
|
|
from functools import wraps
|
|
|
|
from starlette.requests import Request
|
|
|
|
from app.exceptions.base import RateLimitException
|
|
from middleware.cloudflare import get_real_client_ip
|
|
from middleware.rate_limiter import RateLimiter
|
|
|
|
# Initialize rate limiter instance
|
|
rate_limiter = RateLimiter()
|
|
|
|
|
|
def _find_request(*args, **kwargs) -> Request | None:
|
|
"""Extract a Request object from function args/kwargs."""
|
|
# Check kwargs first (FastAPI usually passes request= as keyword)
|
|
for val in kwargs.values():
|
|
if isinstance(val, Request):
|
|
return val
|
|
# Check positional args (e.g. self, request, ...)
|
|
for val in args:
|
|
if isinstance(val, Request):
|
|
return val
|
|
return None
|
|
|
|
|
|
def rate_limit(max_requests: int = 100, window_seconds: int = 3600):
|
|
"""Rate limiting decorator for FastAPI endpoints.
|
|
|
|
Works with both sync and async endpoint functions.
|
|
Extracts the real client IP from the Request object for per-client limiting.
|
|
"""
|
|
|
|
def decorator(func):
|
|
if asyncio.iscoroutinefunction(func):
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
request = _find_request(*args, **kwargs)
|
|
client_id = (
|
|
get_real_client_ip(request) if request else "anonymous"
|
|
)
|
|
|
|
if not rate_limiter.allow_request(
|
|
client_id, max_requests, window_seconds
|
|
):
|
|
raise RateLimitException(
|
|
message="Rate limit exceeded",
|
|
retry_after=window_seconds,
|
|
)
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
return async_wrapper
|
|
|
|
@wraps(func)
|
|
def sync_wrapper(*args, **kwargs):
|
|
request = _find_request(*args, **kwargs)
|
|
client_id = (
|
|
get_real_client_ip(request) if request else "anonymous"
|
|
)
|
|
|
|
if not rate_limiter.allow_request(
|
|
client_id, max_requests, window_seconds
|
|
):
|
|
raise RateLimitException(
|
|
message="Rate limit exceeded",
|
|
retry_after=window_seconds,
|
|
)
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
return sync_wrapper
|
|
|
|
return decorator
|