unit test bug fixes
This commit is contained in:
@@ -1,74 +0,0 @@
|
||||
# tests/test_middleware.py
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from middleware.auth import AuthManager
|
||||
from middleware.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth # for auth manager tests
|
||||
class TestRateLimiter:
|
||||
def test_rate_limiter_allows_requests(self):
|
||||
"""Test rate limiter allows requests within limit"""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client"
|
||||
|
||||
# Should allow first request
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||
is True
|
||||
)
|
||||
|
||||
# Should allow subsequent requests within limit
|
||||
for _ in range(5):
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_rate_limiter_blocks_excess_requests(self):
|
||||
"""Test rate limiter blocks requests exceeding limit"""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_blocked"
|
||||
max_requests = 3
|
||||
|
||||
# Use up the allowed requests
|
||||
for _ in range(max_requests):
|
||||
assert limiter.allow_request(client_id, max_requests, 3600) is True
|
||||
|
||||
# Next request should be blocked
|
||||
assert limiter.allow_request(client_id, max_requests, 3600) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth # for auth manager tests
|
||||
class TestAuthManager:
|
||||
def test_password_hashing_and_verification(self):
|
||||
"""Test password hashing and verification"""
|
||||
auth_manager = AuthManager()
|
||||
password = "test_password_123"
|
||||
|
||||
# Hash password
|
||||
hashed = auth_manager.hash_password(password)
|
||||
|
||||
# Verify correct password
|
||||
assert auth_manager.verify_password(password, hashed) is True
|
||||
|
||||
# Verify incorrect password
|
||||
assert auth_manager.verify_password("wrong_password", hashed) is False
|
||||
|
||||
def test_jwt_token_creation_and_validation(self, test_user):
|
||||
"""Test JWT token creation and validation"""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# Create token
|
||||
token_data = auth_manager.create_access_token(test_user)
|
||||
|
||||
assert "access_token" in token_data
|
||||
assert token_data["token_type"] == "bearer"
|
||||
assert isinstance(token_data["expires_in"], int)
|
||||
|
||||
# Token should be a string
|
||||
assert isinstance(token_data["access_token"], str)
|
||||
assert len(token_data["access_token"]) > 50 # JWT tokens are long
|
||||
@@ -474,10 +474,13 @@ class TestRateLimiterEdgeCases:
|
||||
def test_cleanup_updates_last_cleanup_time(self):
|
||||
"""Test that cleanup updates last_cleanup timestamp."""
|
||||
limiter = RateLimiter()
|
||||
old_cleanup_time = limiter.last_cleanup
|
||||
|
||||
# Force cleanup
|
||||
limiter.cleanup_interval = 0
|
||||
# Set last_cleanup to past to ensure cleanup triggers
|
||||
old_cleanup_time = datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
limiter.last_cleanup = old_cleanup_time
|
||||
limiter.cleanup_interval = 0 # Force cleanup on next request
|
||||
|
||||
# Make request (should trigger cleanup)
|
||||
limiter.allow_request("test", 10, 3600)
|
||||
|
||||
# last_cleanup should be updated
|
||||
|
||||
@@ -16,9 +16,9 @@ See main.py for current implementation using app.include_router() with different
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||
from fastapi import Request
|
||||
import time
|
||||
|
||||
from middleware.theme_context import (
|
||||
ThemeContextManager,
|
||||
@@ -26,10 +26,22 @@ from middleware.theme_context import (
|
||||
get_current_theme,
|
||||
)
|
||||
from middleware.logging_middleware import LoggingMiddleware
|
||||
from middleware.decorators import rate_limit
|
||||
from middleware.decorators import rate_limit, rate_limiter
|
||||
from app.exceptions.base import RateLimitException
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limiter():
|
||||
"""Reset rate limiter state before each test."""
|
||||
rate_limiter.clients.clear()
|
||||
yield
|
||||
rate_limiter.clients.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Theme Context Tests
|
||||
# =============================================================================
|
||||
@@ -89,12 +101,16 @@ class TestThemeContextManager:
|
||||
"""Test getting vendor-specific theme."""
|
||||
mock_db = Mock()
|
||||
mock_theme = Mock()
|
||||
mock_theme.to_dict.return_value = {
|
||||
|
||||
# Mock to_dict to return actual dictionary
|
||||
custom_theme_dict = {
|
||||
"theme_name": "custom",
|
||||
"colors": {"primary": "#ff0000"}
|
||||
}
|
||||
mock_theme.to_dict.return_value = custom_theme_dict
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_theme
|
||||
# Correct filter chain: query().filter().first() (single filter, not double)
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_theme
|
||||
|
||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||
|
||||
@@ -105,10 +121,13 @@ class TestThemeContextManager:
|
||||
def test_get_vendor_theme_fallback_to_default(self):
|
||||
"""Test falling back to default theme when no custom theme exists."""
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
# Correct filter chain: query().filter().first()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||
|
||||
# Verify it returns a dict (not a Mock)
|
||||
assert isinstance(theme, dict)
|
||||
assert theme["theme_name"] == "default"
|
||||
assert "colors" in theme
|
||||
assert "fonts" in theme
|
||||
@@ -116,11 +135,13 @@ class TestThemeContextManager:
|
||||
def test_get_vendor_theme_inactive_theme(self):
|
||||
"""Test that inactive themes are not returned."""
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
# Correct filter chain: query().filter().first()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||
|
||||
# Should return default theme
|
||||
# Should return default theme (actual dict)
|
||||
assert isinstance(theme, dict)
|
||||
assert theme["theme_name"] == "default"
|
||||
|
||||
|
||||
@@ -228,7 +249,12 @@ class TestLoggingMiddleware:
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.client = Mock(host="127.0.0.1")
|
||||
|
||||
call_next = AsyncMock(return_value=Mock(status_code=200))
|
||||
# Create mock response with actual dict for headers
|
||||
response = Mock()
|
||||
response.status_code = 200
|
||||
response.headers = {} # Use actual dict
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging_middleware.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
@@ -297,7 +323,11 @@ class TestLoggingMiddleware:
|
||||
request.url = Mock(path="/test")
|
||||
request.client = None # No client info
|
||||
|
||||
call_next = AsyncMock(return_value=Mock(status_code=200))
|
||||
response = Mock()
|
||||
response.status_code = 200
|
||||
response.headers = {}
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging_middleware.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
@@ -342,7 +372,6 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = slow_call_next
|
||||
|
||||
import asyncio
|
||||
with patch('middleware.logging_middleware.logger'):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -375,16 +404,16 @@ class TestRateLimitDecorator:
|
||||
async def test_decorator_blocks_exceeding_limit(self):
|
||||
"""Test decorator blocks requests exceeding rate limit."""
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
async def test_endpoint_blocked():
|
||||
return {"status": "ok"}
|
||||
|
||||
# First two should succeed
|
||||
await test_endpoint()
|
||||
await test_endpoint()
|
||||
await test_endpoint_blocked()
|
||||
await test_endpoint_blocked()
|
||||
|
||||
# Third should raise exception
|
||||
with pytest.raises(RateLimitException) as exc_info:
|
||||
await test_endpoint()
|
||||
await test_endpoint_blocked()
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert "Rate limit exceeded" in exc_info.value.message
|
||||
@@ -430,13 +459,13 @@ class TestRateLimitDecorator:
|
||||
async def test_decorator_exception_includes_retry_after(self):
|
||||
"""Test rate limit exception includes retry_after."""
|
||||
@rate_limit(max_requests=1, window_seconds=60)
|
||||
async def test_endpoint():
|
||||
async def test_endpoint_retry():
|
||||
return {"status": "ok"}
|
||||
|
||||
await test_endpoint() # Use up limit
|
||||
await test_endpoint_retry() # Use up limit
|
||||
|
||||
with pytest.raises(RateLimitException) as exc_info:
|
||||
await test_endpoint()
|
||||
await test_endpoint_retry()
|
||||
|
||||
assert exc_info.value.details.get("retry_after") == 60
|
||||
|
||||
|
||||
Reference in New Issue
Block a user