167 lines
5.4 KiB
Python
167 lines
5.4 KiB
Python
# middleware/vendor_context.py
|
|
import logging
|
|
from typing import Optional
|
|
from fastapi import Request
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import func
|
|
|
|
from app.core.database import get_db
|
|
from models.database.vendor import Vendor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VendorContextManager:
|
|
"""Manages vendor context detection for multi-tenant routing."""
|
|
|
|
@staticmethod
|
|
def detect_vendor_context(request: Request) -> Optional[dict]:
|
|
"""
|
|
Detect vendor context from request.
|
|
|
|
Returns dict with vendor info or None if not found.
|
|
"""
|
|
host = request.headers.get("host", "")
|
|
path = request.url.path
|
|
|
|
# Method 1: Subdomain detection (production)
|
|
if "." in host:
|
|
parts = host.split(".")
|
|
if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]:
|
|
subdomain = parts[0]
|
|
return {
|
|
"subdomain": subdomain,
|
|
"detection_method": "subdomain",
|
|
"host": host
|
|
}
|
|
|
|
# Method 2: Path-based detection (development)
|
|
if path.startswith("/vendor/"):
|
|
path_parts = path.split("/")
|
|
if len(path_parts) >= 3:
|
|
subdomain = path_parts[2]
|
|
return {
|
|
"subdomain": subdomain,
|
|
"detection_method": "path",
|
|
"path_prefix": f"/vendor/{subdomain}",
|
|
"host": host
|
|
}
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_vendor_from_context(db: Session, context: dict) -> Optional[Vendor]:
|
|
"""Get vendor from database using context information."""
|
|
if not context or "subdomain" not in context:
|
|
return None
|
|
|
|
# Query vendor by subdomain (case-insensitive)
|
|
vendor = (
|
|
db.query(Vendor)
|
|
.filter(func.lower(Vendor.subdomain) == context["subdomain"].lower())
|
|
.filter(Vendor.is_active == True) # Only active vendors
|
|
.first()
|
|
)
|
|
|
|
return vendor
|
|
|
|
@staticmethod
|
|
def extract_clean_path(request: Request, vendor_context: Optional[dict]) -> str:
|
|
"""Extract clean path without vendor prefix for routing."""
|
|
if not vendor_context:
|
|
return request.url.path
|
|
|
|
if vendor_context.get("detection_method") == "path":
|
|
path_prefix = vendor_context.get("path_prefix", "")
|
|
path = request.url.path
|
|
if path.startswith(path_prefix):
|
|
clean_path = path[len(path_prefix):]
|
|
return clean_path if clean_path else "/"
|
|
|
|
return request.url.path
|
|
|
|
@staticmethod
|
|
def is_admin_request(request: Request) -> bool:
|
|
"""Check if request is for admin interface."""
|
|
host = request.headers.get("host", "")
|
|
path = request.url.path
|
|
|
|
if host.startswith("admin."):
|
|
return True
|
|
|
|
if "/admin" in path:
|
|
return True
|
|
|
|
return False
|
|
|
|
@staticmethod
|
|
def is_api_request(request: Request) -> bool:
|
|
"""Check if request is for API endpoints."""
|
|
return request.url.path.startswith("/api/")
|
|
|
|
|
|
async def vendor_context_middleware(request: Request, call_next):
|
|
"""
|
|
Middleware to inject vendor context into request state.
|
|
"""
|
|
# Skip vendor detection for admin, API, and system requests
|
|
if (VendorContextManager.is_admin_request(request) or
|
|
VendorContextManager.is_api_request(request) or
|
|
request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]):
|
|
return await call_next(request)
|
|
|
|
# Detect vendor context
|
|
vendor_context = VendorContextManager.detect_vendor_context(request)
|
|
|
|
if vendor_context:
|
|
db_gen = get_db()
|
|
db = next(db_gen)
|
|
try:
|
|
vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
|
|
|
|
if vendor:
|
|
request.state.vendor = vendor
|
|
request.state.vendor_context = vendor_context
|
|
request.state.clean_path = VendorContextManager.extract_clean_path(
|
|
request, vendor_context
|
|
)
|
|
|
|
logger.debug(
|
|
f"Vendor context: {vendor.name} ({vendor.subdomain}) "
|
|
f"via {vendor_context['detection_method']}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Vendor not found for subdomain: {vendor_context['subdomain']}"
|
|
)
|
|
request.state.vendor = None
|
|
request.state.vendor_context = vendor_context
|
|
finally:
|
|
db.close()
|
|
else:
|
|
request.state.vendor = None
|
|
request.state.vendor_context = None
|
|
request.state.clean_path = request.url.path
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
def get_current_vendor(request: Request) -> Optional[Vendor]:
|
|
"""Helper function to get current vendor from request state."""
|
|
return getattr(request.state, "vendor", None)
|
|
|
|
|
|
def require_vendor_context():
|
|
"""Dependency to require vendor context in endpoints."""
|
|
def dependency(request: Request):
|
|
vendor = get_current_vendor(request)
|
|
if not vendor:
|
|
from fastapi import HTTPException
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Vendor not found or not active"
|
|
)
|
|
return vendor
|
|
|
|
return dependency
|