# app/modules/billing/services/subscription_service.py """ Subscription service for merchant-level subscription management. Handles: - MerchantSubscription creation and management - Tier lookup and resolution - Store → merchant → subscription resolution Limit checks are now handled by feature_service.check_resource_limit(). Modules own their own limit checks (catalog, orders, tenancy, etc.). Usage: from app.modules.billing.services import subscription_service # Get merchant subscription sub = subscription_service.get_merchant_subscription(db, merchant_id, platform_id) # Create merchant subscription sub = subscription_service.create_merchant_subscription(db, merchant_id, platform_id, tier_code) # Resolve store to merchant subscription sub = subscription_service.get_subscription_for_store(db, store_id) """ import logging from datetime import UTC, datetime, timedelta from sqlalchemy.orm import Session, joinedload from app.exceptions import ResourceNotFoundException from app.modules.billing.exceptions import SubscriptionNotFoundException from app.modules.billing.models import ( MerchantSubscription, SubscriptionStatus, SubscriptionTier, TierCode, ) logger = logging.getLogger(__name__) class SubscriptionService: """Service for merchant-level subscription management.""" # ========================================================================= # Store Resolution # ========================================================================= def resolve_store_to_merchant(self, db: Session, store_id: int) -> tuple[int, int]: """Resolve store_id to (merchant_id, platform_id). Raises: ResourceNotFoundException: If store not found or has no platform """ from app.modules.tenancy.models import Store, StorePlatform store = db.query(Store).filter(Store.id == store_id).first() if not store or not store.merchant_id: raise ResourceNotFoundException("Store", str(store_id)) sp = db.query(StorePlatform.platform_id).filter( StorePlatform.store_id == store_id ).first() if not sp: raise ResourceNotFoundException("StorePlatform", f"store_id={store_id}") return store.merchant_id, sp[0] def get_store_code(self, db: Session, store_id: int) -> str: """Get the store_code for a given store_id. Raises: ResourceNotFoundException: If store not found """ from app.modules.tenancy.models import Store store = db.query(Store).filter(Store.id == store_id).first() if not store: raise ResourceNotFoundException("Store", str(store_id)) return store.store_code # ========================================================================= # Tier Information # ========================================================================= def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: """Get subscription tier by code.""" return ( db.query(SubscriptionTier) .filter(SubscriptionTier.code == tier_code) .first() ) def get_tier_id(self, db: Session, tier_code: str) -> int | None: """Get tier ID from tier code. Returns None if tier not found.""" tier = self.get_tier_by_code(db, tier_code) return tier.id if tier else None def get_all_tiers( self, db: Session, platform_id: int | None = None ) -> list[SubscriptionTier]: """ Get all active, public tiers. If platform_id is provided, returns tiers for that platform plus global tiers (platform_id=NULL). """ query = db.query(SubscriptionTier).filter( SubscriptionTier.is_active == True, # noqa: E712 SubscriptionTier.is_public == True, # noqa: E712 ) if platform_id is not None: query = query.filter( (SubscriptionTier.platform_id == platform_id) | (SubscriptionTier.platform_id.is_(None)) ) return query.order_by(SubscriptionTier.display_order).all() # ========================================================================= # Merchant Subscription CRUD # ========================================================================= def get_merchant_subscription( self, db: Session, merchant_id: int, platform_id: int ) -> MerchantSubscription | None: """Get merchant subscription for a specific platform.""" return ( db.query(MerchantSubscription) .options( joinedload(MerchantSubscription.tier) .joinedload(SubscriptionTier.feature_limits), joinedload(MerchantSubscription.platform), ) .filter( MerchantSubscription.merchant_id == merchant_id, MerchantSubscription.platform_id == platform_id, ) .first() ) def get_merchant_subscriptions( self, db: Session, merchant_id: int ) -> list[MerchantSubscription]: """Get all subscriptions for a merchant across platforms.""" return ( db.query(MerchantSubscription) .options( joinedload(MerchantSubscription.tier), joinedload(MerchantSubscription.platform), ) .filter(MerchantSubscription.merchant_id == merchant_id) .all() ) def get_subscription_for_store( self, db: Session, store_id: int ) -> MerchantSubscription | None: """ Convenience method that resolves the store -> merchant -> platform hierarchy and returns the associated merchant subscription. Looks up the store's merchant_id and platform_id, then delegates to get_merchant_subscription(). Args: db: Database session. store_id: The store ID to resolve. Returns: The merchant subscription, or None if the store, merchant, or platform cannot be resolved. """ from app.modules.tenancy.models import Store store = db.query(Store).filter(Store.id == store_id).first() if not store: return None merchant_id = store.merchant_id if merchant_id is None: return None # Get platform_id from store platform_id = getattr(store, "platform_id", None) if platform_id is None: from app.modules.tenancy.models import StorePlatform sp = ( db.query(StorePlatform.platform_id) .filter(StorePlatform.store_id == store_id) .first() ) platform_id = sp[0] if sp else None if platform_id is None: return None return self.get_merchant_subscription(db, merchant_id, platform_id) def get_subscription_or_raise( self, db: Session, merchant_id: int, platform_id: int ) -> MerchantSubscription: """Get merchant subscription or raise exception.""" subscription = self.get_merchant_subscription(db, merchant_id, platform_id) if not subscription: raise SubscriptionNotFoundException(merchant_id) return subscription def create_merchant_subscription( self, db: Session, merchant_id: int, platform_id: int, tier_code: str = TierCode.ESSENTIAL.value, trial_days: int = 14, is_annual: bool = False, ) -> MerchantSubscription: """ Create a new merchant subscription for a platform. Args: db: Database session merchant_id: Merchant ID (the billing entity) platform_id: Platform ID tier_code: Tier code (default: essential) trial_days: Trial period in days (0 = no trial) is_annual: Annual billing cycle Returns: New MerchantSubscription """ # Check for existing existing = self.get_merchant_subscription(db, merchant_id, platform_id) if existing: raise ValueError( f"Merchant {merchant_id} already has a subscription " f"on platform {platform_id}" ) now = datetime.now(UTC) # Calculate period if trial_days > 0: period_end = now + timedelta(days=trial_days) trial_ends_at = period_end status = SubscriptionStatus.TRIAL.value elif is_annual: period_end = now + timedelta(days=365) trial_ends_at = None status = SubscriptionStatus.ACTIVE.value else: period_end = now + timedelta(days=30) trial_ends_at = None status = SubscriptionStatus.ACTIVE.value tier_id = self.get_tier_id(db, tier_code) subscription = MerchantSubscription( merchant_id=merchant_id, platform_id=platform_id, tier_id=tier_id, status=status, is_annual=is_annual, period_start=now, period_end=period_end, trial_ends_at=trial_ends_at, ) db.add(subscription) db.flush() db.refresh(subscription) logger.info( f"Created subscription for merchant {merchant_id} on platform {platform_id} " f"(tier={tier_code}, status={status})" ) return subscription def get_or_create_subscription( self, db: Session, merchant_id: int, platform_id: int, tier_code: str = TierCode.ESSENTIAL.value, trial_days: int = 14, ) -> MerchantSubscription: """Get existing subscription or create a new trial subscription.""" subscription = self.get_merchant_subscription(db, merchant_id, platform_id) if subscription: return subscription return self.create_merchant_subscription( db, merchant_id, platform_id, tier_code, trial_days ) def upgrade_tier( self, db: Session, merchant_id: int, platform_id: int, new_tier_code: str, ) -> MerchantSubscription: """Upgrade merchant to a new tier.""" subscription = self.get_subscription_or_raise(db, merchant_id, platform_id) old_tier_id = subscription.tier_id new_tier = self.get_tier_by_code(db, new_tier_code) if not new_tier: raise ValueError(f"Tier '{new_tier_code}' not found") subscription.tier_id = new_tier.id subscription.updated_at = datetime.now(UTC) # If upgrading from trial, mark as active if subscription.status == SubscriptionStatus.TRIAL.value: subscription.status = SubscriptionStatus.ACTIVE.value db.flush() db.refresh(subscription) logger.info( f"Upgraded merchant {merchant_id} on platform {platform_id} " f"from tier_id={old_tier_id} to tier_id={new_tier.id} ({new_tier_code})" ) return subscription def cancel_subscription( self, db: Session, merchant_id: int, platform_id: int, reason: str | None = None, ) -> MerchantSubscription: """Cancel a merchant subscription (access continues until period end).""" subscription = self.get_subscription_or_raise(db, merchant_id, platform_id) subscription.status = SubscriptionStatus.CANCELLED.value subscription.cancelled_at = datetime.now(UTC) subscription.cancellation_reason = reason subscription.updated_at = datetime.now(UTC) db.flush() db.refresh(subscription) logger.info( f"Cancelled subscription for merchant {merchant_id} " f"on platform {platform_id}" ) return subscription def reactivate_subscription( self, db: Session, merchant_id: int, platform_id: int, ) -> MerchantSubscription: """Reactivate a cancelled subscription.""" subscription = self.get_subscription_or_raise(db, merchant_id, platform_id) subscription.status = SubscriptionStatus.ACTIVE.value subscription.cancelled_at = None subscription.cancellation_reason = None subscription.updated_at = datetime.now(UTC) db.flush() db.refresh(subscription) logger.info( f"Reactivated subscription for merchant {merchant_id} " f"on platform {platform_id}" ) return subscription # Singleton instance subscription_service = SubscriptionService()