Files
orion/middleware/vendor_context.py

288 lines
9.9 KiB
Python

# middleware/vendor_context.py
import logging
from typing import Optional
from fastapi import Request
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from app.core.database import get_db
from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain
logger = logging.getLogger(__name__)
class VendorContextManager:
"""Manages vendor context detection for multi-tenant routing."""
@staticmethod
def detect_vendor_context(request: Request) -> Optional[dict]:
"""
Detect vendor context from request.
Priority order:
1. Custom domain (customdomain1.com)
2. Subdomain (vendor1.platform.com)
3. Path-based (/vendor/vendor1/)
Returns dict with vendor info or None if not found.
"""
host = request.headers.get("host", "")
path = request.url.path
# Remove port from host if present (e.g., localhost:8000 -> localhost)
if ":" in host:
host = host.split(":")[0]
# Method 1: Custom domain detection (HIGHEST PRIORITY)
# Check if this is a custom domain (not platform.com and not localhost)
from app.core.config import settings
platform_domain = getattr(settings, 'platform_domain', 'platform.com')
is_custom_domain = (
host and
not host.endswith(f".{platform_domain}") and
host != platform_domain and
host not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] and
not host.startswith("admin.")
)
if is_custom_domain:
# This could be a custom domain like customdomain1.com
normalized_domain = VendorDomain.normalize_domain(host)
return {
"domain": normalized_domain,
"detection_method": "custom_domain",
"host": host,
"original_host": request.headers.get("host", "")
}
# Method 2: Subdomain detection (vendor1.platform.com)
if "." in host:
parts = host.split(".")
# Check if it's a valid subdomain (not www, admin, api)
if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]:
subdomain = parts[0]
return {
"subdomain": subdomain,
"detection_method": "subdomain",
"host": host
}
# Method 3: Path-based detection (/vendor/vendorname/) - for development
if path.startswith("/vendor/"):
path_parts = path.split("/")
if len(path_parts) >= 3:
subdomain = path_parts[2]
return {
"subdomain": subdomain,
"detection_method": "path",
"path_prefix": f"/vendor/{subdomain}",
"host": host
}
return None
@staticmethod
def get_vendor_from_context(db: Session, context: dict) -> Optional[Vendor]:
"""
Get vendor from database using context information.
Supports three methods:
1. Custom domain lookup (VendorDomain table)
2. Subdomain lookup (Vendor.subdomain)
3. Path-based lookup (Vendor.subdomain)
"""
if not context:
return None
vendor = None
# Method 1: Custom domain lookup
if context.get("detection_method") == "custom_domain":
domain = context.get("domain")
if domain:
# Look up vendor by custom domain
vendor_domain = (
db.query(VendorDomain)
.filter(VendorDomain.domain == domain)
.filter(VendorDomain.is_active == True)
.filter(VendorDomain.is_verified == True)
.first()
)
if vendor_domain:
vendor = vendor_domain.vendor
# Check if vendor is active
if not vendor or not vendor.is_active:
logger.warning(f"Vendor for domain {domain} is not active")
return None
logger.info(f"[OK] Vendor found via custom domain: {domain} -> {vendor.name}")
return vendor
else:
logger.warning(f"No active vendor found for custom domain: {domain}")
return None
# Method 2 & 3: Subdomain or path-based lookup
if "subdomain" in context:
subdomain = context["subdomain"]
# Query vendor by subdomain (case-insensitive)
vendor = (
db.query(Vendor)
.filter(func.lower(Vendor.subdomain) == subdomain.lower())
.filter(Vendor.is_active == True)
.first()
)
if vendor:
method = context.get("detection_method", "unknown")
logger.info(f"[OK] Vendor found via {method}: {subdomain} -> {vendor.name}")
else:
logger.warning(f"No active vendor found for subdomain: {subdomain}")
return vendor
@staticmethod
def extract_clean_path(request: Request, vendor_context: Optional[dict]) -> str:
"""Extract clean path without vendor prefix for routing."""
if not vendor_context:
return request.url.path
# Only strip path prefix for path-based detection
if vendor_context.get("detection_method") == "path":
path_prefix = vendor_context.get("path_prefix", "")
path = request.url.path
if path.startswith(path_prefix):
clean_path = path[len(path_prefix):]
return clean_path if clean_path else "/"
return request.url.path
@staticmethod
def is_admin_request(request: Request) -> bool:
"""Check if request is for admin interface."""
host = request.headers.get("host", "")
path = request.url.path
# Remove port from host
if ":" in host:
host = host.split(":")[0]
# Check for admin subdomain
if host.startswith("admin."):
return True
# Check for admin path
if path.startswith("/admin"):
return True
return False
@staticmethod
def is_api_request(request: Request) -> bool:
"""Check if request is for API endpoints."""
return request.url.path.startswith("/api/")
@staticmethod
def is_static_file_request(request: Request) -> bool:
"""Check if request is for static files."""
path = request.url.path.lower()
# Static file extensions
static_extensions = (
'.ico', '.css', '.js', '.png', '.jpg', '.jpeg', '.gif', '.svg',
'.woff', '.woff2', '.ttf', '.eot', '.webp', '.map', '.json',
'.xml', '.txt', '.pdf', '.webmanifest'
)
# Static paths
static_paths = ('/static/', '/media/', '/assets/', '/.well-known/')
# Check if it's a static file by extension
if path.endswith(static_extensions):
return True
# Check if it's in a static directory
if any(path.startswith(static_path) for static_path in static_paths):
return True
# Special case: favicon.ico at any level
if 'favicon.ico' in path:
return True
return False
async def vendor_context_middleware(request: Request, call_next):
"""
Middleware to inject vendor context into request state.
Handles three routing modes:
1. Custom domains (customdomain1.com -> Vendor 1)
2. Subdomains (vendor1.platform.com -> Vendor 1)
3. Path-based (/vendor/vendor1/ -> Vendor 1)
"""
# Skip vendor detection for admin, API, static files, and system requests
if (VendorContextManager.is_admin_request(request) or
VendorContextManager.is_api_request(request) or
VendorContextManager.is_static_file_request(request) or
request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]):
return await call_next(request)
# Detect vendor context
vendor_context = VendorContextManager.detect_vendor_context(request)
if vendor_context:
db_gen = get_db()
db = next(db_gen)
try:
vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
if vendor:
request.state.vendor = vendor
request.state.vendor_context = vendor_context
request.state.clean_path = VendorContextManager.extract_clean_path(
request, vendor_context
)
logger.debug(
f"[VENDOR] Vendor context: {vendor.name} ({vendor.subdomain}) "
f"via {vendor_context['detection_method']}"
)
else:
logger.warning(
f"[WARNING] Vendor not found for context: {vendor_context}"
)
request.state.vendor = None
request.state.vendor_context = vendor_context
finally:
db.close()
else:
request.state.vendor = None
request.state.vendor_context = None
request.state.clean_path = request.url.path
return await call_next(request)
def get_current_vendor(request: Request) -> Optional[Vendor]:
"""Helper function to get current vendor from request state."""
return getattr(request.state, "vendor", None)
def require_vendor_context():
"""Dependency to require vendor context in endpoints."""
def dependency(request: Request):
vendor = get_current_vendor(request)
if not vendor:
from fastapi import HTTPException
raise HTTPException(
status_code=404,
detail="Vendor not found or not active"
)
return vendor
return dependency