# app/modules/billing/services/admin_subscription_service.py """ Admin Subscription Service. Handles subscription management operations for platform administrators: - Subscription tier CRUD - Merchant subscription management - Billing history queries - Subscription analytics """ import logging from math import ceil from sqlalchemy import func from sqlalchemy.orm import Session, joinedload from app.exceptions import ( BusinessLogicException, ConflictException, ResourceNotFoundException, ) from app.modules.billing.exceptions import TierNotFoundException from app.modules.billing.models import ( BillingHistory, MerchantSubscription, SubscriptionStatus, SubscriptionTier, ) from app.modules.tenancy.exceptions import PlatformNotFoundException logger = logging.getLogger(__name__) class AdminSubscriptionService: """Service for admin subscription management operations.""" # ========================================================================= # Stripe Tier Sync # ========================================================================= @staticmethod def _sync_tier_to_stripe(db: Session, tier: SubscriptionTier) -> None: """ Sync a tier's Stripe product and prices. Creates or verifies the Stripe Product and Price objects, and populates the stripe_product_id, stripe_price_monthly_id, and stripe_price_annual_id fields on the tier. Skips gracefully if Stripe is not configured (dev mode). Stripe Prices are immutable — on price changes, new Prices are created and old ones archived. """ from app.core.config import settings if not settings.stripe_secret_key: logger.debug( f"Stripe not configured, skipping sync for tier {tier.code}" ) return import stripe stripe.api_key = settings.stripe_secret_key # Resolve platform name for product naming platform_name = "Platform" if tier.platform_id: from app.modules.tenancy.services.platform_service import ( platform_service, ) try: platform = platform_service.get_platform_by_id(db, tier.platform_id) platform_name = platform.name except Exception: # noqa: EXC-003 pass # --- Product --- if tier.stripe_product_id: # Verify it still exists in Stripe try: stripe.Product.retrieve(tier.stripe_product_id) except stripe.InvalidRequestError: logger.warning( f"Stripe product {tier.stripe_product_id} not found, " f"recreating for tier {tier.code}" ) tier.stripe_product_id = None if not tier.stripe_product_id: product = stripe.Product.create( name=f"{platform_name} - {tier.name}", metadata={ "tier_code": tier.code, "platform_id": str(tier.platform_id or ""), }, ) tier.stripe_product_id = product.id logger.info( f"Created Stripe product {product.id} for tier {tier.code}" ) # --- Monthly Price --- if tier.price_monthly_cents: if tier.stripe_price_monthly_id: # Verify price matches; if not, create new one try: existing = stripe.Price.retrieve(tier.stripe_price_monthly_id) if existing.unit_amount != tier.price_monthly_cents: # Price changed — archive old, create new stripe.Price.modify( tier.stripe_price_monthly_id, active=False ) tier.stripe_price_monthly_id = None logger.info( f"Archived old monthly price for tier {tier.code}" ) except stripe.InvalidRequestError: tier.stripe_price_monthly_id = None if not tier.stripe_price_monthly_id: price = stripe.Price.create( product=tier.stripe_product_id, unit_amount=tier.price_monthly_cents, currency="eur", recurring={"interval": "month"}, metadata={ "tier_code": tier.code, "billing_period": "monthly", }, ) tier.stripe_price_monthly_id = price.id logger.info( f"Created Stripe monthly price {price.id} " f"for tier {tier.code} ({tier.price_monthly_cents} cents)" ) # --- Annual Price --- if tier.price_annual_cents: if tier.stripe_price_annual_id: try: existing = stripe.Price.retrieve(tier.stripe_price_annual_id) if existing.unit_amount != tier.price_annual_cents: stripe.Price.modify( tier.stripe_price_annual_id, active=False ) tier.stripe_price_annual_id = None logger.info( f"Archived old annual price for tier {tier.code}" ) except stripe.InvalidRequestError: tier.stripe_price_annual_id = None if not tier.stripe_price_annual_id: price = stripe.Price.create( product=tier.stripe_product_id, unit_amount=tier.price_annual_cents, currency="eur", recurring={"interval": "year"}, metadata={ "tier_code": tier.code, "billing_period": "annual", }, ) tier.stripe_price_annual_id = price.id logger.info( f"Created Stripe annual price {price.id} " f"for tier {tier.code} ({tier.price_annual_cents} cents)" ) # ========================================================================= # Subscription Tiers # ========================================================================= def get_tiers( self, db: Session, include_inactive: bool = False, platform_id: int | None = None ) -> list[SubscriptionTier]: """Get all subscription tiers, optionally filtered by platform.""" query = db.query(SubscriptionTier) if not include_inactive: query = query.filter(SubscriptionTier.is_active == True) # noqa: E712 if platform_id is not None: query = query.filter( (SubscriptionTier.platform_id == platform_id) | (SubscriptionTier.platform_id.is_(None)) ) return query.order_by(SubscriptionTier.display_order).all() def get_tier_by_code( self, db: Session, tier_code: str, platform_id: int | None = None ) -> SubscriptionTier: """Get a subscription tier by code, optionally scoped to a platform.""" query = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code) if platform_id is not None: query = query.filter(SubscriptionTier.platform_id == platform_id) tier = query.first() if not tier: raise TierNotFoundException(tier_code) return tier def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier: """Create a new subscription tier.""" # Check for duplicate code existing = ( db.query(SubscriptionTier) .filter(SubscriptionTier.code == tier_data["code"]) .first() ) if existing: raise ConflictException( f"Tier with code '{tier_data['code']}' already exists" ) tier = SubscriptionTier(**tier_data) db.add(tier) db.flush() # Get tier.id before Stripe sync self._sync_tier_to_stripe(db, tier) logger.info(f"Created subscription tier: {tier.code}") return tier def update_tier( self, db: Session, tier_code: str, update_data: dict ) -> SubscriptionTier: """Update a subscription tier.""" tier = self.get_tier_by_code(db, tier_code) # Track price changes to know if Stripe sync is needed price_changed = ( "price_monthly_cents" in update_data and update_data["price_monthly_cents"] != tier.price_monthly_cents ) or ( "price_annual_cents" in update_data and update_data["price_annual_cents"] != tier.price_annual_cents ) for field, value in update_data.items(): setattr(tier, field, value) if price_changed or not tier.stripe_product_id: self._sync_tier_to_stripe(db, tier) logger.info(f"Updated subscription tier: {tier.code}") return tier def deactivate_tier(self, db: Session, tier_code: str) -> None: """Soft-delete a subscription tier.""" tier = self.get_tier_by_code(db, tier_code) # Check if any active subscriptions use this tier (by tier_id FK) active_subs = ( db.query(MerchantSubscription) .filter( MerchantSubscription.tier_id == tier.id, MerchantSubscription.status.in_([ SubscriptionStatus.ACTIVE.value, SubscriptionStatus.TRIAL.value, ]), ) .count() ) if active_subs > 0: raise BusinessLogicException( f"Cannot delete tier: {active_subs} active subscriptions are using it", "TIER_HAS_ACTIVE_SUBSCRIPTIONS", ) tier.is_active = False logger.info(f"Soft-deleted subscription tier: {tier.code}") # ========================================================================= # Merchant Subscriptions # ========================================================================= def list_subscriptions( self, db: Session, page: int = 1, per_page: int = 20, status: str | None = None, tier: str | None = None, search: str | None = None, ) -> dict: """List merchant subscriptions with filtering and pagination.""" query = ( db.query(MerchantSubscription) .join(MerchantSubscription.merchant) .options(joinedload(MerchantSubscription.merchant)) ) # Apply filters if status: query = query.filter(MerchantSubscription.status == status) if tier: query = query.join( SubscriptionTier, MerchantSubscription.tier_id == SubscriptionTier.id ).filter(SubscriptionTier.code == tier) if search: from app.modules.tenancy.services.merchant_service import merchant_service merchants, _ = merchant_service.get_merchants(db, search=search, limit=10000) merchant_ids = [m.id for m in merchants] if not merchant_ids: return { "results": [], "total": 0, "page": page, "per_page": per_page, "pages": 0, } query = query.filter(MerchantSubscription.merchant_id.in_(merchant_ids)) # Count total total = query.count() # Paginate offset = (page - 1) * per_page subs = ( query.order_by(MerchantSubscription.created_at.desc()) .offset(offset) .limit(per_page) .all() ) # Return (sub, merchant) tuples for backward compatibility with callers results = [(sub, sub.merchant) for sub in subs] return { "results": results, "total": total, "page": page, "per_page": per_page, "pages": ceil(total / per_page) if total > 0 else 0, } def get_subscription( self, db: Session, merchant_id: int, platform_id: int ) -> tuple: """Get subscription for a specific merchant on a platform.""" sub = ( db.query(MerchantSubscription) .options(joinedload(MerchantSubscription.merchant)) .filter( MerchantSubscription.merchant_id == merchant_id, MerchantSubscription.platform_id == platform_id, ) .first() ) if not sub: raise ResourceNotFoundException( "Subscription", f"merchant_id={merchant_id}, platform_id={platform_id}", ) return sub, sub.merchant def update_subscription( self, db: Session, merchant_id: int, platform_id: int, update_data: dict ) -> tuple: """Update a merchant's subscription.""" result = self.get_subscription(db, merchant_id, platform_id) sub, merchant = result # Handle tier_code separately: resolve to tier_id tier_code = update_data.pop("tier_code", None) if tier_code is not None: if sub.stripe_subscription_id: from app.modules.billing.services.billing_service import billing_service billing_service.change_tier( db, merchant_id, platform_id, tier_code, sub.is_annual ) else: tier = self.get_tier_by_code(db, tier_code, platform_id=platform_id) sub.tier_id = tier.id for field, value in update_data.items(): setattr(sub, field, value) logger.info( f"Admin updated subscription for merchant {merchant_id} " f"on platform {platform_id}: {list(update_data.keys())}" + (f", tier_code={tier_code}" if tier_code else "") ) return sub, merchant # ========================================================================= # Billing History # ========================================================================= def list_billing_history( self, db: Session, page: int = 1, per_page: int = 20, merchant_id: int | None = None, status: str | None = None, ) -> dict: """List billing history across all merchants.""" query = db.query(BillingHistory) if merchant_id: query = query.filter(BillingHistory.merchant_id == merchant_id) if status: query = query.filter(BillingHistory.status == status) total = query.count() offset = (page - 1) * per_page invoices = ( query.order_by(BillingHistory.invoice_date.desc()) .offset(offset) .limit(per_page) .all() ) # Batch-fetch merchant names for display from app.modules.tenancy.services.merchant_service import merchant_service merchant_ids = {inv.merchant_id for inv in invoices if inv.merchant_id} merchants_map = {} for mid in merchant_ids: m = merchant_service.get_merchant_by_id_optional(db, mid) if m: merchants_map[mid] = m # Return (invoice, merchant) tuples for backward compatibility results = [ (inv, merchants_map.get(inv.merchant_id)) for inv in invoices ] return { "results": results, "total": total, "page": page, "per_page": per_page, "pages": ceil(total / per_page) if total > 0 else 0, } # ========================================================================= # Platform Helpers # ========================================================================= def get_platform_names_map(self, db: Session) -> dict[int, str]: """Get mapping of platform_id -> platform_name.""" from app.modules.tenancy.services.platform_service import platform_service platforms = platform_service.list_platforms(db, include_inactive=True) return {p.id: p.name for p in platforms} def get_platform_name(self, db: Session, platform_id: int) -> str | None: """Get platform name by ID.""" from app.modules.tenancy.services.platform_service import platform_service try: p = platform_service.get_platform_by_id(db, platform_id) return p.name except PlatformNotFoundException: return None # ========================================================================= # Merchant Subscriptions with Usage # ========================================================================= def get_merchant_subscriptions_with_usage( self, db: Session, merchant_id: int ) -> list[dict]: """Get all subscriptions for a merchant with tier info and feature usage. Returns a list of dicts, each containing: - subscription: serialized MerchantSubscription - tier: tier info dict (code, name, feature_codes) - features: list of quantitative usage metrics - platform_id: int - platform_name: str """ from app.modules.billing.schemas import MerchantSubscriptionAdminResponse from app.modules.billing.services.feature_service import feature_service from app.modules.billing.services.subscription_service import ( subscription_service, ) subs = subscription_service.get_merchant_subscriptions(db, merchant_id) platforms_map = self.get_platform_names_map(db) results = [] for sub in subs: features_summary = feature_service.get_merchant_features_summary( db, merchant_id, sub.platform_id ) tier_info = None if sub.tier: tier_info = { "code": sub.tier.code, "name": sub.tier.name, "feature_codes": [ fl.feature_code for fl in (sub.tier.feature_limits or []) ], } usage_metrics = [] for fs in features_summary: if fs.feature_type == "quantitative" and fs.enabled: usage_metrics.append({ "name": fs.name_key.replace("_", " ").title(), "current": fs.current or 0, "limit": fs.limit, "percentage": fs.percent_used or 0, "is_unlimited": fs.limit is None, "is_at_limit": fs.remaining == 0 if fs.remaining is not None else False, "is_approaching_limit": (fs.percent_used or 0) >= 80, }) results.append({ "subscription": MerchantSubscriptionAdminResponse.model_validate(sub).model_dump(), "tier": tier_info, "features": usage_metrics, "platform_id": sub.platform_id, "platform_name": platforms_map.get(sub.platform_id, ""), }) return results def get_subscriptions_for_store( self, db: Session, store_id: int ) -> list[dict]: """Get subscriptions + feature usage for a store (resolves to merchant). Convenience method for admin store detail page. Resolves store -> merchant -> all platform subscriptions. """ from app.modules.tenancy.services.store_service import store_service store = store_service.get_store_by_id_optional(db, store_id) if not store or not store.merchant_id: raise ResourceNotFoundException("Store", str(store_id)) return self.get_merchant_subscriptions_with_usage(db, store.merchant_id) # ========================================================================= # Statistics # ========================================================================= def get_stats(self, db: Session) -> dict: """Get subscription statistics for admin dashboard.""" # Count by status status_counts = ( db.query( MerchantSubscription.status, func.count(MerchantSubscription.id), ) .group_by(MerchantSubscription.status) .all() ) stats = { "total_subscriptions": 0, "active_count": 0, "trial_count": 0, "past_due_count": 0, "cancelled_count": 0, "expired_count": 0, } for sub_status, count in status_counts: stats["total_subscriptions"] += count if sub_status == SubscriptionStatus.ACTIVE.value: stats["active_count"] = count elif sub_status == SubscriptionStatus.TRIAL.value: stats["trial_count"] = count elif sub_status == SubscriptionStatus.PAST_DUE.value: stats["past_due_count"] = count elif sub_status == SubscriptionStatus.CANCELLED.value: stats["cancelled_count"] = count elif sub_status == SubscriptionStatus.EXPIRED.value: stats["expired_count"] = count # Count by tier (join with SubscriptionTier to get tier name) tier_counts = ( db.query(SubscriptionTier.name, func.count(MerchantSubscription.id)) .join( SubscriptionTier, MerchantSubscription.tier_id == SubscriptionTier.id, ) .filter( MerchantSubscription.status.in_([ SubscriptionStatus.ACTIVE.value, SubscriptionStatus.TRIAL.value, ]) ) .group_by(SubscriptionTier.name) .all() ) tier_distribution = dict(tier_counts) # Calculate MRR (Monthly Recurring Revenue) mrr_cents = 0 arr_cents = 0 active_subs = ( db.query(MerchantSubscription, SubscriptionTier) .join( SubscriptionTier, MerchantSubscription.tier_id == SubscriptionTier.id, ) .filter(MerchantSubscription.status == SubscriptionStatus.ACTIVE.value) .all() ) for sub, sub_tier in active_subs: if sub.is_annual and sub_tier.price_annual_cents: mrr_cents += sub_tier.price_annual_cents // 12 arr_cents += sub_tier.price_annual_cents else: mrr_cents += sub_tier.price_monthly_cents arr_cents += sub_tier.price_monthly_cents * 12 stats["tier_distribution"] = tier_distribution stats["mrr_cents"] = mrr_cents stats["arr_cents"] = arr_cents return stats # Singleton instance admin_subscription_service = AdminSubscriptionService()