# tests/unit/middleware/test_context_middleware.py """ Comprehensive unit tests for ContextMiddleware and ContextManager. Tests cover: - Context detection for API, Admin, Vendor Dashboard, Shop, and Fallback - Clean path usage for correct context detection - Host and path-based context determination - Middleware state injection - Edge cases and error handling """ import pytest from unittest.mock import Mock, AsyncMock, patch from fastapi import Request from middleware.context_middleware import ( ContextManager, ContextMiddleware, RequestContext, get_request_context, ) @pytest.mark.unit class TestRequestContextEnum: """Test suite for RequestContext enum.""" def test_request_context_values(self): """Test RequestContext enum has correct values.""" assert RequestContext.API.value == "api" assert RequestContext.ADMIN.value == "admin" assert RequestContext.VENDOR_DASHBOARD.value == "vendor" assert RequestContext.SHOP.value == "shop" assert RequestContext.FALLBACK.value == "fallback" def test_request_context_types(self): """Test RequestContext enum values are strings.""" for context in RequestContext: assert isinstance(context.value, str) @pytest.mark.unit class TestContextManagerDetection: """Test suite for ContextManager.detect_context().""" # ======================================================================== # API Context Tests (Highest Priority) # ======================================================================== def test_detect_api_context(self): """Test API context detection.""" request = Mock(spec=Request) request.url = Mock(path="/api/v1/vendors") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/api/v1/vendors") context = ContextManager.detect_context(request) assert context == RequestContext.API def test_detect_api_context_nested_path(self): """Test API context detection with nested path.""" request = Mock(spec=Request) request.url = Mock(path="/api/v1/vendors/123/products") request.headers = {"host": "platform.com"} request.state = Mock(clean_path="/api/v1/vendors/123/products") context = ContextManager.detect_context(request) assert context == RequestContext.API def test_detect_api_context_with_clean_path(self): """Test API context detection uses clean_path when available.""" request = Mock(spec=Request) request.url = Mock(path="/vendor/testvendor/api/products") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/api/products") context = ContextManager.detect_context(request) assert context == RequestContext.API # ======================================================================== # Admin Context Tests # ======================================================================== def test_detect_admin_context_from_subdomain(self): """Test admin context detection from subdomain.""" request = Mock(spec=Request) request.url = Mock(path="/dashboard") request.headers = {"host": "admin.platform.com"} request.state = Mock(clean_path="/dashboard") context = ContextManager.detect_context(request) assert context == RequestContext.ADMIN def test_detect_admin_context_from_path(self): """Test admin context detection from path.""" request = Mock(spec=Request) request.url = Mock(path="/admin/dashboard") request.headers = {"host": "platform.com"} request.state = Mock(clean_path="/admin/dashboard") context = ContextManager.detect_context(request) assert context == RequestContext.ADMIN def test_detect_admin_context_with_port(self): """Test admin context detection with port number.""" request = Mock(spec=Request) request.url = Mock(path="/dashboard") request.headers = {"host": "admin.localhost:8000"} request.state = Mock(clean_path="/dashboard") context = ContextManager.detect_context(request) assert context == RequestContext.ADMIN def test_detect_admin_context_nested_path(self): """Test admin context detection with nested admin path.""" request = Mock(spec=Request) request.url = Mock(path="/admin/vendors/list") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/admin/vendors/list") context = ContextManager.detect_context(request) assert context == RequestContext.ADMIN # ======================================================================== # Vendor Dashboard Context Tests # ======================================================================== def test_detect_vendor_dashboard_context(self): """Test vendor dashboard context detection.""" request = Mock(spec=Request) request.url = Mock(path="/vendor/testvendor/dashboard") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/dashboard") context = ContextManager.detect_context(request) assert context == RequestContext.VENDOR_DASHBOARD def test_detect_vendor_dashboard_context_direct_path(self): """Test vendor dashboard with direct /vendor/ path.""" request = Mock(spec=Request) request.url = Mock(path="/vendor/settings") request.headers = {"host": "testvendor.platform.com"} request.state = Mock(clean_path="/vendor/settings") context = ContextManager.detect_context(request) assert context == RequestContext.VENDOR_DASHBOARD def test_not_detect_vendors_plural_as_dashboard(self): """Test that /vendors/ path is not detected as vendor dashboard.""" request = Mock(spec=Request) request.url = Mock(path="/vendors/testvendor/shop") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/shop") # Should not be vendor dashboard context = ContextManager.detect_context(request) assert context != RequestContext.VENDOR_DASHBOARD # ======================================================================== # Shop Context Tests # ======================================================================== def test_detect_shop_context_with_vendor_state(self): """Test shop context detection when vendor exists in request state.""" request = Mock(spec=Request) request.url = Mock(path="/products") request.headers = {"host": "testvendor.platform.com"} mock_vendor = Mock() mock_vendor.name = "Test Vendor" request.state = Mock(clean_path="/products", vendor=mock_vendor) context = ContextManager.detect_context(request) assert context == RequestContext.SHOP def test_detect_shop_context_from_shop_path(self): """Test shop context detection from /shop/ path.""" request = Mock(spec=Request) request.url = Mock(path="/shop/products") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/shop/products", vendor=None) context = ContextManager.detect_context(request) assert context == RequestContext.SHOP def test_detect_shop_context_custom_domain(self): """Test shop context with custom domain and vendor.""" request = Mock(spec=Request) request.url = Mock(path="/products") request.headers = {"host": "customdomain.com"} mock_vendor = Mock(name="Custom Vendor") request.state = Mock(clean_path="/products", vendor=mock_vendor) context = ContextManager.detect_context(request) assert context == RequestContext.SHOP # ======================================================================== # Fallback Context Tests # ======================================================================== def test_detect_fallback_context(self): """Test fallback context for unknown paths.""" request = Mock(spec=Request) request.url = Mock(path="/random/path") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/random/path", vendor=None) context = ContextManager.detect_context(request) assert context == RequestContext.FALLBACK def test_detect_fallback_context_root(self): """Test fallback context for root path.""" request = Mock(spec=Request) request.url = Mock(path="/") request.headers = {"host": "platform.com"} request.state = Mock(clean_path="/", vendor=None) context = ContextManager.detect_context(request) assert context == RequestContext.FALLBACK def test_detect_fallback_context_no_vendor(self): """Test fallback context when no vendor context exists.""" request = Mock(spec=Request) request.url = Mock(path="/about") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/about", vendor=None) context = ContextManager.detect_context(request) assert context == RequestContext.FALLBACK # ======================================================================== # Clean Path Tests # ======================================================================== def test_uses_clean_path_when_available(self): """Test that clean_path is used over original path.""" request = Mock(spec=Request) request.url = Mock(path="/vendor/testvendor/api/products") request.headers = {"host": "localhost"} # clean_path shows the rewritten path request.state = Mock(clean_path="/api/products") context = ContextManager.detect_context(request) # Should detect as API based on clean_path, not original path assert context == RequestContext.API def test_falls_back_to_original_path(self): """Test falls back to original path when clean_path not set.""" request = Mock(spec=Request) request.url = Mock(path="/api/vendors") request.headers = {"host": "localhost"} request.state = Mock(spec=[]) # No clean_path attribute context = ContextManager.detect_context(request) assert context == RequestContext.API # ======================================================================== # Priority Order Tests # ======================================================================== def test_api_has_highest_priority(self): """Test API context takes precedence over admin.""" request = Mock(spec=Request) request.url = Mock(path="/api/admin/users") request.headers = {"host": "admin.platform.com"} request.state = Mock(clean_path="/api/admin/users") context = ContextManager.detect_context(request) # API should win even though it's admin subdomain assert context == RequestContext.API def test_admin_has_priority_over_shop(self): """Test admin context takes precedence over shop.""" request = Mock(spec=Request) request.url = Mock(path="/admin/shops") request.headers = {"host": "localhost"} mock_vendor = Mock() request.state = Mock(clean_path="/admin/shops", vendor=mock_vendor) context = ContextManager.detect_context(request) # Admin should win even though vendor exists assert context == RequestContext.ADMIN def test_vendor_dashboard_priority_over_shop(self): """Test vendor dashboard takes precedence over shop.""" request = Mock(spec=Request) request.url = Mock(path="/vendor/settings") request.headers = {"host": "testvendor.platform.com"} mock_vendor = Mock() request.state = Mock(clean_path="/vendor/settings", vendor=mock_vendor) context = ContextManager.detect_context(request) assert context == RequestContext.VENDOR_DASHBOARD @pytest.mark.unit class TestContextManagerHelpers: """Test suite for ContextManager helper methods.""" def test_is_admin_context_from_subdomain(self): """Test _is_admin_context with admin subdomain.""" request = Mock() assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True def test_is_admin_context_from_path(self): """Test _is_admin_context with admin path.""" request = Mock() assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True def test_is_admin_context_both(self): """Test _is_admin_context with both subdomain and path.""" request = Mock() assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True def test_is_not_admin_context(self): """Test _is_admin_context returns False for non-admin.""" request = Mock() assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False def test_is_vendor_dashboard_context(self): """Test _is_vendor_dashboard_context with /vendor/ path.""" assert ContextManager._is_vendor_dashboard_context("/vendor/settings") is True def test_is_vendor_dashboard_context_nested(self): """Test _is_vendor_dashboard_context with nested vendor path.""" assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True def test_is_not_vendor_dashboard_context_vendors_plural(self): """Test _is_vendor_dashboard_context excludes /vendors/ path.""" assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False def test_is_not_vendor_dashboard_context(self): """Test _is_vendor_dashboard_context returns False for non-vendor paths.""" assert ContextManager._is_vendor_dashboard_context("/shop/products") is False @pytest.mark.unit class TestContextMiddleware: """Test suite for ContextMiddleware.""" @pytest.mark.asyncio async def test_middleware_sets_context(self): """Test middleware successfully sets context in request state.""" middleware = ContextMiddleware(app=None) request = Mock(spec=Request) request.url = Mock(path="/api/vendors") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/api/vendors", vendor=None) call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) assert hasattr(request.state, 'context_type') assert request.state.context_type == RequestContext.API call_next.assert_called_once_with(request) @pytest.mark.asyncio async def test_middleware_sets_admin_context(self): """Test middleware sets admin context.""" middleware = ContextMiddleware(app=None) request = Mock(spec=Request) request.url = Mock(path="/admin/dashboard") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/admin/dashboard") call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) assert request.state.context_type == RequestContext.ADMIN call_next.assert_called_once() @pytest.mark.asyncio async def test_middleware_sets_vendor_dashboard_context(self): """Test middleware sets vendor dashboard context.""" middleware = ContextMiddleware(app=None) request = Mock(spec=Request) request.url = Mock(path="/vendor/settings") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/vendor/settings") call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) assert request.state.context_type == RequestContext.VENDOR_DASHBOARD call_next.assert_called_once() @pytest.mark.asyncio async def test_middleware_sets_shop_context(self): """Test middleware sets shop context.""" middleware = ContextMiddleware(app=None) request = Mock(spec=Request) request.url = Mock(path="/products") request.headers = {"host": "shop.platform.com"} mock_vendor = Mock() request.state = Mock(clean_path="/products", vendor=mock_vendor) call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) assert request.state.context_type == RequestContext.SHOP call_next.assert_called_once() @pytest.mark.asyncio async def test_middleware_sets_fallback_context(self): """Test middleware sets fallback context.""" middleware = ContextMiddleware(app=None) request = Mock(spec=Request) request.url = Mock(path="/random") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/random", vendor=None) call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) assert request.state.context_type == RequestContext.FALLBACK call_next.assert_called_once() @pytest.mark.asyncio async def test_middleware_returns_response(self): """Test middleware returns response from call_next.""" middleware = ContextMiddleware(app=None) request = Mock(spec=Request) request.url = Mock(path="/api/test") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/api/test") expected_response = Mock() call_next = AsyncMock(return_value=expected_response) response = await middleware.dispatch(request, call_next) assert response is expected_response @pytest.mark.unit class TestGetRequestContextHelper: """Test suite for get_request_context helper function.""" def test_get_request_context_exists(self): """Test getting request context when it exists.""" request = Mock(spec=Request) request.state.context_type = RequestContext.API context = get_request_context(request) assert context == RequestContext.API def test_get_request_context_default(self): """Test getting request context returns FALLBACK as default.""" request = Mock(spec=Request) request.state = Mock(spec=[]) # No context_type attribute context = get_request_context(request) assert context == RequestContext.FALLBACK def test_get_request_context_for_all_types(self): """Test getting all context types.""" for expected_context in RequestContext: request = Mock(spec=Request) request.state.context_type = expected_context context = get_request_context(request) assert context == expected_context @pytest.mark.unit class TestEdgeCases: """Test suite for edge cases and error scenarios.""" def test_detect_context_empty_path(self): """Test context detection with empty path.""" request = Mock(spec=Request) request.url = Mock(path="") request.headers = {"host": "localhost"} request.state = Mock(clean_path="", vendor=None) context = ContextManager.detect_context(request) assert context == RequestContext.FALLBACK def test_detect_context_missing_host(self): """Test context detection with missing host header.""" request = Mock(spec=Request) request.url = Mock(path="/shop/products") request.headers = {} request.state = Mock(clean_path="/shop/products", vendor=None) context = ContextManager.detect_context(request) assert context == RequestContext.SHOP def test_detect_context_case_sensitivity(self): """Test that context detection is case-sensitive for paths.""" request = Mock(spec=Request) request.url = Mock(path="/API/vendors") # Uppercase request.headers = {"host": "localhost"} request.state = Mock(clean_path="/API/vendors") context = ContextManager.detect_context(request) # Should NOT match /api/ because it's case-sensitive assert context != RequestContext.API def test_detect_context_path_with_query_params(self): """Test context detection handles path with query parameters.""" request = Mock(spec=Request) request.url = Mock(path="/api/vendors?page=1") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/api/vendors?page=1") # path.startswith should still work context = ContextManager.detect_context(request) assert context == RequestContext.API def test_detect_context_admin_substring(self): """Test that 'admin' substring doesn't trigger false positive.""" request = Mock(spec=Request) request.url = Mock(path="/administration/docs") request.headers = {"host": "localhost"} request.state = Mock(clean_path="/administration/docs") context = ContextManager.detect_context(request) # Should match because path starts with /admin assert context == RequestContext.ADMIN def test_detect_context_no_state_attribute(self): """Test context detection when request has no state.""" request = Mock(spec=Request) request.url = Mock(path="/api/vendors") request.headers = {"host": "localhost"} # No state attribute at all delattr(request, 'state') # Should still work, falling back to url.path with pytest.raises(AttributeError): # This will raise because we're trying to access request.state ContextManager.detect_context(request)