Files
orion/middleware/vendor_context.py

343 lines
12 KiB
Python

# middleware/vendor_context.py
"""
Vendor Context Middleware (Class-Based)
Detects vendor from host/domain/path and injects into request.state.
Handles three routing modes:
1. Custom domains (customdomain1.com → Vendor 1)
2. Subdomains (vendor1.platform.com → Vendor 1)
3. Path-based (/vendor/vendor1/ or /vendors/vendor1/ → Vendor 1)
Also extracts clean_path for nested routing patterns.
"""
import logging
from typing import Optional
from sqlalchemy.orm import Session
from sqlalchemy import func
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
from app.core.database import get_db
from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain
logger = logging.getLogger(__name__)
class VendorContextManager:
"""Manages vendor context detection for multi-tenant routing."""
@staticmethod
def detect_vendor_context(request: Request) -> Optional[dict]:
"""
Detect vendor context from request.
Priority order:
1. Custom domain (customdomain1.com)
2. Subdomain (vendor1.platform.com)
3. Path-based (/vendor/vendor1/ or /vendors/vendor1/)
Returns dict with vendor info or None if not found.
"""
host = request.headers.get("host", "")
path = request.url.path
# Remove port from host if present (e.g., localhost:8000 -> localhost)
if ":" in host:
host = host.split(":")[0]
# Method 1: Custom domain detection (HIGHEST PRIORITY)
# Check if this is a custom domain (not platform.com and not localhost)
from app.core.config import settings
platform_domain = getattr(settings, 'platform_domain', 'platform.com')
is_custom_domain = (
host and
not host.endswith(f".{platform_domain}") and
host != platform_domain and
host not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] and
not host.startswith("admin.")
)
if is_custom_domain:
normalized_domain = VendorDomain.normalize_domain(host)
return {
"domain": normalized_domain,
"detection_method": "custom_domain",
"host": host,
"original_host": request.headers.get("host", "")
}
# Method 2: Subdomain detection (vendor1.platform.com)
if "." in host:
parts = host.split(".")
# Check if it's a valid subdomain (not www, admin, api)
if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]:
subdomain = parts[0]
return {
"subdomain": subdomain,
"detection_method": "subdomain",
"host": host
}
# Method 3: Path-based detection (/vendor/vendorname/ or /vendors/vendorname/)
# Support BOTH patterns for flexibility
if path.startswith("/vendor/") or path.startswith("/vendors/"):
# Determine which pattern
if path.startswith("/vendors/"):
prefix_len = len("/vendors/")
else:
prefix_len = len("/vendor/")
path_parts = path[prefix_len:].split("/")
if len(path_parts) >= 1 and path_parts[0]:
vendor_code = path_parts[0]
return {
"subdomain": vendor_code,
"detection_method": "path",
"path_prefix": path[:prefix_len + len(vendor_code)],
"full_prefix": path[:prefix_len], # /vendor/ or /vendors/
"host": host
}
return None
@staticmethod
def get_vendor_from_context(db: Session, context: dict) -> Optional[Vendor]:
"""
Get vendor from database using context information.
Supports three methods:
1. Custom domain lookup (VendorDomain table)
2. Subdomain lookup (Vendor.subdomain)
3. Path-based lookup (Vendor.subdomain)
"""
if not context:
return None
vendor = None
# Method 1: Custom domain lookup
if context.get("detection_method") == "custom_domain":
domain = context.get("domain")
if domain:
vendor_domain = (
db.query(VendorDomain)
.filter(VendorDomain.domain == domain)
.filter(VendorDomain.is_active == True)
.filter(VendorDomain.is_verified == True)
.first()
)
if vendor_domain:
vendor = vendor_domain.vendor
if not vendor or not vendor.is_active:
logger.warning(f"Vendor for domain {domain} is not active")
return None
logger.info(f"[OK] Vendor found via custom domain: {domain}{vendor.name}")
return vendor
else:
logger.warning(f"No active vendor found for custom domain: {domain}")
return None
# Method 2 & 3: Subdomain or path-based lookup
if "subdomain" in context:
subdomain = context["subdomain"]
vendor = (
db.query(Vendor)
.filter(func.lower(Vendor.subdomain) == subdomain.lower())
.filter(Vendor.is_active == True)
.first()
)
if vendor:
method = context.get("detection_method", "unknown")
logger.info(f"[OK] Vendor found via {method}: {subdomain}{vendor.name}")
else:
logger.warning(f"No active vendor found for subdomain: {subdomain}")
return vendor
@staticmethod
def extract_clean_path(request: Request, vendor_context: Optional[dict]) -> str:
"""
Extract clean path without vendor prefix for routing.
Supports both /vendor/ and /vendors/ prefixes.
"""
if not vendor_context:
return request.url.path
# Only strip path prefix for path-based detection
if vendor_context.get("detection_method") == "path":
path = request.url.path
path_prefix = vendor_context.get("path_prefix", "")
if path.startswith(path_prefix):
clean_path = path[len(path_prefix):]
return clean_path if clean_path else "/"
return request.url.path
@staticmethod
def is_admin_request(request: Request) -> bool:
"""Check if request is for admin interface."""
host = request.headers.get("host", "")
path = request.url.path
if ":" in host:
host = host.split(":")[0]
if host.startswith("admin."):
return True
if path.startswith("/admin"):
return True
return False
@staticmethod
def is_api_request(request: Request) -> bool:
"""Check if request is for API endpoints."""
return request.url.path.startswith("/api/")
@staticmethod
def is_static_file_request(request: Request) -> bool:
"""Check if request is for static files."""
path = request.url.path.lower()
static_extensions = (
'.ico', '.css', '.js', '.png', '.jpg', '.jpeg', '.gif', '.svg',
'.woff', '.woff2', '.ttf', '.eot', '.webp', '.map', '.json',
'.xml', '.txt', '.pdf', '.webmanifest'
)
static_paths = ('/static/', '/media/', '/assets/', '/.well-known/')
if path.endswith(static_extensions):
return True
if any(path.startswith(static_path) for static_path in static_paths):
return True
if 'favicon.ico' in path:
return True
return False
class VendorContextMiddleware(BaseHTTPMiddleware):
"""
Middleware to inject vendor context into request state.
Class-based middleware provides:
- Better state management
- Easier testing
- More organized code
- Standard ASGI pattern
Runs FIRST in middleware chain.
Sets:
request.state.vendor: Vendor object
request.state.vendor_context: Detection metadata
request.state.clean_path: Path without vendor prefix
"""
async def dispatch(self, request: Request, call_next):
"""
Detect and inject vendor context.
"""
# Skip vendor detection for admin, API, static files, and system requests
if (
VendorContextManager.is_admin_request(request) or
VendorContextManager.is_api_request(request) or
VendorContextManager.is_static_file_request(request) or
request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]
):
logger.debug(
f"[VENDOR] Skipping vendor detection: {request.url.path}",
extra={"path": request.url.path, "reason": "admin/api/static/system"}
)
request.state.vendor = None
request.state.vendor_context = None
request.state.clean_path = request.url.path
return await call_next(request)
# Detect vendor context
vendor_context = VendorContextManager.detect_vendor_context(request)
if vendor_context:
db_gen = get_db()
db = next(db_gen)
try:
vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
if vendor:
request.state.vendor = vendor
request.state.vendor_context = vendor_context
request.state.clean_path = VendorContextManager.extract_clean_path(
request, vendor_context
)
logger.debug(
f"[VENDOR_CONTEXT] Vendor detected",
extra={
"vendor_id": vendor.id,
"vendor_name": vendor.name,
"vendor_subdomain": vendor.subdomain,
"detection_method": vendor_context.get("detection_method"),
"original_path": request.url.path,
"clean_path": request.state.clean_path,
}
)
else:
logger.warning(
f"[WARNING] Vendor context detected but vendor not found",
extra={
"context": vendor_context,
"detection_method": vendor_context.get("detection_method"),
}
)
request.state.vendor = None
request.state.vendor_context = vendor_context
request.state.clean_path = request.url.path
finally:
db.close()
else:
logger.debug(
f"[VENDOR] No vendor context detected",
extra={
"path": request.url.path,
"host": request.headers.get("host", ""),
}
)
request.state.vendor = None
request.state.vendor_context = None
request.state.clean_path = request.url.path
# Continue to next middleware
return await call_next(request)
def get_current_vendor(request: Request) -> Optional[Vendor]:
"""Helper function to get current vendor from request state."""
return getattr(request.state, "vendor", None)
def require_vendor_context():
"""Dependency to require vendor context in endpoints."""
def dependency(request: Request):
vendor = get_current_vendor(request)
if not vendor:
from fastapi import HTTPException
raise HTTPException(
status_code=404,
detail="Vendor not found or not active"
)
return vendor
return dependency