Files
orion/tests/unit/middleware/test_rate_limiter.py
2025-11-19 20:21:11 +01:00

540 lines
19 KiB
Python

# tests/unit/middleware/test_rate_limiter.py
"""
Comprehensive unit tests for RateLimiter.
Tests cover:
- Request allowance within limits
- Request blocking when exceeding limits
- Sliding window algorithm
- Cleanup of old entries
- Client statistics
- Edge cases and concurrency scenarios
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from collections import deque
from middleware.rate_limiter import RateLimiter
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterBasic:
"""Test suite for basic rate limiter functionality."""
def test_rate_limiter_initialization(self):
"""Test rate limiter initializes correctly."""
limiter = RateLimiter()
assert isinstance(limiter.clients, dict)
assert limiter.cleanup_interval == 3600
assert isinstance(limiter.last_cleanup, datetime)
def test_allow_first_request(self):
"""Test rate limiter allows first request."""
limiter = RateLimiter()
client_id = "test_client_1"
result = limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
assert len(limiter.clients[client_id]) == 1
def test_allow_multiple_requests_within_limit(self):
"""Test rate limiter allows multiple requests within limit."""
limiter = RateLimiter()
client_id = "test_client_2"
max_requests = 10
# Make 10 requests (at the limit)
for i in range(max_requests):
result = limiter.allow_request(client_id, max_requests, 3600)
assert result is True, f"Request {i+1} should be allowed"
assert len(limiter.clients[client_id]) == max_requests
def test_block_request_exceeding_limit(self):
"""Test rate limiter blocks requests exceeding limit."""
limiter = RateLimiter()
client_id = "test_client_blocked"
max_requests = 3
# Use up the allowed requests
for _ in range(max_requests):
assert limiter.allow_request(client_id, max_requests, 3600) is True
# Next request should be blocked
result = limiter.allow_request(client_id, max_requests, 3600)
assert result is False
# Client should still have only max_requests entries
assert len(limiter.clients[client_id]) == max_requests
def test_different_clients_separate_limits(self):
"""Test different clients have separate rate limits."""
limiter = RateLimiter()
client1 = "client_1"
client2 = "client_2"
max_requests = 5
# Client 1 makes requests
for _ in range(max_requests):
assert limiter.allow_request(client1, max_requests, 3600) is True
# Client 1 is blocked
assert limiter.allow_request(client1, max_requests, 3600) is False
# Client 2 should still be allowed
assert limiter.allow_request(client2, max_requests, 3600) is True
# Verify separate tracking
assert len(limiter.clients[client1]) == max_requests
assert len(limiter.clients[client2]) == 1
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterSlidingWindow:
"""Test suite for sliding window algorithm."""
def test_sliding_window_removes_old_requests(self):
"""Test sliding window removes requests outside time window."""
limiter = RateLimiter()
client_id = "test_client_window"
max_requests = 3
window_seconds = 10
# Manually add old requests
old_time = datetime.now(timezone.utc) - timedelta(seconds=15)
limiter.clients[client_id].append(old_time)
limiter.clients[client_id].append(old_time)
# These old requests should be removed, so new request should be allowed
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is True
assert len(limiter.clients[client_id]) == 1 # Only the new request
def test_sliding_window_keeps_recent_requests(self):
"""Test sliding window keeps requests within time window."""
limiter = RateLimiter()
client_id = "test_client_recent"
max_requests = 3
window_seconds = 60
# Add recent requests
recent_time = datetime.now(timezone.utc) - timedelta(seconds=30)
limiter.clients[client_id].append(recent_time)
limiter.clients[client_id].append(recent_time)
# These requests are within window, so we can only add 1 more
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is True
# Now at limit
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is False
def test_sliding_window_mixed_old_and_recent(self):
"""Test sliding window with mix of old and recent requests."""
limiter = RateLimiter()
client_id = "test_client_mixed"
max_requests = 3
window_seconds = 30
# Add old requests (outside window)
old_time = datetime.now(timezone.utc) - timedelta(seconds=60)
limiter.clients[client_id].append(old_time)
limiter.clients[client_id].append(old_time)
# Add recent request (within window)
recent_time = datetime.now(timezone.utc) - timedelta(seconds=10)
limiter.clients[client_id].append(recent_time)
# Old requests removed, only 1 recent request, so 2 more allowed
assert limiter.allow_request(client_id, max_requests, window_seconds) is True
assert limiter.allow_request(client_id, max_requests, window_seconds) is True
# Now at limit
assert limiter.allow_request(client_id, max_requests, window_seconds) is False
def test_sliding_window_with_zero_window(self):
"""Test rate limiter with very short window."""
limiter = RateLimiter()
client_id = "test_client_zero_window"
max_requests = 5
window_seconds = 1 # 1 second window
# Add old request
old_time = datetime.now(timezone.utc) - timedelta(seconds=2)
limiter.clients[client_id].append(old_time)
# Should allow request because old one is outside 1-second window
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is True
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterCleanup:
"""Test suite for cleanup functionality."""
def test_cleanup_removes_old_entries(self):
"""Test cleanup removes entries older than 24 hours."""
limiter = RateLimiter()
# Add clients with old requests
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
limiter.clients["old_client_1"].append(old_time)
limiter.clients["old_client_2"].append(old_time)
# Add client with recent requests
recent_time = datetime.now(timezone.utc) - timedelta(hours=1)
limiter.clients["recent_client"].append(recent_time)
# Run cleanup
limiter._cleanup_old_entries()
# Old clients should be removed
assert "old_client_1" not in limiter.clients
assert "old_client_2" not in limiter.clients
# Recent client should remain
assert "recent_client" in limiter.clients
def test_cleanup_removes_empty_clients(self):
"""Test cleanup removes clients with no requests."""
limiter = RateLimiter()
# Add empty clients
limiter.clients["empty_client_1"] = deque()
limiter.clients["empty_client_2"] = deque()
# Add client with requests
limiter.clients["active_client"].append(datetime.now(timezone.utc))
# Run cleanup
limiter._cleanup_old_entries()
# Empty clients should be removed
assert "empty_client_1" not in limiter.clients
assert "empty_client_2" not in limiter.clients
# Active client should remain
assert "active_client" in limiter.clients
def test_cleanup_partial_removal(self):
"""Test cleanup removes only old requests, keeps recent ones."""
limiter = RateLimiter()
client_id = "mixed_client"
# Add old requests
old_time = datetime.now(timezone.utc) - timedelta(hours=30)
limiter.clients[client_id].append(old_time)
limiter.clients[client_id].append(old_time)
# Add recent requests
recent_time = datetime.now(timezone.utc) - timedelta(hours=1)
limiter.clients[client_id].append(recent_time)
limiter.clients[client_id].append(recent_time)
# Run cleanup
limiter._cleanup_old_entries()
# Client should remain with only recent requests
assert client_id in limiter.clients
assert len(limiter.clients[client_id]) == 2
def test_automatic_cleanup_triggers(self):
"""Test automatic cleanup triggers after interval."""
limiter = RateLimiter()
limiter.cleanup_interval = 0 # Force immediate cleanup
# Set last_cleanup to past
limiter.last_cleanup = datetime.now(timezone.utc) - timedelta(hours=2)
# Add old client
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
limiter.clients["old_client"].append(old_time)
# Make request (should trigger cleanup)
limiter.allow_request("new_client", 10, 3600)
# Old client should be cleaned up
assert "old_client" not in limiter.clients
def test_cleanup_does_not_affect_active_clients(self):
"""Test cleanup doesn't remove clients with recent requests."""
limiter = RateLimiter()
# Add multiple active clients
now = datetime.now(timezone.utc)
for i in range(5):
limiter.clients[f"client_{i}"].append(now - timedelta(hours=i))
# Run cleanup
limiter._cleanup_old_entries()
# All clients should still exist (all within 24 hours)
assert len(limiter.clients) == 5
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterStatistics:
"""Test suite for client statistics functionality."""
def test_get_client_stats_empty(self):
"""Test getting stats for client with no requests."""
limiter = RateLimiter()
client_id = "new_client"
stats = limiter.get_client_stats(client_id)
assert stats["requests_last_hour"] == 0
assert stats["requests_last_day"] == 0
assert stats["total_tracked_requests"] == 0
def test_get_client_stats_with_requests(self):
"""Test getting stats for client with requests."""
limiter = RateLimiter()
client_id = "active_client"
# Add requests at different times
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
stats = limiter.get_client_stats(client_id)
assert stats["requests_last_hour"] == 1
assert stats["requests_last_day"] == 3
assert stats["total_tracked_requests"] == 3
def test_get_client_stats_old_requests(self):
"""Test stats exclude requests older than tracking period."""
limiter = RateLimiter()
client_id = "old_requests_client"
# Add very old requests
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(days=2))
limiter.clients[client_id].append(now - timedelta(days=3))
stats = limiter.get_client_stats(client_id)
assert stats["requests_last_hour"] == 0
assert stats["requests_last_day"] == 0
assert stats["total_tracked_requests"] == 2 # Still tracked, just not counted
def test_get_client_stats_nonexistent_client(self):
"""Test getting stats for client that doesn't exist."""
limiter = RateLimiter()
stats = limiter.get_client_stats("nonexistent_client")
assert stats["requests_last_hour"] == 0
assert stats["requests_last_day"] == 0
assert stats["total_tracked_requests"] == 0
def test_get_client_stats_boundary_cases(self):
"""Test stats at exact hour/day boundaries."""
limiter = RateLimiter()
client_id = "boundary_client"
now = datetime.now(timezone.utc)
# Exactly 1 hour ago (should be included)
limiter.clients[client_id].append(now - timedelta(hours=1, seconds=1))
# Exactly 24 hours ago (should be excluded)
limiter.clients[client_id].append(now - timedelta(days=1, seconds=1))
stats = limiter.get_client_stats(client_id)
# Boundary behavior depends on > vs >= comparison
assert stats["requests_last_hour"] >= 0
assert stats["requests_last_day"] >= 1
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterEdgeCases:
"""Test suite for edge cases and error scenarios."""
def test_rate_limiter_with_zero_max_requests(self):
"""Test rate limiter with max_requests=0."""
limiter = RateLimiter()
client_id = "zero_limit_client"
result = limiter.allow_request(client_id, max_requests=0, window_seconds=3600)
# Should be blocked immediately
assert result is False
def test_rate_limiter_with_negative_max_requests(self):
"""Test rate limiter with negative max_requests."""
limiter = RateLimiter()
client_id = "negative_limit_client"
result = limiter.allow_request(client_id, max_requests=-1, window_seconds=3600)
# Should be blocked
assert result is False
def test_rate_limiter_with_large_max_requests(self):
"""Test rate limiter with very large max_requests."""
limiter = RateLimiter()
client_id = "large_limit_client"
max_requests = 1000000
result = limiter.allow_request(client_id, max_requests, 3600)
# Should be allowed
assert result is True
def test_rate_limiter_very_short_window(self):
"""Test rate limiter with very short time window."""
limiter = RateLimiter()
client_id = "short_window_client"
result = limiter.allow_request(client_id, max_requests=1, window_seconds=1)
assert result is True
def test_rate_limiter_very_long_window(self):
"""Test rate limiter with very long time window."""
limiter = RateLimiter()
client_id = "long_window_client"
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
assert result is True
def test_rate_limiter_same_client_different_limits(self):
"""Test same client with different rate limits."""
limiter = RateLimiter()
client_id = "same_client"
# Allow with one limit
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
# Check with stricter limit
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
def test_rate_limiter_unicode_client_id(self):
"""Test rate limiter with unicode client ID."""
limiter = RateLimiter()
client_id = "クライアント_123"
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
def test_rate_limiter_special_characters_client_id(self):
"""Test rate limiter with special characters in client ID."""
limiter = RateLimiter()
client_id = "client!@#$%^&*()_+-=[]{}|;:,.<>?"
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
def test_rate_limiter_empty_client_id(self):
"""Test rate limiter with empty client ID."""
limiter = RateLimiter()
client_id = ""
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
def test_rate_limiter_concurrent_same_client(self):
"""Test rate limiter behavior with rapid requests from same client."""
limiter = RateLimiter()
client_id = "concurrent_client"
max_requests = 3
# Simulate rapid requests
results = []
for _ in range(5):
results.append(limiter.allow_request(client_id, max_requests, 3600))
# First 3 should be True, rest False
assert results[:3] == [True, True, True]
assert results[3:] == [False, False]
def test_cleanup_updates_last_cleanup_time(self):
"""Test that cleanup updates last_cleanup timestamp."""
limiter = RateLimiter()
# Set last_cleanup to past to ensure cleanup triggers
old_cleanup_time = datetime.now(timezone.utc) - timedelta(hours=2)
limiter.last_cleanup = old_cleanup_time
limiter.cleanup_interval = 0 # Force cleanup on next request
# Make request (should trigger cleanup)
limiter.allow_request("test", 10, 3600)
# last_cleanup should be updated
assert limiter.last_cleanup > old_cleanup_time
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterMemoryManagement:
"""Test suite for memory management and performance."""
def test_limiter_does_not_grow_indefinitely(self):
"""Test that old entries are cleaned up to prevent memory leaks."""
limiter = RateLimiter()
limiter.cleanup_interval = 0 # Force cleanup on every request
# Simulate many requests over time
for i in range(100):
limiter.allow_request(f"client_{i}", max_requests=10, window_seconds=3600)
# Force cleanup
limiter._cleanup_old_entries()
# Should have cleaned up clients with no recent activity
# Exact number depends on timing, but should be less than 100
assert len(limiter.clients) <= 100
def test_deque_efficiency(self):
"""Test that deque is used for efficient popleft operations."""
limiter = RateLimiter()
client_id = "efficiency_test"
# Add many old requests
old_time = datetime.now(timezone.utc) - timedelta(hours=2)
for _ in range(1000):
limiter.clients[client_id].append(old_time)
# This should efficiently remove all old requests
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
# Should only have the new request
assert len(limiter.clients[client_id]) == 1
def test_multiple_clients_independence(self):
"""Test that multiple clients don't interfere with each other."""
limiter = RateLimiter()
num_clients = 100
# Create many clients with requests
for i in range(num_clients):
limiter.allow_request(f"client_{i}", max_requests=5, window_seconds=3600)
# Each client should have exactly 1 request
assert len(limiter.clients) == num_clients
for i in range(num_clients):
assert len(limiter.clients[f"client_{i}"]) == 1