From b717c237871597e8b0173132031b27541853b6e5 Mon Sep 17 00:00:00 2001 From: Samir Boulahtit Date: Fri, 26 Dec 2025 07:35:53 +0100 Subject: [PATCH] feat: update subscription service to use tier_id relationship MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add get_tier_by_code and get_tier_id helper methods - Update get_or_create_subscription to set tier_id - Update create_subscription to set tier_id - Update update_subscription to sync tier_id when tier changes - Update upgrade_tier to set tier_id All subscription CRUD operations now maintain the tier_id FK relationship in sync with the tier code string. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- app/services/subscription_service.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/app/services/subscription_service.py b/app/services/subscription_service.py index 5f1e8b6f..0f01ecac 100644 --- a/app/services/subscription_service.py +++ b/app/services/subscription_service.py @@ -28,6 +28,7 @@ from sqlalchemy.orm import Session from models.database.product import Product from models.database.subscription import ( SubscriptionStatus, + SubscriptionTier, TIER_LIMITS, TierCode, VendorSubscription, @@ -107,6 +108,19 @@ class SubscriptionService: for tier in TierCode ] + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: + """Get subscription tier by code.""" + return ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == tier_code) + .first() + ) + + def get_tier_id(self, db: Session, tier_code: str) -> int | None: + """Get tier ID from tier code. Returns None if tier not found.""" + tier = self.get_tier_by_code(db, tier_code) + return tier.id if tier else None + # ========================================================================= # Subscription CRUD # ========================================================================= @@ -152,9 +166,13 @@ class SubscriptionService: now = datetime.now(UTC) trial_end = now + timedelta(days=trial_days) + # Lookup tier_id from tier code + tier_id = self.get_tier_id(db, tier) + subscription = VendorSubscription( vendor_id=vendor_id, tier=tier, + tier_id=tier_id, status=SubscriptionStatus.TRIAL.value, period_start=now, period_end=trial_end, @@ -201,9 +219,13 @@ class SubscriptionService: status = SubscriptionStatus.TRIAL.value period_end = trial_ends_at + # Lookup tier_id from tier code + tier_id = self.get_tier_id(db, data.tier) + subscription = VendorSubscription( vendor_id=vendor_id, tier=data.tier, + tier_id=tier_id, status=status, period_start=now, period_end=period_end, @@ -228,6 +250,12 @@ class SubscriptionService: subscription = self.get_subscription_or_raise(db, vendor_id) update_data = data.model_dump(exclude_unset=True) + + # If tier is being updated, also update tier_id + if "tier" in update_data: + tier_id = self.get_tier_id(db, update_data["tier"]) + update_data["tier_id"] = tier_id + for key, value in update_data.items(): setattr(subscription, key, value) @@ -249,6 +277,7 @@ class SubscriptionService: old_tier = subscription.tier subscription.tier = new_tier + subscription.tier_id = self.get_tier_id(db, new_tier) subscription.updated_at = datetime.now(UTC) # If upgrading from trial, mark as active