# middleware/rate_limiter.py """Summary description .... This module provides classes and functions for: - .... - .... - .... """ import logging from collections import defaultdict, deque from datetime import UTC, datetime, timedelta logger = logging.getLogger(__name__) class RateLimiter: """In-memory rate limiter using sliding window.""" def __init__(self): """Class constructor.""" # Dictionary to store request timestamps for each client self.clients: dict[str, deque] = defaultdict(lambda: deque()) self.cleanup_interval = 3600 # Clean up old entries every hour self.last_cleanup = datetime.now(UTC) def allow_request( self, client_id: str, max_requests: int, window_seconds: int ) -> bool: """ Check if client is allowed to make a request. Uses sliding window algorithm """ now = datetime.now(UTC) window_start = now - timedelta(seconds=window_seconds) # Clean up old entries periodically if (now - self.last_cleanup).seconds > self.cleanup_interval: self._cleanup_old_entries() self.last_cleanup = now # Get client's request history client_requests = self.clients[client_id] # Remove requests outside the window while client_requests and client_requests[0] < window_start: client_requests.popleft() # Check if under rate limit if len(client_requests) < max_requests: client_requests.append(now) return True logger.warning( f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}" ) return False def _cleanup_old_entries(self): """Clean up old entries to prevent memory leaks.""" cutoff_time = datetime.now(UTC) - timedelta(hours=24) clients_to_remove = [] for client_id, requests in self.clients.items(): # Remove old requests while requests and requests[0] < cutoff_time: requests.popleft() # Mark empty clients for removal if not requests: clients_to_remove.append(client_id) # Remove empty clients for client_id in clients_to_remove: del self.clients[client_id] logger.info( f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients" ) def get_client_stats(self, client_id: str) -> dict[str, int]: """Get statistics for a specific client.""" client_requests = self.clients.get(client_id, deque()) now = datetime.now(UTC) hour_ago = now - timedelta(hours=1) day_ago = now - timedelta(days=1) requests_last_hour = sum( 1 for req_time in client_requests if req_time > hour_ago ) requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago) return { "requests_last_hour": requests_last_hour, "requests_last_day": requests_last_day, "total_tracked_requests": len(client_requests), }