# app/services/subscription_service.py """ Subscription service for tier-based access control. Handles: - Subscription creation and management - Tier limit enforcement - Usage tracking - Feature gating Usage: from app.services.subscription_service import subscription_service # Check if vendor can create an order can_create, message = subscription_service.can_create_order(db, vendor_id) # Increment order counter after successful order subscription_service.increment_order_count(db, vendor_id) """ import logging from datetime import UTC, datetime, timedelta from typing import Any from sqlalchemy import func from sqlalchemy.orm import Session from models.database.product import Product from models.database.subscription import ( SubscriptionStatus, SubscriptionTier, TIER_LIMITS, TierCode, VendorSubscription, ) from models.database.vendor import Vendor, VendorUser from models.schema.subscription import ( SubscriptionCreate, SubscriptionUpdate, SubscriptionUsage, TierInfo, TierLimits, UsageSummary, ) logger = logging.getLogger(__name__) class SubscriptionNotFoundException(Exception): """Raised when subscription not found.""" pass class TierLimitExceededException(Exception): """Raised when a tier limit is exceeded.""" def __init__(self, message: str, limit_type: str, current: int, limit: int): super().__init__(message) self.limit_type = limit_type self.current = current self.limit = limit class FeatureNotAvailableException(Exception): """Raised when a feature is not available in current tier.""" def __init__(self, feature: str, current_tier: str, required_tier: str): message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})" super().__init__(message) self.feature = feature self.current_tier = current_tier self.required_tier = required_tier class SubscriptionService: """Service for subscription and tier limit operations.""" # ========================================================================= # Tier Information # ========================================================================= def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo: """ Get full tier information. Queries database if db session provided, otherwise falls back to TIER_LIMITS. """ # Try database first if session provided if db is not None: db_tier = self.get_tier_by_code(db, tier_code) if db_tier: return TierInfo( code=db_tier.code, name=db_tier.name, price_monthly_cents=db_tier.price_monthly_cents, price_annual_cents=db_tier.price_annual_cents, limits=TierLimits( orders_per_month=db_tier.orders_per_month, products_limit=db_tier.products_limit, team_members=db_tier.team_members, order_history_months=db_tier.order_history_months, ), features=db_tier.features or [], ) # Fallback to hardcoded TIER_LIMITS return self._get_tier_from_legacy(tier_code) def _get_tier_from_legacy(self, tier_code: str) -> TierInfo: """Get tier info from hardcoded TIER_LIMITS (fallback).""" try: tier = TierCode(tier_code) except ValueError: tier = TierCode.ESSENTIAL limits = TIER_LIMITS[tier] return TierInfo( code=tier.value, name=limits["name"], price_monthly_cents=limits["price_monthly_cents"], price_annual_cents=limits.get("price_annual_cents"), limits=TierLimits( orders_per_month=limits.get("orders_per_month"), products_limit=limits.get("products_limit"), team_members=limits.get("team_members"), order_history_months=limits.get("order_history_months"), ), features=limits.get("features", []), ) def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]: """ Get information for all tiers. Queries database if db session provided, otherwise falls back to TIER_LIMITS. """ if db is not None: db_tiers = ( db.query(SubscriptionTier) .filter( SubscriptionTier.is_active == True, # noqa: E712 SubscriptionTier.is_public == True, # noqa: E712 ) .order_by(SubscriptionTier.display_order) .all() ) if db_tiers: return [ TierInfo( code=t.code, name=t.name, price_monthly_cents=t.price_monthly_cents, price_annual_cents=t.price_annual_cents, limits=TierLimits( orders_per_month=t.orders_per_month, products_limit=t.products_limit, team_members=t.team_members, order_history_months=t.order_history_months, ), features=t.features or [], ) for t in db_tiers ] # Fallback to hardcoded return [ self._get_tier_from_legacy(tier.value) for tier in TierCode ] def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: """Get subscription tier by code.""" return ( db.query(SubscriptionTier) .filter(SubscriptionTier.code == tier_code) .first() ) def get_tier_id(self, db: Session, tier_code: str) -> int | None: """Get tier ID from tier code. Returns None if tier not found.""" tier = self.get_tier_by_code(db, tier_code) return tier.id if tier else None # ========================================================================= # Subscription CRUD # ========================================================================= def get_subscription( self, db: Session, vendor_id: int ) -> VendorSubscription | None: """Get vendor subscription.""" return ( db.query(VendorSubscription) .filter(VendorSubscription.vendor_id == vendor_id) .first() ) def get_subscription_or_raise( self, db: Session, vendor_id: int ) -> VendorSubscription: """Get vendor subscription or raise exception.""" subscription = self.get_subscription(db, vendor_id) if not subscription: raise SubscriptionNotFoundException( f"No subscription found for vendor {vendor_id}" ) return subscription def get_current_tier( self, db: Session, vendor_id: int ) -> TierCode | None: """Get vendor's current subscription tier code.""" subscription = self.get_subscription(db, vendor_id) if subscription: try: return TierCode(subscription.tier) except ValueError: return None return None def get_or_create_subscription( self, db: Session, vendor_id: int, tier: str = TierCode.ESSENTIAL.value, trial_days: int = 14, ) -> VendorSubscription: """ Get existing subscription or create a new trial subscription. Used when a vendor first accesses the system. """ subscription = self.get_subscription(db, vendor_id) if subscription: return subscription # Create new trial subscription now = datetime.now(UTC) trial_end = now + timedelta(days=trial_days) # Lookup tier_id from tier code tier_id = self.get_tier_id(db, tier) subscription = VendorSubscription( vendor_id=vendor_id, tier=tier, tier_id=tier_id, status=SubscriptionStatus.TRIAL.value, period_start=now, period_end=trial_end, trial_ends_at=trial_end, is_annual=False, ) db.add(subscription) db.flush() db.refresh(subscription) logger.info( f"Created trial subscription for vendor {vendor_id} " f"(tier={tier}, trial_ends={trial_end})" ) return subscription def create_subscription( self, db: Session, vendor_id: int, data: SubscriptionCreate, ) -> VendorSubscription: """Create a subscription for a vendor.""" # Check if subscription exists existing = self.get_subscription(db, vendor_id) if existing: raise ValueError("Vendor already has a subscription") now = datetime.now(UTC) # Calculate period end based on billing cycle if data.is_annual: period_end = now + timedelta(days=365) else: period_end = now + timedelta(days=30) # Handle trial trial_ends_at = None status = SubscriptionStatus.ACTIVE.value if data.trial_days > 0: trial_ends_at = now + timedelta(days=data.trial_days) status = SubscriptionStatus.TRIAL.value period_end = trial_ends_at # Lookup tier_id from tier code tier_id = self.get_tier_id(db, data.tier) subscription = VendorSubscription( vendor_id=vendor_id, tier=data.tier, tier_id=tier_id, status=status, period_start=now, period_end=period_end, trial_ends_at=trial_ends_at, is_annual=data.is_annual, ) db.add(subscription) db.flush() db.refresh(subscription) logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}") return subscription def update_subscription( self, db: Session, vendor_id: int, data: SubscriptionUpdate, ) -> VendorSubscription: """Update a vendor subscription.""" subscription = self.get_subscription_or_raise(db, vendor_id) update_data = data.model_dump(exclude_unset=True) # If tier is being updated, also update tier_id if "tier" in update_data: tier_id = self.get_tier_id(db, update_data["tier"]) update_data["tier_id"] = tier_id for key, value in update_data.items(): setattr(subscription, key, value) subscription.updated_at = datetime.now(UTC) db.flush() db.refresh(subscription) logger.info(f"Updated subscription for vendor {vendor_id}") return subscription def upgrade_tier( self, db: Session, vendor_id: int, new_tier: str, ) -> VendorSubscription: """Upgrade vendor to a new tier.""" subscription = self.get_subscription_or_raise(db, vendor_id) old_tier = subscription.tier subscription.tier = new_tier subscription.tier_id = self.get_tier_id(db, new_tier) subscription.updated_at = datetime.now(UTC) # If upgrading from trial, mark as active if subscription.status == SubscriptionStatus.TRIAL.value: subscription.status = SubscriptionStatus.ACTIVE.value db.flush() db.refresh(subscription) logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}") return subscription def cancel_subscription( self, db: Session, vendor_id: int, reason: str | None = None, ) -> VendorSubscription: """Cancel a vendor subscription (access until period end).""" subscription = self.get_subscription_or_raise(db, vendor_id) subscription.status = SubscriptionStatus.CANCELLED.value subscription.cancelled_at = datetime.now(UTC) subscription.cancellation_reason = reason subscription.updated_at = datetime.now(UTC) db.flush() db.refresh(subscription) logger.info(f"Cancelled subscription for vendor {vendor_id}") return subscription # ========================================================================= # Usage Tracking # ========================================================================= def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage: """Get current subscription usage statistics.""" subscription = self.get_or_create_subscription(db, vendor_id) # Get actual counts products_count = ( db.query(func.count(Product.id)) .filter(Product.vendor_id == vendor_id) .scalar() or 0 ) team_count = ( db.query(func.count(VendorUser.id)) .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) .scalar() or 0 ) # Calculate usage stats orders_limit = subscription.orders_limit products_limit = subscription.products_limit team_limit = subscription.team_members_limit def calc_remaining(current: int, limit: int | None) -> int | None: if limit is None: return None return max(0, limit - current) def calc_percent(current: int, limit: int | None) -> float | None: if limit is None or limit == 0: return None return min(100.0, (current / limit) * 100) return SubscriptionUsage( orders_used=subscription.orders_this_period, orders_limit=orders_limit, orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit), orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit), products_used=products_count, products_limit=products_limit, products_remaining=calc_remaining(products_count, products_limit), products_percent_used=calc_percent(products_count, products_limit), team_members_used=team_count, team_members_limit=team_limit, team_members_remaining=calc_remaining(team_count, team_limit), team_members_percent_used=calc_percent(team_count, team_limit), ) def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary: """Get usage summary for billing page display.""" subscription = self.get_or_create_subscription(db, vendor_id) # Get actual counts products_count = ( db.query(func.count(Product.id)) .filter(Product.vendor_id == vendor_id) .scalar() or 0 ) team_count = ( db.query(func.count(VendorUser.id)) .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) .scalar() or 0 ) # Get limits orders_limit = subscription.orders_limit products_limit = subscription.products_limit team_limit = subscription.team_members_limit def calc_remaining(current: int, limit: int | None) -> int | None: if limit is None: return None return max(0, limit - current) return UsageSummary( orders_this_period=subscription.orders_this_period, orders_limit=orders_limit, orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit), products_count=products_count, products_limit=products_limit, products_remaining=calc_remaining(products_count, products_limit), team_count=team_count, team_limit=team_limit, team_remaining=calc_remaining(team_count, team_limit), ) def increment_order_count(self, db: Session, vendor_id: int) -> None: """ Increment the order counter for the current period. Call this after successfully creating/importing an order. """ subscription = self.get_or_create_subscription(db, vendor_id) subscription.increment_order_count() db.flush() def reset_period_counters(self, db: Session, vendor_id: int) -> None: """Reset counters for a new billing period.""" subscription = self.get_subscription_or_raise(db, vendor_id) subscription.reset_period_counters() db.flush() logger.info(f"Reset period counters for vendor {vendor_id}") # ========================================================================= # Limit Checks # ========================================================================= def can_create_order( self, db: Session, vendor_id: int ) -> tuple[bool, str | None]: """ Check if vendor can create/import another order. Returns: (allowed, error_message) """ subscription = self.get_or_create_subscription(db, vendor_id) return subscription.can_create_order() def check_order_limit(self, db: Session, vendor_id: int) -> None: """ Check order limit and raise exception if exceeded. Use this in order creation flows. """ can_create, message = self.can_create_order(db, vendor_id) if not can_create: subscription = self.get_subscription(db, vendor_id) raise TierLimitExceededException( message=message or "Order limit exceeded", limit_type="orders", current=subscription.orders_this_period if subscription else 0, limit=subscription.orders_limit if subscription else 0, ) def can_add_product( self, db: Session, vendor_id: int ) -> tuple[bool, str | None]: """ Check if vendor can add another product. Returns: (allowed, error_message) """ subscription = self.get_or_create_subscription(db, vendor_id) products_count = ( db.query(func.count(Product.id)) .filter(Product.vendor_id == vendor_id) .scalar() or 0 ) return subscription.can_add_product(products_count) def check_product_limit(self, db: Session, vendor_id: int) -> None: """ Check product limit and raise exception if exceeded. Use this in product creation flows. """ can_add, message = self.can_add_product(db, vendor_id) if not can_add: subscription = self.get_subscription(db, vendor_id) products_count = ( db.query(func.count(Product.id)) .filter(Product.vendor_id == vendor_id) .scalar() or 0 ) raise TierLimitExceededException( message=message or "Product limit exceeded", limit_type="products", current=products_count, limit=subscription.products_limit if subscription else 0, ) def can_add_team_member( self, db: Session, vendor_id: int ) -> tuple[bool, str | None]: """ Check if vendor can add another team member. Returns: (allowed, error_message) """ subscription = self.get_or_create_subscription(db, vendor_id) team_count = ( db.query(func.count(VendorUser.id)) .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) .scalar() or 0 ) return subscription.can_add_team_member(team_count) def check_team_limit(self, db: Session, vendor_id: int) -> None: """ Check team member limit and raise exception if exceeded. Use this in team member invitation flows. """ can_add, message = self.can_add_team_member(db, vendor_id) if not can_add: subscription = self.get_subscription(db, vendor_id) team_count = ( db.query(func.count(VendorUser.id)) .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) .scalar() or 0 ) raise TierLimitExceededException( message=message or "Team member limit exceeded", limit_type="team_members", current=team_count, limit=subscription.team_members_limit if subscription else 0, ) # ========================================================================= # Feature Gating # ========================================================================= def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool: """Check if vendor has access to a feature.""" subscription = self.get_or_create_subscription(db, vendor_id) return subscription.has_feature(feature) def check_feature(self, db: Session, vendor_id: int, feature: str) -> None: """ Check feature access and raise exception if not available. Use this to gate premium features. """ if not self.has_feature(db, vendor_id, feature): subscription = self.get_or_create_subscription(db, vendor_id) # Find which tier has this feature required_tier = None for tier_code, limits in TIER_LIMITS.items(): if feature in limits.get("features", []): required_tier = limits["name"] break raise FeatureNotAvailableException( feature=feature, current_tier=subscription.tier, required_tier=required_tier or "higher", ) def get_feature_tier(self, feature: str) -> str | None: """Get the minimum tier required for a feature.""" for tier_code in [ TierCode.ESSENTIAL, TierCode.PROFESSIONAL, TierCode.BUSINESS, TierCode.ENTERPRISE, ]: if feature in TIER_LIMITS[tier_code].get("features", []): return tier_code.value return None # Singleton instance subscription_service = SubscriptionService()