unit test bug fixes
This commit is contained in:
@@ -1,74 +0,0 @@
|
|||||||
# tests/test_middleware.py
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from middleware.auth import AuthManager
|
|
||||||
from middleware.rate_limiter import RateLimiter
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
@pytest.mark.auth # for auth manager tests
|
|
||||||
class TestRateLimiter:
|
|
||||||
def test_rate_limiter_allows_requests(self):
|
|
||||||
"""Test rate limiter allows requests within limit"""
|
|
||||||
limiter = RateLimiter()
|
|
||||||
client_id = "test_client"
|
|
||||||
|
|
||||||
# Should allow first request
|
|
||||||
assert (
|
|
||||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should allow subsequent requests within limit
|
|
||||||
for _ in range(5):
|
|
||||||
assert (
|
|
||||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_rate_limiter_blocks_excess_requests(self):
|
|
||||||
"""Test rate limiter blocks requests exceeding limit"""
|
|
||||||
limiter = RateLimiter()
|
|
||||||
client_id = "test_client_blocked"
|
|
||||||
max_requests = 3
|
|
||||||
|
|
||||||
# Use up the allowed requests
|
|
||||||
for _ in range(max_requests):
|
|
||||||
assert limiter.allow_request(client_id, max_requests, 3600) is True
|
|
||||||
|
|
||||||
# Next request should be blocked
|
|
||||||
assert limiter.allow_request(client_id, max_requests, 3600) is False
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
@pytest.mark.auth # for auth manager tests
|
|
||||||
class TestAuthManager:
|
|
||||||
def test_password_hashing_and_verification(self):
|
|
||||||
"""Test password hashing and verification"""
|
|
||||||
auth_manager = AuthManager()
|
|
||||||
password = "test_password_123"
|
|
||||||
|
|
||||||
# Hash password
|
|
||||||
hashed = auth_manager.hash_password(password)
|
|
||||||
|
|
||||||
# Verify correct password
|
|
||||||
assert auth_manager.verify_password(password, hashed) is True
|
|
||||||
|
|
||||||
# Verify incorrect password
|
|
||||||
assert auth_manager.verify_password("wrong_password", hashed) is False
|
|
||||||
|
|
||||||
def test_jwt_token_creation_and_validation(self, test_user):
|
|
||||||
"""Test JWT token creation and validation"""
|
|
||||||
auth_manager = AuthManager()
|
|
||||||
|
|
||||||
# Create token
|
|
||||||
token_data = auth_manager.create_access_token(test_user)
|
|
||||||
|
|
||||||
assert "access_token" in token_data
|
|
||||||
assert token_data["token_type"] == "bearer"
|
|
||||||
assert isinstance(token_data["expires_in"], int)
|
|
||||||
|
|
||||||
# Token should be a string
|
|
||||||
assert isinstance(token_data["access_token"], str)
|
|
||||||
assert len(token_data["access_token"]) > 50 # JWT tokens are long
|
|
||||||
@@ -474,10 +474,13 @@ class TestRateLimiterEdgeCases:
|
|||||||
def test_cleanup_updates_last_cleanup_time(self):
|
def test_cleanup_updates_last_cleanup_time(self):
|
||||||
"""Test that cleanup updates last_cleanup timestamp."""
|
"""Test that cleanup updates last_cleanup timestamp."""
|
||||||
limiter = RateLimiter()
|
limiter = RateLimiter()
|
||||||
old_cleanup_time = limiter.last_cleanup
|
|
||||||
|
|
||||||
# Force cleanup
|
# Set last_cleanup to past to ensure cleanup triggers
|
||||||
limiter.cleanup_interval = 0
|
old_cleanup_time = datetime.now(timezone.utc) - timedelta(hours=2)
|
||||||
|
limiter.last_cleanup = old_cleanup_time
|
||||||
|
limiter.cleanup_interval = 0 # Force cleanup on next request
|
||||||
|
|
||||||
|
# Make request (should trigger cleanup)
|
||||||
limiter.allow_request("test", 10, 3600)
|
limiter.allow_request("test", 10, 3600)
|
||||||
|
|
||||||
# last_cleanup should be updated
|
# last_cleanup should be updated
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ See main.py for current implementation using app.include_router() with different
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import asyncio
|
||||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
import time
|
|
||||||
|
|
||||||
from middleware.theme_context import (
|
from middleware.theme_context import (
|
||||||
ThemeContextManager,
|
ThemeContextManager,
|
||||||
@@ -26,10 +26,22 @@ from middleware.theme_context import (
|
|||||||
get_current_theme,
|
get_current_theme,
|
||||||
)
|
)
|
||||||
from middleware.logging_middleware import LoggingMiddleware
|
from middleware.logging_middleware import LoggingMiddleware
|
||||||
from middleware.decorators import rate_limit
|
from middleware.decorators import rate_limit, rate_limiter
|
||||||
from app.exceptions.base import RateLimitException
|
from app.exceptions.base import RateLimitException
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_rate_limiter():
|
||||||
|
"""Reset rate limiter state before each test."""
|
||||||
|
rate_limiter.clients.clear()
|
||||||
|
yield
|
||||||
|
rate_limiter.clients.clear()
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Theme Context Tests
|
# Theme Context Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -89,12 +101,16 @@ class TestThemeContextManager:
|
|||||||
"""Test getting vendor-specific theme."""
|
"""Test getting vendor-specific theme."""
|
||||||
mock_db = Mock()
|
mock_db = Mock()
|
||||||
mock_theme = Mock()
|
mock_theme = Mock()
|
||||||
mock_theme.to_dict.return_value = {
|
|
||||||
|
# Mock to_dict to return actual dictionary
|
||||||
|
custom_theme_dict = {
|
||||||
"theme_name": "custom",
|
"theme_name": "custom",
|
||||||
"colors": {"primary": "#ff0000"}
|
"colors": {"primary": "#ff0000"}
|
||||||
}
|
}
|
||||||
|
mock_theme.to_dict.return_value = custom_theme_dict
|
||||||
|
|
||||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_theme
|
# Correct filter chain: query().filter().first() (single filter, not double)
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_theme
|
||||||
|
|
||||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||||
|
|
||||||
@@ -105,10 +121,13 @@ class TestThemeContextManager:
|
|||||||
def test_get_vendor_theme_fallback_to_default(self):
|
def test_get_vendor_theme_fallback_to_default(self):
|
||||||
"""Test falling back to default theme when no custom theme exists."""
|
"""Test falling back to default theme when no custom theme exists."""
|
||||||
mock_db = Mock()
|
mock_db = Mock()
|
||||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
# Correct filter chain: query().filter().first()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||||
|
|
||||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||||
|
|
||||||
|
# Verify it returns a dict (not a Mock)
|
||||||
|
assert isinstance(theme, dict)
|
||||||
assert theme["theme_name"] == "default"
|
assert theme["theme_name"] == "default"
|
||||||
assert "colors" in theme
|
assert "colors" in theme
|
||||||
assert "fonts" in theme
|
assert "fonts" in theme
|
||||||
@@ -116,11 +135,13 @@ class TestThemeContextManager:
|
|||||||
def test_get_vendor_theme_inactive_theme(self):
|
def test_get_vendor_theme_inactive_theme(self):
|
||||||
"""Test that inactive themes are not returned."""
|
"""Test that inactive themes are not returned."""
|
||||||
mock_db = Mock()
|
mock_db = Mock()
|
||||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
# Correct filter chain: query().filter().first()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||||
|
|
||||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||||
|
|
||||||
# Should return default theme
|
# Should return default theme (actual dict)
|
||||||
|
assert isinstance(theme, dict)
|
||||||
assert theme["theme_name"] == "default"
|
assert theme["theme_name"] == "default"
|
||||||
|
|
||||||
|
|
||||||
@@ -228,7 +249,12 @@ class TestLoggingMiddleware:
|
|||||||
request.url = Mock(path="/api/vendors")
|
request.url = Mock(path="/api/vendors")
|
||||||
request.client = Mock(host="127.0.0.1")
|
request.client = Mock(host="127.0.0.1")
|
||||||
|
|
||||||
call_next = AsyncMock(return_value=Mock(status_code=200))
|
# Create mock response with actual dict for headers
|
||||||
|
response = Mock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.headers = {} # Use actual dict
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
with patch('middleware.logging_middleware.logger') as mock_logger:
|
with patch('middleware.logging_middleware.logger') as mock_logger:
|
||||||
await middleware.dispatch(request, call_next)
|
await middleware.dispatch(request, call_next)
|
||||||
@@ -297,7 +323,11 @@ class TestLoggingMiddleware:
|
|||||||
request.url = Mock(path="/test")
|
request.url = Mock(path="/test")
|
||||||
request.client = None # No client info
|
request.client = None # No client info
|
||||||
|
|
||||||
call_next = AsyncMock(return_value=Mock(status_code=200))
|
response = Mock()
|
||||||
|
response.status_code = 200
|
||||||
|
response.headers = {}
|
||||||
|
|
||||||
|
call_next = AsyncMock(return_value=response)
|
||||||
|
|
||||||
with patch('middleware.logging_middleware.logger') as mock_logger:
|
with patch('middleware.logging_middleware.logger') as mock_logger:
|
||||||
await middleware.dispatch(request, call_next)
|
await middleware.dispatch(request, call_next)
|
||||||
@@ -342,7 +372,6 @@ class TestLoggingMiddleware:
|
|||||||
|
|
||||||
call_next = slow_call_next
|
call_next = slow_call_next
|
||||||
|
|
||||||
import asyncio
|
|
||||||
with patch('middleware.logging_middleware.logger'):
|
with patch('middleware.logging_middleware.logger'):
|
||||||
result = await middleware.dispatch(request, call_next)
|
result = await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
@@ -375,16 +404,16 @@ class TestRateLimitDecorator:
|
|||||||
async def test_decorator_blocks_exceeding_limit(self):
|
async def test_decorator_blocks_exceeding_limit(self):
|
||||||
"""Test decorator blocks requests exceeding rate limit."""
|
"""Test decorator blocks requests exceeding rate limit."""
|
||||||
@rate_limit(max_requests=2, window_seconds=3600)
|
@rate_limit(max_requests=2, window_seconds=3600)
|
||||||
async def test_endpoint():
|
async def test_endpoint_blocked():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
# First two should succeed
|
# First two should succeed
|
||||||
await test_endpoint()
|
await test_endpoint_blocked()
|
||||||
await test_endpoint()
|
await test_endpoint_blocked()
|
||||||
|
|
||||||
# Third should raise exception
|
# Third should raise exception
|
||||||
with pytest.raises(RateLimitException) as exc_info:
|
with pytest.raises(RateLimitException) as exc_info:
|
||||||
await test_endpoint()
|
await test_endpoint_blocked()
|
||||||
|
|
||||||
assert exc_info.value.status_code == 429
|
assert exc_info.value.status_code == 429
|
||||||
assert "Rate limit exceeded" in exc_info.value.message
|
assert "Rate limit exceeded" in exc_info.value.message
|
||||||
@@ -430,13 +459,13 @@ class TestRateLimitDecorator:
|
|||||||
async def test_decorator_exception_includes_retry_after(self):
|
async def test_decorator_exception_includes_retry_after(self):
|
||||||
"""Test rate limit exception includes retry_after."""
|
"""Test rate limit exception includes retry_after."""
|
||||||
@rate_limit(max_requests=1, window_seconds=60)
|
@rate_limit(max_requests=1, window_seconds=60)
|
||||||
async def test_endpoint():
|
async def test_endpoint_retry():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
await test_endpoint() # Use up limit
|
await test_endpoint_retry() # Use up limit
|
||||||
|
|
||||||
with pytest.raises(RateLimitException) as exc_info:
|
with pytest.raises(RateLimitException) as exc_info:
|
||||||
await test_endpoint()
|
await test_endpoint_retry()
|
||||||
|
|
||||||
assert exc_info.value.details.get("retry_after") == 60
|
assert exc_info.value.details.get("retry_after") == 60
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user