unit test bug fixes

This commit is contained in:
2025-11-19 20:21:11 +01:00
parent c38da2780a
commit 21bd390685
3 changed files with 52 additions and 94 deletions

View File

@@ -1,74 +0,0 @@
# tests/test_middleware.py
from unittest.mock import Mock, patch
import pytest
from middleware.auth import AuthManager
from middleware.rate_limiter import RateLimiter
@pytest.mark.unit
@pytest.mark.auth # for auth manager tests
class TestRateLimiter:
def test_rate_limiter_allows_requests(self):
"""Test rate limiter allows requests within limit"""
limiter = RateLimiter()
client_id = "test_client"
# Should allow first request
assert (
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
is True
)
# Should allow subsequent requests within limit
for _ in range(5):
assert (
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
is True
)
def test_rate_limiter_blocks_excess_requests(self):
"""Test rate limiter blocks requests exceeding limit"""
limiter = RateLimiter()
client_id = "test_client_blocked"
max_requests = 3
# Use up the allowed requests
for _ in range(max_requests):
assert limiter.allow_request(client_id, max_requests, 3600) is True
# Next request should be blocked
assert limiter.allow_request(client_id, max_requests, 3600) is False
@pytest.mark.unit
@pytest.mark.auth # for auth manager tests
class TestAuthManager:
def test_password_hashing_and_verification(self):
"""Test password hashing and verification"""
auth_manager = AuthManager()
password = "test_password_123"
# Hash password
hashed = auth_manager.hash_password(password)
# Verify correct password
assert auth_manager.verify_password(password, hashed) is True
# Verify incorrect password
assert auth_manager.verify_password("wrong_password", hashed) is False
def test_jwt_token_creation_and_validation(self, test_user):
"""Test JWT token creation and validation"""
auth_manager = AuthManager()
# Create token
token_data = auth_manager.create_access_token(test_user)
assert "access_token" in token_data
assert token_data["token_type"] == "bearer"
assert isinstance(token_data["expires_in"], int)
# Token should be a string
assert isinstance(token_data["access_token"], str)
assert len(token_data["access_token"]) > 50 # JWT tokens are long

View File

@@ -474,10 +474,13 @@ class TestRateLimiterEdgeCases:
def test_cleanup_updates_last_cleanup_time(self): def test_cleanup_updates_last_cleanup_time(self):
"""Test that cleanup updates last_cleanup timestamp.""" """Test that cleanup updates last_cleanup timestamp."""
limiter = RateLimiter() limiter = RateLimiter()
old_cleanup_time = limiter.last_cleanup
# Force cleanup # Set last_cleanup to past to ensure cleanup triggers
limiter.cleanup_interval = 0 old_cleanup_time = datetime.now(timezone.utc) - timedelta(hours=2)
limiter.last_cleanup = old_cleanup_time
limiter.cleanup_interval = 0 # Force cleanup on next request
# Make request (should trigger cleanup)
limiter.allow_request("test", 10, 3600) limiter.allow_request("test", 10, 3600)
# last_cleanup should be updated # last_cleanup should be updated

View File

