major refactoring adding vendor and customer features
This commit is contained in:
166
middleware/vendor_context.py
Normal file
166
middleware/vendor_context.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# middleware/vendor_context.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.core.database import get_db
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VendorContextManager:
|
||||
"""Manages vendor context detection for multi-tenant routing."""
|
||||
|
||||
@staticmethod
|
||||
def detect_vendor_context(request: Request) -> Optional[dict]:
|
||||
"""
|
||||
Detect vendor context from request.
|
||||
|
||||
Returns dict with vendor info or None if not found.
|
||||
"""
|
||||
host = request.headers.get("host", "")
|
||||
path = request.url.path
|
||||
|
||||
# Method 1: Subdomain detection (production)
|
||||
if "." in host:
|
||||
parts = host.split(".")
|
||||
if len(parts) >= 2 and parts[0] not in ["www", "admin", "api"]:
|
||||
subdomain = parts[0]
|
||||
return {
|
||||
"subdomain": subdomain,
|
||||
"detection_method": "subdomain",
|
||||
"host": host
|
||||
}
|
||||
|
||||
# Method 2: Path-based detection (development)
|
||||
if path.startswith("/vendor/"):
|
||||
path_parts = path.split("/")
|
||||
if len(path_parts) >= 3:
|
||||
subdomain = path_parts[2]
|
||||
return {
|
||||
"subdomain": subdomain,
|
||||
"detection_method": "path",
|
||||
"path_prefix": f"/vendor/{subdomain}",
|
||||
"host": host
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_vendor_from_context(db: Session, context: dict) -> Optional[Vendor]:
|
||||
"""Get vendor from database using context information."""
|
||||
if not context or "subdomain" not in context:
|
||||
return None
|
||||
|
||||
# Query vendor by subdomain (case-insensitive)
|
||||
vendor = (
|
||||
db.query(Vendor)
|
||||
.filter(func.lower(Vendor.subdomain) == context["subdomain"].lower())
|
||||
.filter(Vendor.is_active == True) # Only active vendors
|
||||
.first()
|
||||
)
|
||||
|
||||
return vendor
|
||||
|
||||
@staticmethod
|
||||
def extract_clean_path(request: Request, vendor_context: Optional[dict]) -> str:
|
||||
"""Extract clean path without vendor prefix for routing."""
|
||||
if not vendor_context:
|
||||
return request.url.path
|
||||
|
||||
if vendor_context.get("detection_method") == "path":
|
||||
path_prefix = vendor_context.get("path_prefix", "")
|
||||
path = request.url.path
|
||||
if path.startswith(path_prefix):
|
||||
clean_path = path[len(path_prefix):]
|
||||
return clean_path if clean_path else "/"
|
||||
|
||||
return request.url.path
|
||||
|
||||
@staticmethod
|
||||
def is_admin_request(request: Request) -> bool:
|
||||
"""Check if request is for admin interface."""
|
||||
host = request.headers.get("host", "")
|
||||
path = request.url.path
|
||||
|
||||
if host.startswith("admin."):
|
||||
return True
|
||||
|
||||
if "/admin" in path:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_api_request(request: Request) -> bool:
|
||||
"""Check if request is for API endpoints."""
|
||||
return request.url.path.startswith("/api/")
|
||||
|
||||
|
||||
async def vendor_context_middleware(request: Request, call_next):
|
||||
"""
|
||||
Middleware to inject vendor context into request state.
|
||||
"""
|
||||
# Skip vendor detection for admin, API, and system requests
|
||||
if (VendorContextManager.is_admin_request(request) or
|
||||
VendorContextManager.is_api_request(request) or
|
||||
request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]):
|
||||
return await call_next(request)
|
||||
|
||||
# Detect vendor context
|
||||
vendor_context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
if vendor_context:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
|
||||
|
||||
if vendor:
|
||||
request.state.vendor = vendor
|
||||
request.state.vendor_context = vendor_context
|
||||
request.state.clean_path = VendorContextManager.extract_clean_path(
|
||||
request, vendor_context
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Vendor context: {vendor.name} ({vendor.subdomain}) "
|
||||
f"via {vendor_context['detection_method']}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Vendor not found for subdomain: {vendor_context['subdomain']}"
|
||||
)
|
||||
request.state.vendor = None
|
||||
request.state.vendor_context = vendor_context
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
request.state.vendor = None
|
||||
request.state.vendor_context = None
|
||||
request.state.clean_path = request.url.path
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def get_current_vendor(request: Request) -> Optional[Vendor]:
|
||||
"""Helper function to get current vendor from request state."""
|
||||
return getattr(request.state, "vendor", None)
|
||||
|
||||
|
||||
def require_vendor_context():
|
||||
"""Dependency to require vendor context in endpoints."""
|
||||
def dependency(request: Request):
|
||||
vendor = get_current_vendor(request)
|
||||
if not vendor:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Vendor not found or not active"
|
||||
)
|
||||
return vendor
|
||||
|
||||
return dependency
|
||||
Reference in New Issue
Block a user