middleware fix for path-based vendor url

This commit is contained in:
2025-11-09 18:47:53 +01:00
parent 79dfcab09f
commit adbcee4ce3
13 changed files with 2078 additions and 810 deletions

View File

@@ -1,9 +1,22 @@
# middleware/vendor_context.py
"""
Vendor Context Middleware (Class-Based)
Detects vendor from host/domain/path and injects into request.state.
Handles three routing modes:
1. Custom domains (customdomain1.com → Vendor 1)
2. Subdomains (vendor1.platform.com → Vendor 1)
3. Path-based (/vendor/vendor1/ or /vendors/vendor1/ → Vendor 1)
Also extracts clean_path for nested routing patterns.
"""
import logging
from typing import Optional
from fastapi import Request
from sqlalchemy.orm import Session
from sqlalchemy import func, or_
from sqlalchemy import func
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
from app.core.database import get_db
from models.database.vendor import Vendor
@@ -23,7 +36,7 @@ class VendorContextManager:
Priority order:
1. Custom domain (customdomain1.com)
2. Subdomain (vendor1.platform.com)
3. Path-based (/vendor/vendor1/)
3. Path-based (/vendor/vendor1/ or /vendors/vendor1/)
Returns dict with vendor info or None if not found.
"""
@@ -48,7 +61,6 @@ class VendorContextManager:
)
if is_custom_domain:
# This could be a custom domain like customdomain1.com
normalized_domain = VendorDomain.normalize_domain(host)
return {
"domain": normalized_domain,
@@ -69,15 +81,23 @@ class VendorContextManager:
"host": host
}
# Method 3: Path-based detection (/vendor/vendorname/) - for development
if path.startswith("/vendor/"):
path_parts = path.split("/")
if len(path_parts) >= 3:
subdomain = path_parts[2]
# Method 3: Path-based detection (/vendor/vendorname/ or /vendors/vendorname/)
# Support BOTH patterns for flexibility
if path.startswith("/vendor/") or path.startswith("/vendors/"):
# Determine which pattern
if path.startswith("/vendors/"):
prefix_len = len("/vendors/")
else:
prefix_len = len("/vendor/")
path_parts = path[prefix_len:].split("/")
if len(path_parts) >= 1 and path_parts[0]:
vendor_code = path_parts[0]
return {
"subdomain": subdomain,
"subdomain": vendor_code,
"detection_method": "path",
"path_prefix": f"/vendor/{subdomain}",
"path_prefix": path[:prefix_len + len(vendor_code)],
"full_prefix": path[:prefix_len], # /vendor/ or /vendors/
"host": host
}
@@ -102,7 +122,6 @@ class VendorContextManager:
if context.get("detection_method") == "custom_domain":
domain = context.get("domain")
if domain:
# Look up vendor by custom domain
vendor_domain = (
db.query(VendorDomain)
.filter(VendorDomain.domain == domain)
@@ -113,12 +132,11 @@ class VendorContextManager:
if vendor_domain:
vendor = vendor_domain.vendor
# Check if vendor is active
if not vendor or not vendor.is_active:
logger.warning(f"Vendor for domain {domain} is not active")
return None
logger.info(f"[OK] Vendor found via custom domain: {domain} -> {vendor.name}")
logger.info(f"[OK] Vendor found via custom domain: {domain} {vendor.name}")
return vendor
else:
logger.warning(f"No active vendor found for custom domain: {domain}")
@@ -127,7 +145,6 @@ class VendorContextManager:
# Method 2 & 3: Subdomain or path-based lookup
if "subdomain" in context:
subdomain = context["subdomain"]
# Query vendor by subdomain (case-insensitive)
vendor = (
db.query(Vendor)
.filter(func.lower(Vendor.subdomain) == subdomain.lower())
@@ -137,7 +154,7 @@ class VendorContextManager:
if vendor:
method = context.get("detection_method", "unknown")
logger.info(f"[OK] Vendor found via {method}: {subdomain} -> {vendor.name}")
logger.info(f"[OK] Vendor found via {method}: {subdomain} {vendor.name}")
else:
logger.warning(f"No active vendor found for subdomain: {subdomain}")
@@ -145,14 +162,19 @@ class VendorContextManager:
@staticmethod
def extract_clean_path(request: Request, vendor_context: Optional[dict]) -> str:
"""Extract clean path without vendor prefix for routing."""
"""
Extract clean path without vendor prefix for routing.
Supports both /vendor/ and /vendors/ prefixes.
"""
if not vendor_context:
return request.url.path
# Only strip path prefix for path-based detection
if vendor_context.get("detection_method") == "path":
path_prefix = vendor_context.get("path_prefix", "")
path = request.url.path
path_prefix = vendor_context.get("path_prefix", "")
if path.startswith(path_prefix):
clean_path = path[len(path_prefix):]
return clean_path if clean_path else "/"
@@ -165,15 +187,12 @@ class VendorContextManager:
host = request.headers.get("host", "")
path = request.url.path
# Remove port from host
if ":" in host:
host = host.split(":")[0]
# Check for admin subdomain
if host.startswith("admin."):
return True
# Check for admin path
if path.startswith("/admin"):
return True
@@ -189,82 +208,118 @@ class VendorContextManager:
"""Check if request is for static files."""
path = request.url.path.lower()
# Static file extensions
static_extensions = (
'.ico', '.css', '.js', '.png', '.jpg', '.jpeg', '.gif', '.svg',
'.woff', '.woff2', '.ttf', '.eot', '.webp', '.map', '.json',
'.xml', '.txt', '.pdf', '.webmanifest'
)
# Static paths
static_paths = ('/static/', '/media/', '/assets/', '/.well-known/')
# Check if it's a static file by extension
if path.endswith(static_extensions):
return True
# Check if it's in a static directory
if any(path.startswith(static_path) for static_path in static_paths):
return True
# Special case: favicon.ico at any level
if 'favicon.ico' in path:
return True
return False
async def vendor_context_middleware(request: Request, call_next):
class VendorContextMiddleware(BaseHTTPMiddleware):
"""
Middleware to inject vendor context into request state.
Handles three routing modes:
1. Custom domains (customdomain1.com -> Vendor 1)
2. Subdomains (vendor1.platform.com -> Vendor 1)
3. Path-based (/vendor/vendor1/ -> Vendor 1)
Class-based middleware provides:
- Better state management
- Easier testing
- More organized code
- Standard ASGI pattern
Runs FIRST in middleware chain.
Sets:
request.state.vendor: Vendor object
request.state.vendor_context: Detection metadata
request.state.clean_path: Path without vendor prefix
"""
# Skip vendor detection for admin, API, static files, and system requests
if (VendorContextManager.is_admin_request(request) or
VendorContextManager.is_api_request(request) or
VendorContextManager.is_static_file_request(request) or
request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]):
async def dispatch(self, request: Request, call_next):
"""
Detect and inject vendor context.
"""
# Skip vendor detection for admin, API, static files, and system requests
if (
VendorContextManager.is_admin_request(request) or
VendorContextManager.is_api_request(request) or
VendorContextManager.is_static_file_request(request) or
request.url.path in ["/", "/health", "/docs", "/redoc", "/openapi.json"]
):
logger.debug(
f"[VENDOR] Skipping vendor detection: {request.url.path}",
extra={"path": request.url.path, "reason": "admin/api/static/system"}
)
request.state.vendor = None
request.state.vendor_context = None
request.state.clean_path = request.url.path
return await call_next(request)
# Detect vendor context
vendor_context = VendorContextManager.detect_vendor_context(request)
if vendor_context:
db_gen = get_db()
db = next(db_gen)
try:
vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
if vendor:
request.state.vendor = vendor
request.state.vendor_context = vendor_context
request.state.clean_path = VendorContextManager.extract_clean_path(
request, vendor_context
)
logger.debug(
f"[VENDOR_CONTEXT] Vendor detected",
extra={
"vendor_id": vendor.id,
"vendor_name": vendor.name,
"vendor_subdomain": vendor.subdomain,
"detection_method": vendor_context.get("detection_method"),
"original_path": request.url.path,
"clean_path": request.state.clean_path,
}
)
else:
logger.warning(
f"[WARNING] Vendor context detected but vendor not found",
extra={
"context": vendor_context,
"detection_method": vendor_context.get("detection_method"),
}
)
request.state.vendor = None
request.state.vendor_context = vendor_context
request.state.clean_path = request.url.path
finally:
db.close()
else:
logger.debug(
f"[VENDOR] No vendor context detected",
extra={
"path": request.url.path,
"host": request.headers.get("host", ""),
}
)
request.state.vendor = None
request.state.vendor_context = None
request.state.clean_path = request.url.path
# Continue to next middleware
return await call_next(request)
# Detect vendor context
vendor_context = VendorContextManager.detect_vendor_context(request)
if vendor_context:
db_gen = get_db()
db = next(db_gen)
try:
vendor = VendorContextManager.get_vendor_from_context(db, vendor_context)
if vendor:
request.state.vendor = vendor
request.state.vendor_context = vendor_context
request.state.clean_path = VendorContextManager.extract_clean_path(
request, vendor_context
)
logger.debug(
f"[VENDOR] Vendor context: {vendor.name} ({vendor.subdomain}) "
f"via {vendor_context['detection_method']}"
)
else:
logger.warning(
f"[WARNING] Vendor not found for context: {vendor_context}"
)
request.state.vendor = None
request.state.vendor_context = vendor_context
finally:
db.close()
else:
request.state.vendor = None
request.state.vendor_context = None
request.state.clean_path = request.url.path
return await call_next(request)
def get_current_vendor(request: Request) -> Optional[Vendor]:
"""Helper function to get current vendor from request state."""