288 lines
9.9 KiB
Python
288 lines
9.9 KiB
Python
# middleware/vendor_context.py
|
|
import logging
|
|
from typing import Optional
|
|
from fastapi import Request
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import func, or_
|
|
|
|
from app.core.database import get_db
|
|
from models.database.vendor import Vendor
|
|
from models.database.vendor_domain import VendorDomain
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VendorContextManager:
|
|
"""Manages vendor context detection for multi-tenant routing."""
|
|
|
|
@staticmethod
|
|
def detect_vendor_context(request: Request) -> Optional[dict]:
|
|
"""
|
|
Detect vendor context from request.
|
|
|
|
Priority order:
|
|
1. Custom domain (customdomain1.com)
|
|
2. Subdomain (vendor1.platform.com)
|
|
3. Path-based (/vendor/vendor1/)
|
|
|
|
Returns dict with vendor info or None if not found.
|
|
"""
|
|
host = request.headers.get("host", "")
|
|
path = request.url.path
|
|
|
|
# Remove port from host if present (e.g., localhost:8000 -> localhost)
|
|
if ":" in host:
|
|
host = host.split(":")[0]
|
|
|
|
# Method 1: Custom domain detection (HIGHEST PRIORITY)
|
|
# Check if this is a custom domain (not platform.com and not localhost)
|
|
from app.core.config import settings
|
|
platform_domain = getattr(settings, 'platform_domain', 'platform.com')
|
|
|
|
is_custom_domain = (
|
|
host and
|
|
not host.endswith(f".{platform_domain}") and
|
|
host != platform_domain and
|
|
host not in ["localhost", "127.0.0.1", "admin.localhost", "admin.127.0.0.1"] and
|
|
not host.startswith("admin.")
|
|
)
|
|
|
|
if is_custom_domain:
|
|
# This could be a custom domain like customdomain1.com
|
|
normalized_domain = VendorDomain.normalize_domain(host)
|
|
return {
|
|
"domain": normalized_domain,
|
|
"detection_method": "custom_domain",
|
|
"host": host,
|
|
"original_host": request.headers.get("host", "")
|
|
}
|
|
|
|
# Method 2: Subdomain detection (vendor1.platform.com)
|
|
if "." in host:
|
|
parts = host.split(".")
|
|
# Check if it's a valid subdomain (not www, admin, api)
|
|
if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]:
|
|
subdomain = parts[0]
|
|
return {
|
|
"subdomain": subdomain,
|
|
"detection_method": "subdomain",
|
|
"host": host
|
|
}
|
|
|
|
# Method 3: Path-based detection (/vendor/vendorname/) - for development
|
|
if path.startswith("/vendor/"):
|
|
path_parts = path.split("/")
|
|
if len(path_parts) >= 3:
|
|
subdomain = path_parts[2]
|
|
return {
|
|
"subdomain": subdomain,
|
|
"detection_method": "path",
|
|
"path_prefix": f"/vendor/{subdomain}",
|
|
"host": host
|
|
}
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_vendor_from_context(db: Session, context: dict) -> Optional[Vendor]:
|
|
"""
|
|
Get vendor from database using context information.
|
|
|
|
Supports three methods:
|
|
1. Custom domain lookup (VendorDomain table)
|
|
2. Subdomain lookup (Vendor.subdomain)
|
|
3. Path-based lookup (Vendor.subdomain)
|
|
"""
|
|
if not context:
|
|
return None
|
|
|
|
vendor = None
|
|
|
|
# Method 1: Custom domain lookup
|
|
if context.get("detection_method") == "custom_domain":
|
|
domain = context.get("domain")
|
|
if domain:
|
|
# Look up vendor by custom domain
|
|
vendor_domain = (
|
|
db.query(VendorDomain)
|
|
.filter(VendorDomain.domain == domain)
|
|
.filter(VendorDomain.is_active == True)
|
|
.filter(VendorDomain.is_verified == True)
|
|
.first()
|
|
)
|
|
|
|
if vendor_domain:
|
|
vendor = vendor_domain.vendor
|
|
# Check if vendor is active
|
|
if not vendor or not vendor.is_active:
|
|
logger.warning(f"Vendor for domain {domain} is not active")
|
|
return None
|
|
|
|
logger.info(f"[OK] Vendor found via custom domain: {domain} -> {vendor.name}")
|
|
return vendor
|
|
else:
|
|
logger.warning(f"No active vendor found for custom domain: {domain}")
|
|
return None
|
|
|
|
# Method 2 & 3: Subdomain or path-based lookup
|
|
if "subdomain" in context:
|
|
subdomain = context["subdomain"]
|
|
# Query vendor by subdomain (case-insensitive)
|
|
vendor = (
|
|
db.query(Vendor)
|
|
.filter(func.lower(Vendor.subdomain) == subdomain.lower())
|
|
.filter(Vendor.is_active == True)
|
|
.first()
|
|
)
|
|
|
|
if vendor:
|
|
method = context.get("detection_method", "unknown")
|
|
logger.info(f"[OK] Vendor found via {method}: {subdomain} -> {vendor.name}")
|
|
else:
|
|
logger.warning(f"No active vendor found for subdomain: {subdomain}")
|
|
|
|
return vendor
|
|
|
|
@staticmethod
|
|
def extract_clean_path(request: Request, vendor_context: Optional[dict]) -> str:
|
|
"""Extract clean path without vendor prefix for routing."""
|
|
if not vendor_context:
|
|
return request.url.path
|
|
|
|
# Only strip path prefix for path-based detection
|
|
if vendor_context.get("detection_method") == "path":
|
|
path_prefix = vendor_context.get("path_prefix", "")
|
|
path = request.url.path
|
|
if path.startswith(path_prefix):
|
|
clean_path = path[len(path_prefix):]
|
|
return clean_path if clean_path else "/"
|
|
|
|
return request.url.path
|
|
|
|
@staticmethod
|
|
def is_admin_request(request: Request) -> bool:
|
|
"""Check if request is for admin interface."""
|
|
host = request.headers.get("host", "")
|
|
path = request.url.path
|
|
|
|
# Remove port from host
|
|
if ":" in host:
|
|
host = host.split(":")[0]
|
|
|
|
# Check for admin subdomain
|
|
if host.startswith("admin."):
|
|
return True
|
|
|
|
# Check for admin path
|
|
if path.startswith("/admin"):
|
|
return True
|
|
|
|
return False
|
|
|
|
@staticmethod
|
|
def is_api_request(request: Request) -> bool:
|
|
"""Check if request is for API endpoints."""
|
|
return request.url.path.startswith("/api/")
|
|
|
|
@staticmethod
|
|
def is_static_file_request(request: Request) -> bool:
|
|
"""Check if request is for static files."""
|
|
path = request.url.path.lower()
|
|
|
|
# Static file extensions
|
|
static_extensions = (
|
|
'.ico', '.css', '.js', '.png', '.jpg', '.jpeg', '.gif', '.svg',
|
|
'.woff', '.woff2', '.ttf', '.eot', '.webp', '.map', '.json',
|
|
'.xml', '.txt', '.pdf', '.webmanifest'
|
|
)
|
|
|
|
# Static paths
|
|
static_paths = ('/static/', '/media/', '/assets/', '/.well-known/')
|
|
|
|
# Check if it's a static file by extension
|
|
if path.endswith(static_extensions):
|
|
return True
|
|
|
|
# Check if it's in a static directory
|
|
if any(path.startswith(static_path) for static_path in static_paths):
|
|
return True
|
|
|
|
# Special case: favicon.ico at any level
|
|
if 'favicon.ico' in path:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
async def vendor_context_middleware(request: Request, call_next):
|
|
"""
|
|
Middleware to inject vendor context into request state.
|
|
|
|
Handles three routing modes:
|
|
1. Custom domains (customdomain1.com -> Vendor 1)
|
|
2. Subdomains (vendor1.platform.com -> Vendor 1)
|
|
3. Path-based (/vendor/vendor1/ -> Vendor 1)
|
|
"""
|
|
# Skip vendor detection for admin, API, static files, and system requests
|
|
if (VendorContextManager.is_admin_request(request) or
|
|
VendorContextManager.is_api_request(request) or
|
|
VendorContextManager.is_static_file_request(request) or
|
|
request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]):
|
|
return await call_next(request)
|
|
|
|
# Detect vendor context
|
|
vendor_context = VendorContextManager.detect_vendor_context(request)
|
|
|
|
if vendor_context:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
|
|
|
|
if vendor:
|
|
request.state.vendor = vendor
|
|
request.state.vendor_context = vendor_context
|
|
request.state.clean_path = VendorContextManager.extract_clean_path(
|
|
request, vendor_context
|
|
)
|
|
|
|
logger.debug(
|
|
f"[VENDOR] Vendor context: {vendor.name} ({vendor.subdomain}) "
|
|
f"via {vendor_context['detection_method']}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"[WARNING] Vendor not found for context: {vendor_context}"
|
|
)
|
|
request.state.vendor = None
|
|
request.state.vendor_context = vendor_context
|
|
finally:
|
|
db.close()
|
|
else:
|
|
request.state.vendor = None
|
|
request.state.vendor_context = None
|
|
request.state.clean_path = request.url.path
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
def get_current_vendor(request: Request) -> Optional[Vendor]:
|
|
"""Helper function to get current vendor from request state."""
|
|
return getattr(request.state, "vendor", None)
|
|
|
|
|
|
def require_vendor_context():
|
|
"""Dependency to require vendor context in endpoints."""
|
|
|
|
def dependency(request: Request):
|
|
vendor = get_current_vendor(request)
|
|
if not vendor:
|
|
from fastapi import HTTPException
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Vendor not found or not active"
|
|
)
|
|
return vendor
|
|
|
|
return dependency
|