# app/modules/billing/tests/unit/test_billing_service.py """Unit tests for BillingService.""" from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, patch import pytest from app.modules.billing.exceptions import ( NoActiveSubscriptionException, PaymentSystemNotConfiguredException, StripePriceNotConfiguredException, SubscriptionNotCancelledException, TierNotFoundException, ) from app.modules.billing.models import ( AddOnProduct, BillingHistory, MerchantSubscription, SubscriptionStatus, SubscriptionTier, ) from app.modules.billing.services.billing_service import BillingService # ============================================================================ # Tier Lookup # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceTiers: """Test suite for BillingService tier operations.""" def setup_method(self): self.service = BillingService() def test_get_tier_by_code_found(self, db, bs_tier_essential): """Returns the active tier.""" tier = self.service.get_tier_by_code(db, "essential") assert tier.code == "essential" def test_get_tier_by_code_not_found(self, db): """Nonexistent tier raises TierNotFoundException.""" with pytest.raises(TierNotFoundException) as exc_info: self.service.get_tier_by_code(db, "nonexistent") assert exc_info.value.tier_code == "nonexistent" def test_get_tier_by_code_inactive_not_returned(self, db, bs_tier_essential): """Inactive tier raises TierNotFoundException (only active tiers returned).""" bs_tier_essential.is_active = False db.flush() with pytest.raises(TierNotFoundException): self.service.get_tier_by_code(db, "essential") # ============================================================================ # Available Tiers with Upgrade/Downgrade Flags # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceAvailableTiers: """Test suite for get_available_tiers (upgrade/downgrade detection).""" def setup_method(self): self.service = BillingService() def test_get_available_tiers_returns_all(self, db, bs_tiers): """Returns all active public tiers.""" tier_list, tier_order = self.service.get_available_tiers(db, None) assert len(tier_list) == 3 def test_get_available_tiers_marks_current(self, db, bs_tiers): """Current tier is marked with is_current=True.""" tier_list, _ = self.service.get_available_tiers(db, bs_tiers[1].id) current = [t for t in tier_list if t["is_current"]] assert len(current) == 1 assert current[0]["code"] == "professional" def test_get_available_tiers_upgrade_flags(self, db, bs_tiers): """Tiers with higher display_order have can_upgrade=True.""" tier_list, _ = self.service.get_available_tiers(db, bs_tiers[0].id) essential = next(t for t in tier_list if t["code"] == "essential") professional = next(t for t in tier_list if t["code"] == "professional") business = next(t for t in tier_list if t["code"] == "business") assert essential["is_current"] is True assert essential["can_upgrade"] is False assert professional["can_upgrade"] is True assert business["can_upgrade"] is True def test_get_available_tiers_downgrade_flags(self, db, bs_tiers): """Tiers with lower display_order have can_downgrade=True.""" tier_list, _ = self.service.get_available_tiers(db, bs_tiers[2].id) essential = next(t for t in tier_list if t["code"] == "essential") professional = next(t for t in tier_list if t["code"] == "professional") business = next(t for t in tier_list if t["code"] == "business") assert essential["can_downgrade"] is True assert professional["can_downgrade"] is True assert business["is_current"] is True assert business["can_downgrade"] is False def test_get_available_tiers_no_current_tier(self, db, bs_tiers): """When current_tier_id is None, no tier is marked current.""" tier_list, _ = self.service.get_available_tiers(db, None) assert all(t["is_current"] is False for t in tier_list) def test_get_available_tiers_returns_tier_order_map(self, db, bs_tiers): """Returns tier_order map of code → display_order.""" _, tier_order = self.service.get_available_tiers(db, None) assert tier_order["essential"] == 1 assert tier_order["professional"] == 2 assert tier_order["business"] == 3 # ============================================================================ # Invoices # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceInvoices: """Test suite for BillingService invoice operations.""" def setup_method(self): self.service = BillingService() def test_get_invoices_empty(self, db, test_merchant): """Returns empty list and zero total when no invoices exist.""" invoices, total = self.service.get_invoices(db, test_merchant.id) assert invoices == [] assert total == 0 def test_get_invoices_with_data(self, db, bs_billing_history): """Returns invoices for the merchant.""" merchant_id = bs_billing_history[0].merchant_id invoices, total = self.service.get_invoices(db, merchant_id) assert total == 3 assert len(invoices) == 3 def test_get_invoices_pagination(self, db, bs_billing_history): """Pagination limits and offsets results.""" merchant_id = bs_billing_history[0].merchant_id invoices, total = self.service.get_invoices(db, merchant_id, skip=0, limit=2) assert total == 3 assert len(invoices) == 2 invoices2, _ = self.service.get_invoices(db, merchant_id, skip=2, limit=2) assert len(invoices2) == 1 def test_get_invoices_ordered_by_date_desc(self, db, bs_billing_history): """Invoices are returned newest first.""" merchant_id = bs_billing_history[0].merchant_id invoices, _ = self.service.get_invoices(db, merchant_id) dates = [inv.invoice_date for inv in invoices] assert dates == sorted(dates, reverse=True) # ============================================================================ # Add-ons # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceAddons: """Test suite for BillingService addon operations.""" def setup_method(self): self.service = BillingService() def test_get_available_addons_empty(self, db): """Returns empty when no addons exist.""" addons = self.service.get_available_addons(db) assert addons == [] def test_get_available_addons_with_data(self, db, test_addon_products): """Returns all active addons.""" addons = self.service.get_available_addons(db) assert len(addons) == 3 assert all(addon.is_active for addon in addons) def test_get_available_addons_by_category(self, db, test_addon_products): """Filters by category.""" domain_addons = self.service.get_available_addons(db, category="domain") assert len(domain_addons) == 1 assert domain_addons[0].category == "domain" def test_get_store_addons_empty(self, db, test_store): """Returns empty when store has no purchased addons.""" addons = self.service.get_store_addons(db, test_store.id) assert addons == [] # ============================================================================ # Subscription with Tier # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceSubscriptionWithTier: """Tests for get_subscription_with_tier.""" def setup_method(self): self.service = BillingService() def test_get_subscription_with_tier_existing( self, db, bs_subscription ): """Returns (subscription, tier) tuple for existing subscription.""" sub, tier = self.service.get_subscription_with_tier( db, bs_subscription.merchant_id, bs_subscription.platform_id ) assert sub.id == bs_subscription.id assert tier is not None assert tier.code == "essential" def test_get_subscription_with_tier_creates_if_missing( self, db, test_merchant, test_platform, bs_tier_essential ): """Creates a trial subscription when none exists (via get_or_create).""" sub, tier = self.service.get_subscription_with_tier( db, test_merchant.id, test_platform.id ) assert sub is not None assert sub.status == SubscriptionStatus.TRIAL.value # ============================================================================ # Change Tier # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceChangeTier: """Tests for change_tier (the tier upgrade/downgrade flow).""" def setup_method(self): self.service = BillingService() def test_change_tier_no_subscription_raises(self, db, bs_tiers): """Raises NoActiveSubscriptionException when no subscription exists.""" with pytest.raises(NoActiveSubscriptionException): self.service.change_tier(db, 99999, 99999, "professional", False) def test_change_tier_no_stripe_subscription_raises( self, db, bs_subscription, bs_tiers ): """Raises when subscription has no stripe_subscription_id.""" # bs_subscription has no Stripe IDs with pytest.raises(NoActiveSubscriptionException): self.service.change_tier( db, bs_subscription.merchant_id, bs_subscription.platform_id, "professional", False, ) def test_change_tier_nonexistent_tier_raises( self, db, bs_stripe_subscription ): """Raises TierNotFoundException for nonexistent tier.""" with pytest.raises(TierNotFoundException): self.service.change_tier( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, "nonexistent", False, ) def test_change_tier_no_price_id_raises( self, db, bs_stripe_subscription, bs_tiers ): """Raises StripePriceNotConfiguredException when tier has no Stripe price.""" # bs_tiers have no stripe_price_* set with pytest.raises(StripePriceNotConfiguredException): self.service.change_tier( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, "professional", False, ) @patch("app.modules.billing.services.billing_service.stripe_service") def test_change_tier_success( self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe ): """Successful tier change updates local subscription and calls Stripe.""" mock_stripe.is_configured = True mock_stripe.update_subscription = MagicMock() result = self.service.change_tier( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, "professional", False, ) assert result["new_tier"] == "professional" assert result["effective_immediately"] is True assert bs_stripe_subscription.tier_id == bs_tiers_with_stripe[1].id mock_stripe.update_subscription.assert_called_once() @patch("app.modules.billing.services.billing_service.stripe_service") def test_change_tier_annual_uses_annual_price( self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe ): """Annual billing selects stripe_price_annual_id.""" mock_stripe.is_configured = True mock_stripe.update_subscription = MagicMock() self.service.change_tier( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, "professional", True, ) call_args = mock_stripe.update_subscription.call_args assert call_args.kwargs["new_price_id"] == "price_pro_annual" # ============================================================================ # _is_upgrade # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceIsUpgrade: """Tests for _is_upgrade helper.""" def setup_method(self): self.service = BillingService() def test_is_upgrade_true(self, db, bs_tiers): """Higher display_order is an upgrade.""" assert self.service._is_upgrade(db, bs_tiers[0].id, bs_tiers[2].id) is True def test_is_upgrade_false_downgrade(self, db, bs_tiers): """Lower display_order is not an upgrade.""" assert self.service._is_upgrade(db, bs_tiers[2].id, bs_tiers[0].id) is False def test_is_upgrade_same_tier(self, db, bs_tiers): """Same tier is not an upgrade.""" assert self.service._is_upgrade(db, bs_tiers[1].id, bs_tiers[1].id) is False def test_is_upgrade_none_ids(self, db): """None tier IDs return False.""" assert self.service._is_upgrade(db, None, None) is False assert self.service._is_upgrade(db, None, 1) is False assert self.service._is_upgrade(db, 1, None) is False # ============================================================================ # Cancel Subscription # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceCancel: """Tests for cancel_subscription.""" def setup_method(self): self.service = BillingService() def test_cancel_no_subscription_raises(self, db): """Raises when no subscription found.""" with pytest.raises(NoActiveSubscriptionException): self.service.cancel_subscription(db, 99999, 99999, None, False) def test_cancel_no_stripe_id_raises(self, db, bs_subscription): """Raises when subscription has no stripe_subscription_id.""" with pytest.raises(NoActiveSubscriptionException): self.service.cancel_subscription( db, bs_subscription.merchant_id, bs_subscription.platform_id, "reason", False, ) @patch("app.modules.billing.services.billing_service.stripe_service") def test_cancel_success(self, mock_stripe, db, bs_stripe_subscription): """Cancellation records timestamp and reason.""" mock_stripe.is_configured = True mock_stripe.cancel_subscription = MagicMock() result = self.service.cancel_subscription( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, "Too expensive", False, ) assert result["message"] == "Subscription cancelled successfully" assert bs_stripe_subscription.cancelled_at is not None assert bs_stripe_subscription.cancellation_reason == "Too expensive" mock_stripe.cancel_subscription.assert_called_once() # ============================================================================ # Reactivate Subscription # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceReactivate: """Tests for reactivate_subscription.""" def setup_method(self): self.service = BillingService() def test_reactivate_no_subscription_raises(self, db): """Raises when no subscription found.""" with pytest.raises(NoActiveSubscriptionException): self.service.reactivate_subscription(db, 99999, 99999) def test_reactivate_not_cancelled_raises(self, db, bs_stripe_subscription): """Raises SubscriptionNotCancelledException when not cancelled.""" with pytest.raises(SubscriptionNotCancelledException): self.service.reactivate_subscription( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, ) @patch("app.modules.billing.services.billing_service.stripe_service") def test_reactivate_success(self, mock_stripe, db, bs_stripe_subscription): """Reactivation clears cancellation and calls Stripe.""" mock_stripe.is_configured = True mock_stripe.reactivate_subscription = MagicMock() # Cancel first bs_stripe_subscription.cancelled_at = datetime.now(UTC) bs_stripe_subscription.cancellation_reason = "Testing" db.flush() result = self.service.reactivate_subscription( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, ) assert result["message"] == "Subscription reactivated successfully" assert bs_stripe_subscription.cancelled_at is None assert bs_stripe_subscription.cancellation_reason is None mock_stripe.reactivate_subscription.assert_called_once() # ============================================================================ # Checkout Session # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceCheckout: """Tests for create_checkout_session.""" def setup_method(self): self.service = BillingService() def test_checkout_stripe_not_configured_raises(self, db, bs_tiers_with_stripe): """Raises PaymentSystemNotConfiguredException when Stripe is off.""" with patch( "app.modules.billing.services.billing_service.stripe_service" ) as mock_stripe: mock_stripe.is_configured = False with pytest.raises(PaymentSystemNotConfiguredException): self.service.create_checkout_session( db, 1, 1, "essential", False, "http://ok", "http://cancel" # noqa: SEC034 ) def test_checkout_nonexistent_tier_raises(self, db): """Raises TierNotFoundException for nonexistent tier.""" with patch( "app.modules.billing.services.billing_service.stripe_service" ) as mock_stripe: mock_stripe.is_configured = True with pytest.raises(TierNotFoundException): self.service.create_checkout_session( db, 1, 1, "nonexistent", False, "http://ok", "http://cancel" # noqa: SEC034 ) # ============================================================================ # Portal Session # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServicePortal: """Tests for create_portal_session.""" def setup_method(self): self.service = BillingService() def test_portal_stripe_not_configured_raises(self, db): """Raises PaymentSystemNotConfiguredException when Stripe is off.""" with patch( "app.modules.billing.services.billing_service.stripe_service" ) as mock_stripe: mock_stripe.is_configured = False with pytest.raises(PaymentSystemNotConfiguredException): self.service.create_portal_session(db, 1, 1, "http://return") # noqa: SEC034 def test_portal_no_subscription_raises(self, db): """Raises NoActiveSubscriptionException when no subscription found.""" with patch( "app.modules.billing.services.billing_service.stripe_service" ) as mock_stripe: mock_stripe.is_configured = True with pytest.raises(NoActiveSubscriptionException): self.service.create_portal_session(db, 99999, 99999, "http://return") # noqa: SEC034 def test_portal_no_customer_id_raises(self, db, bs_subscription): """Raises when subscription has no stripe_customer_id.""" with patch( "app.modules.billing.services.billing_service.stripe_service" ) as mock_stripe: mock_stripe.is_configured = True with pytest.raises(NoActiveSubscriptionException): self.service.create_portal_session( db, bs_subscription.merchant_id, bs_subscription.platform_id, "http://return", # noqa: SEC034 ) # ============================================================================ # Upcoming Invoice # ============================================================================ @pytest.mark.unit @pytest.mark.billing class TestBillingServiceUpcomingInvoice: """Tests for get_upcoming_invoice.""" def setup_method(self): self.service = BillingService() def test_upcoming_invoice_no_subscription_raises(self, db): """Raises when no subscription exists.""" with pytest.raises(NoActiveSubscriptionException): self.service.get_upcoming_invoice(db, 99999, 99999) def test_upcoming_invoice_no_customer_id_raises(self, db, bs_subscription): """Raises when subscription has no stripe_customer_id.""" with pytest.raises(NoActiveSubscriptionException): self.service.get_upcoming_invoice( db, bs_subscription.merchant_id, bs_subscription.platform_id ) def test_upcoming_invoice_stripe_not_configured_returns_empty( self, db, bs_stripe_subscription ): """Returns empty invoice when Stripe is not configured.""" with patch( "app.modules.billing.services.billing_service.stripe_service" ) as mock_stripe: mock_stripe.is_configured = False result = self.service.get_upcoming_invoice( db, bs_stripe_subscription.merchant_id, bs_stripe_subscription.platform_id, ) assert result["amount_due_cents"] == 0 assert result["line_items"] == [] # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture def bs_tier_essential(db, test_platform): """Create essential subscription tier.""" tier = SubscriptionTier( code="essential", name="Essential", description="Essential plan", price_monthly_cents=4900, price_annual_cents=49000, display_order=1, is_active=True, is_public=True, platform_id=test_platform.id, ) db.add(tier) db.commit() db.refresh(tier) return tier @pytest.fixture def bs_tiers(db, test_platform): """Create three tiers without Stripe config.""" tiers = [ SubscriptionTier( code="essential", name="Essential", price_monthly_cents=4900, price_annual_cents=49000, display_order=1, is_active=True, is_public=True, platform_id=test_platform.id, ), SubscriptionTier( code="professional", name="Professional", price_monthly_cents=9900, price_annual_cents=99000, display_order=2, is_active=True, is_public=True, platform_id=test_platform.id, ), SubscriptionTier( code="business", name="Business", price_monthly_cents=19900, price_annual_cents=199000, display_order=3, is_active=True, is_public=True, platform_id=test_platform.id, ), ] db.add_all(tiers) db.commit() for t in tiers: db.refresh(t) return tiers @pytest.fixture def bs_tiers_with_stripe(db, test_platform): """Create tiers with Stripe price IDs configured.""" tiers = [ SubscriptionTier( code="essential", name="Essential", price_monthly_cents=4900, price_annual_cents=49000, display_order=1, is_active=True, is_public=True, platform_id=test_platform.id, stripe_product_id="prod_essential", stripe_price_monthly_id="price_ess_monthly", stripe_price_annual_id="price_ess_annual", ), SubscriptionTier( code="professional", name="Professional", price_monthly_cents=9900, price_annual_cents=99000, display_order=2, is_active=True, is_public=True, platform_id=test_platform.id, stripe_product_id="prod_professional", stripe_price_monthly_id="price_pro_monthly", stripe_price_annual_id="price_pro_annual", ), SubscriptionTier( code="business", name="Business", price_monthly_cents=19900, price_annual_cents=199000, display_order=3, is_active=True, is_public=True, platform_id=test_platform.id, stripe_product_id="prod_business", stripe_price_monthly_id="price_biz_monthly", stripe_price_annual_id="price_biz_annual", ), ] db.add_all(tiers) db.commit() for t in tiers: db.refresh(t) return tiers @pytest.fixture def bs_subscription(db, test_merchant, test_platform, bs_tier_essential): """Create an active merchant subscription (no Stripe IDs).""" now = datetime.now(UTC) sub = MerchantSubscription( merchant_id=test_merchant.id, platform_id=test_platform.id, tier_id=bs_tier_essential.id, status=SubscriptionStatus.ACTIVE.value, period_start=now, period_end=now + timedelta(days=30), ) db.add(sub) db.commit() db.refresh(sub) return sub @pytest.fixture def bs_stripe_subscription(db, test_merchant, test_platform, bs_tier_essential): """Create an active subscription with Stripe IDs.""" now = datetime.now(UTC) sub = MerchantSubscription( merchant_id=test_merchant.id, platform_id=test_platform.id, tier_id=bs_tier_essential.id, status=SubscriptionStatus.ACTIVE.value, stripe_customer_id="cus_test123", stripe_subscription_id="sub_test123", period_start=now, period_end=now + timedelta(days=30), ) db.add(sub) db.commit() db.refresh(sub) return sub @pytest.fixture def bs_billing_history(db, test_merchant): """Create billing history records for test_merchant.""" records = [] for i in range(3): record = BillingHistory( merchant_id=test_merchant.id, stripe_invoice_id=f"in_bs_test_{i}", invoice_number=f"BS-{i:03d}", invoice_date=datetime.now(UTC) - timedelta(days=i * 30), subtotal_cents=4900, tax_cents=0, total_cents=4900, amount_paid_cents=4900, currency="EUR", status="paid", ) records.append(record) db.add_all(records) db.commit() for r in records: db.refresh(r) return records @pytest.fixture def test_addon_products(db): """Create test addon products.""" addons = [ AddOnProduct( code="domain", name="Custom Domain", category="domain", price_cents=1500, billing_period="annual", display_order=1, is_active=True, ), AddOnProduct( code="email_5", name="5 Email Addresses", category="email", price_cents=500, billing_period="monthly", quantity_value=5, display_order=2, is_active=True, ), AddOnProduct( code="email_10", name="10 Email Addresses", category="email", price_cents=900, billing_period="monthly", quantity_value=10, display_order=3, is_active=True, ), ] db.add_all(addons) db.commit() for addon in addons: db.refresh(addon) return addons