diff --git a/alembic/versions/r6f7a8b9c0d1_add_country_iso_to_addresses.py b/alembic/versions/r6f7a8b9c0d1_add_country_iso_to_addresses.py new file mode 100644 index 00000000..5ba7916e --- /dev/null +++ b/alembic/versions/r6f7a8b9c0d1_add_country_iso_to_addresses.py @@ -0,0 +1,141 @@ +"""Add country_iso to customer_addresses + +Revision ID: r6f7a8b9c0d1 +Revises: q5e6f7a8b9c0 +Create Date: 2026-01-02 + +Adds country_iso field to customer_addresses table and renames +country to country_name for clarity. + +This migration is idempotent - it checks for existing columns before +making changes. +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "r6f7a8b9c0d1" +down_revision = "q5e6f7a8b9c0" +branch_labels = None +depends_on = None + + +# Country name to ISO code mapping for backfill +COUNTRY_ISO_MAP = { + "Luxembourg": "LU", + "Germany": "DE", + "France": "FR", + "Belgium": "BE", + "Netherlands": "NL", + "Austria": "AT", + "Italy": "IT", + "Spain": "ES", + "Portugal": "PT", + "Poland": "PL", + "Czech Republic": "CZ", + "Czechia": "CZ", + "Slovakia": "SK", + "Hungary": "HU", + "Romania": "RO", + "Bulgaria": "BG", + "Greece": "GR", + "Croatia": "HR", + "Slovenia": "SI", + "Estonia": "EE", + "Latvia": "LV", + "Lithuania": "LT", + "Finland": "FI", + "Sweden": "SE", + "Denmark": "DK", + "Ireland": "IE", + "Cyprus": "CY", + "Malta": "MT", + "United Kingdom": "GB", + "Switzerland": "CH", + "United States": "US", +} + + +def get_column_names(connection, table_name): + """Get list of column names for a table.""" + result = connection.execute(sa.text(f"PRAGMA table_info({table_name})")) + return [row[1] for row in result] + + +def upgrade() -> None: + connection = op.get_bind() + columns = get_column_names(connection, "customer_addresses") + + # Check if we need to do anything (idempotent check) + has_country = "country" in columns + has_country_name = "country_name" in columns + has_country_iso = "country_iso" in columns + + # If already has new columns, nothing to do + if has_country_name and has_country_iso: + print(" Columns country_name and country_iso already exist, skipping") + return + + # If has old 'country' column, rename it and add country_iso + if has_country and not has_country_name: + with op.batch_alter_table("customer_addresses") as batch_op: + batch_op.alter_column( + "country", + new_column_name="country_name", + ) + + # Add country_iso if it doesn't exist + if not has_country_iso: + with op.batch_alter_table("customer_addresses") as batch_op: + batch_op.add_column( + sa.Column("country_iso", sa.String(5), nullable=True) + ) + + # Backfill country_iso from country_name + for country_name, iso_code in COUNTRY_ISO_MAP.items(): + connection.execute( + sa.text( + "UPDATE customer_addresses SET country_iso = :iso " + "WHERE country_name = :name" + ), + {"iso": iso_code, "name": country_name}, + ) + + # Set default for any remaining NULL values + connection.execute( + sa.text( + "UPDATE customer_addresses SET country_iso = 'LU' " + "WHERE country_iso IS NULL" + ) + ) + + # Make country_iso NOT NULL using batch operation + with op.batch_alter_table("customer_addresses") as batch_op: + batch_op.alter_column( + "country_iso", + existing_type=sa.String(5), + nullable=False, + ) + + +def downgrade() -> None: + connection = op.get_bind() + columns = get_column_names(connection, "customer_addresses") + + has_country_name = "country_name" in columns + has_country_iso = "country_iso" in columns + has_country = "country" in columns + + # Only downgrade if in the new state + if has_country_name and not has_country: + with op.batch_alter_table("customer_addresses") as batch_op: + batch_op.alter_column( + "country_name", + new_column_name="country", + ) + + if has_country_iso: + with op.batch_alter_table("customer_addresses") as batch_op: + batch_op.drop_column("country_iso") diff --git a/app/api/v1/shop/__init__.py b/app/api/v1/shop/__init__.py index be4259b2..60d7ce84 100644 --- a/app/api/v1/shop/__init__.py +++ b/app/api/v1/shop/__init__.py @@ -21,7 +21,7 @@ Authentication: from fastapi import APIRouter # Import shop routers -from . import auth, carts, content_pages, messages, orders, products +from . import addresses, auth, carts, content_pages, messages, orders, products # Create shop router router = APIRouter() @@ -30,6 +30,9 @@ router = APIRouter() # SHOP API ROUTES (All vendor-context aware via middleware) # ============================================================================ +# Addresses (authenticated) +router.include_router(addresses.router, tags=["shop-addresses"]) + # Authentication (public) router.include_router(auth.router, tags=["shop-auth"]) diff --git a/app/api/v1/shop/addresses.py b/app/api/v1/shop/addresses.py new file mode 100644 index 00000000..76a78adc --- /dev/null +++ b/app/api/v1/shop/addresses.py @@ -0,0 +1,269 @@ +# app/api/v1/shop/addresses.py +""" +Shop Addresses API (Customer authenticated) + +Endpoints for managing customer addresses in shop frontend. +Uses vendor from request.state (injected by VendorContextMiddleware). +Requires customer authentication. +""" + +import logging + +from fastapi import APIRouter, Depends, Path, Request +from sqlalchemy.orm import Session + +from app.api.deps import get_current_customer_api +from app.core.database import get_db +from app.exceptions import VendorNotFoundException +from app.services.customer_address_service import customer_address_service +from models.database.customer import Customer +from models.schema.customer import ( + CustomerAddressCreate, + CustomerAddressListResponse, + CustomerAddressResponse, + CustomerAddressUpdate, +) + +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("/addresses", response_model=CustomerAddressListResponse) +def list_addresses( + request: Request, + customer: Customer = Depends(get_current_customer_api), + db: Session = Depends(get_db), +): + """ + List all addresses for authenticated customer. + + Vendor is automatically determined from request context. + Returns all addresses sorted by default first, then by creation date. + """ + vendor = getattr(request.state, "vendor", None) + if not vendor: + raise VendorNotFoundException("context", identifier_type="subdomain") + + logger.debug( + f"[SHOP_API] list_addresses for customer {customer.id}", + extra={ + "vendor_id": vendor.id, + "vendor_code": vendor.subdomain, + "customer_id": customer.id, + }, + ) + + addresses = customer_address_service.list_addresses( + db=db, vendor_id=vendor.id, customer_id=customer.id + ) + + return CustomerAddressListResponse( + addresses=[CustomerAddressResponse.model_validate(a) for a in addresses], + total=len(addresses), + ) + + +@router.get("/addresses/{address_id}", response_model=CustomerAddressResponse) +def get_address( + request: Request, + address_id: int = Path(..., description="Address ID", gt=0), + customer: Customer = Depends(get_current_customer_api), + db: Session = Depends(get_db), +): + """ + Get specific address by ID. + + Vendor is automatically determined from request context. + Customer can only access their own addresses. + """ + vendor = getattr(request.state, "vendor", None) + if not vendor: + raise VendorNotFoundException("context", identifier_type="subdomain") + + logger.debug( + f"[SHOP_API] get_address {address_id} for customer {customer.id}", + extra={ + "vendor_id": vendor.id, + "customer_id": customer.id, + "address_id": address_id, + }, + ) + + address = customer_address_service.get_address( + db=db, vendor_id=vendor.id, customer_id=customer.id, address_id=address_id + ) + + return CustomerAddressResponse.model_validate(address) + + +@router.post("/addresses", response_model=CustomerAddressResponse, status_code=201) +def create_address( + request: Request, + address_data: CustomerAddressCreate, + customer: Customer = Depends(get_current_customer_api), + db: Session = Depends(get_db), +): + """ + Create new address for authenticated customer. + + Vendor is automatically determined from request context. + Maximum 10 addresses per customer. + If is_default=True, clears default flag on other addresses of same type. + """ + vendor = getattr(request.state, "vendor", None) + if not vendor: + raise VendorNotFoundException("context", identifier_type="subdomain") + + logger.debug( + f"[SHOP_API] create_address for customer {customer.id}", + extra={ + "vendor_id": vendor.id, + "customer_id": customer.id, + "address_type": address_data.address_type, + }, + ) + + address = customer_address_service.create_address( + db=db, + vendor_id=vendor.id, + customer_id=customer.id, + address_data=address_data, + ) + db.commit() + + logger.info( + f"Created address {address.id} for customer {customer.id} " + f"(type={address_data.address_type})", + extra={ + "address_id": address.id, + "customer_id": customer.id, + "address_type": address_data.address_type, + }, + ) + + return CustomerAddressResponse.model_validate(address) + + +@router.put("/addresses/{address_id}", response_model=CustomerAddressResponse) +def update_address( + request: Request, + address_data: CustomerAddressUpdate, + address_id: int = Path(..., description="Address ID", gt=0), + customer: Customer = Depends(get_current_customer_api), + db: Session = Depends(get_db), +): + """ + Update existing address. + + Vendor is automatically determined from request context. + Customer can only update their own addresses. + If is_default=True, clears default flag on other addresses of same type. + """ + vendor = getattr(request.state, "vendor", None) + if not vendor: + raise VendorNotFoundException("context", identifier_type="subdomain") + + logger.debug( + f"[SHOP_API] update_address {address_id} for customer {customer.id}", + extra={ + "vendor_id": vendor.id, + "customer_id": customer.id, + "address_id": address_id, + }, + ) + + address = customer_address_service.update_address( + db=db, + vendor_id=vendor.id, + customer_id=customer.id, + address_id=address_id, + address_data=address_data, + ) + db.commit() + + logger.info( + f"Updated address {address_id} for customer {customer.id}", + extra={"address_id": address_id, "customer_id": customer.id}, + ) + + return CustomerAddressResponse.model_validate(address) + + +@router.delete("/addresses/{address_id}", status_code=204) +def delete_address( + request: Request, + address_id: int = Path(..., description="Address ID", gt=0), + customer: Customer = Depends(get_current_customer_api), + db: Session = Depends(get_db), +): + """ + Delete address. + + Vendor is automatically determined from request context. + Customer can only delete their own addresses. + """ + vendor = getattr(request.state, "vendor", None) + if not vendor: + raise VendorNotFoundException("context", identifier_type="subdomain") + + logger.debug( + f"[SHOP_API] delete_address {address_id} for customer {customer.id}", + extra={ + "vendor_id": vendor.id, + "customer_id": customer.id, + "address_id": address_id, + }, + ) + + customer_address_service.delete_address( + db=db, vendor_id=vendor.id, customer_id=customer.id, address_id=address_id + ) + db.commit() + + logger.info( + f"Deleted address {address_id} for customer {customer.id}", + extra={"address_id": address_id, "customer_id": customer.id}, + ) + + +@router.put("/addresses/{address_id}/default", response_model=CustomerAddressResponse) +def set_address_default( + request: Request, + address_id: int = Path(..., description="Address ID", gt=0), + customer: Customer = Depends(get_current_customer_api), + db: Session = Depends(get_db), +): + """ + Set address as default for its type. + + Vendor is automatically determined from request context. + Clears default flag on other addresses of the same type. + """ + vendor = getattr(request.state, "vendor", None) + if not vendor: + raise VendorNotFoundException("context", identifier_type="subdomain") + + logger.debug( + f"[SHOP_API] set_address_default {address_id} for customer {customer.id}", + extra={ + "vendor_id": vendor.id, + "customer_id": customer.id, + "address_id": address_id, + }, + ) + + address = customer_address_service.set_default( + db=db, vendor_id=vendor.id, customer_id=customer.id, address_id=address_id + ) + db.commit() + + logger.info( + f"Set address {address_id} as default for customer {customer.id}", + extra={ + "address_id": address_id, + "customer_id": customer.id, + "address_type": address.address_type, + }, + ) + + return CustomerAddressResponse.model_validate(address) diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py index ebb4e8b7..8656b180 100644 --- a/app/exceptions/__init__.py +++ b/app/exceptions/__init__.py @@ -6,6 +6,13 @@ This module provides frontend-friendly exceptions with consistent error codes, messages, and HTTP status mappings. """ +# Address exceptions +from .address import ( + AddressLimitExceededException, + AddressNotFoundException, + InvalidAddressTypeException, +) + # Admin exceptions from .admin import ( AdminOperationException, diff --git a/app/exceptions/address.py b/app/exceptions/address.py new file mode 100644 index 00000000..188f8ef1 --- /dev/null +++ b/app/exceptions/address.py @@ -0,0 +1,38 @@ +# app/exceptions/address.py +""" +Address-related custom exceptions. + +Used for customer address management operations. +""" + +from .base import BusinessLogicException, ResourceNotFoundException + + +class AddressNotFoundException(ResourceNotFoundException): + """Raised when a customer address is not found.""" + + def __init__(self, address_id: str | int): + super().__init__( + resource_type="Address", + identifier=str(address_id), + ) + + +class AddressLimitExceededException(BusinessLogicException): + """Raised when customer exceeds maximum number of addresses.""" + + def __init__(self, max_addresses: int = 10): + super().__init__( + message=f"Maximum number of addresses ({max_addresses}) reached", + error_code="ADDRESS_LIMIT_EXCEEDED", + ) + + +class InvalidAddressTypeException(BusinessLogicException): + """Raised when an invalid address type is provided.""" + + def __init__(self, address_type: str): + super().__init__( + message=f"Invalid address type '{address_type}'. Must be 'shipping' or 'billing'", + error_code="INVALID_ADDRESS_TYPE", + ) diff --git a/app/services/customer_address_service.py b/app/services/customer_address_service.py new file mode 100644 index 00000000..a1548b7c --- /dev/null +++ b/app/services/customer_address_service.py @@ -0,0 +1,314 @@ +# app/services/customer_address_service.py +""" +Customer Address Service + +Business logic for managing customer addresses with vendor isolation. +""" + +import logging + +from sqlalchemy.orm import Session + +from app.exceptions import ( + AddressLimitExceededException, + AddressNotFoundException, +) +from models.database.customer import CustomerAddress +from models.schema.customer import CustomerAddressCreate, CustomerAddressUpdate + +logger = logging.getLogger(__name__) + + +class CustomerAddressService: + """Service for managing customer addresses with vendor isolation.""" + + MAX_ADDRESSES_PER_CUSTOMER = 10 + + def list_addresses( + self, db: Session, vendor_id: int, customer_id: int + ) -> list[CustomerAddress]: + """ + Get all addresses for a customer. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + + Returns: + List of customer addresses + """ + return ( + db.query(CustomerAddress) + .filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + ) + .order_by(CustomerAddress.is_default.desc(), CustomerAddress.created_at.desc()) + .all() + ) + + def get_address( + self, db: Session, vendor_id: int, customer_id: int, address_id: int + ) -> CustomerAddress: + """ + Get a specific address with ownership validation. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + + Returns: + Customer address + + Raises: + AddressNotFoundException: If address not found or doesn't belong to customer + """ + address = ( + db.query(CustomerAddress) + .filter( + CustomerAddress.id == address_id, + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + ) + .first() + ) + + if not address: + raise AddressNotFoundException(address_id) + + return address + + def get_default_address( + self, db: Session, vendor_id: int, customer_id: int, address_type: str + ) -> CustomerAddress | None: + """ + Get the default address for a specific type. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_type: 'shipping' or 'billing' + + Returns: + Default address or None if not set + """ + return ( + db.query(CustomerAddress) + .filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + CustomerAddress.address_type == address_type, + CustomerAddress.is_default == True, # noqa: E712 + ) + .first() + ) + + def create_address( + self, + db: Session, + vendor_id: int, + customer_id: int, + address_data: CustomerAddressCreate, + ) -> CustomerAddress: + """ + Create a new address for a customer. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_data: Address creation data + + Returns: + Created customer address + + Raises: + AddressLimitExceededException: If max addresses reached + """ + # Check address limit + current_count = ( + db.query(CustomerAddress) + .filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + ) + .count() + ) + + if current_count >= self.MAX_ADDRESSES_PER_CUSTOMER: + raise AddressLimitExceededException(self.MAX_ADDRESSES_PER_CUSTOMER) + + # If setting as default, clear other defaults of same type + if address_data.is_default: + self._clear_other_defaults( + db, vendor_id, customer_id, address_data.address_type + ) + + # Create the address + address = CustomerAddress( + vendor_id=vendor_id, + customer_id=customer_id, + address_type=address_data.address_type, + first_name=address_data.first_name, + last_name=address_data.last_name, + company=address_data.company, + address_line_1=address_data.address_line_1, + address_line_2=address_data.address_line_2, + city=address_data.city, + postal_code=address_data.postal_code, + country_name=address_data.country_name, + country_iso=address_data.country_iso, + is_default=address_data.is_default, + ) + + db.add(address) + db.flush() + + logger.info( + f"Created address {address.id} for customer {customer_id} " + f"(type={address_data.address_type}, default={address_data.is_default})" + ) + + return address + + def update_address( + self, + db: Session, + vendor_id: int, + customer_id: int, + address_id: int, + address_data: CustomerAddressUpdate, + ) -> CustomerAddress: + """ + Update an existing address. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + address_data: Address update data + + Returns: + Updated customer address + + Raises: + AddressNotFoundException: If address not found + """ + address = self.get_address(db, vendor_id, customer_id, address_id) + + # Update only provided fields + update_data = address_data.model_dump(exclude_unset=True) + + # Handle default flag - clear others if setting to default + if update_data.get("is_default") is True: + # Use updated type if provided, otherwise current type + address_type = update_data.get("address_type", address.address_type) + self._clear_other_defaults( + db, vendor_id, customer_id, address_type, exclude_id=address_id + ) + + for field, value in update_data.items(): + setattr(address, field, value) + + db.flush() + + logger.info(f"Updated address {address_id} for customer {customer_id}") + + return address + + def delete_address( + self, db: Session, vendor_id: int, customer_id: int, address_id: int + ) -> None: + """ + Delete an address. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + + Raises: + AddressNotFoundException: If address not found + """ + address = self.get_address(db, vendor_id, customer_id, address_id) + + db.delete(address) + db.flush() + + logger.info(f"Deleted address {address_id} for customer {customer_id}") + + def set_default( + self, db: Session, vendor_id: int, customer_id: int, address_id: int + ) -> CustomerAddress: + """ + Set an address as the default for its type. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + + Returns: + Updated customer address + + Raises: + AddressNotFoundException: If address not found + """ + address = self.get_address(db, vendor_id, customer_id, address_id) + + # Clear other defaults of same type + self._clear_other_defaults( + db, vendor_id, customer_id, address.address_type, exclude_id=address_id + ) + + # Set this one as default + address.is_default = True + db.flush() + + logger.info( + f"Set address {address_id} as default {address.address_type} " + f"for customer {customer_id}" + ) + + return address + + def _clear_other_defaults( + self, + db: Session, + vendor_id: int, + customer_id: int, + address_type: str, + exclude_id: int | None = None, + ) -> None: + """ + Clear the default flag on other addresses of the same type. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_type: 'shipping' or 'billing' + exclude_id: Address ID to exclude from clearing + """ + query = db.query(CustomerAddress).filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + CustomerAddress.address_type == address_type, + CustomerAddress.is_default == True, # noqa: E712 + ) + + if exclude_id: + query = query.filter(CustomerAddress.id != exclude_id) + + query.update({"is_default": False}, synchronize_session=False) + + +# Singleton instance +customer_address_service = CustomerAddressService() diff --git a/app/templates/shop/account/addresses.html b/app/templates/shop/account/addresses.html index 6c4d4326..4f060592 100644 --- a/app/templates/shop/account/addresses.html +++ b/app/templates/shop/account/addresses.html @@ -1,15 +1,562 @@ {# app/templates/shop/account/addresses.html #} {% extends "shop/base.html" %} -{% block title %}My Addresses{% endblock %} +{% block title %}My Addresses - {{ vendor.name }}{% endblock %} + +{% block alpine_data %}addressesPage(){% endblock %} {% block content %}
-

