# middleware/rate_limiter.py from typing import Dict, Tuple from datetime import datetime, timedelta import logging from collections import defaultdict, deque logger = logging.getLogger(__name__) class RateLimiter: """In-memory rate limiter using sliding window""" def __init__(self): # Dictionary to store request timestamps for each client self.clients: Dict[str, deque] = defaultdict(lambda: deque()) self.cleanup_interval = 3600 # Clean up old entries every hour self.last_cleanup = datetime.utcnow() def allow_request(self, client_id: str, max_requests: int, window_seconds: int) -> bool: """ Check if client is allowed to make a request Uses sliding window algorithm """ now = datetime.utcnow() window_start = now - timedelta(seconds=window_seconds) # Clean up old entries periodically if (now - self.last_cleanup).seconds > self.cleanup_interval: self._cleanup_old_entries() self.last_cleanup = now # Get client's request history client_requests = self.clients[client_id] # Remove requests outside the window while client_requests and client_requests[0] < window_start: client_requests.popleft() # Check if under rate limit if len(client_requests) < max_requests: client_requests.append(now) return True logger.warning(f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}") return False def _cleanup_old_entries(self): """Clean up old entries to prevent memory leaks""" cutoff_time = datetime.utcnow() - timedelta(hours=24) clients_to_remove = [] for client_id, requests in self.clients.items(): # Remove old requests while requests and requests[0] < cutoff_time: requests.popleft() # Mark empty clients for removal if not requests: clients_to_remove.append(client_id) # Remove empty clients for client_id in clients_to_remove: del self.clients[client_id] logger.info(f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients") def get_client_stats(self, client_id: str) -> Dict[str, int]: """Get statistics for a specific client""" client_requests = self.clients.get(client_id, deque()) now = datetime.utcnow() hour_ago = now - timedelta(hours=1) day_ago = now - timedelta(days=1) requests_last_hour = sum(1 for req_time in client_requests if req_time > hour_ago) requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago) return { "requests_last_hour": requests_last_hour, "requests_last_day": requests_last_day, "total_tracked_requests": len(client_requests) }