# app/modules/billing/tests/unit/test_subscription_service.py """Unit tests for SubscriptionService (merchant-level subscription management).""" from datetime import UTC, datetime, timedelta import pytest from app.modules.billing.exceptions import SubscriptionNotFoundException from app.modules.billing.models import ( MerchantSubscription, SubscriptionStatus, SubscriptionTier, ) from app.modules.billing.services.subscription_service import SubscriptionService # ============================================================================ # Tier Information # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceTierLookup: """Tests for tier lookup methods.""" def setup_method(self): self.service = SubscriptionService() def test_get_tier_by_code_returns_tier(self, db, billing_tier_essential): """Existing tier code returns the tier object.""" tier = self.service.get_tier_by_code(db, "essential") assert tier is not None assert tier.code == "essential" assert tier.id == billing_tier_essential.id def test_get_tier_by_code_nonexistent_returns_none(self, db): """Nonexistent tier code returns None (does not raise).""" tier = self.service.get_tier_by_code(db, "nonexistent") assert tier is None def test_get_tier_id_returns_id(self, db, billing_tier_essential): """get_tier_id returns the integer ID for a valid code.""" tier_id = self.service.get_tier_id(db, "essential") assert tier_id == billing_tier_essential.id def test_get_tier_id_nonexistent_returns_none(self, db): """get_tier_id returns None for nonexistent code.""" tier_id = self.service.get_tier_id(db, "nonexistent") assert tier_id is None @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceGetAllTiers: """Tests for get_all_tiers with platform filtering.""" def setup_method(self): self.service = SubscriptionService() def test_get_all_tiers_returns_active_public_tiers(self, db, billing_tiers): """Returns only active + public tiers, ordered by display_order.""" tiers = self.service.get_all_tiers(db) assert len(tiers) == 3 assert tiers[0].code == "essential" assert tiers[1].code == "professional" assert tiers[2].code == "business" def test_get_all_tiers_excludes_inactive(self, db, billing_tiers): """Inactive tiers are excluded.""" # Deactivate one billing_tiers[0].is_active = False db.flush() tiers = self.service.get_all_tiers(db) assert len(tiers) == 2 assert all(t.code != "essential" for t in tiers) def test_get_all_tiers_excludes_non_public(self, db, billing_tiers): """Non-public tiers are excluded.""" billing_tiers[2].is_public = False db.flush() tiers = self.service.get_all_tiers(db) assert len(tiers) == 2 assert all(t.code != "business" for t in tiers) def test_get_all_tiers_with_platform_filter(self, db, test_platform, billing_tiers): """Platform filter returns platform-specific tiers + global tiers.""" # Make one tier platform-specific billing_tiers[0].platform_id = test_platform.id # billing_tiers[1] and [2] remain global (platform_id=NULL) db.flush() tiers = self.service.get_all_tiers(db, platform_id=test_platform.id) codes = {t.code for t in tiers} assert "essential" in codes # Platform-specific assert "professional" in codes # Global assert "business" in codes # Global def test_get_all_tiers_platform_filter_excludes_other_platform( self, db, test_platform, another_platform, billing_tiers ): """Tiers belonging to another platform are excluded.""" billing_tiers[0].platform_id = another_platform.id db.flush() tiers = self.service.get_all_tiers(db, platform_id=test_platform.id) codes = {t.code for t in tiers} assert "essential" not in codes assert len(tiers) == 2 # ============================================================================ # Merchant Subscription CRUD # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceGetMerchantSubscription: """Tests for subscription retrieval methods.""" def setup_method(self): self.service = SubscriptionService() def test_get_merchant_subscription_found( self, db, billing_subscription ): """Returns subscription when it exists.""" sub = self.service.get_merchant_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub is not None assert sub.id == billing_subscription.id def test_get_merchant_subscription_not_found(self, db): """Returns None when no subscription exists.""" sub = self.service.get_merchant_subscription(db, 99999, 99999) assert sub is None def test_get_merchant_subscription_eager_loads_tier( self, db, billing_subscription ): """Subscription comes with tier relationship loaded.""" sub = self.service.get_merchant_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) # Accessing .tier should not trigger a lazy load error assert sub.tier is not None assert sub.tier.code == "essential" def test_get_merchant_subscription_eager_loads_platform( self, db, billing_subscription ): """Subscription comes with platform relationship loaded.""" sub = self.service.get_merchant_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub.platform is not None def test_get_merchant_subscriptions_all( self, db, test_merchant, test_platform, another_platform, billing_tier_essential ): """Returns all subscriptions for a merchant across platforms.""" now = datetime.now(UTC) sub1 = MerchantSubscription( merchant_id=test_merchant.id, platform_id=test_platform.id, tier_id=billing_tier_essential.id, status=SubscriptionStatus.ACTIVE.value, period_start=now, period_end=now + timedelta(days=30), ) sub2 = MerchantSubscription( merchant_id=test_merchant.id, platform_id=another_platform.id, tier_id=billing_tier_essential.id, status=SubscriptionStatus.TRIAL.value, period_start=now, period_end=now + timedelta(days=14), ) db.add_all([sub1, sub2]) db.flush() subs = self.service.get_merchant_subscriptions(db, test_merchant.id) assert len(subs) == 2 def test_get_merchant_subscriptions_empty(self, db, test_merchant): """Returns empty list when merchant has no subscriptions.""" subs = self.service.get_merchant_subscriptions(db, test_merchant.id) assert subs == [] @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceGetOrRaise: """Tests for get_subscription_or_raise.""" def setup_method(self): self.service = SubscriptionService() def test_get_subscription_or_raise_found(self, db, billing_subscription): """Returns subscription when found.""" sub = self.service.get_subscription_or_raise( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub.id == billing_subscription.id def test_get_subscription_or_raise_not_found(self, db): """Raises SubscriptionNotFoundException when not found.""" with pytest.raises(SubscriptionNotFoundException): self.service.get_subscription_or_raise(db, 99999, 99999) # ============================================================================ # Create Subscription # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceCreate: """Tests for create_merchant_subscription.""" def setup_method(self): self.service = SubscriptionService() def test_create_trial_subscription( self, db, test_merchant, test_platform, billing_tier_essential ): """Default creation starts a 14-day trial.""" sub = self.service.create_merchant_subscription( db, test_merchant.id, test_platform.id ) assert sub.merchant_id == test_merchant.id assert sub.platform_id == test_platform.id assert sub.status == SubscriptionStatus.TRIAL.value assert sub.tier_id == billing_tier_essential.id assert sub.trial_ends_at is not None assert sub.is_annual is False def test_create_subscription_period_end_matches_trial( self, db, test_merchant, test_platform, billing_tier_essential ): """Period end is set to trial_days from now.""" sub = self.service.create_merchant_subscription( db, test_merchant.id, test_platform.id, trial_days=7 ) expected_end = sub.period_start + timedelta(days=7) # Allow 1 second tolerance assert abs((sub.period_end - expected_end).total_seconds()) < 1 def test_create_annual_subscription( self, db, test_merchant, test_platform, billing_tier_essential ): """Annual subscription with no trial gets 365-day period.""" sub = self.service.create_merchant_subscription( db, test_merchant.id, test_platform.id, trial_days=0, is_annual=True ) assert sub.status == SubscriptionStatus.ACTIVE.value assert sub.is_annual is True assert sub.trial_ends_at is None expected_end = sub.period_start + timedelta(days=365) assert abs((sub.period_end - expected_end).total_seconds()) < 1 def test_create_monthly_subscription( self, db, test_merchant, test_platform, billing_tier_essential ): """Monthly subscription with no trial gets 30-day period.""" sub = self.service.create_merchant_subscription( db, test_merchant.id, test_platform.id, trial_days=0, is_annual=False ) assert sub.status == SubscriptionStatus.ACTIVE.value assert sub.is_annual is False assert sub.trial_ends_at is None expected_end = sub.period_start + timedelta(days=30) assert abs((sub.period_end - expected_end).total_seconds()) < 1 def test_create_subscription_custom_tier( self, db, test_merchant, test_platform, billing_tiers ): """Can specify a custom tier code.""" sub = self.service.create_merchant_subscription( db, test_merchant.id, test_platform.id, tier_code="professional" ) assert sub.tier_id == billing_tiers[1].id def test_create_duplicate_raises( self, db, billing_subscription ): """Creating a second subscription for same merchant+platform raises ValueError.""" with pytest.raises(ValueError, match="already has a subscription"): self.service.create_merchant_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceGetOrCreate: """Tests for get_or_create_subscription.""" def setup_method(self): self.service = SubscriptionService() def test_get_or_create_returns_existing(self, db, billing_subscription): """Returns existing subscription without creating a new one.""" sub = self.service.get_or_create_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub.id == billing_subscription.id def test_get_or_create_creates_new( self, db, test_merchant, test_platform, billing_tier_essential ): """Creates new subscription when none exists.""" sub = self.service.get_or_create_subscription( db, test_merchant.id, test_platform.id ) assert sub is not None assert sub.merchant_id == test_merchant.id assert sub.status == SubscriptionStatus.TRIAL.value # ============================================================================ # Upgrade / Cancel / Reactivate # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceUpgradeTier: """Tests for upgrade_tier.""" def setup_method(self): self.service = SubscriptionService() def test_upgrade_tier_changes_tier_id(self, db, billing_subscription, billing_tiers): """Tier is updated to the new tier's ID.""" sub = self.service.upgrade_tier( db, billing_subscription.merchant_id, billing_subscription.platform_id, "professional", ) assert sub.tier_id == billing_tiers[1].id def test_upgrade_from_trial_activates(self, db, billing_subscription, billing_tiers): """Upgrading from trial status sets status to active.""" billing_subscription.status = SubscriptionStatus.TRIAL.value db.flush() sub = self.service.upgrade_tier( db, billing_subscription.merchant_id, billing_subscription.platform_id, "professional", ) assert sub.status == SubscriptionStatus.ACTIVE.value def test_upgrade_active_stays_active(self, db, billing_subscription, billing_tiers): """Upgrading an already-active subscription keeps status active.""" billing_subscription.status = SubscriptionStatus.ACTIVE.value db.flush() sub = self.service.upgrade_tier( db, billing_subscription.merchant_id, billing_subscription.platform_id, "professional", ) assert sub.status == SubscriptionStatus.ACTIVE.value def test_upgrade_nonexistent_tier_raises(self, db, billing_subscription): """Upgrading to a nonexistent tier raises ValueError.""" with pytest.raises(ValueError, match="not found"): self.service.upgrade_tier( db, billing_subscription.merchant_id, billing_subscription.platform_id, "nonexistent_tier", ) def test_upgrade_no_subscription_raises(self, db): """Upgrading when no subscription exists raises SubscriptionNotFoundException.""" with pytest.raises(SubscriptionNotFoundException): self.service.upgrade_tier(db, 99999, 99999, "professional") @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceCancel: """Tests for cancel_subscription.""" def setup_method(self): self.service = SubscriptionService() def test_cancel_sets_status(self, db, billing_subscription): """Cancellation sets status to cancelled.""" sub = self.service.cancel_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub.status == SubscriptionStatus.CANCELLED.value def test_cancel_records_timestamp(self, db, billing_subscription): """Cancellation records cancelled_at.""" sub = self.service.cancel_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub.cancelled_at is not None def test_cancel_with_reason(self, db, billing_subscription): """Cancellation stores the reason.""" sub = self.service.cancel_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, reason="Too expensive", ) assert sub.cancellation_reason == "Too expensive" def test_cancel_no_subscription_raises(self, db): """Cancelling when no subscription exists raises.""" with pytest.raises(SubscriptionNotFoundException): self.service.cancel_subscription(db, 99999, 99999) @pytest.mark.unit @pytest.mark.billing class TestSubscriptionServiceReactivate: """Tests for reactivate_subscription.""" def setup_method(self): self.service = SubscriptionService() def test_reactivate_sets_active(self, db, billing_subscription): """Reactivation sets status back to active.""" # First cancel self.service.cancel_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) sub = self.service.reactivate_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub.status == SubscriptionStatus.ACTIVE.value def test_reactivate_clears_cancellation(self, db, billing_subscription): """Reactivation clears cancelled_at and cancellation_reason.""" self.service.cancel_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, reason="Testing", ) sub = self.service.reactivate_subscription( db, billing_subscription.merchant_id, billing_subscription.platform_id, ) assert sub.cancelled_at is None assert sub.cancellation_reason is None def test_reactivate_no_subscription_raises(self, db): """Reactivating when no subscription exists raises.""" with pytest.raises(SubscriptionNotFoundException): self.service.reactivate_subscription(db, 99999, 99999) # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture def billing_tier_essential(db, test_platform): """Create essential subscription tier.""" tier = SubscriptionTier( code="essential", name="Essential", description="Essential plan", price_monthly_cents=4900, price_annual_cents=49000, display_order=1, is_active=True, is_public=True, platform_id=test_platform.id, ) db.add(tier) db.commit() db.refresh(tier) return tier @pytest.fixture def billing_tiers(db, test_platform): """Create essential, professional, and business tiers.""" tiers = [ SubscriptionTier( code="essential", name="Essential", price_monthly_cents=4900, price_annual_cents=49000, display_order=1, is_active=True, is_public=True, platform_id=test_platform.id, ), SubscriptionTier( code="professional", name="Professional", price_monthly_cents=9900, price_annual_cents=99000, display_order=2, is_active=True, is_public=True, platform_id=test_platform.id, ), SubscriptionTier( code="business", name="Business", price_monthly_cents=19900, price_annual_cents=199000, display_order=3, is_active=True, is_public=True, platform_id=test_platform.id, ), ] db.add_all(tiers) db.commit() for t in tiers: db.refresh(t) return tiers @pytest.fixture def billing_subscription(db, test_merchant, test_platform, billing_tier_essential): """Create an active merchant subscription (essential tier).""" now = datetime.now(UTC) sub = MerchantSubscription( merchant_id=test_merchant.id, platform_id=test_platform.id, tier_id=billing_tier_essential.id, status=SubscriptionStatus.ACTIVE.value, period_start=now, period_end=now + timedelta(days=30), ) db.add(sub) db.commit() db.refresh(sub) return sub