# app/services/admin_subscription_service.py """ Admin Subscription Service. Handles subscription management operations for platform administrators: - Subscription tier CRUD - Vendor subscription management - Billing history queries - Subscription analytics """ import logging from math import ceil from sqlalchemy import func from sqlalchemy.orm import Session from app.exceptions import ( BusinessLogicException, ConflictException, ResourceNotFoundException, TierNotFoundException, ) from models.database.subscription import ( BillingHistory, SubscriptionStatus, SubscriptionTier, VendorSubscription, ) from models.database.vendor import Vendor logger = logging.getLogger(__name__) class AdminSubscriptionService: """Service for admin subscription management operations.""" # ========================================================================= # Subscription Tiers # ========================================================================= def get_tiers( self, db: Session, include_inactive: bool = False ) -> list[SubscriptionTier]: """Get all subscription tiers.""" query = db.query(SubscriptionTier) if not include_inactive: query = query.filter(SubscriptionTier.is_active == True) # noqa: E712 return query.order_by(SubscriptionTier.display_order).all() def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: """Get a subscription tier by code.""" tier = ( db.query(SubscriptionTier) .filter(SubscriptionTier.code == tier_code) .first() ) if not tier: raise TierNotFoundException(tier_code) return tier def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier: """Create a new subscription tier.""" # Check for duplicate code existing = ( db.query(SubscriptionTier) .filter(SubscriptionTier.code == tier_data["code"]) .first() ) if existing: raise ConflictException( f"Tier with code '{tier_data['code']}' already exists" ) tier = SubscriptionTier(**tier_data) db.add(tier) logger.info(f"Created subscription tier: {tier.code}") return tier def update_tier( self, db: Session, tier_code: str, update_data: dict ) -> SubscriptionTier: """Update a subscription tier.""" tier = self.get_tier_by_code(db, tier_code) for field, value in update_data.items(): setattr(tier, field, value) logger.info(f"Updated subscription tier: {tier.code}") return tier def deactivate_tier(self, db: Session, tier_code: str) -> None: """Soft-delete a subscription tier.""" tier = self.get_tier_by_code(db, tier_code) # Check if any active subscriptions use this tier active_subs = ( db.query(VendorSubscription) .filter( VendorSubscription.tier == tier_code, VendorSubscription.status.in_([ SubscriptionStatus.ACTIVE.value, SubscriptionStatus.TRIAL.value, ]), ) .count() ) if active_subs > 0: raise BusinessLogicException( f"Cannot delete tier: {active_subs} active subscriptions are using it" ) tier.is_active = False logger.info(f"Soft-deleted subscription tier: {tier.code}") # ========================================================================= # Vendor Subscriptions # ========================================================================= def list_subscriptions( self, db: Session, page: int = 1, per_page: int = 20, status: str | None = None, tier: str | None = None, search: str | None = None, ) -> dict: """List vendor subscriptions with filtering and pagination.""" query = ( db.query(VendorSubscription, Vendor) .join(Vendor, VendorSubscription.vendor_id == Vendor.id) ) # Apply filters if status: query = query.filter(VendorSubscription.status == status) if tier: query = query.filter(VendorSubscription.tier == tier) if search: query = query.filter(Vendor.name.ilike(f"%{search}%")) # Count total total = query.count() # Paginate offset = (page - 1) * per_page results = ( query.order_by(VendorSubscription.created_at.desc()) .offset(offset) .limit(per_page) .all() ) return { "results": results, "total": total, "page": page, "per_page": per_page, "pages": ceil(total / per_page) if total > 0 else 0, } def get_subscription(self, db: Session, vendor_id: int) -> tuple: """Get subscription for a specific vendor.""" result = ( db.query(VendorSubscription, Vendor) .join(Vendor, VendorSubscription.vendor_id == Vendor.id) .filter(VendorSubscription.vendor_id == vendor_id) .first() ) if not result: raise ResourceNotFoundException("Subscription", str(vendor_id)) return result def update_subscription( self, db: Session, vendor_id: int, update_data: dict ) -> tuple: """Update a vendor's subscription.""" result = self.get_subscription(db, vendor_id) sub, vendor = result for field, value in update_data.items(): setattr(sub, field, value) logger.info( f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}" ) return sub, vendor # ========================================================================= # Billing History # ========================================================================= def list_billing_history( self, db: Session, page: int = 1, per_page: int = 20, vendor_id: int | None = None, status: str | None = None, ) -> dict: """List billing history across all vendors.""" query = ( db.query(BillingHistory, Vendor) .join(Vendor, BillingHistory.vendor_id == Vendor.id) ) if vendor_id: query = query.filter(BillingHistory.vendor_id == vendor_id) if status: query = query.filter(BillingHistory.status == status) total = query.count() offset = (page - 1) * per_page results = ( query.order_by(BillingHistory.invoice_date.desc()) .offset(offset) .limit(per_page) .all() ) return { "results": results, "total": total, "page": page, "per_page": per_page, "pages": ceil(total / per_page) if total > 0 else 0, } # ========================================================================= # Statistics # ========================================================================= def get_stats(self, db: Session) -> dict: """Get subscription statistics for admin dashboard.""" # Count by status status_counts = ( db.query(VendorSubscription.status, func.count(VendorSubscription.id)) .group_by(VendorSubscription.status) .all() ) stats = { "total_subscriptions": 0, "active_count": 0, "trial_count": 0, "past_due_count": 0, "cancelled_count": 0, "expired_count": 0, } for status, count in status_counts: stats["total_subscriptions"] += count if status == SubscriptionStatus.ACTIVE.value: stats["active_count"] = count elif status == SubscriptionStatus.TRIAL.value: stats["trial_count"] = count elif status == SubscriptionStatus.PAST_DUE.value: stats["past_due_count"] = count elif status == SubscriptionStatus.CANCELLED.value: stats["cancelled_count"] = count elif status == SubscriptionStatus.EXPIRED.value: stats["expired_count"] = count # Count by tier tier_counts = ( db.query(VendorSubscription.tier, func.count(VendorSubscription.id)) .filter( VendorSubscription.status.in_([ SubscriptionStatus.ACTIVE.value, SubscriptionStatus.TRIAL.value, ]) ) .group_by(VendorSubscription.tier) .all() ) tier_distribution = {tier: count for tier, count in tier_counts} # Calculate MRR (Monthly Recurring Revenue) mrr_cents = 0 arr_cents = 0 active_subs = ( db.query(VendorSubscription, SubscriptionTier) .join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code) .filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value) .all() ) for sub, tier in active_subs: if sub.is_annual and tier.price_annual_cents: mrr_cents += tier.price_annual_cents // 12 arr_cents += tier.price_annual_cents else: mrr_cents += tier.price_monthly_cents arr_cents += tier.price_monthly_cents * 12 stats["tier_distribution"] = tier_distribution stats["mrr_cents"] = mrr_cents stats["arr_cents"] = arr_cents return stats # Singleton instance admin_subscription_service = AdminSubscriptionService()