# middleware/vendor_context.py """ Vendor Context Middleware (Class-Based) Detects vendor from host/domain/path and injects into request.state. Handles three routing modes: 1. Custom domains (customdomain1.com → Vendor 1) 2. Subdomains (vendor1.platform.com → Vendor 1) 3. Path-based (/vendor/vendor1/ or /vendors/vendor1/ → Vendor 1) Also extracts clean_path for nested routing patterns. """ import logging from fastapi import Request from sqlalchemy import func from sqlalchemy.orm import Session from starlette.middleware.base import BaseHTTPMiddleware from app.core.config import settings from app.core.database import get_db from models.database.vendor import Vendor from models.database.vendor_domain import VendorDomain logger = logging.getLogger(__name__) class VendorContextManager: """Manages vendor context detection for multi-tenant routing.""" @staticmethod def detect_vendor_context(request: Request) -> dict | None: """ Detect vendor context from request. Priority order: 1. Custom domain (customdomain1.com) 2. Subdomain (vendor1.platform.com) 3. Path-based (/vendor/vendor1/ or /vendors/vendor1/) Returns dict with vendor info or None if not found. """ host = request.headers.get("host", "") path = request.url.path # Remove port from host if present (e.g., localhost:8000 -> localhost) if ":" in host: host = host.split(":")[0] # Method 1: Custom domain detection (HIGHEST PRIORITY) # Check if this is a custom domain (not platform.com and not localhost) platform_domain = getattr(settings, "platform_domain", "platform.com") is_custom_domain = ( host and not host.endswith(f".{platform_domain}") and host != platform_domain and host not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] and not host.startswith("admin.") ) if is_custom_domain: normalized_domain = VendorDomain.normalize_domain(host) return { "domain": normalized_domain, "detection_method": "custom_domain", "host": host, "original_host": request.headers.get("host", ""), } # Method 2: Subdomain detection (vendor1.platform.com) if "." in host: parts = host.split(".") # Check if it's a valid subdomain (not www, admin, api) if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]: subdomain = parts[0] return { "subdomain": subdomain, "detection_method": "subdomain", "host": host, } # Method 3: Path-based detection (/vendor/vendorname/ or /vendors/vendorname/) # Support BOTH patterns for flexibility if path.startswith("/vendor/") or path.startswith("/vendors/"): # Determine which pattern if path.startswith("/vendors/"): prefix_len = len("/vendors/") else: prefix_len = len("/vendor/") path_parts = path[prefix_len:].split("/") if len(path_parts) >= 1 and path_parts[0]: vendor_code = path_parts[0] return { "subdomain": vendor_code, "detection_method": "path", "path_prefix": path[: prefix_len + len(vendor_code)], "full_prefix": path[:prefix_len], # /vendor/ or /vendors/ "host": host, } return None @staticmethod def get_vendor_from_context(db: Session, context: dict) -> Vendor | None: """ Get vendor from database using context information. Supports three methods: 1. Custom domain lookup (VendorDomain table) 2. Subdomain lookup (Vendor.subdomain) 3. Path-based lookup (Vendor.subdomain) """ if not context: return None vendor = None # Method 1: Custom domain lookup if context.get("detection_method") == "custom_domain": domain = context.get("domain") if domain: vendor_domain = ( db.query(VendorDomain) .filter(VendorDomain.domain == domain) .filter(VendorDomain.is_active == True) .filter(VendorDomain.is_verified == True) .first() ) if vendor_domain: vendor = vendor_domain.vendor if not vendor or not vendor.is_active: logger.warning(f"Vendor for domain {domain} is not active") return None logger.info( f"[OK] Vendor found via custom domain: {domain} → {vendor.name}" ) return vendor logger.warning(f"No active vendor found for custom domain: {domain}") return None # Method 2 & 3: Subdomain or path-based lookup if "subdomain" in context: subdomain = context["subdomain"] vendor = ( db.query(Vendor) .filter(func.lower(Vendor.subdomain) == subdomain.lower()) .filter(Vendor.is_active == True) .first() ) if vendor: method = context.get("detection_method", "unknown") logger.info( f"[OK] Vendor found via {method}: {subdomain} → {vendor.name}" ) else: logger.warning(f"No active vendor found for subdomain: {subdomain}") return vendor @staticmethod def extract_clean_path(request: Request, vendor_context: dict | None) -> str: """ Extract clean path without vendor prefix for routing. Supports both /vendor/ and /vendors/ prefixes. """ if not vendor_context: return request.url.path # Only strip path prefix for path-based detection if vendor_context.get("detection_method") == "path": path = request.url.path path_prefix = vendor_context.get("path_prefix", "") if path.startswith(path_prefix): clean_path = path[len(path_prefix) :] return clean_path if clean_path else "/" return request.url.path @staticmethod def is_admin_request(request: Request) -> bool: """Check if request is for admin interface.""" host = request.headers.get("host", "") path = request.url.path if ":" in host: host = host.split(":")[0] if host.startswith("admin."): return True if path.startswith("/admin"): return True return False @staticmethod def is_api_request(request: Request) -> bool: """Check if request is for API endpoints.""" return request.url.path.startswith("/api/") @staticmethod def is_shop_api_request(request: Request) -> bool: """Check if request is for shop API endpoints.""" return request.url.path.startswith("/api/v1/shop/") @staticmethod def extract_vendor_from_referer(request: Request) -> dict | None: """ Extract vendor context from Referer header. Used for shop API requests where vendor context comes from the page that made the API call (e.g., JavaScript on /vendors/wizamart/shop/products calling /api/v1/shop/products). Extracts vendor from Referer URL patterns: - http://localhost:8000/vendors/wizamart/shop/... → wizamart - http://wizamart.platform.com/shop/... → wizamart (subdomain) # noqa - http://custom-domain.com/shop/... → custom-domain.com # noqa Returns vendor context dict or None if unable to extract. """ referer = request.headers.get("referer") or request.headers.get("origin") if not referer: logger.debug("[VENDOR] No Referer/Origin header for shop API request") return None try: from urllib.parse import urlparse parsed = urlparse(referer) referer_host = parsed.hostname or "" referer_path = parsed.path or "" # Remove port from host if ":" in referer_host: referer_host = referer_host.split(":")[0] logger.debug( "[VENDOR] Extracting vendor from Referer", extra={ "referer": referer, "referer_host": referer_host, "referer_path": referer_path, }, ) # Method 1: Path-based detection from referer path # /vendors/wizamart/shop/products → wizamart if referer_path.startswith("/vendors/") or referer_path.startswith( "/vendor/" ): prefix = ( "/vendors/" if referer_path.startswith("/vendors/") else "/vendor/" ) path_parts = referer_path[len(prefix) :].split("/") if len(path_parts) >= 1 and path_parts[0]: vendor_code = path_parts[0] prefix_len = len(prefix) logger.debug( f"[VENDOR] Extracted vendor from Referer path: {vendor_code}", extra={"vendor_code": vendor_code, "method": "referer_path"}, ) # Use "path" as detection_method to be consistent with direct path detection # This allows cookie path logic to work the same way return { "subdomain": vendor_code, "detection_method": "path", # Consistent with direct path detection "path_prefix": referer_path[ : prefix_len + len(vendor_code) ], # /vendor/vendor1 "full_prefix": prefix, # /vendor/ or /vendors/ "host": referer_host, "referer": referer, } # Method 2: Subdomain detection from referer host # wizamart.platform.com → wizamart platform_domain = getattr(settings, "platform_domain", "platform.com") if "." in referer_host: parts = referer_host.split(".") if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]: # Check if it's a subdomain of platform domain if referer_host.endswith(f".{platform_domain}"): subdomain = parts[0] logger.debug( f"[VENDOR] Extracted vendor from Referer subdomain: {subdomain}", extra={ "subdomain": subdomain, "method": "referer_subdomain", }, ) return { "subdomain": subdomain, "detection_method": "referer_subdomain", "host": referer_host, "referer": referer, } # Method 3: Custom domain detection from referer host # custom-shop.com → custom-shop.com is_custom_domain = ( referer_host and not referer_host.endswith(f".{platform_domain}") and referer_host != platform_domain and referer_host not in ["localhost", "127.0.0.1"] and not referer_host.startswith("admin.") ) if is_custom_domain: from models.database.vendor_domain import VendorDomain normalized_domain = VendorDomain.normalize_domain(referer_host) logger.debug( f"[VENDOR] Extracted vendor from Referer custom domain: {normalized_domain}", extra={ "domain": normalized_domain, "method": "referer_custom_domain", }, ) return { "domain": normalized_domain, "detection_method": "referer_custom_domain", "host": referer_host, "referer": referer, } except Exception as e: logger.warning( f"[VENDOR] Failed to extract vendor from Referer: {e}", extra={"referer": referer, "error": str(e)}, ) return None @staticmethod def is_static_file_request(request: Request) -> bool: """Check if request is for static files.""" path = request.url.path.lower() static_extensions = ( ".ico", ".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".woff", ".woff2", ".ttf", ".eot", ".webp", ".map", ".json", ".xml", ".txt", ".pdf", ".webmanifest", ) static_paths = ("/static/", "/media/", "/assets/", "/.well-known/") if path.endswith(static_extensions): return True if any(path.startswith(static_path) for static_path in static_paths): return True if "favicon.ico" in path: return True return False class VendorContextMiddleware(BaseHTTPMiddleware): """ Middleware to inject vendor context into request state. Class-based middleware provides: - Better state management - Easier testing - More organized code - Standard ASGI pattern Runs FIRST in middleware chain. Sets: request.state.vendor: Vendor object request.state.vendor_context: Detection metadata request.state.clean_path: Path without vendor prefix """ async def dispatch(self, request: Request, call_next): """ Detect and inject vendor context. """ # Skip vendor detection for admin, static files, and system requests if ( VendorContextManager.is_admin_request(request) or VendorContextManager.is_static_file_request(request) or request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"] ): logger.debug( f"[VENDOR] Skipping vendor detection: {request.url.path}", extra={"path": request.url.path, "reason": "admin/static/system"}, ) request.state.vendor = None request.state.vendor_context = None request.state.clean_path = request.url.path return await call_next(request) # Handle shop API routes specially - extract vendor from Referer header if VendorContextManager.is_shop_api_request(request): logger.debug( f"[VENDOR] Shop API request detected: {request.url.path}", extra={ "path": request.url.path, "referer": request.headers.get("referer", ""), }, ) vendor_context = VendorContextManager.extract_vendor_from_referer(request) if vendor_context: db_gen = get_db() db = next(db_gen) try: vendor = VendorContextManager.get_vendor_from_context( db, vendor_context ) if vendor: request.state.vendor = vendor request.state.vendor_context = vendor_context request.state.clean_path = request.url.path logger.debug( "[VENDOR_CONTEXT] Vendor detected from Referer for shop API", extra={ "vendor_id": vendor.id, "vendor_name": vendor.name, "vendor_subdomain": vendor.subdomain, "detection_method": vendor_context.get( "detection_method" ), "api_path": request.url.path, "referer": vendor_context.get("referer", ""), }, ) else: logger.warning( "[WARNING] Vendor context from Referer but vendor not found", extra={ "context": vendor_context, "detection_method": vendor_context.get( "detection_method" ), "api_path": request.url.path, }, ) request.state.vendor = None request.state.vendor_context = vendor_context request.state.clean_path = request.url.path finally: db.close() else: logger.warning( "[VENDOR] Shop API request without Referer header", extra={"path": request.url.path}, ) request.state.vendor = None request.state.vendor_context = None request.state.clean_path = request.url.path return await call_next(request) # Skip vendor detection for other API routes (admin API, vendor API have vendor_id in URL) if VendorContextManager.is_api_request(request): logger.debug( f"[VENDOR] Skipping vendor detection for non-shop API: {request.url.path}", extra={"path": request.url.path, "reason": "api"}, ) request.state.vendor = None request.state.vendor_context = None request.state.clean_path = request.url.path return await call_next(request) # Detect vendor context vendor_context = VendorContextManager.detect_vendor_context(request) if vendor_context: db_gen = get_db() db = next(db_gen) try: vendor = VendorContextManager.get_vendor_from_context( db, vendor_context ) if vendor: request.state.vendor = vendor request.state.vendor_context = vendor_context request.state.clean_path = VendorContextManager.extract_clean_path( request, vendor_context ) logger.debug( "[VENDOR_CONTEXT] Vendor detected", extra={ "vendor_id": vendor.id, "vendor_name": vendor.name, "vendor_subdomain": vendor.subdomain, "detection_method": vendor_context.get("detection_method"), "original_path": request.url.path, "clean_path": request.state.clean_path, }, ) else: logger.warning( "[WARNING] Vendor context detected but vendor not found", extra={ "context": vendor_context, "detection_method": vendor_context.get("detection_method"), }, ) request.state.vendor = None request.state.vendor_context = vendor_context request.state.clean_path = request.url.path finally: db.close() else: logger.debug( "[VENDOR] No vendor context detected", extra={ "path": request.url.path, "host": request.headers.get("host", ""), }, ) request.state.vendor = None request.state.vendor_context = None request.state.clean_path = request.url.path # Continue to next middleware return await call_next(request) def get_current_vendor(request: Request) -> Vendor | None: """Helper function to get current vendor from request state.""" return getattr(request.state, "vendor", None) def require_vendor_context(): """Dependency to require vendor context in endpoints.""" def dependency(request: Request): vendor = get_current_vendor(request) if not vendor: from app.exceptions import VendorNotFoundException raise VendorNotFoundException("unknown", identifier_type="context") return vendor return dependency