refactor(arch): use CustomerContext schema for dependency injection
Phase 5 of storefront restructure plan - fix direct model imports in API routes by using schemas for dependency injection. Created CustomerContext schema: - Lightweight Pydantic model for customer data in API routes - Populated from Customer DB model in auth dependency - Contains all fields needed by storefront routes - Includes from_db_model() factory method Updated app/api/deps.py: - _validate_customer_token now returns CustomerContext instead of Customer - Updated docstrings for all customer auth functions Updated module storefront routes: - customers: Uses CustomerContext for profile/address endpoints - orders: Uses CustomerContext for order history endpoints - checkout: Uses CustomerContext for order placement - messaging: Uses CustomerContext for messaging endpoints This enforces the layered architecture (Routes → Services → Models) by ensuring API routes never import database models directly. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -716,7 +716,7 @@ def get_current_vendor_api(
|
||||
|
||||
def _validate_customer_token(token: str, request: Request, db: Session):
|
||||
"""
|
||||
Validate customer JWT token and return Customer object.
|
||||
Validate customer JWT token and return CustomerContext schema.
|
||||
|
||||
Validates:
|
||||
1. Token signature and expiration
|
||||
@@ -730,7 +730,7 @@ def _validate_customer_token(token: str, request: Request, db: Session):
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Customer: Authenticated customer object
|
||||
CustomerContext: Authenticated customer context schema
|
||||
|
||||
Raises:
|
||||
InvalidTokenException: If token is invalid or expired
|
||||
@@ -741,6 +741,7 @@ def _validate_customer_token(token: str, request: Request, db: Session):
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.modules.customers.models.customer import Customer
|
||||
from app.modules.customers.schemas import CustomerContext
|
||||
|
||||
# Decode and validate customer JWT token
|
||||
try:
|
||||
@@ -800,7 +801,8 @@ def _validate_customer_token(token: str, request: Request, db: Session):
|
||||
|
||||
logger.debug(f"Customer authenticated: {customer.email} (ID: {customer.id})")
|
||||
|
||||
return customer
|
||||
# Return CustomerContext schema instead of database model
|
||||
return CustomerContext.from_db_model(customer)
|
||||
|
||||
|
||||
def get_current_customer_from_cookie_or_header(
|
||||
@@ -828,7 +830,7 @@ def get_current_customer_from_cookie_or_header(
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Customer: Authenticated customer object
|
||||
CustomerContext: Authenticated customer context schema
|
||||
|
||||
Raises:
|
||||
InvalidTokenException: If no token or invalid token
|
||||
@@ -862,7 +864,7 @@ def get_current_customer_api(
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Customer: Authenticated customer object
|
||||
CustomerContext: Authenticated customer context schema
|
||||
|
||||
Raises:
|
||||
InvalidTokenException: If no token or invalid token
|
||||
@@ -1327,7 +1329,7 @@ def get_current_customer_optional(
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Customer: Authenticated customer if valid token exists
|
||||
CustomerContext: Authenticated customer context if valid token exists
|
||||
None: If no token, invalid token, or vendor mismatch
|
||||
"""
|
||||
token, source = _get_token_from_request(
|
||||
|
||||
@@ -26,10 +26,10 @@ from app.modules.checkout.schemas import (
|
||||
CheckoutSessionResponse,
|
||||
)
|
||||
from app.modules.checkout.services import checkout_service
|
||||
from app.modules.customers.schemas import CustomerContext
|
||||
from app.modules.orders.services import order_service
|
||||
from app.services.email_service import EmailService
|
||||
from middleware.vendor_context import require_vendor_context
|
||||
from models.database.customer import Customer
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.order import OrderCreate, OrderResponse
|
||||
|
||||
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
|
||||
def place_order(
|
||||
request: Request,
|
||||
order_data: OrderCreate,
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -24,13 +24,13 @@ from app.api.deps import get_current_customer_api
|
||||
from app.core.database import get_db
|
||||
from app.core.environment import should_use_secure_cookies
|
||||
from app.exceptions import ValidationException, VendorNotFoundException
|
||||
from app.modules.customers.schemas import CustomerContext
|
||||
from app.modules.customers.services import (
|
||||
customer_address_service,
|
||||
customer_service,
|
||||
)
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.email_service import EmailService
|
||||
from models.database.customer import Customer
|
||||
from models.database.password_reset_token import PasswordResetToken
|
||||
from models.schema.auth import (
|
||||
LogoutResponse,
|
||||
@@ -371,7 +371,7 @@ def reset_password(
|
||||
|
||||
@router.get("/profile", response_model=CustomerResponse) # authenticated
|
||||
def get_profile(
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -390,7 +390,7 @@ def get_profile(
|
||||
@router.put("/profile", response_model=CustomerResponse)
|
||||
def update_profile(
|
||||
update_data: CustomerUpdate,
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -444,7 +444,7 @@ def update_profile(
|
||||
@router.put("/profile/password", response_model=dict)
|
||||
def change_password(
|
||||
password_data: CustomerPasswordChange,
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -492,7 +492,7 @@ def change_password(
|
||||
@router.get("/addresses", response_model=CustomerAddressListResponse) # authenticated
|
||||
def list_addresses(
|
||||
request: Request,
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -528,7 +528,7 @@ def list_addresses(
|
||||
def get_address(
|
||||
request: Request,
|
||||
address_id: int = Path(..., description="Address ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -561,7 +561,7 @@ def get_address(
|
||||
def create_address(
|
||||
request: Request,
|
||||
address_data: CustomerAddressCreate,
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -610,7 +610,7 @@ def update_address(
|
||||
request: Request,
|
||||
address_data: CustomerAddressUpdate,
|
||||
address_id: int = Path(..., description="Address ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -654,7 +654,7 @@ def update_address(
|
||||
def delete_address(
|
||||
request: Request,
|
||||
address_id: int = Path(..., description="Address ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -691,7 +691,7 @@ def delete_address(
|
||||
def set_address_default(
|
||||
request: Request,
|
||||
address_id: int = Path(..., description="Address ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -9,9 +9,11 @@ Usage:
|
||||
CustomerRegister,
|
||||
CustomerUpdate,
|
||||
CustomerResponse,
|
||||
CustomerContext,
|
||||
)
|
||||
"""
|
||||
|
||||
from app.modules.customers.schemas.context import CustomerContext
|
||||
from app.modules.customers.schemas.customer import (
|
||||
# Registration & Authentication
|
||||
CustomerRegister,
|
||||
@@ -41,6 +43,8 @@ from app.modules.customers.schemas.customer import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Context (for dependency injection)
|
||||
"CustomerContext",
|
||||
# Registration & Authentication
|
||||
"CustomerRegister",
|
||||
"CustomerUpdate",
|
||||
|
||||
99
app/modules/customers/schemas/context.py
Normal file
99
app/modules/customers/schemas/context.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# app/modules/customers/schemas/context.py
|
||||
"""
|
||||
Customer context schema for dependency injection in storefront routes.
|
||||
|
||||
This schema provides a clean interface for customer data in API routes,
|
||||
avoiding direct database model imports in the API layer.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class CustomerContext(BaseModel):
|
||||
"""
|
||||
Customer context for dependency injection in storefront routes.
|
||||
|
||||
This is a lightweight schema that contains the customer information
|
||||
needed by API routes. It's populated from the Customer database model
|
||||
in the authentication dependency.
|
||||
|
||||
Usage:
|
||||
@router.get("/profile")
|
||||
def get_profile(
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
):
|
||||
return {"email": customer.email}
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# Core identification
|
||||
id: int
|
||||
vendor_id: int
|
||||
email: str
|
||||
customer_number: str
|
||||
|
||||
# Profile info
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
phone: str | None = None
|
||||
|
||||
# Preferences
|
||||
marketing_consent: bool = False
|
||||
preferred_language: str | None = None
|
||||
|
||||
# Stats (for order placement)
|
||||
total_orders: int = 0
|
||||
total_spent: Decimal = Decimal("0.00")
|
||||
last_order_date: datetime | None = None
|
||||
|
||||
# Status
|
||||
is_active: bool = True
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
# Password hash (needed for password change endpoint)
|
||||
# This is included but should not be exposed in API responses
|
||||
hashed_password: str | None = None
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
"""Get customer full name."""
|
||||
if self.first_name and self.last_name:
|
||||
return f"{self.first_name} {self.last_name}"
|
||||
return self.email
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, customer) -> "CustomerContext":
|
||||
"""
|
||||
Create CustomerContext from a Customer database model.
|
||||
|
||||
Args:
|
||||
customer: Customer database model instance
|
||||
|
||||
Returns:
|
||||
CustomerContext: Pydantic schema instance
|
||||
"""
|
||||
return cls(
|
||||
id=customer.id,
|
||||
vendor_id=customer.vendor_id,
|
||||
email=customer.email,
|
||||
customer_number=customer.customer_number,
|
||||
first_name=customer.first_name,
|
||||
last_name=customer.last_name,
|
||||
phone=customer.phone,
|
||||
marketing_consent=customer.marketing_consent,
|
||||
preferred_language=customer.preferred_language,
|
||||
total_orders=customer.total_orders or 0,
|
||||
total_spent=customer.total_spent or Decimal("0.00"),
|
||||
last_order_date=customer.last_order_date,
|
||||
is_active=customer.is_active,
|
||||
created_at=customer.created_at,
|
||||
updated_at=customer.updated_at,
|
||||
hashed_password=customer.hashed_password,
|
||||
)
|
||||
@@ -33,6 +33,7 @@ from app.exceptions import (
|
||||
ConversationNotFoundException,
|
||||
VendorNotFoundException,
|
||||
)
|
||||
from app.modules.customers.schemas import CustomerContext
|
||||
from app.modules.messaging.models.message import ConversationType, ParticipantType
|
||||
from app.modules.messaging.schemas import (
|
||||
ConversationDetailResponse,
|
||||
@@ -45,7 +46,6 @@ from app.modules.messaging.services import (
|
||||
message_attachment_service,
|
||||
messaging_service,
|
||||
)
|
||||
from models.database.customer import Customer
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -74,7 +74,7 @@ def list_conversations(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
status: Optional[str] = Query(None, pattern="^(open|closed)$"),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -146,7 +146,7 @@ def list_conversations(
|
||||
@router.get("/messages/unread-count", response_model=UnreadCountResponse)
|
||||
def get_unread_count(
|
||||
request: Request,
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -171,7 +171,7 @@ def get_unread_count(
|
||||
def get_conversation(
|
||||
request: Request,
|
||||
conversation_id: int = Path(..., description="Conversation ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -261,7 +261,7 @@ async def send_message(
|
||||
conversation_id: int = Path(..., description="Conversation ID", gt=0),
|
||||
content: str = Form(..., min_length=1, max_length=10000),
|
||||
attachments: List[UploadFile] = File(default=[]),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -359,7 +359,7 @@ async def send_message(
|
||||
def mark_as_read(
|
||||
request: Request,
|
||||
conversation_id: int = Path(..., description="Conversation ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Mark conversation as read."""
|
||||
@@ -394,7 +394,7 @@ async def download_attachment(
|
||||
request: Request,
|
||||
conversation_id: int = Path(..., description="Conversation ID", gt=0),
|
||||
attachment_id: int = Path(..., description="Attachment ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -439,7 +439,7 @@ async def get_attachment_thumbnail(
|
||||
request: Request,
|
||||
conversation_id: int = Path(..., description="Conversation ID", gt=0),
|
||||
attachment_id: int = Path(..., description="Attachment ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -503,7 +503,7 @@ def _get_other_participant_name(conversation, customer_id: int) -> str:
|
||||
def _get_sender_name(message) -> str:
|
||||
"""Get sender name for a message."""
|
||||
if message.sender_type == ParticipantType.CUSTOMER:
|
||||
from models.database.customer import Customer
|
||||
from app.modules.customers.models import Customer
|
||||
|
||||
customer = (
|
||||
Customer.query.filter_by(id=message.sender_id).first()
|
||||
|
||||
@@ -22,9 +22,9 @@ from app.api.deps import get_current_customer_api
|
||||
from app.core.database import get_db
|
||||
from app.exceptions import OrderNotFoundException, VendorNotFoundException
|
||||
from app.exceptions.invoice import InvoicePDFNotFoundException
|
||||
from app.modules.customers.schemas import CustomerContext
|
||||
from app.modules.orders.services import order_service
|
||||
from app.services.invoice_service import invoice_service
|
||||
from models.database.customer import Customer
|
||||
from models.schema.order import (
|
||||
OrderDetailResponse,
|
||||
OrderListResponse,
|
||||
@@ -40,7 +40,7 @@ def get_my_orders(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -85,7 +85,7 @@ def get_my_orders(
|
||||
def get_order_details(
|
||||
request: Request,
|
||||
order_id: int = Path(..., description="Order ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
@@ -125,7 +125,7 @@ def get_order_details(
|
||||
def download_order_invoice(
|
||||
request: Request,
|
||||
order_id: int = Path(..., description="Order ID", gt=0),
|
||||
customer: Customer = Depends(get_current_customer_api),
|
||||
customer: CustomerContext = Depends(get_current_customer_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -314,7 +314,7 @@ After migrated to `app/modules/cart/services/cart_service.py`.
|
||||
2. **Phase 2** - Rename shop → storefront (terminology) ✅ COMPLETE
|
||||
3. **Phase 3** - Create new modules (cart, checkout, catalog) ✅ COMPLETE
|
||||
4. **Phase 4** - Move routes to modules ✅ COMPLETE
|
||||
5. **Phase 5** - Fix direct model imports
|
||||
5. **Phase 5** - Fix direct model imports ✅ COMPLETE
|
||||
6. **Phase 6** - Delete legacy files
|
||||
7. **Phase 7** - Update documentation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user