diff --git a/app/services/subscription_service.py b/app/services/subscription_service.py index 5f1e8b6f..0f01ecac 100644 --- a/app/services/subscription_service.py +++ b/app/services/subscription_service.py @@ -28,6 +28,7 @@ from sqlalchemy.orm import Session from models.database.product import Product from models.database.subscription import ( SubscriptionStatus, + SubscriptionTier, TIER_LIMITS, TierCode, VendorSubscription, @@ -107,6 +108,19 @@ class SubscriptionService: for tier in TierCode ] + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: + """Get subscription tier by code.""" + return ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == tier_code) + .first() + ) + + def get_tier_id(self, db: Session, tier_code: str) -> int | None: + """Get tier ID from tier code. Returns None if tier not found.""" + tier = self.get_tier_by_code(db, tier_code) + return tier.id if tier else None + # ========================================================================= # Subscription CRUD # ========================================================================= @@ -152,9 +166,13 @@ class SubscriptionService: now = datetime.now(UTC) trial_end = now + timedelta(days=trial_days) + # Lookup tier_id from tier code + tier_id = self.get_tier_id(db, tier) + subscription = VendorSubscription( vendor_id=vendor_id, tier=tier, + tier_id=tier_id, status=SubscriptionStatus.TRIAL.value, period_start=now, period_end=trial_end, @@ -201,9 +219,13 @@ class SubscriptionService: status = SubscriptionStatus.TRIAL.value period_end = trial_ends_at + # Lookup tier_id from tier code + tier_id = self.get_tier_id(db, data.tier) + subscription = VendorSubscription( vendor_id=vendor_id, tier=data.tier, + tier_id=tier_id, status=status, period_start=now, period_end=period_end, @@ -228,6 +250,12 @@ class SubscriptionService: subscription = self.get_subscription_or_raise(db, vendor_id) update_data = data.model_dump(exclude_unset=True) + + # If tier is being updated, also update tier_id + if "tier" in update_data: + tier_id = self.get_tier_id(db, update_data["tier"]) + update_data["tier_id"] = tier_id + for key, value in update_data.items(): setattr(subscription, key, value) @@ -249,6 +277,7 @@ class SubscriptionService: old_tier = subscription.tier subscription.tier = new_tier + subscription.tier_id = self.get_tier_id(db, new_tier) subscription.updated_at = datetime.now(UTC) # If upgrading from trial, mark as active