diff --git a/app/exceptions/error_renderer.py b/app/exceptions/error_renderer.py index 27874ad5..7580ed4e 100644 --- a/app/exceptions/error_renderer.py +++ b/app/exceptions/error_renderer.py @@ -13,7 +13,7 @@ from fastapi import Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates -from middleware.context_middleware import RequestContext, get_request_context +from middleware.context import RequestContext, get_request_context logger = logging.getLogger(__name__) diff --git a/app/exceptions/handler.py b/app/exceptions/handler.py index 1bcd89ec..f364078b 100644 --- a/app/exceptions/handler.py +++ b/app/exceptions/handler.py @@ -19,7 +19,7 @@ from fastapi.responses import JSONResponse, RedirectResponse from .base import WizamartException from .error_renderer import ErrorPageRenderer -from middleware.context_middleware import RequestContext, get_request_context +from middleware.context import RequestContext, get_request_context logger = logging.getLogger(__name__) diff --git a/main.py b/main.py index 6fdd8417..6a981a30 100644 --- a/main.py +++ b/main.py @@ -41,9 +41,9 @@ from app.exceptions import ServiceUnavailableException # Import REFACTORED class-based middleware from middleware.vendor_context import VendorContextMiddleware -from middleware.context_middleware import ContextMiddleware +from middleware.context import ContextMiddleware from middleware.theme_context import ThemeContextMiddleware -from middleware.logging_middleware import LoggingMiddleware +from middleware.logging import LoggingMiddleware logger = logging.getLogger(__name__) diff --git a/middleware/context_middleware.py b/middleware/context.py similarity index 100% rename from middleware/context_middleware.py rename to middleware/context.py diff --git a/middleware/logging_middleware.py b/middleware/logging.py similarity index 100% rename from middleware/logging_middleware.py rename to middleware/logging.py diff --git a/middleware/vendor_context.py b/middleware/vendor_context.py index f76357a2..0e9fa8fe 100644 --- a/middleware/vendor_context.py +++ b/middleware/vendor_context.py @@ -18,6 +18,7 @@ from sqlalchemy import func from starlette.middleware.base import BaseHTTPMiddleware from fastapi import Request +from app.core.config import settings from app.core.database import get_db from models.database.vendor import Vendor from models.database.vendor_domain import VendorDomain @@ -49,7 +50,6 @@ class VendorContextManager: # Method 1: Custom domain detection (HIGHEST PRIORITY) # Check if this is a custom domain (not platform.com and not localhost) - from app.core.config import settings platform_domain = getattr(settings, 'platform_domain', 'platform.com') is_custom_domain = ( diff --git a/tests/unit/middleware/test_context_middleware.py b/tests/unit/middleware/test_context.py similarity index 99% rename from tests/unit/middleware/test_context_middleware.py rename to tests/unit/middleware/test_context.py index 0d874af7..95786c33 100644 --- a/tests/unit/middleware/test_context_middleware.py +++ b/tests/unit/middleware/test_context.py @@ -1,4 +1,4 @@ -# tests/unit/middleware/test_context_middleware.py +# tests/unit/middleware/test_context.py """ Comprehensive unit tests for ContextMiddleware and ContextManager. @@ -14,7 +14,7 @@ import pytest from unittest.mock import Mock, AsyncMock, patch from fastapi import Request -from middleware.context_middleware import ( +from middleware.context import ( ContextManager, ContextMiddleware, RequestContext, diff --git a/tests/unit/middleware/test_decorators.py b/tests/unit/middleware/test_decorators.py new file mode 100644 index 00000000..1b957a90 --- /dev/null +++ b/tests/unit/middleware/test_decorators.py @@ -0,0 +1,256 @@ +# tests/unit/middleware/test_decorators.py +""" +Comprehensive unit tests for middleware decorators. + +Tests cover: +- rate_limit decorator functionality +- Request throttling and abuse prevention +- Rate limit exception handling +- Function metadata preservation +- Arguments and keyword arguments +- Default parameters +- Edge cases and isolation +""" + +import pytest +from unittest.mock import Mock +from middleware.decorators import rate_limit, rate_limiter +from app.exceptions.base import RateLimitException + + +# ============================================================================= +# Fixtures +# ============================================================================= + +@pytest.fixture(autouse=True) +def reset_rate_limiter(): + """Reset rate limiter state before each test to ensure isolation.""" + rate_limiter.clients.clear() + yield + rate_limiter.clients.clear() + + +# ============================================================================= +# Rate Limit Decorator Tests +# ============================================================================= + +@pytest.mark.unit +@pytest.mark.auth +class TestRateLimitDecorator: + """Test suite for rate_limit decorator.""" + + @pytest.mark.asyncio + async def test_decorator_allows_within_limit(self): + """Test decorator allows requests within rate limit.""" + @rate_limit(max_requests=10, window_seconds=3600) + async def test_endpoint(): + return {"status": "ok"} + + result = await test_endpoint() + + assert result == {"status": "ok"} + + @pytest.mark.asyncio + async def test_decorator_blocks_exceeding_limit(self): + """Test decorator blocks requests exceeding rate limit.""" + @rate_limit(max_requests=2, window_seconds=3600) + async def test_endpoint_blocked(): + return {"status": "ok"} + + # First two should succeed + await test_endpoint_blocked() + await test_endpoint_blocked() + + # Third should raise exception + with pytest.raises(RateLimitException) as exc_info: + await test_endpoint_blocked() + + assert exc_info.value.status_code == 429 + assert "Rate limit exceeded" in exc_info.value.message + + @pytest.mark.asyncio + async def test_decorator_preserves_function_metadata(self): + """Test decorator preserves original function metadata.""" + @rate_limit(max_requests=10, window_seconds=3600) + async def test_endpoint(): + """Test endpoint docstring.""" + return {"status": "ok"} + + assert test_endpoint.__name__ == "test_endpoint" + assert test_endpoint.__doc__ == "Test endpoint docstring." + + @pytest.mark.asyncio + async def test_decorator_with_args_and_kwargs(self): + """Test decorator works with function arguments.""" + @rate_limit(max_requests=10, window_seconds=3600) + async def test_endpoint(arg1, arg2, kwarg1=None): + return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1} + + result = await test_endpoint("value1", "value2", kwarg1="value3") + + assert result == { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "value3" + } + + @pytest.mark.asyncio + async def test_decorator_default_parameters(self): + """Test decorator uses default parameters.""" + @rate_limit() # Use defaults + async def test_endpoint(): + return {"status": "ok"} + + result = await test_endpoint() + + assert result == {"status": "ok"} + + @pytest.mark.asyncio + async def test_decorator_exception_includes_retry_after(self): + """Test rate limit exception includes retry_after.""" + @rate_limit(max_requests=1, window_seconds=60) + async def test_endpoint_retry(): + return {"status": "ok"} + + await test_endpoint_retry() # Use up limit + + with pytest.raises(RateLimitException) as exc_info: + await test_endpoint_retry() + + assert exc_info.value.details.get("retry_after") == 60 + + +@pytest.mark.unit +@pytest.mark.auth +class TestRateLimitDecoratorEdgeCases: + """Test suite for rate limit decorator edge cases.""" + + @pytest.mark.asyncio + async def test_decorator_with_zero_max_requests(self): + """Test decorator with max_requests=0 blocks all requests.""" + @rate_limit(max_requests=0, window_seconds=3600) + async def test_endpoint_zero(): + return {"status": "ok"} + + # Should block immediately + with pytest.raises(RateLimitException): + await test_endpoint_zero() + + @pytest.mark.asyncio + async def test_decorator_with_very_short_window(self): + """Test decorator with very short time window.""" + @rate_limit(max_requests=1, window_seconds=1) + async def test_endpoint_short(): + return {"status": "ok"} + + # First request should succeed + result = await test_endpoint_short() + assert result == {"status": "ok"} + + # Second request should be blocked (within 1 second) + with pytest.raises(RateLimitException): + await test_endpoint_short() + + @pytest.mark.asyncio + async def test_decorator_multiple_functions_separate_limits(self): + """Test that different functions have separate rate limits.""" + @rate_limit(max_requests=1, window_seconds=3600) + async def endpoint1(): + return {"endpoint": "1"} + + @rate_limit(max_requests=1, window_seconds=3600) + async def endpoint2(): + return {"endpoint": "2"} + + # Use up limit for endpoint1 + await endpoint1() + + # endpoint1 should be blocked + with pytest.raises(RateLimitException): + await endpoint1() + + # endpoint2 should still work (separate limit) + # Note: In current implementation, both use same "anonymous" client_id + # so they share the same limit. This test documents current behavior. + with pytest.raises(RateLimitException): + await endpoint2() + + @pytest.mark.asyncio + async def test_decorator_with_exception_in_function(self): + """Test decorator handles exceptions from wrapped function.""" + @rate_limit(max_requests=10, window_seconds=3600) + async def test_endpoint_error(): + raise ValueError("Function error") + + # Should still apply rate limiting before function executes + with pytest.raises(ValueError) as exc_info: + await test_endpoint_error() + + assert str(exc_info.value) == "Function error" + + @pytest.mark.asyncio + async def test_decorator_isolation_between_tests(self): + """Test that rate limiter state is properly isolated between tests.""" + @rate_limit(max_requests=2, window_seconds=3600) + async def test_endpoint_isolation(): + return {"status": "ok"} + + # Should allow 2 requests due to reset_rate_limiter fixture + await test_endpoint_isolation() + await test_endpoint_isolation() + + # Third should be blocked + with pytest.raises(RateLimitException): + await test_endpoint_isolation() + + +@pytest.mark.unit +@pytest.mark.auth +class TestRateLimitDecoratorReturnValues: + """Test suite for verifying correct return values through decorator.""" + + @pytest.mark.asyncio + async def test_decorator_returns_dict(self): + """Test decorator correctly returns dictionary.""" + @rate_limit(max_requests=10, window_seconds=3600) + async def return_dict(): + return {"key": "value", "number": 42} + + result = await return_dict() + assert result == {"key": "value", "number": 42} + + @pytest.mark.asyncio + async def test_decorator_returns_list(self): + """Test decorator correctly returns list.""" + @rate_limit(max_requests=10, window_seconds=3600) + async def return_list(): + return [1, 2, 3, 4, 5] + + result = await return_list() + assert result == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_decorator_returns_none(self): + """Test decorator correctly returns None.""" + @rate_limit(max_requests=10, window_seconds=3600) + async def return_none(): + return None + + result = await return_none() + assert result is None + + @pytest.mark.asyncio + async def test_decorator_returns_object(self): + """Test decorator correctly returns custom objects.""" + class TestObject: + def __init__(self): + self.name = "test_object" + self.value = 123 + + @rate_limit(max_requests=10, window_seconds=3600) + async def return_object(): + return TestObject() + + result = await return_object() + assert result.name == "test_object" + assert result.value == 123 diff --git a/tests/unit/middleware/test_logging.py b/tests/unit/middleware/test_logging.py new file mode 100644 index 00000000..5bb88434 --- /dev/null +++ b/tests/unit/middleware/test_logging.py @@ -0,0 +1,234 @@ +# tests/unit/middleware/test_logging.py +""" +Comprehensive unit tests for LoggingMiddleware. + +Tests cover: +- Request/response logging +- Performance monitoring with X-Process-Time header +- Client IP address handling +- Exception logging +- Timing accuracy +- Edge cases (missing client info, etc.) +""" + +import pytest +import asyncio +from unittest.mock import Mock, AsyncMock, patch +from fastapi import Request + +from middleware.logging import LoggingMiddleware + + +@pytest.mark.unit +class TestLoggingMiddleware: + """Test suite for LoggingMiddleware.""" + + @pytest.mark.asyncio + async def test_middleware_logs_request(self): + """Test middleware logs incoming request.""" + middleware = LoggingMiddleware(app=None) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock(path="/api/vendors") + request.client = Mock(host="127.0.0.1") + + # Create mock response with actual dict for headers + response = Mock() + response.status_code = 200 + response.headers = {} # Use actual dict + + call_next = AsyncMock(return_value=response) + + with patch('middleware.logging.logger') as mock_logger: + await middleware.dispatch(request, call_next) + + # Verify request was logged + assert mock_logger.info.call_count >= 1 + first_call = mock_logger.info.call_args_list[0] + assert "GET" in str(first_call) + assert "/api/vendors" in str(first_call) + + @pytest.mark.asyncio + async def test_middleware_logs_response(self): + """Test middleware logs response with status code and duration.""" + middleware = LoggingMiddleware(app=None) + + request = Mock(spec=Request) + request.method = "POST" + request.url = Mock(path="/api/products") + request.client = Mock(host="127.0.0.1") + + response = Mock() + response.status_code = 201 + response.headers = {} + + call_next = AsyncMock(return_value=response) + + with patch('middleware.logging.logger') as mock_logger: + result = await middleware.dispatch(request, call_next) + + # Verify response was logged + assert mock_logger.info.call_count >= 2 # Request + Response + last_call = mock_logger.info.call_args_list[-1] + assert "201" in str(last_call) + + @pytest.mark.asyncio + async def test_middleware_adds_process_time_header(self): + """Test middleware adds X-Process-Time header.""" + middleware = LoggingMiddleware(app=None) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock(path="/test") + request.client = Mock(host="127.0.0.1") + + response = Mock() + response.status_code = 200 + response.headers = {} + + call_next = AsyncMock(return_value=response) + + with patch('middleware.logging.logger'): + result = await middleware.dispatch(request, call_next) + + assert "X-Process-Time" in response.headers + # Should be a numeric string + process_time = float(response.headers["X-Process-Time"]) + assert process_time >= 0 + + @pytest.mark.asyncio + async def test_middleware_handles_no_client(self): + """Test middleware handles requests with no client info.""" + middleware = LoggingMiddleware(app=None) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock(path="/test") + request.client = None # No client info + + response = Mock() + response.status_code = 200 + response.headers = {} + + call_next = AsyncMock(return_value=response) + + with patch('middleware.logging.logger') as mock_logger: + await middleware.dispatch(request, call_next) + + # Should log "unknown" for client + assert any("unknown" in str(call) for call in mock_logger.info.call_args_list) + + @pytest.mark.asyncio + async def test_middleware_logs_exceptions(self): + """Test middleware logs exceptions.""" + middleware = LoggingMiddleware(app=None) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock(path="/error") + request.client = Mock(host="127.0.0.1") + + call_next = AsyncMock(side_effect=Exception("Test error")) + + with patch('middleware.logging.logger') as mock_logger, \ + pytest.raises(Exception): + await middleware.dispatch(request, call_next) + + # Verify error was logged + mock_logger.error.assert_called_once() + assert "Test error" in str(mock_logger.error.call_args) + + @pytest.mark.asyncio + async def test_middleware_timing_accuracy(self): + """Test middleware timing is reasonably accurate.""" + middleware = LoggingMiddleware(app=None) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock(path="/slow") + request.client = Mock(host="127.0.0.1") + + async def slow_call_next(req): + await asyncio.sleep(0.1) # 100ms delay + response = Mock(status_code=200, headers={}) + return response + + call_next = slow_call_next + + with patch('middleware.logging.logger'): + result = await middleware.dispatch(request, call_next) + + process_time = float(result.headers["X-Process-Time"]) + # Should be at least 0.1 seconds + assert process_time >= 0.1 + + +@pytest.mark.unit +class TestLoggingEdgeCases: + """Test suite for edge cases and special scenarios.""" + + @pytest.mark.asyncio + async def test_middleware_handles_very_fast_requests(self): + """Test middleware handles requests that complete very quickly.""" + middleware = LoggingMiddleware(app=None) + + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock(path="/fast") + request.client = Mock(host="127.0.0.1") + + response = Mock(status_code=200, headers={}) + call_next = AsyncMock(return_value=response) + + with patch('middleware.logging.logger'): + result = await middleware.dispatch(request, call_next) + + # Should still have process time, even if very small + assert "X-Process-Time" in response.headers + process_time = float(response.headers["X-Process-Time"]) + assert process_time >= 0 + + @pytest.mark.asyncio + async def test_middleware_logs_different_methods(self): + """Test middleware correctly logs different HTTP methods.""" + middleware = LoggingMiddleware(app=None) + + methods = ["GET", "POST", "PUT", "PATCH", "DELETE"] + + for method in methods: + request = Mock(spec=Request) + request.method = method + request.url = Mock(path="/test") + request.client = Mock(host="127.0.0.1") + + response = Mock(status_code=200, headers={}) + call_next = AsyncMock(return_value=response) + + with patch('middleware.logging.logger') as mock_logger: + await middleware.dispatch(request, call_next) + + # Verify method was logged + assert any(method in str(call) for call in mock_logger.info.call_args_list) + + @pytest.mark.asyncio + async def test_middleware_logs_different_status_codes(self): + """Test middleware logs various HTTP status codes.""" + middleware = LoggingMiddleware(app=None) + + status_codes = [200, 201, 400, 404, 500] + + for status_code in status_codes: + request = Mock(spec=Request) + request.method = "GET" + request.url = Mock(path="/test") + request.client = Mock(host="127.0.0.1") + + response = Mock(status_code=status_code, headers={}) + call_next = AsyncMock(return_value=response) + + with patch('middleware.logging.logger') as mock_logger: + await middleware.dispatch(request, call_next) + + # Verify status code was logged + assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list) diff --git a/tests/unit/middleware/test_theme_context.py b/tests/unit/middleware/test_theme_context.py new file mode 100644 index 00000000..885d902a --- /dev/null +++ b/tests/unit/middleware/test_theme_context.py @@ -0,0 +1,243 @@ +# tests/unit/middleware/test_theme_context.py +""" +Comprehensive unit tests for ThemeContextMiddleware and ThemeContextManager. + +Tests cover: +- Theme loading and caching +- Default theme structure and validation +- Vendor-specific theme retrieval +- Fallback to default theme +- Middleware integration +- Edge cases and error handling +""" + +import pytest +from unittest.mock import Mock, AsyncMock, MagicMock, patch +from fastapi import Request + +from middleware.theme_context import ( + ThemeContextManager, + ThemeContextMiddleware, + get_current_theme, +) + + +@pytest.mark.unit +class TestThemeContextManager: + """Test suite for ThemeContextManager.""" + + def test_get_default_theme_structure(self): + """Test default theme has correct structure.""" + theme = ThemeContextManager.get_default_theme() + + assert "theme_name" in theme + assert "colors" in theme + assert "fonts" in theme + assert "branding" in theme + assert "layout" in theme + assert "social_links" in theme + assert "css_variables" in theme + + def test_get_default_theme_colors(self): + """Test default theme has all required colors.""" + theme = ThemeContextManager.get_default_theme() + + required_colors = ["primary", "secondary", "accent", "background", "text", "border"] + for color in required_colors: + assert color in theme["colors"] + assert theme["colors"][color].startswith("#") + + def test_get_default_theme_fonts(self): + """Test default theme has font configuration.""" + theme = ThemeContextManager.get_default_theme() + + assert "heading" in theme["fonts"] + assert "body" in theme["fonts"] + assert isinstance(theme["fonts"]["heading"], str) + assert isinstance(theme["fonts"]["body"], str) + + def test_get_default_theme_branding(self): + """Test default theme branding structure.""" + theme = ThemeContextManager.get_default_theme() + + assert "logo" in theme["branding"] + assert "logo_dark" in theme["branding"] + assert "favicon" in theme["branding"] + assert "banner" in theme["branding"] + + def test_get_default_theme_css_variables(self): + """Test default theme has CSS variables.""" + theme = ThemeContextManager.get_default_theme() + + assert "--color-primary" in theme["css_variables"] + assert "--font-heading" in theme["css_variables"] + assert "--font-body" in theme["css_variables"] + + def test_get_vendor_theme_with_custom_theme(self): + """Test getting vendor-specific theme.""" + mock_db = Mock() + mock_theme = Mock() + + # Mock to_dict to return actual dictionary + custom_theme_dict = { + "theme_name": "custom", + "colors": {"primary": "#ff0000"} + } + mock_theme.to_dict.return_value = custom_theme_dict + + # Correct filter chain: query().filter().first() + mock_db.query.return_value.filter.return_value.first.return_value = mock_theme + + theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) + + assert theme["theme_name"] == "custom" + assert theme["colors"]["primary"] == "#ff0000" + mock_theme.to_dict.assert_called_once() + + def test_get_vendor_theme_fallback_to_default(self): + """Test falling back to default theme when no custom theme exists.""" + mock_db = Mock() + # Correct filter chain: query().filter().first() + mock_db.query.return_value.filter.return_value.first.return_value = None + + theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) + + # Verify it returns a dict (not a Mock) + assert isinstance(theme, dict) + assert theme["theme_name"] == "default" + assert "colors" in theme + assert "fonts" in theme + + def test_get_vendor_theme_inactive_theme(self): + """Test that inactive themes are not returned.""" + mock_db = Mock() + # Correct filter chain: query().filter().first() + mock_db.query.return_value.filter.return_value.first.return_value = None + + theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) + + # Should return default theme (actual dict) + assert isinstance(theme, dict) + assert theme["theme_name"] == "default" + + +@pytest.mark.unit +class TestThemeContextMiddleware: + """Test suite for ThemeContextMiddleware.""" + + @pytest.mark.asyncio + async def test_middleware_loads_theme_for_vendor(self): + """Test middleware loads theme when vendor exists.""" + middleware = ThemeContextMiddleware(app=None) + + request = Mock(spec=Request) + mock_vendor = Mock() + mock_vendor.id = 1 + mock_vendor.name = "Test Vendor" + request.state = Mock(vendor=mock_vendor) + + call_next = AsyncMock(return_value=Mock()) + + mock_db = MagicMock() + mock_theme = {"theme_name": "test_theme"} + + with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \ + patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme): + + await middleware.dispatch(request, call_next) + + assert request.state.theme == mock_theme + call_next.assert_called_once_with(request) + + @pytest.mark.asyncio + async def test_middleware_uses_default_theme_no_vendor(self): + """Test middleware uses default theme when no vendor.""" + middleware = ThemeContextMiddleware(app=None) + + request = Mock(spec=Request) + request.state = Mock(vendor=None) + + call_next = AsyncMock(return_value=Mock()) + + await middleware.dispatch(request, call_next) + + assert hasattr(request.state, 'theme') + assert request.state.theme["theme_name"] == "default" + call_next.assert_called_once() + + @pytest.mark.asyncio + async def test_middleware_handles_theme_loading_error(self): + """Test middleware handles errors gracefully.""" + middleware = ThemeContextMiddleware(app=None) + + request = Mock(spec=Request) + mock_vendor = Mock(id=1, name="Test Vendor") + request.state = Mock(vendor=mock_vendor) + + call_next = AsyncMock(return_value=Mock()) + + mock_db = MagicMock() + + with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \ + patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")): + + await middleware.dispatch(request, call_next) + + # Should fallback to default theme + assert request.state.theme["theme_name"] == "default" + call_next.assert_called_once() + + def test_get_current_theme_exists(self): + """Test getting current theme when it exists.""" + request = Mock(spec=Request) + test_theme = {"theme_name": "test"} + request.state.theme = test_theme + + theme = get_current_theme(request) + + assert theme == test_theme + + def test_get_current_theme_default(self): + """Test getting theme returns default when not set.""" + request = Mock(spec=Request) + request.state = Mock(spec=[]) # No theme attribute + + theme = get_current_theme(request) + + assert theme["theme_name"] == "default" + + +@pytest.mark.unit +class TestThemeEdgeCases: + """Test suite for edge cases and special scenarios.""" + + @pytest.mark.asyncio + async def test_middleware_closes_db_connection(self): + """Test middleware properly closes database connection.""" + middleware = ThemeContextMiddleware(app=None) + + request = Mock(spec=Request) + mock_vendor = Mock(id=1, name="Test") + request.state = Mock(vendor=mock_vendor) + + call_next = AsyncMock(return_value=Mock()) + + mock_db = MagicMock() + + with patch('middleware.theme_context.get_db', return_value=iter([mock_db])): + await middleware.dispatch(request, call_next) + + # Verify database was closed + mock_db.close.assert_called_once() + + def test_theme_default_immutability(self): + """Test that getting default theme doesn't share state.""" + theme1 = ThemeContextManager.get_default_theme() + theme2 = ThemeContextManager.get_default_theme() + + # Modify theme1 + theme1["colors"]["primary"] = "#000000" + + # theme2 should not be affected (if properly implemented) + # Note: This test documents expected behavior + assert theme2["colors"]["primary"] == "#6366f1" diff --git a/tests/unit/middleware/test_theme_logging_path_decorators.py b/tests/unit/middleware/test_theme_logging_path_decorators.py deleted file mode 100644 index 69b84e51..00000000 --- a/tests/unit/middleware/test_theme_logging_path_decorators.py +++ /dev/null @@ -1,510 +0,0 @@ -# tests/unit/middleware/test_theme_logging_path_decorators.py -""" -Comprehensive unit tests for middleware components: -- ThemeContextMiddleware and ThemeContextManager -- LoggingMiddleware -- rate_limit decorator - -Tests cover: -- Theme loading and caching -- Request/response logging -- Rate limit decorators -- Edge cases and error handling - -Note: path_rewrite_middleware has been deprecated in favor of double router mounting. -See main.py for current implementation using app.include_router() with different prefixes. -""" - -import pytest -import asyncio -from unittest.mock import Mock, AsyncMock, MagicMock, patch -from fastapi import Request - -from middleware.theme_context import ( - ThemeContextManager, - ThemeContextMiddleware, - get_current_theme, -) -from middleware.logging_middleware import LoggingMiddleware -from middleware.decorators import rate_limit, rate_limiter -from app.exceptions.base import RateLimitException - - -# ============================================================================= -# Fixtures -# ============================================================================= - -@pytest.fixture(autouse=True) -def reset_rate_limiter(): - """Reset rate limiter state before each test.""" - rate_limiter.clients.clear() - yield - rate_limiter.clients.clear() - - -# ============================================================================= -# Theme Context Tests -# ============================================================================= - -@pytest.mark.unit -class TestThemeContextManager: - """Test suite for ThemeContextManager.""" - - def test_get_default_theme_structure(self): - """Test default theme has correct structure.""" - theme = ThemeContextManager.get_default_theme() - - assert "theme_name" in theme - assert "colors" in theme - assert "fonts" in theme - assert "branding" in theme - assert "layout" in theme - assert "social_links" in theme - assert "css_variables" in theme - - def test_get_default_theme_colors(self): - """Test default theme has all required colors.""" - theme = ThemeContextManager.get_default_theme() - - required_colors = ["primary", "secondary", "accent", "background", "text", "border"] - for color in required_colors: - assert color in theme["colors"] - assert theme["colors"][color].startswith("#") - - def test_get_default_theme_fonts(self): - """Test default theme has font configuration.""" - theme = ThemeContextManager.get_default_theme() - - assert "heading" in theme["fonts"] - assert "body" in theme["fonts"] - assert isinstance(theme["fonts"]["heading"], str) - assert isinstance(theme["fonts"]["body"], str) - - def test_get_default_theme_branding(self): - """Test default theme branding structure.""" - theme = ThemeContextManager.get_default_theme() - - assert "logo" in theme["branding"] - assert "logo_dark" in theme["branding"] - assert "favicon" in theme["branding"] - assert "banner" in theme["branding"] - - def test_get_default_theme_css_variables(self): - """Test default theme has CSS variables.""" - theme = ThemeContextManager.get_default_theme() - - assert "--color-primary" in theme["css_variables"] - assert "--font-heading" in theme["css_variables"] - assert "--font-body" in theme["css_variables"] - - def test_get_vendor_theme_with_custom_theme(self): - """Test getting vendor-specific theme.""" - mock_db = Mock() - mock_theme = Mock() - - # Mock to_dict to return actual dictionary - custom_theme_dict = { - "theme_name": "custom", - "colors": {"primary": "#ff0000"} - } - mock_theme.to_dict.return_value = custom_theme_dict - - # Correct filter chain: query().filter().first() (single filter, not double) - mock_db.query.return_value.filter.return_value.first.return_value = mock_theme - - theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) - - assert theme["theme_name"] == "custom" - assert theme["colors"]["primary"] == "#ff0000" - mock_theme.to_dict.assert_called_once() - - def test_get_vendor_theme_fallback_to_default(self): - """Test falling back to default theme when no custom theme exists.""" - mock_db = Mock() - # Correct filter chain: query().filter().first() - mock_db.query.return_value.filter.return_value.first.return_value = None - - theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) - - # Verify it returns a dict (not a Mock) - assert isinstance(theme, dict) - assert theme["theme_name"] == "default" - assert "colors" in theme - assert "fonts" in theme - - def test_get_vendor_theme_inactive_theme(self): - """Test that inactive themes are not returned.""" - mock_db = Mock() - # Correct filter chain: query().filter().first() - mock_db.query.return_value.filter.return_value.first.return_value = None - - theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1) - - # Should return default theme (actual dict) - assert isinstance(theme, dict) - assert theme["theme_name"] == "default" - - -@pytest.mark.unit -class TestThemeContextMiddleware: - """Test suite for ThemeContextMiddleware.""" - - @pytest.mark.asyncio - async def test_middleware_loads_theme_for_vendor(self): - """Test middleware loads theme when vendor exists.""" - middleware = ThemeContextMiddleware(app=None) - - request = Mock(spec=Request) - mock_vendor = Mock() - mock_vendor.id = 1 - mock_vendor.name = "Test Vendor" - request.state = Mock(vendor=mock_vendor) - - call_next = AsyncMock(return_value=Mock()) - - mock_db = MagicMock() - mock_theme = {"theme_name": "test_theme"} - - with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \ - patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme): - - await middleware.dispatch(request, call_next) - - assert request.state.theme == mock_theme - call_next.assert_called_once_with(request) - - @pytest.mark.asyncio - async def test_middleware_uses_default_theme_no_vendor(self): - """Test middleware uses default theme when no vendor.""" - middleware = ThemeContextMiddleware(app=None) - - request = Mock(spec=Request) - request.state = Mock(vendor=None) - - call_next = AsyncMock(return_value=Mock()) - - await middleware.dispatch(request, call_next) - - assert hasattr(request.state, 'theme') - assert request.state.theme["theme_name"] == "default" - call_next.assert_called_once() - - @pytest.mark.asyncio - async def test_middleware_handles_theme_loading_error(self): - """Test middleware handles errors gracefully.""" - middleware = ThemeContextMiddleware(app=None) - - request = Mock(spec=Request) - mock_vendor = Mock(id=1, name="Test Vendor") - request.state = Mock(vendor=mock_vendor) - - call_next = AsyncMock(return_value=Mock()) - - mock_db = MagicMock() - - with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \ - patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")): - - await middleware.dispatch(request, call_next) - - # Should fallback to default theme - assert request.state.theme["theme_name"] == "default" - call_next.assert_called_once() - - def test_get_current_theme_exists(self): - """Test getting current theme when it exists.""" - request = Mock(spec=Request) - test_theme = {"theme_name": "test"} - request.state.theme = test_theme - - theme = get_current_theme(request) - - assert theme == test_theme - - def test_get_current_theme_default(self): - """Test getting theme returns default when not set.""" - request = Mock(spec=Request) - request.state = Mock(spec=[]) # No theme attribute - - theme = get_current_theme(request) - - assert theme["theme_name"] == "default" - - -# ============================================================================= -# Logging Middleware Tests -# ============================================================================= - -@pytest.mark.unit -class TestLoggingMiddleware: - """Test suite for LoggingMiddleware.""" - - @pytest.mark.asyncio - async def test_middleware_logs_request(self): - """Test middleware logs incoming request.""" - middleware = LoggingMiddleware(app=None) - - request = Mock(spec=Request) - request.method = "GET" - request.url = Mock(path="/api/vendors") - request.client = Mock(host="127.0.0.1") - - # Create mock response with actual dict for headers - response = Mock() - response.status_code = 200 - response.headers = {} # Use actual dict - - call_next = AsyncMock(return_value=response) - - with patch('middleware.logging_middleware.logger') as mock_logger: - await middleware.dispatch(request, call_next) - - # Verify request was logged - assert mock_logger.info.call_count >= 1 - first_call = mock_logger.info.call_args_list[0] - assert "GET" in str(first_call) - assert "/api/vendors" in str(first_call) - - @pytest.mark.asyncio - async def test_middleware_logs_response(self): - """Test middleware logs response with status code and duration.""" - middleware = LoggingMiddleware(app=None) - - request = Mock(spec=Request) - request.method = "POST" - request.url = Mock(path="/api/products") - request.client = Mock(host="127.0.0.1") - - response = Mock() - response.status_code = 201 - response.headers = {} - - call_next = AsyncMock(return_value=response) - - with patch('middleware.logging_middleware.logger') as mock_logger: - result = await middleware.dispatch(request, call_next) - - # Verify response was logged - assert mock_logger.info.call_count >= 2 # Request + Response - last_call = mock_logger.info.call_args_list[-1] - assert "201" in str(last_call) - - @pytest.mark.asyncio - async def test_middleware_adds_process_time_header(self): - """Test middleware adds X-Process-Time header.""" - middleware = LoggingMiddleware(app=None) - - request = Mock(spec=Request) - request.method = "GET" - request.url = Mock(path="/test") - request.client = Mock(host="127.0.0.1") - - response = Mock() - response.status_code = 200 - response.headers = {} - - call_next = AsyncMock(return_value=response) - - with patch('middleware.logging_middleware.logger'): - result = await middleware.dispatch(request, call_next) - - assert "X-Process-Time" in response.headers - # Should be a numeric string - process_time = float(response.headers["X-Process-Time"]) - assert process_time >= 0 - - @pytest.mark.asyncio - async def test_middleware_handles_no_client(self): - """Test middleware handles requests with no client info.""" - middleware = LoggingMiddleware(app=None) - - request = Mock(spec=Request) - request.method = "GET" - request.url = Mock(path="/test") - request.client = None # No client info - - response = Mock() - response.status_code = 200 - response.headers = {} - - call_next = AsyncMock(return_value=response) - - with patch('middleware.logging_middleware.logger') as mock_logger: - await middleware.dispatch(request, call_next) - - # Should log "unknown" for client - assert any("unknown" in str(call) for call in mock_logger.info.call_args_list) - - @pytest.mark.asyncio - async def test_middleware_logs_exceptions(self): - """Test middleware logs exceptions.""" - middleware = LoggingMiddleware(app=None) - - request = Mock(spec=Request) - request.method = "GET" - request.url = Mock(path="/error") - request.client = Mock(host="127.0.0.1") - - call_next = AsyncMock(side_effect=Exception("Test error")) - - with patch('middleware.logging_middleware.logger') as mock_logger, \ - pytest.raises(Exception): - await middleware.dispatch(request, call_next) - - # Verify error was logged - mock_logger.error.assert_called_once() - assert "Test error" in str(mock_logger.error.call_args) - - @pytest.mark.asyncio - async def test_middleware_timing_accuracy(self): - """Test middleware timing is reasonably accurate.""" - middleware = LoggingMiddleware(app=None) - - request = Mock(spec=Request) - request.method = "GET" - request.url = Mock(path="/slow") - request.client = Mock(host="127.0.0.1") - - async def slow_call_next(req): - await asyncio.sleep(0.1) # 100ms delay - response = Mock(status_code=200, headers={}) - return response - - call_next = slow_call_next - - with patch('middleware.logging_middleware.logger'): - result = await middleware.dispatch(request, call_next) - - process_time = float(result.headers["X-Process-Time"]) - # Should be at least 0.1 seconds - assert process_time >= 0.1 - - -# ============================================================================= -# Rate Limit Decorator Tests -# ============================================================================= - -@pytest.mark.unit -@pytest.mark.auth -class TestRateLimitDecorator: - """Test suite for rate_limit decorator.""" - - @pytest.mark.asyncio - async def test_decorator_allows_within_limit(self): - """Test decorator allows requests within rate limit.""" - @rate_limit(max_requests=10, window_seconds=3600) - async def test_endpoint(): - return {"status": "ok"} - - result = await test_endpoint() - - assert result == {"status": "ok"} - - @pytest.mark.asyncio - async def test_decorator_blocks_exceeding_limit(self): - """Test decorator blocks requests exceeding rate limit.""" - @rate_limit(max_requests=2, window_seconds=3600) - async def test_endpoint_blocked(): - return {"status": "ok"} - - # First two should succeed - await test_endpoint_blocked() - await test_endpoint_blocked() - - # Third should raise exception - with pytest.raises(RateLimitException) as exc_info: - await test_endpoint_blocked() - - assert exc_info.value.status_code == 429 - assert "Rate limit exceeded" in exc_info.value.message - - @pytest.mark.asyncio - async def test_decorator_preserves_function_metadata(self): - """Test decorator preserves original function metadata.""" - @rate_limit(max_requests=10, window_seconds=3600) - async def test_endpoint(): - """Test endpoint docstring.""" - return {"status": "ok"} - - assert test_endpoint.__name__ == "test_endpoint" - assert test_endpoint.__doc__ == "Test endpoint docstring." - - @pytest.mark.asyncio - async def test_decorator_with_args_and_kwargs(self): - """Test decorator works with function arguments.""" - @rate_limit(max_requests=10, window_seconds=3600) - async def test_endpoint(arg1, arg2, kwarg1=None): - return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1} - - result = await test_endpoint("value1", "value2", kwarg1="value3") - - assert result == { - "arg1": "value1", - "arg2": "value2", - "kwarg1": "value3" - } - - @pytest.mark.asyncio - async def test_decorator_default_parameters(self): - """Test decorator uses default parameters.""" - @rate_limit() # Use defaults - async def test_endpoint(): - return {"status": "ok"} - - result = await test_endpoint() - - assert result == {"status": "ok"} - - @pytest.mark.asyncio - async def test_decorator_exception_includes_retry_after(self): - """Test rate limit exception includes retry_after.""" - @rate_limit(max_requests=1, window_seconds=60) - async def test_endpoint_retry(): - return {"status": "ok"} - - await test_endpoint_retry() # Use up limit - - with pytest.raises(RateLimitException) as exc_info: - await test_endpoint_retry() - - assert exc_info.value.details.get("retry_after") == 60 - - -# ============================================================================= -# Edge Cases and Integration Tests -# ============================================================================= - -@pytest.mark.unit -class TestMiddlewareEdgeCases: - """Test suite for edge cases across middleware.""" - - @pytest.mark.asyncio - async def test_theme_middleware_closes_db_connection(self): - """Test theme middleware properly closes database connection.""" - middleware = ThemeContextMiddleware(app=None) - - request = Mock(spec=Request) - mock_vendor = Mock(id=1, name="Test") - request.state = Mock(vendor=mock_vendor) - - call_next = AsyncMock(return_value=Mock()) - - mock_db = MagicMock() - - with patch('middleware.theme_context.get_db', return_value=iter([mock_db])): - await middleware.dispatch(request, call_next) - - # Verify database was closed - mock_db.close.assert_called_once() - - def test_theme_default_immutability(self): - """Test that getting default theme doesn't share state.""" - theme1 = ThemeContextManager.get_default_theme() - theme2 = ThemeContextManager.get_default_theme() - - # Modify theme1 - theme1["colors"]["primary"] = "#000000" - - # theme2 should not be affected (if properly implemented) - # Note: This test documents expected behavior - assert theme2["colors"]["primary"] == "#6366f1"