# app/modules/billing/routes/api/merchant.py """ Merchant billing API endpoints for the merchant portal. Provides subscription management and billing operations for merchant owners: - View subscriptions across all platforms - Subscription detail and tier info per platform - Stripe checkout session creation - Invoice history Authentication: Authorization header (API-only, no cookies for CSRF safety). The user must own at least one active merchant (validated by get_merchant_for_current_user). Auto-discovered by the route system (merchant.py in routes/api/ triggers registration under /api/v1/merchants/billing/*). """ import logging from fastapi import APIRouter, Depends, Path, Query, Request from sqlalchemy.orm import Session from app.api.deps import get_merchant_for_current_user from app.core.database import get_db from app.modules.billing.schemas import ( ChangeTierRequest, ChangeTierResponse, CheckoutRequest, CheckoutResponse, MerchantPortalAvailableTiersResponse, MerchantPortalInvoiceListResponse, MerchantPortalSubscriptionDetailResponse, MerchantPortalSubscriptionItem, MerchantPortalSubscriptionListResponse, MerchantSubscriptionResponse, TierInfo, ) from app.modules.billing.services.billing_service import billing_service from app.modules.billing.services.subscription_service import subscription_service logger = logging.getLogger(__name__) ROUTE_CONFIG = { "prefix": "/billing", } router = APIRouter() # ============================================================================ # Subscription Endpoints # ============================================================================ @router.get("/subscriptions", response_model=MerchantPortalSubscriptionListResponse) def list_merchant_subscriptions( request: Request, merchant=Depends(get_merchant_for_current_user), db: Session = Depends(get_db), ): """ List all subscriptions for the current merchant. Returns subscriptions across all platforms the merchant is subscribed to, including tier information and status. """ subscriptions = subscription_service.get_merchant_subscriptions(db, merchant.id) items = [] for sub in subscriptions: data = MerchantSubscriptionResponse.model_validate(sub).model_dump() data["tier"] = sub.tier.code if sub.tier else None data["tier_name"] = sub.tier.name if sub.tier else None data["platform_name"] = sub.platform.name if sub.platform else "" items.append(MerchantPortalSubscriptionItem(**data)) return MerchantPortalSubscriptionListResponse( subscriptions=items, total=len(items) ) @router.get( "/subscriptions/{platform_id}", response_model=MerchantPortalSubscriptionDetailResponse, ) def get_merchant_subscription( request: Request, platform_id: int = Path(..., description="Platform ID"), merchant=Depends(get_merchant_for_current_user), db: Session = Depends(get_db), ): """ Get subscription detail for a specific platform. Returns the subscription with tier information for the given platform. """ subscription = subscription_service.get_merchant_subscription( db, merchant.id, platform_id ) if not subscription: from app.exceptions.base import ResourceNotFoundException raise ResourceNotFoundException( resource_type="Subscription", identifier=f"merchant={merchant.id}, platform={platform_id}", error_code="SUBSCRIPTION_NOT_FOUND", ) sub_data = MerchantSubscriptionResponse.model_validate(subscription).model_dump() sub_data["tier"] = subscription.tier.code if subscription.tier else None sub_data["tier_name"] = subscription.tier.name if subscription.tier else None sub_data["platform_name"] = ( subscription.platform.name if subscription.platform else "" ) tier_info = None if subscription.tier: tier = subscription.tier tier_info = TierInfo( code=tier.code, name=tier.name, description=tier.description, price_monthly_cents=tier.price_monthly_cents, price_annual_cents=tier.price_annual_cents, feature_codes=( tier.get_feature_codes() if hasattr(tier, "get_feature_codes") else [] ), ) return MerchantPortalSubscriptionDetailResponse( subscription=MerchantPortalSubscriptionItem(**sub_data), tier=tier_info, ) @router.get( "/subscriptions/{platform_id}/tiers", response_model=MerchantPortalAvailableTiersResponse, ) def get_available_tiers( request: Request, platform_id: int = Path(..., description="Platform ID"), merchant=Depends(get_merchant_for_current_user), db: Session = Depends(get_db), ): """ Get available tiers for upgrade on a specific platform. Returns all public tiers with upgrade/downgrade flags relative to the merchant's current tier. """ subscription = subscription_service.get_merchant_subscription( db, merchant.id, platform_id ) current_tier_id = subscription.tier_id if subscription else None tier_list, tier_order = billing_service.get_available_tiers( db, current_tier_id, platform_id ) current_tier_code = None if subscription and subscription.tier: current_tier_code = subscription.tier.code return MerchantPortalAvailableTiersResponse( tiers=tier_list, current_tier=current_tier_code, ) @router.post( "/subscriptions/{platform_id}/change-tier", response_model=ChangeTierResponse, ) def change_subscription_tier( request: Request, tier_data: ChangeTierRequest, platform_id: int = Path(..., description="Platform ID"), merchant=Depends(get_merchant_for_current_user), db: Session = Depends(get_db), ): """ Change the subscription tier for a specific platform. Handles both Stripe-connected and non-Stripe subscriptions. """ result = billing_service.change_tier( db, merchant.id, platform_id, tier_data.tier_code, tier_data.is_annual ) db.commit() logger.info( f"Merchant {merchant.id} ({merchant.name}) changed tier to " f"{tier_data.tier_code} on platform={platform_id}" ) return result @router.post( "/subscriptions/{platform_id}/checkout", response_model=CheckoutResponse, ) def create_checkout_session( request: Request, checkout_data: CheckoutRequest, platform_id: int = Path(..., description="Platform ID"), merchant=Depends(get_merchant_for_current_user), db: Session = Depends(get_db), ): """ Create a Stripe checkout session for the merchant's subscription. Starts a new subscription or upgrades an existing one to the requested tier. """ # Build success/cancel URLs from request base_url = str(request.base_url).rstrip("/") success_url = ( f"{base_url}/merchants/billing/subscriptions/{platform_id}?checkout=success" ) cancel_url = ( f"{base_url}/merchants/billing/subscriptions/{platform_id}?checkout=cancelled" ) result = billing_service.create_checkout_session( db=db, merchant_id=merchant.id, platform_id=platform_id, tier_code=checkout_data.tier_code, is_annual=checkout_data.is_annual, success_url=success_url, cancel_url=cancel_url, ) db.commit() logger.info( f"Merchant {merchant.id} ({merchant.name}) created checkout session " f"for tier={checkout_data.tier_code} on platform={platform_id}" ) return CheckoutResponse( checkout_url=result["checkout_url"], session_id=result["session_id"], ) # ============================================================================ # Invoice Endpoints # ============================================================================ @router.get("/invoices", response_model=MerchantPortalInvoiceListResponse) def get_invoices( request: Request, skip: int = Query(0, ge=0, description="Number of records to skip"), limit: int = Query(20, ge=1, le=100, description="Max records to return"), merchant=Depends(get_merchant_for_current_user), db: Session = Depends(get_db), ): """ Get invoice history for the current merchant. Returns paginated billing history entries ordered by date descending. """ invoices, total = billing_service.get_invoices( db, merchant.id, skip=skip, limit=limit ) return MerchantPortalInvoiceListResponse( invoices=[ { "id": inv.id, "invoice_number": inv.invoice_number, "invoice_date": inv.invoice_date.isoformat(), "due_date": inv.due_date.isoformat() if inv.due_date else None, "subtotal_cents": inv.subtotal_cents, "tax_cents": inv.tax_cents, "total_cents": inv.total_cents, "amount_paid_cents": inv.amount_paid_cents, "currency": inv.currency, "status": inv.status, "pdf_url": inv.invoice_pdf_url, "hosted_url": inv.hosted_invoice_url, "description": inv.description, "created_at": ( inv.created_at.isoformat() if inv.created_at else None ), } for inv in invoices ], total=total, skip=skip, limit=limit, )