# middleware/vendor_context.py import logging from typing import Optional from fastapi import Request from sqlalchemy.orm import Session from sqlalchemy import func, or_ from app.core.database import get_db from models.database.vendor import Vendor from models.database.vendor_domain import VendorDomain logger = logging.getLogger(__name__) class VendorContextManager: """Manages vendor context detection for multi-tenant routing.""" @staticmethod def detect_vendor_context(request: Request) -> Optional[dict]: """ Detect vendor context from request. Priority order: 1. Custom domain (customdomain1.com) 2. Subdomain (vendor1.platform.com) 3. Path-based (/vendor/vendor1/) Returns dict with vendor info or None if not found. """ host = request.headers.get("host", "") path = request.url.path # Remove port from host if present (e.g., localhost:8000 → localhost) if ":" in host: host = host.split(":")[0] # Method 1: Custom domain detection (HIGHEST PRIORITY) # Check if this is a custom domain (not platform.com and not localhost) from app.core.config import settings platform_domain = getattr(settings, 'platform_domain', 'platform.com') is_custom_domain = ( host and not host.endswith(f".{platform_domain}") and host != platform_domain and host not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] and not host.startswith("admin.") ) if is_custom_domain: # This could be a custom domain like customdomain1.com normalized_domain = VendorDomain.normalize_domain(host) return { "domain": normalized_domain, "detection_method": "custom_domain", "host": host, "original_host": request.headers.get("host", "") } # Method 2: Subdomain detection (vendor1.platform.com) if "." in host: parts = host.split(".") # Check if it's a valid subdomain (not www, admin, api) if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]: subdomain = parts[0] return { "subdomain": subdomain, "detection_method": "subdomain", "host": host } # Method 3: Path-based detection (/vendor/vendorname/) - for development if path.startswith("/vendor/"): path_parts = path.split("/") if len(path_parts) >= 3: subdomain = path_parts[2] return { "subdomain": subdomain, "detection_method": "path", "path_prefix": f"/vendor/{subdomain}", "host": host } return None @staticmethod def get_vendor_from_context(db: Session, context: dict) -> Optional[Vendor]: """ Get vendor from database using context information. Supports three methods: 1. Custom domain lookup (VendorDomain table) 2. Subdomain lookup (Vendor.subdomain) 3. Path-based lookup (Vendor.subdomain) """ if not context: return None vendor = None # Method 1: Custom domain lookup if context.get("detection_method") == "custom_domain": domain = context.get("domain") if domain: # Look up vendor by custom domain vendor_domain = ( db.query(VendorDomain) .filter(VendorDomain.domain == domain) .filter(VendorDomain.is_active == True) .filter(VendorDomain.is_verified == True) .first() ) if vendor_domain: vendor = vendor_domain.vendor # Check if vendor is active if not vendor or not vendor.is_active: logger.warning(f"Vendor for domain {domain} is not active") return None logger.info(f"✓ Vendor found via custom domain: {domain} → {vendor.name}") return vendor else: logger.warning(f"No active vendor found for custom domain: {domain}") return None # Method 2 & 3: Subdomain or path-based lookup if "subdomain" in context: subdomain = context["subdomain"] # Query vendor by subdomain (case-insensitive) vendor = ( db.query(Vendor) .filter(func.lower(Vendor.subdomain) == subdomain.lower()) .filter(Vendor.is_active == True) .first() ) if vendor: method = context.get("detection_method", "unknown") logger.info(f"✓ Vendor found via {method}: {subdomain} → {vendor.name}") else: logger.warning(f"No active vendor found for subdomain: {subdomain}") return vendor @staticmethod def extract_clean_path(request: Request, vendor_context: Optional[dict]) -> str: """Extract clean path without vendor prefix for routing.""" if not vendor_context: return request.url.path # Only strip path prefix for path-based detection if vendor_context.get("detection_method") == "path": path_prefix = vendor_context.get("path_prefix", "") path = request.url.path if path.startswith(path_prefix): clean_path = path[len(path_prefix):] return clean_path if clean_path else "/" return request.url.path @staticmethod def is_admin_request(request: Request) -> bool: """Check if request is for admin interface.""" host = request.headers.get("host", "") path = request.url.path # Remove port from host if ":" in host: host = host.split(":")[0] # Check for admin subdomain if host.startswith("admin."): return True # Check for admin path if path.startswith("/admin"): return True return False @staticmethod def is_api_request(request: Request) -> bool: """Check if request is for API endpoints.""" return request.url.path.startswith("/api/") async def vendor_context_middleware(request: Request, call_next): """ Middleware to inject vendor context into request state. Handles three routing modes: 1. Custom domains (customdomain1.com → Vendor 1) 2. Subdomains (vendor1.platform.com → Vendor 1) 3. Path-based (/vendor/vendor1/ → Vendor 1) """ # Skip vendor detection for admin, API, and system requests if (VendorContextManager.is_admin_request(request) or VendorContextManager.is_api_request(request) or request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]): return await call_next(request) # Detect vendor context vendor_context = VendorContextManager.detect_vendor_context(request) if vendor_context: db_gen = get_db() db = next(db_gen) try: vendor = VendorContextManager.get_vendor_from_context(db, vendor_context) if vendor: request.state.vendor = vendor request.state.vendor_context = vendor_context request.state.clean_path = VendorContextManager.extract_clean_path( request, vendor_context ) logger.debug( f"🏪 Vendor context: {vendor.name} ({vendor.subdomain}) " f"via {vendor_context['detection_method']}" ) else: logger.warning( f"⚠️ Vendor not found for context: {vendor_context}" ) request.state.vendor = None request.state.vendor_context = vendor_context finally: db.close() else: request.state.vendor = None request.state.vendor_context = None request.state.clean_path = request.url.path return await call_next(request) def get_current_vendor(request: Request) -> Optional[Vendor]: """Helper function to get current vendor from request state.""" return getattr(request.state, "vendor", None) def require_vendor_context(): """Dependency to require vendor context in endpoints.""" def dependency(request: Request): vendor = get_current_vendor(request) if not vendor: from fastapi import HTTPException raise HTTPException( status_code=404, detail="Vendor not found or not active" ) return vendor return dependency