# tests/unit/middleware/test_storefront_access.py """ Unit tests for StorefrontAccessMiddleware. Tests cover: - Passthrough for non-storefront frontend types (ADMIN, STORE, PLATFORM, MERCHANT) - Passthrough for static file requests - Blocking when no store is detected (not_found) - Blocking when store has no subscription (not_activated) - Blocking when subscription is inactive (not_activated) - Passthrough when subscription is active (TRIAL, ACTIVE, PAST_DUE, CANCELLED) - Multi-platform subscription resolution (platform-specific, fallback) - API requests return JSON 403 - Page requests return HTML 403 - Language detection for unavailable page - request.state.subscription and subscription_tier are set on passthrough """ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from fastapi import Request from starlette.responses import JSONResponse from app.modules.enums import FrontendType from middleware.storefront_access import ( MESSAGES, StorefrontAccessMiddleware, _is_static_request, ) # ============================================================================= # Helper: build a mock Request with the right state attributes # ============================================================================= def _make_request( path="/storefront/products", frontend_type=FrontendType.STOREFRONT, store=None, platform=None, language="en", theme=None, ): """Create a mock Request with pre-set state attributes.""" request = Mock(spec=Request) request.url = Mock(path=path) request.state = Mock() request.state.frontend_type = frontend_type request.state.store = store request.state.platform = platform request.state.language = language request.state.theme = theme return request def _make_store(store_id=1, subdomain="testshop", merchant_id=10): """Create a mock Store object.""" store = Mock() store.id = store_id store.subdomain = subdomain store.merchant_id = merchant_id store.name = "Test Shop" return store def _make_platform(platform_id=1): """Create a mock Platform object.""" platform = Mock() platform.id = platform_id platform.code = "oms" return platform def _make_subscription(is_active=True, tier_code="essential"): """Create a mock MerchantSubscription.""" tier = Mock() tier.code = tier_code tier.id = 1 sub = Mock() sub.is_active = is_active sub.tier = tier sub.status = "active" if is_active else "expired" return sub # ============================================================================= # Static request detection # ============================================================================= @pytest.mark.unit class TestIsStaticRequest: """Test suite for _is_static_request helper.""" @pytest.mark.parametrize( "path", [ "/static/css/style.css", "/static/js/app.js", "/uploads/images/photo.jpg", "/health", "/docs", "/redoc", "/openapi.json", "/storefront/favicon.ico", "/some/path/favicon.ico", "/static/storefront/css/tailwind.output.css", ], ) def test_static_paths_detected(self, path): """Test that static/system paths are correctly detected.""" assert _is_static_request(path) is True @pytest.mark.parametrize( "path", [ "/storefront/products", "/storefront/", "/api/v1/storefront/cart", "/storefront/category/shoes", ], ) def test_non_static_paths_not_detected(self, path): """Test that real storefront paths are not flagged as static.""" assert _is_static_request(path) is False @pytest.mark.parametrize( "path", [ "/storefront/logo.png", "/some/path/font.woff2", "/assets/icon.svg", "/file.map", ], ) def test_static_extensions_detected(self, path): """Test that paths ending with static extensions are detected.""" assert _is_static_request(path) is True def test_case_insensitive(self): """Test detection is case-insensitive.""" assert _is_static_request("/STATIC/CSS/STYLE.CSS") is True assert _is_static_request("/Uploads/Image.PNG") is True # ============================================================================= # Middleware passthrough tests (non-storefront) # ============================================================================= @pytest.mark.unit class TestStorefrontAccessMiddlewarePassthrough: """Test that non-storefront requests pass through without checks.""" @pytest.mark.asyncio @pytest.mark.parametrize( "frontend_type", [ FrontendType.ADMIN, FrontendType.STORE, FrontendType.PLATFORM, FrontendType.MERCHANT, ], ) async def test_non_storefront_passes_through(self, frontend_type): """Test non-storefront frontend types are not gated.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(frontend_type=frontend_type) call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) call_next.assert_called_once_with(request) @pytest.mark.asyncio async def test_no_frontend_type_non_storefront_passes_through(self): """Test request with no frontend_type on non-storefront path passes through.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/admin/dashboard") request.state.frontend_type = None call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) call_next.assert_called_once_with(request) @pytest.mark.asyncio async def test_no_frontend_type_storefront_path_blocked(self): """Test request with no frontend_type on storefront path is blocked (fail-closed).""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/storefront/products") request.state.frontend_type = None request.state.store = None call_next = AsyncMock() with patch("app.templates_config.templates") as mock_templates: mock_templates.TemplateResponse.return_value = Mock(status_code=403) await middleware.dispatch(request, call_next) call_next.assert_not_called() @pytest.mark.asyncio async def test_no_frontend_type_storefront_api_blocked_json(self): """Test request with no frontend_type on storefront API returns JSON 403.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/api/v1/storefront/cart") request.state.frontend_type = None call_next = AsyncMock() response = await middleware.dispatch(request, call_next) call_next.assert_not_called() assert isinstance(response, JSONResponse) assert response.status_code == 403 @pytest.mark.asyncio async def test_no_frontend_type_static_passes_through(self): """Test request with no frontend_type on static path passes through.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/static/css/style.css") request.state.frontend_type = None call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) call_next.assert_called_once_with(request) @pytest.mark.asyncio async def test_static_file_passes_through(self): """Test storefront static file requests pass through.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/static/css/style.css") call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) call_next.assert_called_once_with(request) @pytest.mark.asyncio async def test_favicon_passes_through(self): """Test favicon requests pass through.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/storefront/favicon.ico") call_next = AsyncMock(return_value=Mock()) await middleware.dispatch(request, call_next) call_next.assert_called_once_with(request) # ============================================================================= # Blocking tests (no store / no subscription) # ============================================================================= @pytest.mark.unit class TestStorefrontAccessMiddlewareBlocking: """Test that requests are blocked when store/subscription is missing.""" @pytest.mark.asyncio async def test_no_store_returns_not_found(self): """Test 'not_found' when no store is detected.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(store=None) call_next = AsyncMock() with patch.object(middleware, "_render_unavailable") as mock_render: mock_render.return_value = Mock() await middleware.dispatch(request, call_next) mock_render.assert_called_once_with(request, "not_found") call_next.assert_not_called() @pytest.mark.asyncio async def test_no_subscription_returns_not_activated(self): """Test 'not_activated' when store exists but no subscription.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() platform = _make_platform() request = _make_request(store=store, platform=platform) call_next = AsyncMock() mock_db = MagicMock() with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch.object( middleware, "_get_subscription", return_value=None, ), ): with patch.object(middleware, "_render_unavailable") as mock_render: mock_render.return_value = Mock() await middleware.dispatch(request, call_next) mock_render.assert_called_once_with( request, "not_activated", store ) call_next.assert_not_called() @pytest.mark.asyncio async def test_inactive_subscription_returns_not_activated(self): """Test 'not_activated' when subscription exists but is inactive.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() platform = _make_platform() request = _make_request(store=store, platform=platform) call_next = AsyncMock() mock_db = MagicMock() inactive_sub = _make_subscription(is_active=False) with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch.object( middleware, "_get_subscription", return_value=inactive_sub, ), ): with patch.object(middleware, "_render_unavailable") as mock_render: mock_render.return_value = Mock() await middleware.dispatch(request, call_next) mock_render.assert_called_once_with( request, "not_activated", store ) call_next.assert_not_called() @pytest.mark.asyncio async def test_db_session_closed_on_block(self): """Test database session is closed even when request is blocked.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() request = _make_request(store=store) call_next = AsyncMock() mock_db = MagicMock() with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch.object(middleware, "_get_subscription", return_value=None), patch.object( middleware, "_render_unavailable", return_value=Mock() ), ): await middleware.dispatch(request, call_next) mock_db.close.assert_called_once() # ============================================================================= # Active subscription passthrough # ============================================================================= @pytest.mark.unit class TestStorefrontAccessMiddlewareActiveSubscription: """Test passthrough and state injection for active subscriptions.""" @pytest.mark.asyncio async def test_active_subscription_passes_through(self): """Test active subscription lets request through.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() platform = _make_platform() request = _make_request(store=store, platform=platform) call_next = AsyncMock(return_value=Mock()) mock_db = MagicMock() active_sub = _make_subscription(is_active=True, tier_code="professional") with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch.object( middleware, "_get_subscription", return_value=active_sub ), ): await middleware.dispatch(request, call_next) call_next.assert_called_once_with(request) @pytest.mark.asyncio async def test_sets_subscription_on_request_state(self): """Test subscription and tier are stored on request.state.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() platform = _make_platform() request = _make_request(store=store, platform=platform) call_next = AsyncMock(return_value=Mock()) mock_db = MagicMock() active_sub = _make_subscription(is_active=True, tier_code="essential") with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch.object( middleware, "_get_subscription", return_value=active_sub ), ): await middleware.dispatch(request, call_next) assert request.state.subscription is active_sub assert request.state.subscription_tier is active_sub.tier @pytest.mark.asyncio async def test_db_session_closed_on_success(self): """Test database session is closed after successful passthrough.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() request = _make_request(store=store) call_next = AsyncMock(return_value=Mock()) mock_db = MagicMock() active_sub = _make_subscription(is_active=True) with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch.object( middleware, "_get_subscription", return_value=active_sub ), ): await middleware.dispatch(request, call_next) mock_db.close.assert_called_once() # ============================================================================= # Multi-platform subscription resolution # ============================================================================= @pytest.mark.unit class TestGetSubscription: """Test _get_subscription multi-platform resolution logic.""" def test_uses_detected_platform(self): """Test subscription is fetched for the detected platform.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store(merchant_id=10) platform = _make_platform(platform_id=2) request = _make_request(store=store, platform=platform) mock_db = MagicMock() expected_sub = _make_subscription() with patch( "app.modules.billing.services.subscription_service.subscription_service" ) as mock_svc: mock_svc.get_merchant_subscription.return_value = expected_sub result = middleware._get_subscription(mock_db, store, request) mock_svc.get_merchant_subscription.assert_called_once_with( mock_db, 10, 2 ) assert result is expected_sub def test_no_fallback_when_merchant_subscription_none(self): """Test no fallback when get_merchant_subscription returns None.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store(store_id=5, merchant_id=10) platform = _make_platform(platform_id=2) request = _make_request(store=store, platform=platform) mock_db = MagicMock() with patch( "app.modules.billing.services.subscription_service.subscription_service" ) as mock_svc: mock_svc.get_merchant_subscription.return_value = None result = middleware._get_subscription(mock_db, store, request) mock_svc.get_merchant_subscription.assert_called_once_with( mock_db, 10, 2 ) assert result is None def test_no_platform_returns_none(self): """Test when no platform is detected, returns None (no fallback).""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store(store_id=7) request = _make_request(store=store, platform=None) mock_db = MagicMock() with patch( "app.modules.billing.services.subscription_service.subscription_service" ) as mock_svc: result = middleware._get_subscription(mock_db, store, request) mock_svc.get_merchant_subscription.assert_not_called() assert result is None # ============================================================================= # Response rendering tests # ============================================================================= @pytest.mark.unit class TestRenderUnavailable: """Test _render_unavailable response generation.""" def test_api_request_returns_json_403(self): """Test API requests get JSON 403 response.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/api/v1/storefront/cart") response = middleware._render_unavailable(request, "not_activated") assert isinstance(response, JSONResponse) assert response.status_code == 403 assert response.body is not None # Decode JSON body import json body = json.loads(response.body) assert body["error"] == "storefront_not_available" assert body["reason"] == "not_activated" def test_api_not_found_returns_json_403(self): """Test API not_found also gets JSON 403.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/api/v1/storefront/products") response = middleware._render_unavailable(request, "not_found") assert isinstance(response, JSONResponse) assert response.status_code == 403 import json body = json.loads(response.body) assert body["reason"] == "not_found" def test_page_request_renders_template(self): """Test page requests render the unavailable template.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/storefront/products", language="en") mock_template_response = Mock(status_code=403) with patch("app.templates_config.templates") as mock_templates: mock_templates.TemplateResponse.return_value = mock_template_response middleware._render_unavailable( request, "not_activated", store=_make_store() ) mock_templates.TemplateResponse.assert_called_once() call_args = mock_templates.TemplateResponse.call_args assert call_args[0][0] == "storefront/unavailable.html" context = call_args[0][1] assert context["request"] is request assert context["reason"] == "not_activated" assert context["title"] == MESSAGES["not_activated"]["en"]["title"] assert context["message"] == MESSAGES["not_activated"]["en"]["message"] assert context["language"] == "en" assert call_args[1]["status_code"] == 403 @pytest.mark.parametrize("language", ["en", "fr", "de", "lb"]) def test_page_request_uses_correct_language(self, language): """Test unavailable page renders in the detected language.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request( path="/storefront/", language=language ) with patch("app.templates_config.templates") as mock_templates: mock_templates.TemplateResponse.return_value = Mock(status_code=403) middleware._render_unavailable(request, "not_found") context = mock_templates.TemplateResponse.call_args[0][1] assert context["language"] == language assert context["title"] == MESSAGES["not_found"][language]["title"] def test_unsupported_language_falls_back_to_english(self): """Test unsupported language falls back to English.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request( path="/storefront/", language="pt" ) with patch("app.templates_config.templates") as mock_templates: mock_templates.TemplateResponse.return_value = Mock(status_code=403) middleware._render_unavailable(request, "not_activated") context = mock_templates.TemplateResponse.call_args[0][1] assert context["language"] == "en" assert context["title"] == MESSAGES["not_activated"]["en"]["title"] def test_page_request_includes_store_when_provided(self): """Test store object is passed to template when available.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() request = _make_request(path="/storefront/") with patch("app.templates_config.templates") as mock_templates: mock_templates.TemplateResponse.return_value = Mock(status_code=403) middleware._render_unavailable(request, "not_activated", store=store) context = mock_templates.TemplateResponse.call_args[0][1] assert context["store"] is store def test_page_request_store_none_for_not_found(self): """Test store is None for not_found case.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/storefront/") with patch("app.templates_config.templates") as mock_templates: mock_templates.TemplateResponse.return_value = Mock(status_code=403) middleware._render_unavailable(request, "not_found") context = mock_templates.TemplateResponse.call_args[0][1] assert context["store"] is None # ============================================================================= # Messages dict validation # ============================================================================= @pytest.mark.unit class TestMessages: """Validate the MESSAGES dict structure.""" def test_all_reasons_have_all_languages(self): """Test every reason has en, fr, de, lb translations.""" for reason in ("not_found", "not_activated"): assert reason in MESSAGES for lang in ("en", "fr", "de", "lb"): assert lang in MESSAGES[reason], f"Missing {lang} for {reason}" assert "title" in MESSAGES[reason][lang] assert "message" in MESSAGES[reason][lang] def test_messages_are_non_empty_strings(self): """Test all message values are non-empty strings.""" for reason in MESSAGES: for lang in MESSAGES[reason]: for field in ("title", "message"): value = MESSAGES[reason][lang][field] assert isinstance(value, str) assert len(value) > 0 # ============================================================================= # Full integration-style dispatch tests # ============================================================================= @pytest.mark.unit class TestStorefrontAccessMiddlewareDispatchIntegration: """End-to-end dispatch tests exercising the full middleware flow.""" @pytest.mark.asyncio async def test_full_flow_active_subscription(self): """Test full dispatch: store + active subscription → passthrough.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store(merchant_id=10) platform = _make_platform(platform_id=2) request = _make_request( path="/storefront/products", store=store, platform=platform, ) expected_response = Mock() call_next = AsyncMock(return_value=expected_response) mock_db = MagicMock() active_sub = _make_subscription(is_active=True, tier_code="professional") with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch( "app.modules.billing.services.subscription_service.subscription_service" ) as mock_svc, ): mock_svc.get_merchant_subscription.return_value = active_sub result = await middleware.dispatch(request, call_next) assert result is expected_response assert request.state.subscription is active_sub assert request.state.subscription_tier is active_sub.tier call_next.assert_called_once_with(request) mock_db.close.assert_called_once() @pytest.mark.asyncio async def test_full_flow_no_subscription_page_request(self): """Test full dispatch: store + no subscription + page → HTML 403.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() platform = _make_platform() request = _make_request( path="/storefront/", store=store, platform=platform, language="fr", ) call_next = AsyncMock() mock_db = MagicMock() with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch( "app.modules.billing.services.subscription_service.subscription_service" ) as mock_svc, patch("app.templates_config.templates") as mock_templates, ): mock_svc.get_merchant_subscription.return_value = None mock_svc.get_subscription_for_store.return_value = None mock_templates.TemplateResponse.return_value = Mock(status_code=403) await middleware.dispatch(request, call_next) call_next.assert_not_called() mock_templates.TemplateResponse.assert_called_once() context = mock_templates.TemplateResponse.call_args[0][1] assert context["language"] == "fr" assert context["reason"] == "not_activated" mock_db.close.assert_called_once() @pytest.mark.asyncio async def test_full_flow_no_subscription_api_request(self): """Test full dispatch: store + no subscription + API → JSON 403.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() platform = _make_platform() request = _make_request( path="/api/v1/storefront/cart", store=store, platform=platform, ) call_next = AsyncMock() mock_db = MagicMock() with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch( "app.modules.billing.services.subscription_service.subscription_service" ) as mock_svc, ): mock_svc.get_merchant_subscription.return_value = None mock_svc.get_subscription_for_store.return_value = None result = await middleware.dispatch(request, call_next) call_next.assert_not_called() assert isinstance(result, JSONResponse) assert result.status_code == 403 mock_db.close.assert_called_once() @pytest.mark.asyncio async def test_full_flow_no_store_detected(self): """Test full dispatch: no store → not_found response.""" middleware = StorefrontAccessMiddleware(app=None) request = _make_request(path="/storefront/", store=None) call_next = AsyncMock() with patch("app.templates_config.templates") as mock_templates: mock_templates.TemplateResponse.return_value = Mock(status_code=403) await middleware.dispatch(request, call_next) call_next.assert_not_called() context = mock_templates.TemplateResponse.call_args[0][1] assert context["reason"] == "not_found" assert context["store"] is None @pytest.mark.asyncio async def test_db_closed_on_exception(self): """Test database session is closed even when _get_subscription raises.""" middleware = StorefrontAccessMiddleware(app=None) store = _make_store() request = _make_request(store=store) call_next = AsyncMock() mock_db = MagicMock() with ( patch( "middleware.storefront_access.get_db", return_value=iter([mock_db]), ), patch.object( middleware, "_get_subscription", side_effect=Exception("db error"), ), pytest.raises(Exception, match="db error"), ): await middleware.dispatch(request, call_next) mock_db.close.assert_called_once()