My Addresses

+ +
+
+

My Addresses

+

Manage your shipping and billing addresses

+
+ +
- {# TODO: Implement address management #} -
-

Address management coming soon...

+ +
+ + + + +
+ + +
+
+ + + +

+
+
+ + +
+ + + + +

No addresses yet

+

Add your first address to speed up checkout.

+ +
+ + +
+ +
+
+ + + + + + {% endblock %} + +{% block extra_scripts %} + +{% endblock %} diff --git a/app/templates/shop/checkout.html b/app/templates/shop/checkout.html index 41320bff..ac8577d2 100644 --- a/app/templates/shop/checkout.html +++ b/app/templates/shop/checkout.html @@ -1,15 +1,926 @@ {# app/templates/shop/checkout.html #} {% extends "shop/base.html" %} -{% block title %}Checkout{% endblock %} +{% block title %}Checkout - {{ vendor.name }}{% endblock %} + +{% block alpine_data %}checkoutPage(){% endblock %} {% block content %}
-

Checkout

- {# TODO: Implement checkout process #} -
-

Checkout process coming soon...

+ {# Breadcrumbs #} + + + {# Page Header #} +

Checkout

+ + {# Loading State #} +
+
+ + {# Empty Cart #} +
+ + + +

Your cart is empty

+

Add some products before checking out.

+ + Browse Products + +
+ + {# Checkout Form #} +
+
+ + {# Left Column - Forms #} +
+ + {# Step Indicator #} +
+
+
1
+ Information +
+
+
+
2
+ Shipping +
+
+
+
3
+ Review +
+
+ + {# Error Message #} +
+
+ + + +

+
+
+ + {# Step 1: Contact & Shipping Address #} +
+ + {# Contact Information #} +
+

Contact Information

+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+
+ + {# Shipping Address #} +
+

Shipping Address

+ + {# Saved Addresses Selector (only shown for logged in customers) #} +
+ + +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ + {# Save Address Checkbox (only for new addresses when logged in) #} +
+ +
+
+
+ + {# Billing Address #} +
+
+

Billing Address

+ +
+ + {# Saved Addresses Selector (only shown for logged in customers when not same as shipping) #} +
+ + +
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ + {# Save Address Checkbox (only for new addresses when logged in) #} +
+ +
+
+
+ +
+ +
+
+ + {# Step 2: Shipping Method #} +
+
+

Shipping Method

+ +
+ + + +
+
+ + {# Order Notes #} +
+

Order Notes (Optional)

+ +
+ +
+ + +
+
+ + {# Step 3: Review & Place Order #} +
+ {# Review Contact Info #} +
+
+

Contact Information

+ +
+

+

+

+
+ + {# Review Addresses #} +
+
+
+

Shipping Address

+ +
+
+

+

+

+

+

+

+
+
+ +
+
+

Billing Address

+ +
+
+ + +
+
+
+ + {# Review Shipping #} +
+
+

Shipping Method

+ +
+

+
+ + {# Order Items Review #} +
+

Order Items

+
+ +
+
+ +
+ + +
+
+ +
+ + {# Right Column - Order Summary #} +
+
+

+ Order Summary +

+ + {# Cart Items Preview #} +
+ +
+ + {# Totals #} +
+
+ Subtotal + +
+
+ Shipping + +
+
+ Tax (incl.) + +
+
+ Total + +
+
+ +

+ Free shipping on orders over €50 +

+
+
+ +
+
+
{% endblock %} + +{% block extra_scripts %} + +{% endblock %} diff --git a/models/database/customer.py b/models/database/customer.py index 1f4e24a2..4583ddd4 100644 --- a/models/database/customer.py +++ b/models/database/customer.py @@ -70,7 +70,8 @@ class CustomerAddress(Base, TimestampMixin): address_line_2 = Column(String(255)) city = Column(String(100), nullable=False) postal_code = Column(String(20), nullable=False) - country = Column(String(100), nullable=False) + country_name = Column(String(100), nullable=False) + country_iso = Column(String(5), nullable=False) is_default = Column(Boolean, default=False) # Relationships diff --git a/models/schema/customer.py b/models/schema/customer.py index 018b308b..d326edc2 100644 --- a/models/schema/customer.py +++ b/models/schema/customer.py @@ -126,7 +126,8 @@ class CustomerAddressCreate(BaseModel): address_line_2: str | None = Field(None, max_length=255) city: str = Field(..., min_length=1, max_length=100) postal_code: str = Field(..., min_length=1, max_length=20) - country: str = Field(..., min_length=2, max_length=100) + country_name: str = Field(..., min_length=2, max_length=100) + country_iso: str = Field(..., min_length=2, max_length=5) is_default: bool = Field(default=False) @@ -141,7 +142,8 @@ class CustomerAddressUpdate(BaseModel): address_line_2: str | None = Field(None, max_length=255) city: str | None = Field(None, min_length=1, max_length=100) postal_code: str | None = Field(None, min_length=1, max_length=20) - country: str | None = Field(None, min_length=2, max_length=100) + country_name: str | None = Field(None, min_length=2, max_length=100) + country_iso: str | None = Field(None, min_length=2, max_length=5) is_default: bool | None = None @@ -159,7 +161,8 @@ class CustomerAddressResponse(BaseModel): address_line_2: str | None city: str postal_code: str - country: str + country_name: str + country_iso: str is_default: bool created_at: datetime updated_at: datetime @@ -167,6 +170,13 @@ class CustomerAddressResponse(BaseModel): model_config = {"from_attributes": True} +class CustomerAddressListResponse(BaseModel): + """Schema for customer address list response.""" + + addresses: list[CustomerAddressResponse] + total: int + + # ============================================================================ # Customer Preferences # ============================================================================ diff --git a/tests/fixtures/customer_fixtures.py b/tests/fixtures/customer_fixtures.py index 96ba92b2..cfa17f51 100644 --- a/tests/fixtures/customer_fixtures.py +++ b/tests/fixtures/customer_fixtures.py @@ -32,7 +32,7 @@ def test_customer(db, test_vendor): @pytest.fixture def test_customer_address(db, test_vendor, test_customer): - """Create a test customer address.""" + """Create a test customer shipping address.""" address = CustomerAddress( vendor_id=test_vendor.id, customer_id=test_customer.id, @@ -42,7 +42,8 @@ def test_customer_address(db, test_vendor, test_customer): address_line_1="123 Main St", city="Luxembourg", postal_code="L-1234", - country="Luxembourg", + country_name="Luxembourg", + country_iso="LU", is_default=True, ) db.add(address) @@ -51,6 +52,55 @@ def test_customer_address(db, test_vendor, test_customer): return address +@pytest.fixture +def test_customer_billing_address(db, test_vendor, test_customer): + """Create a test customer billing address.""" + address = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_type="billing", + first_name="John", + last_name="Doe", + company="Test Company S.A.", + address_line_1="456 Business Ave", + city="Luxembourg", + postal_code="L-5678", + country_name="Luxembourg", + country_iso="LU", + is_default=True, + ) + db.add(address) + db.commit() + db.refresh(address) + return address + + +@pytest.fixture +def test_customer_multiple_addresses(db, test_vendor, test_customer): + """Create multiple addresses for testing limits and listing.""" + addresses = [] + for i in range(3): + address = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_type="shipping" if i % 2 == 0 else "billing", + first_name=f"Name{i}", + last_name="Test", + address_line_1=f"{i}00 Test Street", + city="Luxembourg", + postal_code=f"L-{1000+i}", + country_name="Luxembourg", + country_iso="LU", + is_default=(i == 0), + ) + db.add(address) + addresses.append(address) + db.commit() + for addr in addresses: + db.refresh(addr) + return addresses + + @pytest.fixture def test_order(db, test_vendor, test_customer, test_customer_address): """Create a test order with customer/address snapshots.""" diff --git a/tests/integration/api/v1/shop/__init__.py b/tests/integration/api/v1/shop/__init__.py new file mode 100644 index 00000000..2b82f539 --- /dev/null +++ b/tests/integration/api/v1/shop/__init__.py @@ -0,0 +1 @@ +# Shop API integration tests diff --git a/tests/integration/api/v1/shop/test_addresses.py b/tests/integration/api/v1/shop/test_addresses.py new file mode 100644 index 00000000..e406436f --- /dev/null +++ b/tests/integration/api/v1/shop/test_addresses.py @@ -0,0 +1,621 @@ +# tests/integration/api/v1/shop/test_addresses.py +"""Integration tests for shop addresses API endpoints. + +Tests the /api/v1/shop/addresses/* endpoints. +All endpoints require customer JWT authentication with vendor context. +""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import pytest +from jose import jwt + +from models.database.customer import Customer, CustomerAddress + + +@pytest.fixture +def shop_customer(db, test_vendor): + """Create a test customer for shop API tests.""" + from middleware.auth import AuthManager + + auth_manager = AuthManager() + customer = Customer( + vendor_id=test_vendor.id, + email="shopcustomer@example.com", + hashed_password=auth_manager.hash_password("testpass123"), + first_name="Shop", + last_name="Customer", + customer_number="SHOP001", + is_active=True, + ) + db.add(customer) + db.commit() + db.refresh(customer) + return customer + + +@pytest.fixture +def shop_customer_token(shop_customer, test_vendor): + """Create JWT token for shop customer.""" + from middleware.auth import AuthManager + + auth_manager = AuthManager() + + expires_delta = timedelta(minutes=auth_manager.token_expire_minutes) + expire = datetime.now(UTC) + expires_delta + + payload = { + "sub": str(shop_customer.id), + "email": shop_customer.email, + "vendor_id": test_vendor.id, + "type": "customer", + "exp": expire, + "iat": datetime.now(UTC), + } + + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) + return token + + +@pytest.fixture +def shop_customer_headers(shop_customer_token): + """Get authentication headers for shop customer.""" + return {"Authorization": f"Bearer {shop_customer_token}"} + + +@pytest.fixture +def customer_address(db, test_vendor, shop_customer): + """Create a test address for shop customer.""" + address = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=shop_customer.id, + address_type="shipping", + first_name="Ship", + last_name="Address", + address_line_1="123 Shipping St", + city="Luxembourg", + postal_code="L-1234", + country_name="Luxembourg", + country_iso="LU", + is_default=True, + ) + db.add(address) + db.commit() + db.refresh(address) + return address + + +@pytest.fixture +def customer_billing_address(db, test_vendor, shop_customer): + """Create a billing address for shop customer.""" + address = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=shop_customer.id, + address_type="billing", + first_name="Bill", + last_name="Address", + company="Test Company", + address_line_1="456 Billing Ave", + city="Esch-sur-Alzette", + postal_code="L-5678", + country_name="Luxembourg", + country_iso="LU", + is_default=True, + ) + db.add(address) + db.commit() + db.refresh(address) + return address + + +@pytest.fixture +def other_customer(db, test_vendor): + """Create another customer for testing access controls.""" + from middleware.auth import AuthManager + + auth_manager = AuthManager() + customer = Customer( + vendor_id=test_vendor.id, + email="othercustomer@example.com", + hashed_password=auth_manager.hash_password("otherpass123"), + first_name="Other", + last_name="Customer", + customer_number="OTHER001", + is_active=True, + ) + db.add(customer) + db.commit() + db.refresh(customer) + return customer + + +@pytest.fixture +def other_customer_address(db, test_vendor, other_customer): + """Create an address for another customer.""" + address = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=other_customer.id, + address_type="shipping", + first_name="Other", + last_name="Address", + address_line_1="999 Other St", + city="Differdange", + postal_code="L-9999", + country_name="Luxembourg", + country_iso="LU", + is_default=True, + ) + db.add(address) + db.commit() + db.refresh(address) + return address + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopAddressesListAPI: + """Test shop addresses list endpoint at /api/v1/shop/addresses.""" + + def test_list_addresses_requires_authentication(self, client, test_vendor): + """Test that listing addresses requires authentication.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + response = client.get("/api/v1/shop/addresses") + assert response.status_code in [401, 403] + + def test_list_addresses_success( + self, + client, + shop_customer_headers, + customer_address, + test_vendor, + shop_customer, + ): + """Test listing customer addresses successfully.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/addresses", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert "addresses" in data + assert "total" in data + assert data["total"] == 1 + assert data["addresses"][0]["first_name"] == "Ship" + + def test_list_addresses_empty( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test listing addresses when customer has none.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/addresses", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 0 + assert data["addresses"] == [] + + def test_list_addresses_multiple_types( + self, + client, + shop_customer_headers, + customer_address, + customer_billing_address, + test_vendor, + shop_customer, + ): + """Test listing addresses includes both shipping and billing.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/addresses", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 2 + + types = {addr["address_type"] for addr in data["addresses"]} + assert "shipping" in types + assert "billing" in types + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopAddressDetailAPI: + """Test shop address detail endpoint at /api/v1/shop/addresses/{address_id}.""" + + def test_get_address_success( + self, + client, + shop_customer_headers, + customer_address, + test_vendor, + shop_customer, + ): + """Test getting address details successfully.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + f"/api/v1/shop/addresses/{customer_address.id}", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["id"] == customer_address.id + assert data["first_name"] == "Ship" + assert data["country_iso"] == "LU" + assert data["country_name"] == "Luxembourg" + + def test_get_address_not_found( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test getting non-existent address returns 404.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/addresses/99999", + headers=shop_customer_headers, + ) + + assert response.status_code == 404 + + def test_get_address_other_customer( + self, + client, + shop_customer_headers, + other_customer_address, + test_vendor, + shop_customer, + ): + """Test cannot access another customer's address.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + f"/api/v1/shop/addresses/{other_customer_address.id}", + headers=shop_customer_headers, + ) + + # Should return 404 to prevent enumeration + assert response.status_code == 404 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopAddressCreateAPI: + """Test shop address creation at POST /api/v1/shop/addresses.""" + + def test_create_address_success( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test creating a new address.""" + address_data = { + "address_type": "shipping", + "first_name": "New", + "last_name": "Address", + "address_line_1": "789 New St", + "city": "Luxembourg", + "postal_code": "L-1111", + "country_name": "Luxembourg", + "country_iso": "LU", + "is_default": False, + } + + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.post( + "/api/v1/shop/addresses", + headers=shop_customer_headers, + json=address_data, + ) + + assert response.status_code == 201 + data = response.json() + assert data["first_name"] == "New" + assert data["last_name"] == "Address" + assert data["country_iso"] == "LU" + assert "id" in data + + def test_create_address_with_company( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test creating address with company name.""" + address_data = { + "address_type": "billing", + "first_name": "Business", + "last_name": "Address", + "company": "Acme Corp", + "address_line_1": "100 Business Park", + "city": "Luxembourg", + "postal_code": "L-2222", + "country_name": "Luxembourg", + "country_iso": "LU", + "is_default": True, + } + + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.post( + "/api/v1/shop/addresses", + headers=shop_customer_headers, + json=address_data, + ) + + assert response.status_code == 201 + data = response.json() + assert data["company"] == "Acme Corp" + + def test_create_address_validation_error( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test validation error for missing required fields.""" + address_data = { + "address_type": "shipping", + "first_name": "Test", + # Missing required fields + } + + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.post( + "/api/v1/shop/addresses", + headers=shop_customer_headers, + json=address_data, + ) + + assert response.status_code == 422 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopAddressUpdateAPI: + """Test shop address update at PUT /api/v1/shop/addresses/{address_id}.""" + + def test_update_address_success( + self, + client, + shop_customer_headers, + customer_address, + test_vendor, + shop_customer, + ): + """Test updating an address.""" + update_data = { + "first_name": "Updated", + "city": "Esch-sur-Alzette", + } + + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.put( + f"/api/v1/shop/addresses/{customer_address.id}", + headers=shop_customer_headers, + json=update_data, + ) + + assert response.status_code == 200 + data = response.json() + assert data["first_name"] == "Updated" + assert data["city"] == "Esch-sur-Alzette" + + def test_update_address_not_found( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test updating non-existent address returns 404.""" + update_data = {"first_name": "Test"} + + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.put( + "/api/v1/shop/addresses/99999", + headers=shop_customer_headers, + json=update_data, + ) + + assert response.status_code == 404 + + def test_update_address_other_customer( + self, + client, + shop_customer_headers, + other_customer_address, + test_vendor, + shop_customer, + ): + """Test cannot update another customer's address.""" + update_data = {"first_name": "Hacked"} + + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.put( + f"/api/v1/shop/addresses/{other_customer_address.id}", + headers=shop_customer_headers, + json=update_data, + ) + + assert response.status_code == 404 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopAddressDeleteAPI: + """Test shop address deletion at DELETE /api/v1/shop/addresses/{address_id}.""" + + def test_delete_address_success( + self, + client, + shop_customer_headers, + customer_address, + test_vendor, + shop_customer, + ): + """Test deleting an address.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.delete( + f"/api/v1/shop/addresses/{customer_address.id}", + headers=shop_customer_headers, + ) + + assert response.status_code == 204 + + def test_delete_address_not_found( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test deleting non-existent address returns 404.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.delete( + "/api/v1/shop/addresses/99999", + headers=shop_customer_headers, + ) + + assert response.status_code == 404 + + def test_delete_address_other_customer( + self, + client, + shop_customer_headers, + other_customer_address, + test_vendor, + shop_customer, + ): + """Test cannot delete another customer's address.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.delete( + f"/api/v1/shop/addresses/{other_customer_address.id}", + headers=shop_customer_headers, + ) + + assert response.status_code == 404 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopAddressSetDefaultAPI: + """Test set address as default at PUT /api/v1/shop/addresses/{address_id}/default.""" + + def test_set_default_success( + self, + client, + shop_customer_headers, + customer_address, + test_vendor, + shop_customer, + db, + ): + """Test setting address as default.""" + # Create a second non-default address + second_address = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=shop_customer.id, + address_type="shipping", + first_name="Second", + last_name="Address", + address_line_1="222 Second St", + city="Dudelange", + postal_code="L-3333", + country_name="Luxembourg", + country_iso="LU", + is_default=False, + ) + db.add(second_address) + db.commit() + db.refresh(second_address) + + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.put( + f"/api/v1/shop/addresses/{second_address.id}/default", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["is_default"] is True + + def test_set_default_not_found( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test setting default on non-existent address returns 404.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.put( + "/api/v1/shop/addresses/99999/default", + headers=shop_customer_headers, + ) + + assert response.status_code == 404 + + def test_set_default_other_customer( + self, + client, + shop_customer_headers, + other_customer_address, + test_vendor, + shop_customer, + ): + """Test cannot set default on another customer's address.""" + with patch("app.api.v1.shop.addresses.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.put( + f"/api/v1/shop/addresses/{other_customer_address.id}/default", + headers=shop_customer_headers, + ) + + assert response.status_code == 404 diff --git a/tests/integration/api/v1/shop/test_orders.py b/tests/integration/api/v1/shop/test_orders.py new file mode 100644 index 00000000..3b680817 --- /dev/null +++ b/tests/integration/api/v1/shop/test_orders.py @@ -0,0 +1,557 @@ +# tests/integration/api/v1/shop/test_orders.py +"""Integration tests for shop orders API endpoints. + +Tests the /api/v1/shop/orders/* endpoints. +All endpoints require customer JWT authentication with vendor context. +""" + +from datetime import UTC, datetime, timedelta +from decimal import Decimal +from unittest.mock import patch, MagicMock + +import pytest +from jose import jwt + +from models.database.customer import Customer +from models.database.invoice import Invoice, InvoiceStatus, VendorInvoiceSettings +from models.database.order import Order, OrderItem + + +@pytest.fixture +def shop_customer(db, test_vendor): + """Create a test customer for shop API tests.""" + from middleware.auth import AuthManager + auth_manager = AuthManager() + customer = Customer( + vendor_id=test_vendor.id, + email="shopcustomer@example.com", + hashed_password=auth_manager.hash_password("testpass123"), + first_name="Shop", + last_name="Customer", + customer_number="SHOP001", + is_active=True, + ) + db.add(customer) + db.commit() + db.refresh(customer) + return customer + + +@pytest.fixture +def shop_customer_token(shop_customer, test_vendor): + """Create JWT token for shop customer.""" + from middleware.auth import AuthManager + auth_manager = AuthManager() + + expires_delta = timedelta(minutes=auth_manager.token_expire_minutes) + expire = datetime.now(UTC) + expires_delta + + payload = { + "sub": str(shop_customer.id), + "email": shop_customer.email, + "vendor_id": test_vendor.id, + "type": "customer", + "exp": expire, + "iat": datetime.now(UTC), + } + + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) + return token + + +@pytest.fixture +def shop_customer_headers(shop_customer_token): + """Get authentication headers for shop customer.""" + return {"Authorization": f"Bearer {shop_customer_token}"} + + +@pytest.fixture +def shop_order(db, test_vendor, shop_customer): + """Create a test order for shop customer.""" + order = Order( + vendor_id=test_vendor.id, + customer_id=shop_customer.id, + order_number="SHOP-ORD-001", + status="pending", + channel="direct", + order_date=datetime.now(UTC), + subtotal_cents=10000, + tax_amount_cents=1700, + shipping_amount_cents=500, + total_amount_cents=12200, + currency="EUR", + customer_email=shop_customer.email, + customer_first_name=shop_customer.first_name, + customer_last_name=shop_customer.last_name, + ship_first_name=shop_customer.first_name, + ship_last_name=shop_customer.last_name, + ship_address_line_1="123 Shop St", + ship_city="Luxembourg", + ship_postal_code="L-1234", + ship_country_iso="LU", + bill_first_name=shop_customer.first_name, + bill_last_name=shop_customer.last_name, + bill_address_line_1="123 Shop St", + bill_city="Luxembourg", + bill_postal_code="L-1234", + bill_country_iso="LU", + # VAT fields + vat_regime="domestic", + vat_rate=Decimal("17.00"), + vat_rate_label="Luxembourg VAT 17.00%", + vat_destination_country=None, + ) + db.add(order) + db.commit() + db.refresh(order) + return order + + +@pytest.fixture +def shop_order_processing(db, test_vendor, shop_customer): + """Create a test order with processing status (eligible for invoice).""" + order = Order( + vendor_id=test_vendor.id, + customer_id=shop_customer.id, + order_number="SHOP-ORD-002", + status="processing", + channel="direct", + order_date=datetime.now(UTC), + subtotal_cents=20000, + tax_amount_cents=3400, + shipping_amount_cents=500, + total_amount_cents=23900, + currency="EUR", + customer_email=shop_customer.email, + customer_first_name=shop_customer.first_name, + customer_last_name=shop_customer.last_name, + ship_first_name=shop_customer.first_name, + ship_last_name=shop_customer.last_name, + ship_address_line_1="456 Shop Ave", + ship_city="Luxembourg", + ship_postal_code="L-5678", + ship_country_iso="LU", + bill_first_name=shop_customer.first_name, + bill_last_name=shop_customer.last_name, + bill_address_line_1="456 Shop Ave", + bill_city="Luxembourg", + bill_postal_code="L-5678", + bill_country_iso="LU", + # VAT fields + vat_regime="domestic", + vat_rate=Decimal("17.00"), + vat_rate_label="Luxembourg VAT 17.00%", + ) + db.add(order) + db.flush() + + # Add order item + item = OrderItem( + order_id=order.id, + product_id=1, + product_sku="TEST-SKU-001", + product_name="Test Product", + quantity=2, + unit_price_cents=10000, + total_price_cents=20000, + ) + db.add(item) + db.commit() + db.refresh(order) + return order + + +@pytest.fixture +def shop_invoice_settings(db, test_vendor): + """Create invoice settings for the vendor.""" + settings = VendorInvoiceSettings( + vendor_id=test_vendor.id, + company_name="Shop Test Company S.A.", + company_address="123 Business St", + company_city="Luxembourg", + company_postal_code="L-1234", + company_country="LU", + vat_number="LU12345678", + invoice_prefix="INV", + invoice_next_number=1, + default_vat_rate=Decimal("17.00"), + ) + db.add(settings) + db.commit() + db.refresh(settings) + return settings + + +@pytest.fixture +def shop_order_with_invoice(db, test_vendor, shop_customer, shop_invoice_settings): + """Create an order with an existing invoice.""" + order = Order( + vendor_id=test_vendor.id, + customer_id=shop_customer.id, + order_number="SHOP-ORD-003", + status="shipped", + channel="direct", + order_date=datetime.now(UTC), + subtotal_cents=15000, + tax_amount_cents=2550, + shipping_amount_cents=500, + total_amount_cents=18050, + currency="EUR", + customer_email=shop_customer.email, + customer_first_name=shop_customer.first_name, + customer_last_name=shop_customer.last_name, + ship_first_name=shop_customer.first_name, + ship_last_name=shop_customer.last_name, + ship_address_line_1="789 Shop Blvd", + ship_city="Luxembourg", + ship_postal_code="L-9999", + ship_country_iso="LU", + bill_first_name=shop_customer.first_name, + bill_last_name=shop_customer.last_name, + bill_address_line_1="789 Shop Blvd", + bill_city="Luxembourg", + bill_postal_code="L-9999", + bill_country_iso="LU", + vat_regime="domestic", + vat_rate=Decimal("17.00"), + ) + db.add(order) + db.flush() + + # Create invoice for this order + invoice = Invoice( + vendor_id=test_vendor.id, + order_id=order.id, + invoice_number="INV00001", + invoice_date=datetime.now(UTC), + status=InvoiceStatus.ISSUED.value, + seller_details={"company_name": "Shop Test Company S.A."}, + buyer_details={"name": f"{shop_customer.first_name} {shop_customer.last_name}"}, + line_items=[], + vat_rate=Decimal("17.00"), + subtotal_cents=15000, + vat_amount_cents=2550, + total_cents=18050, + ) + db.add(invoice) + db.commit() + db.refresh(order) + db.refresh(invoice) + return order, invoice + + +@pytest.fixture +def other_customer(db, test_vendor): + """Create another customer for testing access controls.""" + from middleware.auth import AuthManager + auth_manager = AuthManager() + customer = Customer( + vendor_id=test_vendor.id, + email="othercustomer@example.com", + hashed_password=auth_manager.hash_password("otherpass123"), + first_name="Other", + last_name="Customer", + customer_number="OTHER001", + is_active=True, + ) + db.add(customer) + db.commit() + db.refresh(customer) + return customer + + +@pytest.fixture +def other_customer_order(db, test_vendor, other_customer): + """Create an order for another customer.""" + order = Order( + vendor_id=test_vendor.id, + customer_id=other_customer.id, + order_number="OTHER-ORD-001", + status="processing", + channel="direct", + order_date=datetime.now(UTC), + subtotal_cents=5000, + tax_amount_cents=850, + total_amount_cents=5850, + currency="EUR", + customer_email=other_customer.email, + customer_first_name=other_customer.first_name, + customer_last_name=other_customer.last_name, + ship_first_name=other_customer.first_name, + ship_last_name=other_customer.last_name, + ship_address_line_1="Other St", + ship_city="Other City", + ship_postal_code="00000", + ship_country_iso="LU", + bill_first_name=other_customer.first_name, + bill_last_name=other_customer.last_name, + bill_address_line_1="Other St", + bill_city="Other City", + bill_postal_code="00000", + bill_country_iso="LU", + ) + db.add(order) + db.commit() + db.refresh(order) + return order + + +# Note: Shop API endpoints require vendor context from VendorContextMiddleware. +# In integration tests, we mock the middleware to inject the vendor. + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopOrdersListAPI: + """Test shop orders list endpoint at /api/v1/shop/orders.""" + + def test_list_orders_requires_authentication(self, client, test_vendor): + """Test that listing orders requires authentication.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + response = client.get("/api/v1/shop/orders") + # Without token, should get 401 or 403 + assert response.status_code in [401, 403] + + def test_list_orders_success( + self, client, shop_customer_headers, shop_order, test_vendor, shop_customer + ): + """Test listing customer orders successfully.""" + # Mock vendor context and customer auth + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + # Mock the dependency to return our customer + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/orders", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert "orders" in data + assert "total" in data + + def test_list_orders_empty( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test listing orders when customer has none.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/orders", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 0 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopOrderDetailAPI: + """Test shop order detail endpoint at /api/v1/shop/orders/{order_id}.""" + + def test_get_order_detail_success( + self, client, shop_customer_headers, shop_order, test_vendor, shop_customer + ): + """Test getting order details successfully.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + f"/api/v1/shop/orders/{shop_order.id}", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["order_number"] == "SHOP-ORD-001" + assert data["status"] == "pending" + # Check VAT fields are present + assert "vat_regime" in data + assert "vat_rate" in data + + def test_get_order_detail_not_found( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test getting non-existent order returns 404.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/orders/99999", + headers=shop_customer_headers, + ) + + assert response.status_code == 404 + + def test_get_order_detail_other_customer( + self, + client, + shop_customer_headers, + other_customer_order, + test_vendor, + shop_customer, + ): + """Test cannot access another customer's order.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + f"/api/v1/shop/orders/{other_customer_order.id}", + headers=shop_customer_headers, + ) + + # Should return 404 (not 403) to prevent enumeration + assert response.status_code == 404 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopOrderInvoiceDownloadAPI: + """Test shop order invoice download at /api/v1/shop/orders/{order_id}/invoice.""" + + def test_download_invoice_pending_order_rejected( + self, client, shop_customer_headers, shop_order, test_vendor, shop_customer + ): + """Test cannot download invoice for pending orders.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + f"/api/v1/shop/orders/{shop_order.id}/invoice", + headers=shop_customer_headers, + ) + + # Pending orders should not allow invoice download + assert response.status_code == 422 + + @pytest.mark.skip(reason="Requires PDF generation infrastructure") + def test_download_invoice_processing_order_creates_invoice( + self, + client, + shop_customer_headers, + shop_order_processing, + shop_invoice_settings, + test_vendor, + shop_customer, + ): + """Test downloading invoice for processing order creates it if needed.""" + # This test requires actual PDF generation which may not be available + # in all environments. The logic is tested via: + # 1. test_download_invoice_pending_order_rejected - validates status check + # 2. Direct service tests for invoice creation + pass + + @pytest.mark.skip(reason="Requires PDF generation infrastructure") + def test_download_invoice_existing_invoice( + self, + client, + shop_customer_headers, + shop_order_with_invoice, + test_vendor, + shop_customer, + ): + """Test downloading invoice when one already exists.""" + # This test requires PDF file to exist on disk + # The service layer handles invoice retrieval properly + pass + + def test_download_invoice_other_customer( + self, + client, + shop_customer_headers, + other_customer_order, + test_vendor, + shop_customer, + ): + """Test cannot download invoice for another customer's order.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + f"/api/v1/shop/orders/{other_customer_order.id}/invoice", + headers=shop_customer_headers, + ) + + # Should return 404 to prevent enumeration + assert response.status_code == 404 + + def test_download_invoice_not_found( + self, client, shop_customer_headers, test_vendor, shop_customer + ): + """Test downloading invoice for non-existent order.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/orders/99999/invoice", + headers=shop_customer_headers, + ) + + assert response.status_code == 404 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.shop +class TestShopOrderVATFields: + """Test VAT fields in order responses.""" + + def test_order_includes_vat_fields( + self, client, shop_customer_headers, shop_order, test_vendor, shop_customer + ): + """Test order response includes VAT fields.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + f"/api/v1/shop/orders/{shop_order.id}", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + + # Verify VAT fields + assert data.get("vat_regime") == "domestic" + assert data.get("vat_rate") == 17.0 + assert "Luxembourg VAT" in (data.get("vat_rate_label") or "") + + def test_order_list_includes_vat_fields( + self, client, shop_customer_headers, shop_order, test_vendor, shop_customer + ): + """Test order list includes VAT fields.""" + with patch("app.api.v1.shop.orders.getattr") as mock_getattr: + mock_getattr.return_value = test_vendor + with patch("app.api.deps._validate_customer_token") as mock_validate: + mock_validate.return_value = shop_customer + response = client.get( + "/api/v1/shop/orders", + headers=shop_customer_headers, + ) + + assert response.status_code == 200 + data = response.json() + + if data["orders"]: + order = data["orders"][0] + assert "vat_regime" in order + assert "vat_rate" in order diff --git a/tests/unit/services/test_customer_address_service.py b/tests/unit/services/test_customer_address_service.py new file mode 100644 index 00000000..935ea1fd --- /dev/null +++ b/tests/unit/services/test_customer_address_service.py @@ -0,0 +1,453 @@ +# tests/unit/services/test_customer_address_service.py +""" +Unit tests for CustomerAddressService. +""" + +import pytest + +from app.exceptions import AddressLimitExceededException, AddressNotFoundException +from app.services.customer_address_service import CustomerAddressService +from models.database.customer import CustomerAddress +from models.schema.customer import CustomerAddressCreate, CustomerAddressUpdate + + +@pytest.fixture +def address_service(): + """Create CustomerAddressService instance.""" + return CustomerAddressService() + + +@pytest.fixture +def multiple_addresses(db, test_vendor, test_customer): + """Create multiple addresses for testing.""" + addresses = [] + for i in range(3): + address = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_type="shipping" if i < 2 else "billing", + first_name=f"First{i}", + last_name=f"Last{i}", + address_line_1=f"{i+1} Test Street", + city="Luxembourg", + postal_code=f"L-{1000+i}", + country_name="Luxembourg", + country_iso="LU", + is_default=(i == 0), # First shipping is default + ) + db.add(address) + addresses.append(address) + + db.commit() + for a in addresses: + db.refresh(a) + + return addresses + + +@pytest.mark.unit +class TestCustomerAddressServiceList: + """Tests for list_addresses method.""" + + def test_list_addresses_empty(self, db, address_service, test_vendor, test_customer): + """Test listing addresses when none exist.""" + addresses = address_service.list_addresses( + db, vendor_id=test_vendor.id, customer_id=test_customer.id + ) + + assert addresses == [] + + def test_list_addresses_basic( + self, db, address_service, test_vendor, test_customer, test_customer_address + ): + """Test basic address listing.""" + addresses = address_service.list_addresses( + db, vendor_id=test_vendor.id, customer_id=test_customer.id + ) + + assert len(addresses) == 1 + assert addresses[0].id == test_customer_address.id + + def test_list_addresses_ordered_by_default( + self, db, address_service, test_vendor, test_customer, multiple_addresses + ): + """Test addresses are ordered by default flag first.""" + addresses = address_service.list_addresses( + db, vendor_id=test_vendor.id, customer_id=test_customer.id + ) + + # Default address should be first + assert addresses[0].is_default is True + + def test_list_addresses_vendor_isolation( + self, db, address_service, test_vendor, test_customer, test_customer_address + ): + """Test addresses are isolated by vendor.""" + # Query with different vendor ID + addresses = address_service.list_addresses( + db, vendor_id=99999, customer_id=test_customer.id + ) + + assert addresses == [] + + +@pytest.mark.unit +class TestCustomerAddressServiceGet: + """Tests for get_address method.""" + + def test_get_address_success( + self, db, address_service, test_vendor, test_customer, test_customer_address + ): + """Test getting address by ID.""" + address = address_service.get_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=test_customer_address.id, + ) + + assert address.id == test_customer_address.id + assert address.first_name == test_customer_address.first_name + + def test_get_address_not_found( + self, db, address_service, test_vendor, test_customer + ): + """Test error when address not found.""" + with pytest.raises(AddressNotFoundException): + address_service.get_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=99999, + ) + + def test_get_address_wrong_customer( + self, db, address_service, test_vendor, test_customer, test_customer_address + ): + """Test cannot get another customer's address.""" + with pytest.raises(AddressNotFoundException): + address_service.get_address( + db, + vendor_id=test_vendor.id, + customer_id=99999, # Different customer + address_id=test_customer_address.id, + ) + + +@pytest.mark.unit +class TestCustomerAddressServiceGetDefault: + """Tests for get_default_address method.""" + + def test_get_default_address_exists( + self, db, address_service, test_vendor, test_customer, multiple_addresses + ): + """Test getting default shipping address.""" + address = address_service.get_default_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_type="shipping", + ) + + assert address is not None + assert address.is_default is True + assert address.address_type == "shipping" + + def test_get_default_address_not_set( + self, db, address_service, test_vendor, test_customer, multiple_addresses + ): + """Test getting default billing when none is set.""" + # Remove default from billing (none was set as default) + address = address_service.get_default_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_type="billing", + ) + + # The billing address exists but is not default + assert address is None + + +@pytest.mark.unit +class TestCustomerAddressServiceCreate: + """Tests for create_address method.""" + + def test_create_address_success( + self, db, address_service, test_vendor, test_customer + ): + """Test creating a new address.""" + address_data = CustomerAddressCreate( + address_type="shipping", + first_name="John", + last_name="Doe", + address_line_1="123 New Street", + city="Luxembourg", + postal_code="L-1234", + country_name="Luxembourg", + country_iso="LU", + is_default=False, + ) + + address = address_service.create_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_data=address_data, + ) + db.commit() + + assert address.id is not None + assert address.first_name == "John" + assert address.last_name == "Doe" + assert address.country_iso == "LU" + assert address.country_name == "Luxembourg" + + def test_create_address_with_company( + self, db, address_service, test_vendor, test_customer + ): + """Test creating address with company name.""" + address_data = CustomerAddressCreate( + address_type="billing", + first_name="Jane", + last_name="Doe", + company="Acme Corp", + address_line_1="456 Business Ave", + city="Luxembourg", + postal_code="L-5678", + country_name="Luxembourg", + country_iso="LU", + is_default=False, + ) + + address = address_service.create_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_data=address_data, + ) + db.commit() + + assert address.company == "Acme Corp" + + def test_create_address_default_clears_others( + self, db, address_service, test_vendor, test_customer, multiple_addresses + ): + """Test creating default address clears other defaults of same type.""" + # First address is default shipping + assert multiple_addresses[0].is_default is True + + address_data = CustomerAddressCreate( + address_type="shipping", + first_name="New", + last_name="Default", + address_line_1="789 Main St", + city="Luxembourg", + postal_code="L-9999", + country_name="Luxembourg", + country_iso="LU", + is_default=True, + ) + + new_address = address_service.create_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_data=address_data, + ) + db.commit() + + # New address should be default + assert new_address.is_default is True + + # Old default should be cleared + db.refresh(multiple_addresses[0]) + assert multiple_addresses[0].is_default is False + + def test_create_address_limit_exceeded( + self, db, address_service, test_vendor, test_customer + ): + """Test error when max addresses reached.""" + # Create 10 addresses (max limit) + for i in range(10): + addr = CustomerAddress( + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_type="shipping", + first_name=f"Test{i}", + last_name="User", + address_line_1=f"{i} Street", + city="City", + postal_code="12345", + country_name="Luxembourg", + country_iso="LU", + ) + db.add(addr) + db.commit() + + # Try to create 11th address + address_data = CustomerAddressCreate( + address_type="shipping", + first_name="Eleventh", + last_name="User", + address_line_1="11 Street", + city="City", + postal_code="12345", + country_name="Luxembourg", + country_iso="LU", + is_default=False, + ) + + with pytest.raises(AddressLimitExceededException): + address_service.create_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_data=address_data, + ) + + +@pytest.mark.unit +class TestCustomerAddressServiceUpdate: + """Tests for update_address method.""" + + def test_update_address_success( + self, db, address_service, test_vendor, test_customer, test_customer_address + ): + """Test updating an address.""" + update_data = CustomerAddressUpdate( + first_name="Updated", + city="New City", + ) + + address = address_service.update_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=test_customer_address.id, + address_data=update_data, + ) + db.commit() + + assert address.first_name == "Updated" + assert address.city == "New City" + # Unchanged fields should remain + assert address.last_name == test_customer_address.last_name + + def test_update_address_set_default( + self, db, address_service, test_vendor, test_customer, multiple_addresses + ): + """Test setting address as default clears others.""" + # Second address is not default + assert multiple_addresses[1].is_default is False + + update_data = CustomerAddressUpdate(is_default=True) + + address = address_service.update_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=multiple_addresses[1].id, + address_data=update_data, + ) + db.commit() + + assert address.is_default is True + + # Old default should be cleared + db.refresh(multiple_addresses[0]) + assert multiple_addresses[0].is_default is False + + def test_update_address_not_found( + self, db, address_service, test_vendor, test_customer + ): + """Test error when address not found.""" + update_data = CustomerAddressUpdate(first_name="Test") + + with pytest.raises(AddressNotFoundException): + address_service.update_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=99999, + address_data=update_data, + ) + + +@pytest.mark.unit +class TestCustomerAddressServiceDelete: + """Tests for delete_address method.""" + + def test_delete_address_success( + self, db, address_service, test_vendor, test_customer, test_customer_address + ): + """Test deleting an address.""" + address_id = test_customer_address.id + + address_service.delete_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=address_id, + ) + db.commit() + + # Address should be gone + with pytest.raises(AddressNotFoundException): + address_service.get_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=address_id, + ) + + def test_delete_address_not_found( + self, db, address_service, test_vendor, test_customer + ): + """Test error when deleting non-existent address.""" + with pytest.raises(AddressNotFoundException): + address_service.delete_address( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=99999, + ) + + +@pytest.mark.unit +class TestCustomerAddressServiceSetDefault: + """Tests for set_default method.""" + + def test_set_default_success( + self, db, address_service, test_vendor, test_customer, multiple_addresses + ): + """Test setting address as default.""" + # Second shipping address is not default + assert multiple_addresses[1].is_default is False + assert multiple_addresses[1].address_type == "shipping" + + address = address_service.set_default( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=multiple_addresses[1].id, + ) + db.commit() + + assert address.is_default is True + + # Old default should be cleared + db.refresh(multiple_addresses[0]) + assert multiple_addresses[0].is_default is False + + def test_set_default_not_found( + self, db, address_service, test_vendor, test_customer + ): + """Test error when address not found.""" + with pytest.raises(AddressNotFoundException): + address_service.set_default( + db, + vendor_id=test_vendor.id, + customer_id=test_customer.id, + address_id=99999, + )