@@ -16,9 +16,9 @@ See main.py for current implementation using app.include_router() with different
""" """
import pytest import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, MagicMock, patch from unittest.mock import Mock, AsyncMock, MagicMock, patch
from fastapi import Request from fastapi import Request
import time
from middleware.theme_context import ( from middleware.theme_context import (
ThemeContextManager, ThemeContextManager,
@@ -26,10 +26,22 @@ from middleware.theme_context import (
get_current_theme, get_current_theme,
) )
from middleware.logging_middleware import LoggingMiddleware from middleware.logging_middleware import LoggingMiddleware
from middleware.decorators import rate_limit from middleware.decorators import rate_limit, rate_limiter
from app.exceptions.base import RateLimitException from app.exceptions.base import RateLimitException
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(autouse=True)
def reset_rate_limiter():
"""Reset rate limiter state before each test."""
rate_limiter.clients.clear()
yield
rate_limiter.clients.clear()
# ============================================================================= # =============================================================================
# Theme Context Tests # Theme Context Tests
# ============================================================================= # =============================================================================
@@ -89,12 +101,16 @@ class TestThemeContextManager:
"""Test getting vendor-specific theme.""" """Test getting vendor-specific theme."""
mock_db = Mock() mock_db = Mock()
mock_theme = Mock() mock_theme = Mock()
mock_theme.to_dict.return_value = {
# Mock to_dict to return actual dictionary
custom_theme_dict = {
"theme_name": "custom", "theme_name": "custom",
"colors": {"primary": "#ff0000"} "colors": {"primary": "#ff0000"}
} }
mock_theme.to_dict.return_value = custom_theme_dict
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_theme # Correct filter chain: query().filter().first() (single filter, not double)
mock_db.query.return_value.filter.return_value.first.return_value = mock_theme
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
@@ -105,10 +121,13 @@ class TestThemeContextManager:
def test_get_vendor_theme_fallback_to_default(self): def test_get_vendor_theme_fallback_to_default(self):
"""Test falling back to default theme when no custom theme exists.""" """Test falling back to default theme when no custom theme exists."""
mock_db = Mock() mock_db = Mock()
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None # Correct filter chain: query().filter().first()
mock_db.query.return_value.filter.return_value.first.return_value = None
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
# Verify it returns a dict (not a Mock)
assert isinstance(theme, dict)
assert theme["theme_name"] == "default" assert theme["theme_name"] == "default"
assert "colors" in theme assert "colors" in theme
assert "fonts" in theme assert "fonts" in theme
@@ -116,11 +135,13 @@ class TestThemeContextManager:
def test_get_vendor_theme_inactive_theme(self): def test_get_vendor_theme_inactive_theme(self):
"""Test that inactive themes are not returned.""" """Test that inactive themes are not returned."""
mock_db = Mock() mock_db = Mock()
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None # Correct filter chain: query().filter().first()
mock_db.query.return_value.filter.return_value.first.return_value = None
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
# Should return default theme # Should return default theme (actual dict)
assert isinstance(theme, dict)
assert theme["theme_name"] == "default" assert theme["theme_name"] == "default"
@@ -228,7 +249,12 @@ class TestLoggingMiddleware:
request.url = Mock(path="/api/vendors") request.url = Mock(path="/api/vendors")
request.client = Mock(host="127.0.0.1") request.client = Mock(host="127.0.0.1")
call_next = AsyncMock(return_value=Mock(status_code=200)) # Create mock response with actual dict for headers
response = Mock()
response.status_code = 200
response.headers = {} # Use actual dict
call_next = AsyncMock(return_value=response)
with patch('middleware.logging_middleware.logger') as mock_logger: with patch('middleware.logging_middleware.logger') as mock_logger:
await middleware.dispatch(request, call_next) await middleware.dispatch(request, call_next)
@@ -297,7 +323,11 @@ class TestLoggingMiddleware:
request.url = Mock(path="/test") request.url = Mock(path="/test")
request.client = None # No client info request.client = None # No client info
call_next = AsyncMock(return_value=Mock(status_code=200)) response = Mock()
response.status_code = 200
response.headers = {}
call_next = AsyncMock(return_value=response)
with patch('middleware.logging_middleware.logger') as mock_logger: with patch('middleware.logging_middleware.logger') as mock_logger:
await middleware.dispatch(request, call_next) await middleware.dispatch(request, call_next)
@@ -342,7 +372,6 @@ class TestLoggingMiddleware:
call_next = slow_call_next call_next = slow_call_next
import asyncio
with patch('middleware.logging_middleware.logger'): with patch('middleware.logging_middleware.logger'):
result = await middleware.dispatch(request, call_next) result = await middleware.dispatch(request, call_next)
@@ -375,16 +404,16 @@ class TestRateLimitDecorator:
async def test_decorator_blocks_exceeding_limit(self): async def test_decorator_blocks_exceeding_limit(self):
"""Test decorator blocks requests exceeding rate limit.""" """Test decorator blocks requests exceeding rate limit."""
@rate_limit(max_requests=2, window_seconds=3600) @rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint(): async def test_endpoint_blocked():
return {"status": "ok"} return {"status": "ok"}
# First two should succeed # First two should succeed
await test_endpoint() await test_endpoint_blocked()
await test_endpoint() await test_endpoint_blocked()
# Third should raise exception # Third should raise exception
with pytest.raises(RateLimitException) as exc_info: with pytest.raises(RateLimitException) as exc_info:
await test_endpoint() await test_endpoint_blocked()
assert exc_info.value.status_code == 429 assert exc_info.value.status_code == 429
assert "Rate limit exceeded" in exc_info.value.message assert "Rate limit exceeded" in exc_info.value.message
@@ -430,13 +459,13 @@ class TestRateLimitDecorator:
async def test_decorator_exception_includes_retry_after(self): async def test_decorator_exception_includes_retry_after(self):
"""Test rate limit exception includes retry_after.""" """Test rate limit exception includes retry_after."""
@rate_limit(max_requests=1, window_seconds=60) @rate_limit(max_requests=1, window_seconds=60)
async def test_endpoint(): async def test_endpoint_retry():
return {"status": "ok"} return {"status": "ok"}
await test_endpoint() # Use up limit await test_endpoint_retry() # Use up limit
with pytest.raises(RateLimitException) as exc_info: with pytest.raises(RateLimitException) as exc_info:
await test_endpoint() await test_endpoint_retry()
assert exc_info.value.details.get("retry_after") == 60 assert exc_info.value.details.get("retry_after") == 60