Clean up accumulated backward-compat shims, deprecated wrappers, unused aliases, and legacy code across the codebase. Since the platform is not live yet, this establishes a clean baseline. Changes: - Delete deprecated middleware/context.py (RequestContext, get_request_context) - Remove unused factory get_store_email_settings_service() - Remove deprecated pagination_full macro, /admin/platform-homepage route - Remove ConversationResponse, InvoiceSettings* unprefixed aliases - Simplify celery_config.py (remove empty LEGACY_TASK_MODULES) - Standardize billing exceptions: *Error aliases → *Exception names - Consolidate duplicate TierNotFoundError/FeatureNotFoundError classes - Remove deprecated is_admin_request() from Store/PlatformContextManager - Remove is_platform_default field, MediaUploadResponse legacy flat fields - Remove MediaItemResponse.url alias, update JS to use file_url - Update all affected tests and documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
998 lines
35 KiB
Python
998 lines
35 KiB
Python
# tests/unit/middleware/test_platform_context.py
|
|
"""
|
|
Comprehensive unit tests for PlatformContextMiddleware and PlatformContextManager.
|
|
|
|
Tests cover:
|
|
- Platform detection from domains (production) and path prefixes (development)
|
|
- Default platform for main marketing site
|
|
- Path rewriting for routing
|
|
- Database lookup and platform validation
|
|
- Admin, static file, and system endpoint skipping
|
|
- Edge cases and error handling
|
|
|
|
URL Structure:
|
|
- Main marketing site: localhost:9999/ (no prefix) -> 'main' platform
|
|
- Platform sites: localhost:9999/platforms/{code}/ -> specific platform
|
|
- Production: domain-based (oms.lu, loyalty.lu)
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
from fastapi import Request
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.frontend_detector import FrontendDetector
|
|
from middleware.platform_context import (
|
|
DEFAULT_PLATFORM_CODE,
|
|
PlatformContextManager,
|
|
PlatformContextMiddleware,
|
|
get_current_platform,
|
|
require_platform_context,
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.middleware
|
|
class TestPlatformContextManager:
|
|
"""Test suite for PlatformContextManager static methods."""
|
|
|
|
# ========================================================================
|
|
# Platform Context Detection Tests - Domain-Based (Production)
|
|
# ========================================================================
|
|
|
|
def test_detect_domain_based_platform(self):
|
|
"""Test domain-based platform detection for production (e.g., oms.lu)."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "oms.lu"}
|
|
request.url = Mock(path="/pricing")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "domain"
|
|
assert context["domain"] == "oms.lu"
|
|
assert context["host"] == "oms.lu"
|
|
assert context["original_path"] == "/pricing"
|
|
|
|
def test_detect_domain_with_port(self):
|
|
"""Test domain detection with port number."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "loyalty.lu:8443"}
|
|
request.url = Mock(path="/features")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "domain"
|
|
assert context["domain"] == "loyalty.lu"
|
|
|
|
def test_detect_domain_three_level_not_detected(self):
|
|
"""Test that three-level domains (subdomains) are not detected as platform domains."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "store.oms.lu"}
|
|
request.url = Mock(path="/shop")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
# Three-level domains should not be detected as platform domains
|
|
# They could be store subdomains
|
|
assert context is None
|
|
|
|
# ========================================================================
|
|
# Platform Context Detection Tests - Path-Based (Development)
|
|
# ========================================================================
|
|
|
|
def test_detect_path_based_oms_platform(self):
|
|
"""Test path-based detection for /platforms/oms/."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost:9999"}
|
|
request.url = Mock(path="/platforms/oms/pricing")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "path"
|
|
assert context["path_prefix"] == "oms"
|
|
assert context["original_path"] == "/platforms/oms/pricing"
|
|
assert context["clean_path"] == "/pricing"
|
|
|
|
def test_detect_path_based_loyalty_platform(self):
|
|
"""Test path-based detection for /platforms/loyalty/."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost:9999"}
|
|
request.url = Mock(path="/platforms/loyalty/features")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "path"
|
|
assert context["path_prefix"] == "loyalty"
|
|
assert context["clean_path"] == "/features"
|
|
|
|
def test_detect_path_based_platform_root(self):
|
|
"""Test path-based detection for platform root (e.g., /platforms/oms/)."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/platforms/oms/")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "path"
|
|
assert context["path_prefix"] == "oms"
|
|
assert context["clean_path"] == "/"
|
|
|
|
def test_detect_path_based_platform_no_trailing_slash(self):
|
|
"""Test path-based detection without trailing slash."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/platforms/oms")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "path"
|
|
assert context["path_prefix"] == "oms"
|
|
assert context["clean_path"] == "/"
|
|
|
|
def test_detect_path_based_case_insensitive(self):
|
|
"""Test that path-based detection is case-insensitive."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/platforms/OMS/pricing")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["path_prefix"] == "oms" # Lowercased
|
|
|
|
def test_detect_path_based_with_nested_path(self):
|
|
"""Test path-based detection with deeply nested paths."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/platforms/oms/stores/wizamart/shop/products")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "path"
|
|
assert context["path_prefix"] == "oms"
|
|
assert context["clean_path"] == "/stores/wizamart/shop/products"
|
|
|
|
# ========================================================================
|
|
# Platform Context Detection Tests - Default (Main Marketing Site)
|
|
# ========================================================================
|
|
|
|
def test_detect_default_platform_localhost(self):
|
|
"""Test default platform detection for localhost (main marketing site)."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/about")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "default"
|
|
assert context["path_prefix"] == DEFAULT_PLATFORM_CODE
|
|
assert context["clean_path"] == "/about" # No path rewrite for main site
|
|
|
|
def test_detect_default_platform_127_0_0_1(self):
|
|
"""Test default platform detection for 127.0.0.1."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "127.0.0.1:9999"}
|
|
request.url = Mock(path="/faq")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "default"
|
|
assert context["path_prefix"] == DEFAULT_PLATFORM_CODE
|
|
|
|
def test_detect_default_platform_root_path(self):
|
|
"""Test default platform detection for root path."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost:9999"}
|
|
request.url = Mock(path="/")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context is not None
|
|
assert context["detection_method"] == "default"
|
|
assert context["clean_path"] == "/"
|
|
|
|
# ========================================================================
|
|
# Admin Request Skipping Tests
|
|
# ========================================================================
|
|
|
|
def test_skip_admin_subdomain(self):
|
|
"""Test that admin subdomain requests skip platform detection."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "admin.localhost"}
|
|
request.url = Mock(path="/dashboard")
|
|
|
|
assert FrontendDetector.is_admin("admin.localhost", "/dashboard") is True
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
assert context is None
|
|
|
|
def test_skip_admin_path(self):
|
|
"""Test that /admin paths skip platform detection."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/admin/stores")
|
|
|
|
assert FrontendDetector.is_admin("localhost", "/admin/stores") is True
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
assert context is None
|
|
|
|
def test_skip_admin_path_with_port(self):
|
|
"""Test admin detection with port in host."""
|
|
assert FrontendDetector.is_admin("admin.localhost:9999", "/dashboard") is True
|
|
|
|
def test_not_admin_regular_path(self):
|
|
"""Test non-admin path is not detected as admin."""
|
|
assert FrontendDetector.is_admin("localhost", "/shop/products") is False
|
|
|
|
# ========================================================================
|
|
# Static File Detection Tests
|
|
# ========================================================================
|
|
|
|
@pytest.mark.parametrize(
|
|
"path",
|
|
[
|
|
"/static/css/style.css",
|
|
"/static/js/app.js",
|
|
"/media/images/product.png",
|
|
"/assets/logo.svg",
|
|
"/.well-known/security.txt",
|
|
"/favicon.ico",
|
|
"/image.jpg",
|
|
"/style.css",
|
|
"/app.webmanifest",
|
|
"/static/fonts/font.woff2",
|
|
"/media/uploads/file.pdf",
|
|
],
|
|
)
|
|
def test_is_static_file_request(self, path):
|
|
"""Test static file detection for various paths and extensions."""
|
|
request = Mock(spec=Request)
|
|
request.url = Mock(path=path)
|
|
|
|
assert PlatformContextManager.is_static_file_request(request) is True
|
|
|
|
@pytest.mark.parametrize(
|
|
"path",
|
|
[
|
|
"/platforms/oms/pricing",
|
|
"/shop/products",
|
|
"/admin/dashboard",
|
|
"/api/v1/stores",
|
|
"/about",
|
|
"/stores/wizamart/shop",
|
|
],
|
|
)
|
|
def test_is_not_static_file_request(self, path):
|
|
"""Test non-static file paths."""
|
|
request = Mock(spec=Request)
|
|
request.url = Mock(path=path)
|
|
|
|
assert PlatformContextManager.is_static_file_request(request) is False
|
|
|
|
# ========================================================================
|
|
# Platform Database Lookup Tests
|
|
# ========================================================================
|
|
|
|
def test_get_platform_from_domain_context(self):
|
|
"""Test getting platform from domain context."""
|
|
mock_db = Mock(spec=Session)
|
|
mock_platform = Mock()
|
|
mock_platform.id = 1
|
|
mock_platform.code = "oms"
|
|
mock_platform.name = "OMS Platform"
|
|
mock_platform.is_active = True
|
|
|
|
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_platform
|
|
|
|
context = {"detection_method": "domain", "domain": "oms.lu"}
|
|
|
|
platform = PlatformContextManager.get_platform_from_context(mock_db, context)
|
|
|
|
assert platform is mock_platform
|
|
assert platform.code == "oms"
|
|
|
|
def test_get_platform_from_domain_not_found(self):
|
|
"""Test domain lookup when platform not found."""
|
|
mock_db = Mock(spec=Session)
|
|
# Ensure all query chain variants return None for .first()
|
|
query_mock = mock_db.query.return_value
|
|
query_mock.filter.return_value.first.return_value = None
|
|
query_mock.filter.return_value.filter.return_value.first.return_value = None
|
|
|
|
context = {"detection_method": "domain", "domain": "unknown.lu"}
|
|
|
|
platform = PlatformContextManager.get_platform_from_context(mock_db, context)
|
|
|
|
assert platform is None
|
|
|
|
def test_get_platform_from_path_prefix_context(self):
|
|
"""Test getting platform from path prefix context."""
|
|
mock_db = Mock(spec=Session)
|
|
mock_platform = Mock()
|
|
mock_platform.id = 2
|
|
mock_platform.code = "loyalty"
|
|
mock_platform.name = "Loyalty Platform"
|
|
mock_platform.is_active = True
|
|
|
|
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_platform
|
|
|
|
context = {"detection_method": "path", "path_prefix": "loyalty"}
|
|
|
|
platform = PlatformContextManager.get_platform_from_context(mock_db, context)
|
|
|
|
assert platform is mock_platform
|
|
assert platform.code == "loyalty"
|
|
|
|
def test_get_platform_from_default_context(self):
|
|
"""Test getting default platform (main marketing site)."""
|
|
mock_db = Mock(spec=Session)
|
|
mock_platform = Mock()
|
|
mock_platform.id = 3
|
|
mock_platform.code = "main"
|
|
mock_platform.name = "Main Marketing"
|
|
mock_platform.is_active = True
|
|
|
|
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_platform
|
|
|
|
context = {"detection_method": "default", "path_prefix": "main"}
|
|
|
|
platform = PlatformContextManager.get_platform_from_context(mock_db, context)
|
|
|
|
assert platform is mock_platform
|
|
assert platform.code == "main"
|
|
|
|
def test_get_platform_with_no_context(self):
|
|
"""Test getting platform with no context returns None."""
|
|
mock_db = Mock(spec=Session)
|
|
|
|
platform = PlatformContextManager.get_platform_from_context(mock_db, None)
|
|
|
|
assert platform is None
|
|
|
|
def test_get_platform_inactive_not_returned(self):
|
|
"""Test that inactive platforms are not returned."""
|
|
mock_db = Mock(spec=Session)
|
|
# Ensure all query chain variants return None for .first()
|
|
# (primary Platform lookup and StoreDomain/MerchantDomain fallbacks)
|
|
query_mock = mock_db.query.return_value
|
|
query_mock.filter.return_value.first.return_value = None
|
|
query_mock.filter.return_value.filter.return_value.first.return_value = None
|
|
|
|
context = {"detection_method": "domain", "domain": "inactive.lu"}
|
|
|
|
platform = PlatformContextManager.get_platform_from_context(mock_db, context)
|
|
|
|
assert platform is None
|
|
|
|
# ========================================================================
|
|
# Clean Path Extraction Tests
|
|
# ========================================================================
|
|
|
|
def test_extract_clean_path_from_path_context(self):
|
|
"""Test extracting clean path from path-based context."""
|
|
request = Mock(spec=Request)
|
|
request.url = Mock(path="/platforms/oms/pricing")
|
|
|
|
context = {
|
|
"detection_method": "path",
|
|
"clean_path": "/pricing",
|
|
}
|
|
|
|
clean_path = PlatformContextManager.extract_clean_path(request, context)
|
|
|
|
assert clean_path == "/pricing"
|
|
|
|
def test_extract_clean_path_from_domain_context(self):
|
|
"""Test that domain-based context doesn't modify path."""
|
|
request = Mock(spec=Request)
|
|
request.url = Mock(path="/pricing")
|
|
|
|
context = {"detection_method": "domain", "domain": "oms.lu"}
|
|
|
|
clean_path = PlatformContextManager.extract_clean_path(request, context)
|
|
|
|
assert clean_path == "/pricing"
|
|
|
|
def test_extract_clean_path_no_context(self):
|
|
"""Test extracting clean path with no context."""
|
|
request = Mock(spec=Request)
|
|
request.url = Mock(path="/some/path")
|
|
|
|
clean_path = PlatformContextManager.extract_clean_path(request, None)
|
|
|
|
assert clean_path == "/some/path"
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.middleware
|
|
class TestPlatformContextMiddleware:
|
|
"""Test suite for PlatformContextMiddleware ASGI middleware."""
|
|
|
|
# ========================================================================
|
|
# Middleware Skip Conditions Tests
|
|
# ========================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_skips_static_files(self):
|
|
"""Test middleware skips platform detection for static files."""
|
|
middleware = PlatformContextMiddleware(app=AsyncMock())
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/static/css/style.css",
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
assert scope["state"]["platform"] is None
|
|
assert scope["state"]["platform_context"] is None
|
|
assert scope["state"]["platform_clean_path"] == "/static/css/style.css"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_skips_admin_routes(self):
|
|
"""Test middleware skips platform detection for admin routes."""
|
|
middleware = PlatformContextMiddleware(app=AsyncMock())
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/admin/dashboard",
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
assert scope["state"]["platform"] is None
|
|
assert scope["state"]["platform_clean_path"] == "/admin/dashboard"
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"path",
|
|
[
|
|
"/health",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
],
|
|
)
|
|
async def test_middleware_skips_system_endpoints(self, path):
|
|
"""Test middleware skips platform detection for system endpoints."""
|
|
middleware = PlatformContextMiddleware(app=AsyncMock())
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": path,
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
assert scope["state"]["platform"] is None
|
|
assert scope["state"]["platform_clean_path"] == path
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_skips_non_http(self):
|
|
"""Test middleware skips non-HTTP requests (e.g., websocket)."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
scope = {
|
|
"type": "websocket",
|
|
"path": "/ws",
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
# Should pass through without modification
|
|
mock_app.assert_called_once_with(scope, receive, send)
|
|
assert "state" not in scope # No state added for websocket
|
|
|
|
# ========================================================================
|
|
# Path Rewriting Tests
|
|
# ========================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_rewrites_platform_path(self):
|
|
"""Test middleware rewrites /platforms/oms/pricing to /pricing."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
mock_platform = Mock()
|
|
mock_platform.id = 1
|
|
mock_platform.code = "oms"
|
|
mock_platform.name = "OMS"
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/platforms/oms/pricing",
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=mock_platform,
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
# Path should be rewritten for routing
|
|
assert scope["path"] == "/pricing"
|
|
assert scope["state"]["platform_original_path"] == "/platforms/oms/pricing"
|
|
assert scope["state"]["platform_clean_path"] == "/pricing"
|
|
assert scope["state"]["platform"] is mock_platform
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_rewrites_raw_path(self):
|
|
"""Test middleware also rewrites raw_path when present."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
mock_platform = Mock()
|
|
mock_platform.id = 1
|
|
mock_platform.code = "oms"
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/platforms/oms/pricing",
|
|
"raw_path": b"/platforms/oms/pricing",
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=mock_platform,
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
assert scope["path"] == "/pricing"
|
|
assert scope["raw_path"] == b"/pricing"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_no_rewrite_for_domain_based(self):
|
|
"""Test middleware doesn't rewrite path for domain-based detection."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
mock_platform = Mock()
|
|
mock_platform.id = 1
|
|
mock_platform.code = "oms"
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/pricing",
|
|
"headers": [(b"host", b"oms.lu")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=mock_platform,
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
# Path should NOT be rewritten for domain-based detection
|
|
assert scope["path"] == "/pricing"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_no_rewrite_for_default_platform(self):
|
|
"""Test middleware doesn't rewrite path for default (main) platform."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
mock_platform = Mock()
|
|
mock_platform.id = 3
|
|
mock_platform.code = "main"
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/about",
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=mock_platform,
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
# Path should NOT be rewritten for main marketing site
|
|
assert scope["path"] == "/about"
|
|
|
|
# ========================================================================
|
|
# Platform Detection and State Setting Tests
|
|
# ========================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_sets_platform_state(self):
|
|
"""Test middleware sets platform in scope state."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
mock_platform = Mock()
|
|
mock_platform.id = 1
|
|
mock_platform.code = "oms"
|
|
mock_platform.name = "OMS Platform"
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/platforms/oms/",
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=mock_platform,
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
assert scope["state"]["platform"] is mock_platform
|
|
assert scope["state"]["platform_context"] is not None
|
|
assert scope["state"]["platform_context"]["detection_method"] == "path"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_platform_not_found_in_db(self):
|
|
"""Test middleware when platform code not found in database."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/platforms/unknown/pricing",
|
|
"headers": [(b"host", b"localhost:9999")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=None, # Platform not found
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
assert scope["state"]["platform"] is None
|
|
assert scope["state"]["platform_context"] is None
|
|
# Path should NOT be rewritten when platform not found
|
|
assert scope["path"] == "/platforms/unknown/pricing"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_initializes_state_dict(self):
|
|
"""Test middleware initializes state dict if not present."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/static/test.css",
|
|
"headers": [(b"host", b"localhost")],
|
|
# No "state" key
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
await middleware(scope, receive, send)
|
|
|
|
assert "state" in scope
|
|
assert scope["state"]["platform"] is None
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.middleware
|
|
class TestHelperFunctions:
|
|
"""Test suite for helper functions."""
|
|
|
|
def test_get_current_platform_exists(self):
|
|
"""Test getting current platform when it exists."""
|
|
request = Mock(spec=Request)
|
|
mock_platform = Mock()
|
|
mock_platform.code = "oms"
|
|
request.state.platform = mock_platform
|
|
|
|
platform = get_current_platform(request)
|
|
|
|
assert platform is mock_platform
|
|
assert platform.code == "oms"
|
|
|
|
def test_get_current_platform_not_exists(self):
|
|
"""Test getting current platform when it doesn't exist."""
|
|
request = Mock(spec=Request)
|
|
request.state = Mock(spec=[]) # platform attribute doesn't exist
|
|
|
|
platform = get_current_platform(request)
|
|
|
|
assert platform is None
|
|
|
|
def test_require_platform_context_success(self):
|
|
"""Test require_platform_context dependency with platform present."""
|
|
request = Mock(spec=Request)
|
|
mock_platform = Mock()
|
|
mock_platform.code = "oms"
|
|
request.state.platform = mock_platform
|
|
|
|
dependency = require_platform_context()
|
|
result = dependency(request)
|
|
|
|
assert result is mock_platform
|
|
|
|
def test_require_platform_context_failure(self):
|
|
"""Test require_platform_context dependency raises HTTPException when no platform."""
|
|
from fastapi import HTTPException
|
|
|
|
request = Mock(spec=Request)
|
|
request.state.platform = None
|
|
|
|
dependency = require_platform_context()
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
dependency(request)
|
|
|
|
assert exc_info.value.status_code == 404
|
|
assert "Platform not found" in exc_info.value.detail
|
|
|
|
def test_default_platform_code_is_main(self):
|
|
"""Test that DEFAULT_PLATFORM_CODE is 'main'."""
|
|
assert DEFAULT_PLATFORM_CODE == "main"
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.middleware
|
|
class TestEdgeCases:
|
|
"""Test suite for edge cases and error scenarios."""
|
|
|
|
def test_detect_empty_host(self):
|
|
"""Test platform detection with empty host header."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": ""}
|
|
request.url = Mock(path="/")
|
|
|
|
# Empty host on localhost should not match any detection
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
# Empty host doesn't match localhost check, so no default platform
|
|
assert context is None
|
|
|
|
def test_detect_missing_host_header(self):
|
|
"""Test platform detection with missing host header."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {}
|
|
request.url = Mock(path="/platforms/oms/")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
# Should still detect path-based platform
|
|
assert context is not None
|
|
assert context["detection_method"] == "path"
|
|
|
|
def test_detect_platforms_path_empty_code(self):
|
|
"""Test /platforms/ path without platform code."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/platforms/")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
# Empty platform code after /platforms/ should fall back to default
|
|
assert context is not None
|
|
# The path doesn't match /platforms/{code}/ pattern, falls to default
|
|
assert context["detection_method"] == "default"
|
|
|
|
def test_detect_platforms_only(self):
|
|
"""Test /platforms path without trailing slash."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost"}
|
|
request.url = Mock(path="/platforms")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
# /platforms (without trailing slash) doesn't start with /platforms/
|
|
# So it falls back to default platform detection
|
|
assert context is not None
|
|
assert context["detection_method"] == "default"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_closes_db_session(self):
|
|
"""Test middleware properly closes database session."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
mock_platform = Mock()
|
|
mock_platform.code = "oms"
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/platforms/oms/test",
|
|
"headers": [(b"host", b"localhost")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=mock_platform,
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
# Verify db.close() was called
|
|
mock_db.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_closes_db_on_platform_not_found(self):
|
|
"""Test middleware closes database even when platform not found."""
|
|
mock_app = AsyncMock()
|
|
middleware = PlatformContextMiddleware(app=mock_app)
|
|
|
|
scope = {
|
|
"type": "http",
|
|
"path": "/platforms/unknown/test",
|
|
"headers": [(b"host", b"localhost")],
|
|
}
|
|
|
|
receive = AsyncMock()
|
|
send = AsyncMock()
|
|
|
|
mock_db = MagicMock()
|
|
|
|
with patch(
|
|
"middleware.platform_context.get_db", return_value=iter([mock_db])
|
|
), patch.object(
|
|
PlatformContextManager,
|
|
"get_platform_from_context",
|
|
return_value=None,
|
|
):
|
|
await middleware(scope, receive, send)
|
|
|
|
mock_db.close.assert_called_once()
|
|
|
|
def test_admin_subdomain_with_production_domain(self):
|
|
"""Test admin subdomain detection for production domains."""
|
|
assert FrontendDetector.is_admin("admin.oms.lu", "/dashboard") is True
|
|
|
|
def test_static_file_case_insensitive(self):
|
|
"""Test static file detection is case-insensitive."""
|
|
request = Mock(spec=Request)
|
|
request.url = Mock(path="/STATIC/CSS/STYLE.CSS")
|
|
|
|
assert PlatformContextManager.is_static_file_request(request) is True
|
|
|
|
def test_favicon_in_nested_path(self):
|
|
"""Test favicon detection in nested paths."""
|
|
request = Mock(spec=Request)
|
|
request.url = Mock(path="/some/path/favicon.ico")
|
|
|
|
assert PlatformContextManager.is_static_file_request(request) is True
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.middleware
|
|
class TestURLRoutingSummary:
|
|
"""
|
|
Test suite documenting the URL routing behavior.
|
|
|
|
URL Structure Summary:
|
|
- Main marketing: localhost:9999/ -> 'main' platform, path unchanged
|
|
- OMS platform: localhost:9999/platforms/oms/pricing -> 'oms' platform, path=/pricing
|
|
- Loyalty platform: localhost:9999/platforms/loyalty/features -> 'loyalty' platform, path=/features
|
|
- Production OMS: oms.lu/pricing -> 'oms' platform, path=/pricing (no rewrite)
|
|
"""
|
|
|
|
def test_main_marketing_site_routing(self):
|
|
"""Document: Main marketing site uses 'main' platform, no path rewrite."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost:9999"}
|
|
request.url = Mock(path="/about")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context["detection_method"] == "default"
|
|
assert context["path_prefix"] == "main"
|
|
assert context["clean_path"] == "/about" # No rewrite
|
|
|
|
def test_oms_platform_development_routing(self):
|
|
"""Document: OMS platform in dev mode rewrites path."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost:9999"}
|
|
request.url = Mock(path="/platforms/oms/stores/wizamart/shop")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context["detection_method"] == "path"
|
|
assert context["path_prefix"] == "oms"
|
|
assert context["clean_path"] == "/stores/wizamart/shop" # Rewritten
|
|
|
|
def test_loyalty_platform_development_routing(self):
|
|
"""Document: Loyalty platform in dev mode rewrites path."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "localhost:9999"}
|
|
request.url = Mock(path="/platforms/loyalty/rewards")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context["detection_method"] == "path"
|
|
assert context["path_prefix"] == "loyalty"
|
|
assert context["clean_path"] == "/rewards"
|
|
|
|
def test_production_domain_routing(self):
|
|
"""Document: Production domains don't rewrite path."""
|
|
request = Mock(spec=Request)
|
|
request.headers = {"host": "oms.lu"}
|
|
request.url = Mock(path="/pricing")
|
|
|
|
context = PlatformContextManager.detect_platform_context(request)
|
|
|
|
assert context["detection_method"] == "domain"
|
|
assert context["domain"] == "oms.lu"
|
|
# clean_path not set for domain detection - uses original path
|