# app/modules/billing/routes/api/merchant.py """ Merchant billing API endpoints for the merchant portal. Provides subscription management and billing operations for merchant owners: - View subscriptions across all platforms - Subscription detail and tier info per platform - Stripe checkout session creation - Invoice history Authentication: merchant_token cookie or Authorization header. The user must own at least one active merchant (validated by get_current_merchant_from_cookie_or_header). Auto-discovered by the route system (merchant.py in routes/api/ triggers registration under /api/v1/merchants/billing/*). """ import logging from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request from pydantic import BaseModel from sqlalchemy.orm import Session from app.api.deps import get_current_merchant_from_cookie_or_header from app.core.database import get_db from app.modules.billing.schemas import ( CheckoutRequest, CheckoutResponse, MerchantSubscriptionResponse, TierInfo, ) from app.modules.billing.services.billing_service import billing_service from app.modules.billing.services.subscription_service import subscription_service from app.modules.tenancy.models import Merchant from models.schema.auth import UserContext logger = logging.getLogger(__name__) ROUTE_CONFIG = { "prefix": "/billing", } router = APIRouter() # ============================================================================ # Helpers # ============================================================================ def _get_user_merchant(db: Session, user_context: UserContext) -> Merchant: """ Get the first active merchant owned by the current user. Args: db: Database session user_context: Authenticated user context Returns: Merchant: The user's active merchant Raises: HTTPException 404: If the user has no active merchants """ merchant = ( db.query(Merchant) .filter( Merchant.owner_user_id == user_context.id, Merchant.is_active == True, # noqa: E712 ) .first() ) if not merchant: raise HTTPException(status_code=404, detail="No active merchant found") return merchant # ============================================================================ # Subscription Endpoints # ============================================================================ @router.get("/subscriptions") def list_merchant_subscriptions( request: Request, current_user: UserContext = Depends(get_current_merchant_from_cookie_or_header), db: Session = Depends(get_db), ): """ List all subscriptions for the current merchant. Returns subscriptions across all platforms the merchant is subscribed to, including tier information and status. """ merchant = _get_user_merchant(db, current_user) subscriptions = subscription_service.get_merchant_subscriptions(db, merchant.id) items = [] for sub in subscriptions: data = MerchantSubscriptionResponse.model_validate(sub).model_dump() data["tier"] = sub.tier.code if sub.tier else None data["tier_name"] = sub.tier.name if sub.tier else None data["platform_name"] = sub.platform.name if sub.platform else "" items.append(data) return {"subscriptions": items, "total": len(items)} @router.get("/subscriptions/{platform_id}") def get_merchant_subscription( request: Request, platform_id: int = Path(..., description="Platform ID"), current_user: UserContext = Depends(get_current_merchant_from_cookie_or_header), db: Session = Depends(get_db), ): """ Get subscription detail for a specific platform. Returns the subscription with tier information for the given platform. """ merchant = _get_user_merchant(db, current_user) subscription = subscription_service.get_merchant_subscription( db, merchant.id, platform_id ) if not subscription: raise HTTPException( status_code=404, detail=f"No subscription found for platform {platform_id}", ) sub_data = MerchantSubscriptionResponse.model_validate(subscription).model_dump() sub_data["tier"] = subscription.tier.code if subscription.tier else None sub_data["tier_name"] = subscription.tier.name if subscription.tier else None sub_data["platform_name"] = subscription.platform.name if subscription.platform else "" tier_info = None if subscription.tier: tier = subscription.tier tier_info = TierInfo( code=tier.code, name=tier.name, description=tier.description, price_monthly_cents=tier.price_monthly_cents, price_annual_cents=tier.price_annual_cents, feature_codes=tier.get_feature_codes() if hasattr(tier, "get_feature_codes") else [], ) return { "subscription": sub_data, "tier": tier_info, } @router.get("/subscriptions/{platform_id}/tiers") def get_available_tiers( request: Request, platform_id: int = Path(..., description="Platform ID"), current_user: UserContext = Depends(get_current_merchant_from_cookie_or_header), db: Session = Depends(get_db), ): """ Get available tiers for upgrade on a specific platform. Returns all public tiers with upgrade/downgrade flags relative to the merchant's current tier. """ merchant = _get_user_merchant(db, current_user) subscription = subscription_service.get_merchant_subscription( db, merchant.id, platform_id ) current_tier_id = subscription.tier_id if subscription else None tier_list, tier_order = billing_service.get_available_tiers( db, current_tier_id, platform_id ) current_tier_code = None if subscription and subscription.tier: current_tier_code = subscription.tier.code return { "tiers": tier_list, "current_tier": current_tier_code, } class ChangeTierRequest(BaseModel): """Request for changing subscription tier.""" tier_code: str is_annual: bool = False @router.post("/subscriptions/{platform_id}/change-tier") def change_subscription_tier( request: Request, tier_data: ChangeTierRequest, platform_id: int = Path(..., description="Platform ID"), current_user: UserContext = Depends(get_current_merchant_from_cookie_or_header), db: Session = Depends(get_db), ): """ Change the subscription tier for a specific platform. Handles both Stripe-connected and non-Stripe subscriptions. """ merchant = _get_user_merchant(db, current_user) result = billing_service.change_tier( db, merchant.id, platform_id, tier_data.tier_code, tier_data.is_annual ) db.commit() logger.info( f"Merchant {merchant.id} ({merchant.name}) changed tier to " f"{tier_data.tier_code} on platform={platform_id}" ) return result @router.post( "/subscriptions/{platform_id}/checkout", response_model=CheckoutResponse, ) def create_checkout_session( request: Request, checkout_data: CheckoutRequest, platform_id: int = Path(..., description="Platform ID"), current_user: UserContext = Depends(get_current_merchant_from_cookie_or_header), db: Session = Depends(get_db), ): """ Create a Stripe checkout session for the merchant's subscription. Starts a new subscription or upgrades an existing one to the requested tier. """ merchant = _get_user_merchant(db, current_user) # Build success/cancel URLs from request base_url = str(request.base_url).rstrip("/") success_url = f"{base_url}/merchants/billing/subscriptions/{platform_id}?checkout=success" cancel_url = f"{base_url}/merchants/billing/subscriptions/{platform_id}?checkout=cancelled" result = billing_service.create_checkout_session( db=db, merchant_id=merchant.id, platform_id=platform_id, tier_code=checkout_data.tier_code, is_annual=checkout_data.is_annual, success_url=success_url, cancel_url=cancel_url, ) db.commit() logger.info( f"Merchant {merchant.id} ({merchant.name}) created checkout session " f"for tier={checkout_data.tier_code} on platform={platform_id}" ) return CheckoutResponse( checkout_url=result["checkout_url"], session_id=result["session_id"], ) # ============================================================================ # Invoice Endpoints # ============================================================================ @router.get("/invoices") def get_invoices( request: Request, skip: int = Query(0, ge=0, description="Number of records to skip"), limit: int = Query(20, ge=1, le=100, description="Max records to return"), current_user: UserContext = Depends(get_current_merchant_from_cookie_or_header), db: Session = Depends(get_db), ): """ Get invoice history for the current merchant. Returns paginated billing history entries ordered by date descending. """ merchant = _get_user_merchant(db, current_user) invoices, total = billing_service.get_invoices( db, merchant.id, skip=skip, limit=limit ) return { "invoices": [ { "id": inv.id, "invoice_number": inv.invoice_number, "invoice_date": inv.invoice_date.isoformat(), "due_date": inv.due_date.isoformat() if inv.due_date else None, "subtotal_cents": inv.subtotal_cents, "tax_cents": inv.tax_cents, "total_cents": inv.total_cents, "amount_paid_cents": inv.amount_paid_cents, "currency": inv.currency, "status": inv.status, "pdf_url": inv.invoice_pdf_url, "hosted_url": inv.hosted_invoice_url, "description": inv.description, "created_at": inv.created_at.isoformat() if inv.created_at else None, } for inv in invoices ], "total": total, "skip": skip, "limit": limit, }