diff --git a/app/modules/billing/services/admin_subscription_service.py b/app/modules/billing/services/admin_subscription_service.py index 6a334a8e..8357a0ad 100644 --- a/app/modules/billing/services/admin_subscription_service.py +++ b/app/modules/billing/services/admin_subscription_service.py @@ -119,7 +119,8 @@ class AdminSubscriptionService: if active_subs > 0: raise BusinessLogicException( - f"Cannot delete tier: {active_subs} active subscriptions are using it" + f"Cannot delete tier: {active_subs} active subscriptions are using it", + "TIER_HAS_ACTIVE_SUBSCRIPTIONS", ) tier.is_active = False diff --git a/app/modules/billing/tests/unit/test_admin_subscription_service.py b/app/modules/billing/tests/unit/test_admin_subscription_service.py new file mode 100644 index 00000000..74411309 --- /dev/null +++ b/app/modules/billing/tests/unit/test_admin_subscription_service.py @@ -0,0 +1,642 @@ +# app/modules/billing/tests/unit/test_admin_subscription_service.py +"""Unit tests for AdminSubscriptionService (admin tier & subscription management).""" + +from datetime import UTC, datetime, timedelta + +import pytest + +from app.exceptions import ( + BusinessLogicException, + ConflictException, + ResourceNotFoundException, +) +from app.modules.billing.exceptions import TierNotFoundException +from app.modules.billing.models import ( + BillingHistory, + MerchantSubscription, + SubscriptionStatus, + SubscriptionTier, +) +from app.modules.billing.services.admin_subscription_service import ( + AdminSubscriptionService, +) +from app.modules.tenancy.models import Merchant + + +# ============================================================================ +# Tier Management +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminGetTiers: + """Tests for get_tiers with filtering.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_get_tiers_active_only_by_default(self, db, admin_billing_tiers): + """Default call returns only active tiers.""" + tiers = self.service.get_tiers(db) + + assert len(tiers) == 3 + assert all(t.is_active for t in tiers) + + def test_get_tiers_include_inactive(self, db, admin_billing_tiers): + """include_inactive=True returns all tiers.""" + admin_billing_tiers[0].is_active = False + db.flush() + + tiers = self.service.get_tiers(db, include_inactive=True) + assert len(tiers) == 3 + + def test_get_tiers_active_excludes_inactive(self, db, admin_billing_tiers): + """Active filter excludes inactive tiers.""" + admin_billing_tiers[0].is_active = False + db.flush() + + tiers = self.service.get_tiers(db, include_inactive=False) + assert len(tiers) == 2 + assert all(t.code != "essential" for t in tiers) + + def test_get_tiers_ordered_by_display_order(self, db, admin_billing_tiers): + """Tiers are returned in display_order.""" + tiers = self.service.get_tiers(db) + orders = [t.display_order for t in tiers] + assert orders == sorted(orders) + + def test_get_tiers_platform_filter(self, db, test_platform, admin_billing_tiers): + """Platform filter returns platform-specific + global tiers.""" + admin_billing_tiers[0].platform_id = test_platform.id + db.flush() + + tiers = self.service.get_tiers(db, platform_id=test_platform.id) + codes = {t.code for t in tiers} + assert "essential" in codes # Platform-specific + assert "professional" in codes # Global + assert "business" in codes # Global + + def test_get_tiers_platform_filter_excludes_other( + self, db, test_platform, another_platform, admin_billing_tiers + ): + """Tiers from another platform are excluded.""" + admin_billing_tiers[0].platform_id = another_platform.id + db.flush() + + tiers = self.service.get_tiers(db, platform_id=test_platform.id) + assert all(t.code != "essential" for t in tiers) + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminGetTierByCode: + """Tests for get_tier_by_code.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_get_tier_by_code_found(self, db, admin_billing_tiers): + """Returns the tier when it exists.""" + tier = self.service.get_tier_by_code(db, "professional") + assert tier.code == "professional" + assert tier.name == "Professional" + + def test_get_tier_by_code_not_found_raises(self, db): + """Raises TierNotFoundException for nonexistent code.""" + with pytest.raises(TierNotFoundException): + self.service.get_tier_by_code(db, "nonexistent") + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminCreateTier: + """Tests for create_tier.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_create_tier_success(self, db): + """Creates a new tier with provided data.""" + data = { + "code": "starter", + "name": "Starter", + "description": "A starter plan", + "price_monthly_cents": 1900, + "price_annual_cents": 19000, + "display_order": 0, + "is_active": True, + "is_public": True, + } + tier = self.service.create_tier(db, data) + db.flush() + + assert tier.code == "starter" + assert tier.name == "Starter" + assert tier.price_monthly_cents == 1900 + + def test_create_tier_duplicate_raises(self, db, admin_billing_tiers): + """Duplicate tier code raises ConflictException.""" + data = { + "code": "essential", + "name": "Essential Duplicate", + "price_monthly_cents": 4900, + "display_order": 1, + "is_active": True, + "is_public": True, + } + with pytest.raises(ConflictException, match="already exists"): + self.service.create_tier(db, data) + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminUpdateTier: + """Tests for update_tier.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_update_tier_name(self, db, admin_billing_tiers): + """Updates tier fields.""" + tier = self.service.update_tier(db, "essential", {"name": "Essential Plus"}) + assert tier.name == "Essential Plus" + + def test_update_tier_price(self, db, admin_billing_tiers): + """Updates tier pricing.""" + tier = self.service.update_tier( + db, "essential", {"price_monthly_cents": 5900} + ) + assert tier.price_monthly_cents == 5900 + + def test_update_tier_nonexistent_raises(self, db): + """Updating a nonexistent tier raises.""" + with pytest.raises(TierNotFoundException): + self.service.update_tier(db, "nonexistent", {"name": "X"}) + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminDeactivateTier: + """Tests for deactivate_tier (soft-delete).""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_deactivate_tier_sets_inactive(self, db, admin_billing_tiers): + """Deactivation sets is_active to False.""" + self.service.deactivate_tier(db, "business") + db.flush() + + tier = db.query(SubscriptionTier).filter( + SubscriptionTier.code == "business" + ).first() + assert tier.is_active is False + + def test_deactivate_tier_with_active_subs_raises( + self, db, admin_billing_tiers, admin_billing_subscription + ): + """Cannot deactivate a tier used by active subscriptions.""" + with pytest.raises(BusinessLogicException, match="active subscriptions"): + self.service.deactivate_tier(db, "essential") + + def test_deactivate_tier_with_cancelled_subs_ok( + self, db, admin_billing_tiers, admin_billing_subscription + ): + """Can deactivate a tier used only by cancelled subscriptions.""" + admin_billing_subscription.status = SubscriptionStatus.CANCELLED.value + db.flush() + + # Should not raise + self.service.deactivate_tier(db, "essential") + + def test_deactivate_nonexistent_tier_raises(self, db): + """Deactivating a nonexistent tier raises.""" + with pytest.raises(TierNotFoundException): + self.service.deactivate_tier(db, "nonexistent") + + +# ============================================================================ +# Subscription Management +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminListSubscriptions: + """Tests for list_subscriptions with pagination and filtering.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_list_subscriptions_empty(self, db): + """Returns empty results when no subscriptions exist.""" + result = self.service.list_subscriptions(db) + assert result["total"] == 0 + assert result["results"] == [] + assert result["page"] == 1 + + def test_list_subscriptions_returns_results( + self, db, admin_billing_subscription + ): + """Returns subscriptions joined with merchant.""" + result = self.service.list_subscriptions(db) + assert result["total"] == 1 + sub, merchant = result["results"][0] + assert isinstance(sub, MerchantSubscription) + assert isinstance(merchant, Merchant) + + def test_list_subscriptions_filter_by_status( + self, db, admin_billing_subscription, admin_billing_subscription_trial + ): + """Status filter returns only matching subscriptions.""" + result = self.service.list_subscriptions(db, status="trial") + assert result["total"] == 1 + sub, _ = result["results"][0] + assert sub.status == "trial" + + def test_list_subscriptions_filter_by_tier( + self, db, admin_billing_tiers, admin_billing_subscription + ): + """Tier filter returns only matching subscriptions.""" + result = self.service.list_subscriptions(db, tier="essential") + assert result["total"] == 1 + + result = self.service.list_subscriptions(db, tier="professional") + assert result["total"] == 0 + + def test_list_subscriptions_search_by_merchant_name( + self, db, admin_billing_subscription, test_merchant + ): + """Search filters by merchant name (case-insensitive).""" + # test_merchant has a name like "Test Merchant xxxxxxxx" + result = self.service.list_subscriptions(db, search="Test Merchant") + assert result["total"] == 1 + + result = self.service.list_subscriptions(db, search="NonExistent Corp") + assert result["total"] == 0 + + def test_list_subscriptions_pagination( + self, db, admin_billing_multiple_subscriptions + ): + """Pagination returns correct page with correct count.""" + result = self.service.list_subscriptions(db, page=1, per_page=2) + assert result["total"] == 3 + assert len(result["results"]) == 2 + assert result["pages"] == 2 + + result2 = self.service.list_subscriptions(db, page=2, per_page=2) + assert len(result2["results"]) == 1 + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminGetSubscription: + """Tests for get_subscription.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_get_subscription_found(self, db, admin_billing_subscription): + """Returns (subscription, merchant) tuple.""" + sub, merchant = self.service.get_subscription( + db, + admin_billing_subscription.merchant_id, + admin_billing_subscription.platform_id, + ) + assert sub.id == admin_billing_subscription.id + assert isinstance(merchant, Merchant) + + def test_get_subscription_not_found_raises(self, db): + """Raises ResourceNotFoundException when not found.""" + with pytest.raises(ResourceNotFoundException): + self.service.get_subscription(db, 99999, 99999) + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminUpdateSubscription: + """Tests for update_subscription — the tier_code→tier_id resolution bug area.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_update_subscription_status(self, db, admin_billing_subscription): + """Can update subscription status.""" + sub, _ = self.service.update_subscription( + db, + admin_billing_subscription.merchant_id, + admin_billing_subscription.platform_id, + {"status": SubscriptionStatus.PAST_DUE.value}, + ) + assert sub.status == SubscriptionStatus.PAST_DUE.value + + def test_update_subscription_tier_code_resolves_to_tier_id( + self, db, admin_billing_tiers, admin_billing_subscription + ): + """ + tier_code in update_data is resolved to tier_id (FK). + + This is the exact bug that was fixed in commit d1fe358 — the admin form + sends tier_code (string) but the model uses tier_id (int FK). + """ + sub, _ = self.service.update_subscription( + db, + admin_billing_subscription.merchant_id, + admin_billing_subscription.platform_id, + {"tier_code": "professional"}, + ) + assert sub.tier_id == admin_billing_tiers[1].id + + def test_update_subscription_tier_code_nonexistent_raises( + self, db, admin_billing_subscription + ): + """Nonexistent tier_code raises TierNotFoundException.""" + with pytest.raises(TierNotFoundException): + self.service.update_subscription( + db, + admin_billing_subscription.merchant_id, + admin_billing_subscription.platform_id, + {"tier_code": "nonexistent"}, + ) + + def test_update_subscription_not_found_raises(self, db): + """Updating a nonexistent subscription raises.""" + with pytest.raises(ResourceNotFoundException): + self.service.update_subscription(db, 99999, 99999, {"status": "active"}) + + +# ============================================================================ +# Billing History +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminListBillingHistory: + """Tests for list_billing_history.""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_list_billing_history_empty(self, db): + """Returns empty when no records exist.""" + result = self.service.list_billing_history(db) + assert result["total"] == 0 + assert result["results"] == [] + + def test_list_billing_history_with_records( + self, db, admin_billing_history_records + ): + """Returns billing history joined with merchant.""" + result = self.service.list_billing_history(db) + assert result["total"] == 3 + bh, merchant = result["results"][0] + assert isinstance(bh, BillingHistory) + assert isinstance(merchant, Merchant) + + def test_list_billing_history_filter_by_merchant( + self, db, admin_billing_history_records, test_merchant + ): + """Merchant filter returns only matching records.""" + result = self.service.list_billing_history( + db, merchant_id=test_merchant.id + ) + assert result["total"] == 3 + + result = self.service.list_billing_history(db, merchant_id=99999) + assert result["total"] == 0 + + def test_list_billing_history_filter_by_status( + self, db, admin_billing_history_records + ): + """Status filter returns matching records.""" + result = self.service.list_billing_history(db, status="paid") + assert result["total"] == 3 + + result = self.service.list_billing_history(db, status="failed") + assert result["total"] == 0 + + def test_list_billing_history_pagination( + self, db, admin_billing_history_records + ): + """Pagination works correctly.""" + result = self.service.list_billing_history(db, page=1, per_page=2) + assert result["total"] == 3 + assert len(result["results"]) == 2 + assert result["pages"] == 2 + + +# ============================================================================ +# Statistics +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestAdminGetStats: + """Tests for get_stats (dashboard analytics).""" + + def setup_method(self): + self.service = AdminSubscriptionService() + + def test_get_stats_empty_db(self, db): + """Stats with no subscriptions return zeroes.""" + stats = self.service.get_stats(db) + + assert stats["total_subscriptions"] == 0 + assert stats["active_count"] == 0 + assert stats["trial_count"] == 0 + assert stats["mrr_cents"] == 0 + assert stats["arr_cents"] == 0 + assert stats["tier_distribution"] == {} + + def test_get_stats_counts_by_status( + self, db, admin_billing_subscription, admin_billing_subscription_trial + ): + """Status counts are accurate.""" + stats = self.service.get_stats(db) + + assert stats["total_subscriptions"] == 2 + assert stats["active_count"] == 1 + assert stats["trial_count"] == 1 + + def test_get_stats_tier_distribution( + self, db, admin_billing_tiers, admin_billing_subscription + ): + """Tier distribution counts active/trial subscriptions by tier name.""" + stats = self.service.get_stats(db) + + assert "Essential" in stats["tier_distribution"] + assert stats["tier_distribution"]["Essential"] == 1 + + def test_get_stats_mrr_calculation_monthly( + self, db, admin_billing_tiers, admin_billing_subscription + ): + """MRR for a monthly subscription equals the monthly price.""" + stats = self.service.get_stats(db) + # Essential tier is 4900 cents/month + assert stats["mrr_cents"] == 4900 + assert stats["arr_cents"] == 4900 * 12 + + def test_get_stats_mrr_calculation_annual( + self, db, admin_billing_tiers, admin_billing_subscription + ): + """MRR for an annual subscription is annual_price / 12.""" + admin_billing_subscription.is_annual = True + db.flush() + + stats = self.service.get_stats(db) + # Essential annual = 49000 cents → MRR = 49000 / 12 = 4083 + assert stats["mrr_cents"] == 49000 // 12 + assert stats["arr_cents"] == 49000 + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def admin_billing_tiers(db): + """Create essential, professional, business tiers for admin tests.""" + tiers = [ + SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + price_annual_cents=49000, + display_order=1, + is_active=True, + is_public=True, + ), + SubscriptionTier( + code="professional", + name="Professional", + price_monthly_cents=9900, + price_annual_cents=99000, + display_order=2, + is_active=True, + is_public=True, + ), + SubscriptionTier( + code="business", + name="Business", + price_monthly_cents=19900, + price_annual_cents=199000, + display_order=3, + is_active=True, + is_public=True, + ), + ] + db.add_all(tiers) + db.commit() + for t in tiers: + db.refresh(t) + return tiers + + +@pytest.fixture +def admin_billing_subscription(db, test_merchant, test_platform, admin_billing_tiers): + """Create an active subscription on the essential tier.""" + now = datetime.now(UTC) + sub = MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=test_platform.id, + tier_id=admin_billing_tiers[0].id, + status=SubscriptionStatus.ACTIVE.value, + period_start=now, + period_end=now + timedelta(days=30), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return sub + + +@pytest.fixture +def admin_billing_subscription_trial( + db, other_merchant, another_platform, admin_billing_tiers +): + """Create a trial subscription on another merchant/platform.""" + now = datetime.now(UTC) + sub = MerchantSubscription( + merchant_id=other_merchant.id, + platform_id=another_platform.id, + tier_id=admin_billing_tiers[0].id, + status=SubscriptionStatus.TRIAL.value, + period_start=now, + period_end=now + timedelta(days=14), + trial_ends_at=now + timedelta(days=14), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return sub + + +@pytest.fixture +def admin_billing_multiple_subscriptions( + db, test_merchant, other_merchant, test_platform, another_platform, + admin_billing_tiers, platform_factory, +): + """Create 3 subscriptions across different merchants/platforms.""" + now = datetime.now(UTC) + third_platform = platform_factory() + + subs = [ + MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=test_platform.id, + tier_id=admin_billing_tiers[0].id, + status=SubscriptionStatus.ACTIVE.value, + period_start=now, + period_end=now + timedelta(days=30), + ), + MerchantSubscription( + merchant_id=other_merchant.id, + platform_id=another_platform.id, + tier_id=admin_billing_tiers[1].id, + status=SubscriptionStatus.TRIAL.value, + period_start=now, + period_end=now + timedelta(days=14), + ), + MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=third_platform.id, + tier_id=admin_billing_tiers[2].id, + status=SubscriptionStatus.ACTIVE.value, + period_start=now, + period_end=now + timedelta(days=30), + ), + ] + db.add_all(subs) + db.commit() + for s in subs: + db.refresh(s) + return subs + + +@pytest.fixture +def admin_billing_history_records(db, test_merchant): + """Create 3 billing history records for test_merchant.""" + records = [] + for i in range(3): + record = BillingHistory( + merchant_id=test_merchant.id, + stripe_invoice_id=f"in_admin_test_{i}", + invoice_number=f"ADM-{i:03d}", + invoice_date=datetime.now(UTC) - timedelta(days=i * 30), + subtotal_cents=4900, + tax_cents=0, + total_cents=4900, + amount_paid_cents=4900, + currency="EUR", + status="paid", + ) + records.append(record) + db.add_all(records) + db.commit() + for r in records: + db.refresh(r) + return records diff --git a/app/modules/billing/tests/unit/test_billing_service.py b/app/modules/billing/tests/unit/test_billing_service.py index 5b9d8ad3..579bb9c8 100644 --- a/app/modules/billing/tests/unit/test_billing_service.py +++ b/app/modules/billing/tests/unit/test_billing_service.py @@ -1,12 +1,11 @@ -# tests/unit/services/test_billing_service.py +# app/modules/billing/tests/unit/test_billing_service.py """Unit tests for BillingService.""" -from datetime import datetime, timezone +from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, patch import pytest -from app.modules.tenancy.exceptions import StoreNotFoundException from app.modules.billing.services.billing_service import ( BillingService, NoActiveSubscriptionError, @@ -25,26 +24,107 @@ from app.modules.billing.models import ( ) +# ============================================================================ +# Tier Lookup +# ============================================================================ + + @pytest.mark.unit @pytest.mark.billing class TestBillingServiceTiers: """Test suite for BillingService tier operations.""" def setup_method(self): - """Initialize service instance before each test.""" self.service = BillingService() + def test_get_tier_by_code_found(self, db, bs_tier_essential): + """Returns the active tier.""" + tier = self.service.get_tier_by_code(db, "essential") + assert tier.code == "essential" + def test_get_tier_by_code_not_found(self, db): - """Test getting non-existent tier raises error.""" + """Nonexistent tier raises TierNotFoundError.""" with pytest.raises(TierNotFoundError) as exc_info: self.service.get_tier_by_code(db, "nonexistent") - assert exc_info.value.tier_code == "nonexistent" + def test_get_tier_by_code_inactive_not_returned(self, db, bs_tier_essential): + """Inactive tier raises TierNotFoundError (only active tiers returned).""" + bs_tier_essential.is_active = False + db.flush() + + with pytest.raises(TierNotFoundError): + self.service.get_tier_by_code(db, "essential") -# TestBillingServiceCheckout removed — depends on refactored store_id-based API -# TestBillingServicePortal removed — depends on refactored store_id-based API +# ============================================================================ +# Available Tiers with Upgrade/Downgrade Flags +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceAvailableTiers: + """Test suite for get_available_tiers (upgrade/downgrade detection).""" + + def setup_method(self): + self.service = BillingService() + + def test_get_available_tiers_returns_all(self, db, bs_tiers): + """Returns all active public tiers.""" + tier_list, tier_order = self.service.get_available_tiers(db, None) + assert len(tier_list) == 3 + + def test_get_available_tiers_marks_current(self, db, bs_tiers): + """Current tier is marked with is_current=True.""" + tier_list, _ = self.service.get_available_tiers(db, bs_tiers[1].id) + + current = [t for t in tier_list if t["is_current"]] + assert len(current) == 1 + assert current[0]["code"] == "professional" + + def test_get_available_tiers_upgrade_flags(self, db, bs_tiers): + """Tiers with higher display_order have can_upgrade=True.""" + tier_list, _ = self.service.get_available_tiers(db, bs_tiers[0].id) + + essential = next(t for t in tier_list if t["code"] == "essential") + professional = next(t for t in tier_list if t["code"] == "professional") + business = next(t for t in tier_list if t["code"] == "business") + + assert essential["is_current"] is True + assert essential["can_upgrade"] is False + assert professional["can_upgrade"] is True + assert business["can_upgrade"] is True + + def test_get_available_tiers_downgrade_flags(self, db, bs_tiers): + """Tiers with lower display_order have can_downgrade=True.""" + tier_list, _ = self.service.get_available_tiers(db, bs_tiers[2].id) + + essential = next(t for t in tier_list if t["code"] == "essential") + professional = next(t for t in tier_list if t["code"] == "professional") + business = next(t for t in tier_list if t["code"] == "business") + + assert essential["can_downgrade"] is True + assert professional["can_downgrade"] is True + assert business["is_current"] is True + assert business["can_downgrade"] is False + + def test_get_available_tiers_no_current_tier(self, db, bs_tiers): + """When current_tier_id is None, no tier is marked current.""" + tier_list, _ = self.service.get_available_tiers(db, None) + assert all(t["is_current"] is False for t in tier_list) + + def test_get_available_tiers_returns_tier_order_map(self, db, bs_tiers): + """Returns tier_order map of code → display_order.""" + _, tier_order = self.service.get_available_tiers(db, None) + assert tier_order["essential"] == 1 + assert tier_order["professional"] == 2 + assert tier_order["business"] == 3 + + +# ============================================================================ +# Invoices +# ============================================================================ @pytest.mark.unit @@ -53,17 +133,42 @@ class TestBillingServiceInvoices: """Test suite for BillingService invoice operations.""" def setup_method(self): - """Initialize service instance before each test.""" self.service = BillingService() - def test_get_invoices_empty(self, db, test_store): - """Test getting invoices when none exist.""" - invoices, total = self.service.get_invoices(db, test_store.id) - + def test_get_invoices_empty(self, db, test_merchant): + """Returns empty list and zero total when no invoices exist.""" + invoices, total = self.service.get_invoices(db, test_merchant.id) assert invoices == [] assert total == 0 - # test_get_invoices_with_data and test_get_invoices_pagination removed — fixture model mismatch after migration + def test_get_invoices_with_data(self, db, bs_billing_history): + """Returns invoices for the merchant.""" + merchant_id = bs_billing_history[0].merchant_id + invoices, total = self.service.get_invoices(db, merchant_id) + assert total == 3 + assert len(invoices) == 3 + + def test_get_invoices_pagination(self, db, bs_billing_history): + """Pagination limits and offsets results.""" + merchant_id = bs_billing_history[0].merchant_id + invoices, total = self.service.get_invoices(db, merchant_id, skip=0, limit=2) + assert total == 3 + assert len(invoices) == 2 + + invoices2, _ = self.service.get_invoices(db, merchant_id, skip=2, limit=2) + assert len(invoices2) == 1 + + def test_get_invoices_ordered_by_date_desc(self, db, bs_billing_history): + """Invoices are returned newest first.""" + merchant_id = bs_billing_history[0].merchant_id + invoices, _ = self.service.get_invoices(db, merchant_id) + dates = [inv.invoice_date for inv in invoices] + assert dates == sorted(dates, reverse=True) + + +# ============================================================================ +# Add-ons +# ============================================================================ @pytest.mark.unit @@ -72,55 +177,442 @@ class TestBillingServiceAddons: """Test suite for BillingService addon operations.""" def setup_method(self): - """Initialize service instance before each test.""" self.service = BillingService() def test_get_available_addons_empty(self, db): - """Test getting addons when none exist.""" + """Returns empty when no addons exist.""" addons = self.service.get_available_addons(db) assert addons == [] def test_get_available_addons_with_data(self, db, test_addon_products): - """Test getting all available addons.""" + """Returns all active addons.""" addons = self.service.get_available_addons(db) - assert len(addons) == 3 assert all(addon.is_active for addon in addons) def test_get_available_addons_by_category(self, db, test_addon_products): - """Test filtering addons by category.""" + """Filters by category.""" domain_addons = self.service.get_available_addons(db, category="domain") - assert len(domain_addons) == 1 assert domain_addons[0].category == "domain" def test_get_store_addons_empty(self, db, test_store): - """Test getting store addons when none purchased.""" + """Returns empty when store has no purchased addons.""" addons = self.service.get_store_addons(db, test_store.id) assert addons == [] - -# TestBillingServiceCancellation removed — depends on refactored store_id-based API -# TestBillingServiceStore removed — get_store method was removed from BillingService +# ============================================================================ +# Subscription with Tier +# ============================================================================ -# ==================== Fixtures ==================== +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceSubscriptionWithTier: + """Tests for get_subscription_with_tier.""" + + def setup_method(self): + self.service = BillingService() + + def test_get_subscription_with_tier_existing( + self, db, bs_subscription + ): + """Returns (subscription, tier) tuple for existing subscription.""" + sub, tier = self.service.get_subscription_with_tier( + db, bs_subscription.merchant_id, bs_subscription.platform_id + ) + assert sub.id == bs_subscription.id + assert tier is not None + assert tier.code == "essential" + + def test_get_subscription_with_tier_creates_if_missing( + self, db, test_merchant, test_platform, bs_tier_essential + ): + """Creates a trial subscription when none exists (via get_or_create).""" + sub, tier = self.service.get_subscription_with_tier( + db, test_merchant.id, test_platform.id + ) + assert sub is not None + assert sub.status == SubscriptionStatus.TRIAL.value + + +# ============================================================================ +# Change Tier +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceChangeTier: + """Tests for change_tier (the tier upgrade/downgrade flow).""" + + def setup_method(self): + self.service = BillingService() + + def test_change_tier_no_subscription_raises(self, db, bs_tiers): + """Raises NoActiveSubscriptionError when no subscription exists.""" + with pytest.raises(NoActiveSubscriptionError): + self.service.change_tier(db, 99999, 99999, "professional", False) + + def test_change_tier_no_stripe_subscription_raises( + self, db, bs_subscription, bs_tiers + ): + """Raises when subscription has no stripe_subscription_id.""" + # bs_subscription has no Stripe IDs + with pytest.raises(NoActiveSubscriptionError): + self.service.change_tier( + db, + bs_subscription.merchant_id, + bs_subscription.platform_id, + "professional", + False, + ) + + def test_change_tier_nonexistent_tier_raises( + self, db, bs_stripe_subscription + ): + """Raises TierNotFoundError for nonexistent tier.""" + with pytest.raises(TierNotFoundError): + self.service.change_tier( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + "nonexistent", + False, + ) + + def test_change_tier_no_price_id_raises( + self, db, bs_stripe_subscription, bs_tiers + ): + """Raises StripePriceNotConfiguredError when tier has no Stripe price.""" + # bs_tiers have no stripe_price_* set + with pytest.raises(StripePriceNotConfiguredError): + self.service.change_tier( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + "professional", + False, + ) + + @patch("app.modules.billing.services.billing_service.stripe_service") + def test_change_tier_success( + self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe + ): + """Successful tier change updates local subscription and calls Stripe.""" + mock_stripe.is_configured = True + mock_stripe.update_subscription = MagicMock() + + result = self.service.change_tier( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + "professional", + False, + ) + + assert result["new_tier"] == "professional" + assert result["effective_immediately"] is True + assert bs_stripe_subscription.tier_id == bs_tiers_with_stripe[1].id + mock_stripe.update_subscription.assert_called_once() + + @patch("app.modules.billing.services.billing_service.stripe_service") + def test_change_tier_annual_uses_annual_price( + self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe + ): + """Annual billing selects stripe_price_annual_id.""" + mock_stripe.is_configured = True + mock_stripe.update_subscription = MagicMock() + + self.service.change_tier( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + "professional", + True, + ) + + call_args = mock_stripe.update_subscription.call_args + assert call_args.kwargs["new_price_id"] == "price_pro_annual" + + +# ============================================================================ +# _is_upgrade +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceIsUpgrade: + """Tests for _is_upgrade helper.""" + + def setup_method(self): + self.service = BillingService() + + def test_is_upgrade_true(self, db, bs_tiers): + """Higher display_order is an upgrade.""" + assert self.service._is_upgrade(db, bs_tiers[0].id, bs_tiers[2].id) is True + + def test_is_upgrade_false_downgrade(self, db, bs_tiers): + """Lower display_order is not an upgrade.""" + assert self.service._is_upgrade(db, bs_tiers[2].id, bs_tiers[0].id) is False + + def test_is_upgrade_same_tier(self, db, bs_tiers): + """Same tier is not an upgrade.""" + assert self.service._is_upgrade(db, bs_tiers[1].id, bs_tiers[1].id) is False + + def test_is_upgrade_none_ids(self, db): + """None tier IDs return False.""" + assert self.service._is_upgrade(db, None, None) is False + assert self.service._is_upgrade(db, None, 1) is False + assert self.service._is_upgrade(db, 1, None) is False + + +# ============================================================================ +# Cancel Subscription +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceCancel: + """Tests for cancel_subscription.""" + + def setup_method(self): + self.service = BillingService() + + def test_cancel_no_subscription_raises(self, db): + """Raises when no subscription found.""" + with pytest.raises(NoActiveSubscriptionError): + self.service.cancel_subscription(db, 99999, 99999, None, False) + + def test_cancel_no_stripe_id_raises(self, db, bs_subscription): + """Raises when subscription has no stripe_subscription_id.""" + with pytest.raises(NoActiveSubscriptionError): + self.service.cancel_subscription( + db, + bs_subscription.merchant_id, + bs_subscription.platform_id, + "reason", + False, + ) + + @patch("app.modules.billing.services.billing_service.stripe_service") + def test_cancel_success(self, mock_stripe, db, bs_stripe_subscription): + """Cancellation records timestamp and reason.""" + mock_stripe.is_configured = True + mock_stripe.cancel_subscription = MagicMock() + + result = self.service.cancel_subscription( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + "Too expensive", + False, + ) + + assert result["message"] == "Subscription cancelled successfully" + assert bs_stripe_subscription.cancelled_at is not None + assert bs_stripe_subscription.cancellation_reason == "Too expensive" + mock_stripe.cancel_subscription.assert_called_once() + + +# ============================================================================ +# Reactivate Subscription +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceReactivate: + """Tests for reactivate_subscription.""" + + def setup_method(self): + self.service = BillingService() + + def test_reactivate_no_subscription_raises(self, db): + """Raises when no subscription found.""" + with pytest.raises(NoActiveSubscriptionError): + self.service.reactivate_subscription(db, 99999, 99999) + + def test_reactivate_not_cancelled_raises(self, db, bs_stripe_subscription): + """Raises SubscriptionNotCancelledError when not cancelled.""" + with pytest.raises(SubscriptionNotCancelledError): + self.service.reactivate_subscription( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + ) + + @patch("app.modules.billing.services.billing_service.stripe_service") + def test_reactivate_success(self, mock_stripe, db, bs_stripe_subscription): + """Reactivation clears cancellation and calls Stripe.""" + mock_stripe.is_configured = True + mock_stripe.reactivate_subscription = MagicMock() + + # Cancel first + bs_stripe_subscription.cancelled_at = datetime.now(UTC) + bs_stripe_subscription.cancellation_reason = "Testing" + db.flush() + + result = self.service.reactivate_subscription( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + ) + + assert result["message"] == "Subscription reactivated successfully" + assert bs_stripe_subscription.cancelled_at is None + assert bs_stripe_subscription.cancellation_reason is None + mock_stripe.reactivate_subscription.assert_called_once() + + +# ============================================================================ +# Checkout Session +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceCheckout: + """Tests for create_checkout_session.""" + + def setup_method(self): + self.service = BillingService() + + def test_checkout_stripe_not_configured_raises(self, db, bs_tiers_with_stripe): + """Raises PaymentSystemNotConfiguredError when Stripe is off.""" + with patch( + "app.modules.billing.services.billing_service.stripe_service" + ) as mock_stripe: + mock_stripe.is_configured = False + + with pytest.raises(PaymentSystemNotConfiguredError): + self.service.create_checkout_session( + db, 1, 1, "essential", False, "http://ok", "http://cancel" + ) + + def test_checkout_nonexistent_tier_raises(self, db): + """Raises TierNotFoundError for nonexistent tier.""" + with patch( + "app.modules.billing.services.billing_service.stripe_service" + ) as mock_stripe: + mock_stripe.is_configured = True + + with pytest.raises(TierNotFoundError): + self.service.create_checkout_session( + db, 1, 1, "nonexistent", False, "http://ok", "http://cancel" + ) + + +# ============================================================================ +# Portal Session +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServicePortal: + """Tests for create_portal_session.""" + + def setup_method(self): + self.service = BillingService() + + def test_portal_stripe_not_configured_raises(self, db): + """Raises PaymentSystemNotConfiguredError when Stripe is off.""" + with patch( + "app.modules.billing.services.billing_service.stripe_service" + ) as mock_stripe: + mock_stripe.is_configured = False + + with pytest.raises(PaymentSystemNotConfiguredError): + self.service.create_portal_session(db, 1, 1, "http://return") + + def test_portal_no_subscription_raises(self, db): + """Raises NoActiveSubscriptionError when no subscription found.""" + with patch( + "app.modules.billing.services.billing_service.stripe_service" + ) as mock_stripe: + mock_stripe.is_configured = True + + with pytest.raises(NoActiveSubscriptionError): + self.service.create_portal_session(db, 99999, 99999, "http://return") + + def test_portal_no_customer_id_raises(self, db, bs_subscription): + """Raises when subscription has no stripe_customer_id.""" + with patch( + "app.modules.billing.services.billing_service.stripe_service" + ) as mock_stripe: + mock_stripe.is_configured = True + + with pytest.raises(NoActiveSubscriptionError): + self.service.create_portal_session( + db, + bs_subscription.merchant_id, + bs_subscription.platform_id, + "http://return", + ) + + +# ============================================================================ +# Upcoming Invoice +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceUpcomingInvoice: + """Tests for get_upcoming_invoice.""" + + def setup_method(self): + self.service = BillingService() + + def test_upcoming_invoice_no_subscription_raises(self, db): + """Raises when no subscription exists.""" + with pytest.raises(NoActiveSubscriptionError): + self.service.get_upcoming_invoice(db, 99999, 99999) + + def test_upcoming_invoice_no_customer_id_raises(self, db, bs_subscription): + """Raises when subscription has no stripe_customer_id.""" + with pytest.raises(NoActiveSubscriptionError): + self.service.get_upcoming_invoice( + db, bs_subscription.merchant_id, bs_subscription.platform_id + ) + + def test_upcoming_invoice_stripe_not_configured_returns_empty( + self, db, bs_stripe_subscription + ): + """Returns empty invoice when Stripe is not configured.""" + with patch( + "app.modules.billing.services.billing_service.stripe_service" + ) as mock_stripe: + mock_stripe.is_configured = False + + result = self.service.get_upcoming_invoice( + db, + bs_stripe_subscription.merchant_id, + bs_stripe_subscription.platform_id, + ) + + assert result["amount_due_cents"] == 0 + assert result["line_items"] == [] + + +# ============================================================================ +# Fixtures +# ============================================================================ @pytest.fixture -def test_subscription_tier(db): - """Create a basic subscription tier.""" +def bs_tier_essential(db): + """Create essential subscription tier.""" tier = SubscriptionTier( code="essential", name="Essential", description="Essential plan", price_monthly_cents=4900, price_annual_cents=49000, - orders_per_month=100, - products_limit=200, - team_members=1, - features=["basic_support"], display_order=1, is_active=True, is_public=True, @@ -132,39 +624,14 @@ def test_subscription_tier(db): @pytest.fixture -def test_subscription_tier_with_stripe(db): - """Create a subscription tier with Stripe configuration.""" - tier = SubscriptionTier( - code="essential", - name="Essential", - description="Essential plan", - price_monthly_cents=4900, - price_annual_cents=49000, - orders_per_month=100, - products_limit=200, - team_members=1, - features=["basic_support"], - display_order=1, - is_active=True, - is_public=True, - stripe_product_id="prod_test123", - stripe_price_monthly_id="price_test123", - stripe_price_annual_id="price_test456", - ) - db.add(tier) - db.commit() - db.refresh(tier) - return tier - - -@pytest.fixture -def test_subscription_tiers(db): - """Create multiple subscription tiers.""" +def bs_tiers(db): + """Create three tiers without Stripe config.""" tiers = [ SubscriptionTier( code="essential", name="Essential", price_monthly_cents=4900, + price_annual_cents=49000, display_order=1, is_active=True, is_public=True, @@ -173,6 +640,7 @@ def test_subscription_tiers(db): code="professional", name="Professional", price_monthly_cents=9900, + price_annual_cents=99000, display_order=2, is_active=True, is_public=True, @@ -181,6 +649,7 @@ def test_subscription_tiers(db): code="business", name="Business", price_monthly_cents=19900, + price_annual_cents=199000, display_order=3, is_active=True, is_public=True, @@ -188,136 +657,107 @@ def test_subscription_tiers(db): ] db.add_all(tiers) db.commit() - for tier in tiers: - db.refresh(tier) + for t in tiers: + db.refresh(t) return tiers @pytest.fixture -def test_subscription(db, test_store): - """Create a basic subscription for testing.""" - # Create tier first - tier = SubscriptionTier( - code="essential", - name="Essential", - price_monthly_cents=4900, - display_order=1, - is_active=True, - is_public=True, - ) - db.add(tier) +def bs_tiers_with_stripe(db): + """Create tiers with Stripe price IDs configured.""" + tiers = [ + SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + price_annual_cents=49000, + display_order=1, + is_active=True, + is_public=True, + stripe_product_id="prod_essential", + stripe_price_monthly_id="price_ess_monthly", + stripe_price_annual_id="price_ess_annual", + ), + SubscriptionTier( + code="professional", + name="Professional", + price_monthly_cents=9900, + price_annual_cents=99000, + display_order=2, + is_active=True, + is_public=True, + stripe_product_id="prod_professional", + stripe_price_monthly_id="price_pro_monthly", + stripe_price_annual_id="price_pro_annual", + ), + SubscriptionTier( + code="business", + name="Business", + price_monthly_cents=19900, + price_annual_cents=199000, + display_order=3, + is_active=True, + is_public=True, + stripe_product_id="prod_business", + stripe_price_monthly_id="price_biz_monthly", + stripe_price_annual_id="price_biz_annual", + ), + ] + db.add_all(tiers) db.commit() - - subscription = MerchantSubscription( - store_id=test_store.id, - tier="essential", - status=SubscriptionStatus.ACTIVE, - period_start=datetime.now(timezone.utc), - period_end=datetime.now(timezone.utc), - ) - db.add(subscription) - db.commit() - db.refresh(subscription) - return subscription + for t in tiers: + db.refresh(t) + return tiers @pytest.fixture -def test_active_subscription(db, test_store): +def bs_subscription(db, test_merchant, test_platform, bs_tier_essential): + """Create an active merchant subscription (no Stripe IDs).""" + now = datetime.now(UTC) + sub = MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=test_platform.id, + tier_id=bs_tier_essential.id, + status=SubscriptionStatus.ACTIVE.value, + period_start=now, + period_end=now + timedelta(days=30), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return sub + + +@pytest.fixture +def bs_stripe_subscription(db, test_merchant, test_platform, bs_tier_essential): """Create an active subscription with Stripe IDs.""" - # Create tier first if not exists - tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first() - if not tier: - tier = SubscriptionTier( - code="essential", - name="Essential", - price_monthly_cents=4900, - display_order=1, - is_active=True, - is_public=True, - ) - db.add(tier) - db.commit() - - subscription = MerchantSubscription( - store_id=test_store.id, - tier="essential", - status=SubscriptionStatus.ACTIVE, + now = datetime.now(UTC) + sub = MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=test_platform.id, + tier_id=bs_tier_essential.id, + status=SubscriptionStatus.ACTIVE.value, stripe_customer_id="cus_test123", stripe_subscription_id="sub_test123", - period_start=datetime.now(timezone.utc), - period_end=datetime.now(timezone.utc), + period_start=now, + period_end=now + timedelta(days=30), ) - db.add(subscription) + db.add(sub) db.commit() - db.refresh(subscription) - return subscription + db.refresh(sub) + return sub @pytest.fixture -def test_cancelled_subscription(db, test_store): - """Create a cancelled subscription.""" - # Create tier first if not exists - tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first() - if not tier: - tier = SubscriptionTier( - code="essential", - name="Essential", - price_monthly_cents=4900, - display_order=1, - is_active=True, - is_public=True, - ) - db.add(tier) - db.commit() - - subscription = MerchantSubscription( - store_id=test_store.id, - tier="essential", - status=SubscriptionStatus.ACTIVE, - stripe_customer_id="cus_test123", - stripe_subscription_id="sub_test123", - period_start=datetime.now(timezone.utc), - period_end=datetime.now(timezone.utc), - cancelled_at=datetime.now(timezone.utc), - cancellation_reason="Too expensive", - ) - db.add(subscription) - db.commit() - db.refresh(subscription) - return subscription - - -@pytest.fixture -def test_billing_history(db, test_store): - """Create a billing history record.""" - record = BillingHistory( - store_id=test_store.id, - stripe_invoice_id="in_test123", - invoice_number="INV-001", - invoice_date=datetime.now(timezone.utc), - subtotal_cents=4900, - tax_cents=0, - total_cents=4900, - amount_paid_cents=4900, - currency="EUR", - status="paid", - ) - db.add(record) - db.commit() - db.refresh(record) - return record - - -@pytest.fixture -def test_multiple_invoices(db, test_store): - """Create multiple billing history records.""" +def bs_billing_history(db, test_merchant): + """Create billing history records for test_merchant.""" records = [] - for i in range(5): + for i in range(3): record = BillingHistory( - store_id=test_store.id, - stripe_invoice_id=f"in_test{i}", - invoice_number=f"INV-{i:03d}", - invoice_date=datetime.now(timezone.utc), + merchant_id=test_merchant.id, + stripe_invoice_id=f"in_bs_test_{i}", + invoice_number=f"BS-{i:03d}", + invoice_date=datetime.now(UTC) - timedelta(days=i * 30), subtotal_cents=4900, tax_cents=0, total_cents=4900, @@ -328,6 +768,8 @@ def test_multiple_invoices(db, test_store): records.append(record) db.add_all(records) db.commit() + for r in records: + db.refresh(r) return records diff --git a/app/modules/billing/tests/unit/test_subscription_service.py b/app/modules/billing/tests/unit/test_subscription_service.py new file mode 100644 index 00000000..019275e8 --- /dev/null +++ b/app/modules/billing/tests/unit/test_subscription_service.py @@ -0,0 +1,579 @@ +# app/modules/billing/tests/unit/test_subscription_service.py +"""Unit tests for SubscriptionService (merchant-level subscription management).""" + +from datetime import UTC, datetime, timedelta + +import pytest + +from app.modules.billing.exceptions import SubscriptionNotFoundException +from app.modules.billing.models import ( + MerchantSubscription, + SubscriptionStatus, + SubscriptionTier, + TierCode, +) +from app.modules.billing.services.subscription_service import SubscriptionService + + +# ============================================================================ +# Tier Information +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceTierLookup: + """Tests for tier lookup methods.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_get_tier_by_code_returns_tier(self, db, billing_tier_essential): + """Existing tier code returns the tier object.""" + tier = self.service.get_tier_by_code(db, "essential") + + assert tier is not None + assert tier.code == "essential" + assert tier.id == billing_tier_essential.id + + def test_get_tier_by_code_nonexistent_returns_none(self, db): + """Nonexistent tier code returns None (does not raise).""" + tier = self.service.get_tier_by_code(db, "nonexistent") + assert tier is None + + def test_get_tier_id_returns_id(self, db, billing_tier_essential): + """get_tier_id returns the integer ID for a valid code.""" + tier_id = self.service.get_tier_id(db, "essential") + assert tier_id == billing_tier_essential.id + + def test_get_tier_id_nonexistent_returns_none(self, db): + """get_tier_id returns None for nonexistent code.""" + tier_id = self.service.get_tier_id(db, "nonexistent") + assert tier_id is None + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceGetAllTiers: + """Tests for get_all_tiers with platform filtering.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_get_all_tiers_returns_active_public_tiers(self, db, billing_tiers): + """Returns only active + public tiers, ordered by display_order.""" + tiers = self.service.get_all_tiers(db) + + assert len(tiers) == 3 + assert tiers[0].code == "essential" + assert tiers[1].code == "professional" + assert tiers[2].code == "business" + + def test_get_all_tiers_excludes_inactive(self, db, billing_tiers): + """Inactive tiers are excluded.""" + # Deactivate one + billing_tiers[0].is_active = False + db.flush() + + tiers = self.service.get_all_tiers(db) + assert len(tiers) == 2 + assert all(t.code != "essential" for t in tiers) + + def test_get_all_tiers_excludes_non_public(self, db, billing_tiers): + """Non-public tiers are excluded.""" + billing_tiers[2].is_public = False + db.flush() + + tiers = self.service.get_all_tiers(db) + assert len(tiers) == 2 + assert all(t.code != "business" for t in tiers) + + def test_get_all_tiers_with_platform_filter(self, db, test_platform, billing_tiers): + """Platform filter returns platform-specific tiers + global tiers.""" + # Make one tier platform-specific + billing_tiers[0].platform_id = test_platform.id + # billing_tiers[1] and [2] remain global (platform_id=NULL) + db.flush() + + tiers = self.service.get_all_tiers(db, platform_id=test_platform.id) + + codes = {t.code for t in tiers} + assert "essential" in codes # Platform-specific + assert "professional" in codes # Global + assert "business" in codes # Global + + def test_get_all_tiers_platform_filter_excludes_other_platform( + self, db, test_platform, another_platform, billing_tiers + ): + """Tiers belonging to another platform are excluded.""" + billing_tiers[0].platform_id = another_platform.id + db.flush() + + tiers = self.service.get_all_tiers(db, platform_id=test_platform.id) + + codes = {t.code for t in tiers} + assert "essential" not in codes + assert len(tiers) == 2 + + +# ============================================================================ +# Merchant Subscription CRUD +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceGetMerchantSubscription: + """Tests for subscription retrieval methods.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_get_merchant_subscription_found( + self, db, billing_subscription + ): + """Returns subscription when it exists.""" + sub = self.service.get_merchant_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub is not None + assert sub.id == billing_subscription.id + + def test_get_merchant_subscription_not_found(self, db): + """Returns None when no subscription exists.""" + sub = self.service.get_merchant_subscription(db, 99999, 99999) + assert sub is None + + def test_get_merchant_subscription_eager_loads_tier( + self, db, billing_subscription + ): + """Subscription comes with tier relationship loaded.""" + sub = self.service.get_merchant_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + # Accessing .tier should not trigger a lazy load error + assert sub.tier is not None + assert sub.tier.code == "essential" + + def test_get_merchant_subscription_eager_loads_platform( + self, db, billing_subscription + ): + """Subscription comes with platform relationship loaded.""" + sub = self.service.get_merchant_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub.platform is not None + + def test_get_merchant_subscriptions_all( + self, db, test_merchant, test_platform, another_platform, billing_tier_essential + ): + """Returns all subscriptions for a merchant across platforms.""" + now = datetime.now(UTC) + sub1 = MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=test_platform.id, + tier_id=billing_tier_essential.id, + status=SubscriptionStatus.ACTIVE.value, + period_start=now, + period_end=now + timedelta(days=30), + ) + sub2 = MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=another_platform.id, + tier_id=billing_tier_essential.id, + status=SubscriptionStatus.TRIAL.value, + period_start=now, + period_end=now + timedelta(days=14), + ) + db.add_all([sub1, sub2]) + db.flush() + + subs = self.service.get_merchant_subscriptions(db, test_merchant.id) + assert len(subs) == 2 + + def test_get_merchant_subscriptions_empty(self, db, test_merchant): + """Returns empty list when merchant has no subscriptions.""" + subs = self.service.get_merchant_subscriptions(db, test_merchant.id) + assert subs == [] + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceGetOrRaise: + """Tests for get_subscription_or_raise.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_get_subscription_or_raise_found(self, db, billing_subscription): + """Returns subscription when found.""" + sub = self.service.get_subscription_or_raise( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub.id == billing_subscription.id + + def test_get_subscription_or_raise_not_found(self, db): + """Raises SubscriptionNotFoundException when not found.""" + with pytest.raises(SubscriptionNotFoundException): + self.service.get_subscription_or_raise(db, 99999, 99999) + + +# ============================================================================ +# Create Subscription +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceCreate: + """Tests for create_merchant_subscription.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_create_trial_subscription( + self, db, test_merchant, test_platform, billing_tier_essential + ): + """Default creation starts a 14-day trial.""" + sub = self.service.create_merchant_subscription( + db, test_merchant.id, test_platform.id + ) + + assert sub.merchant_id == test_merchant.id + assert sub.platform_id == test_platform.id + assert sub.status == SubscriptionStatus.TRIAL.value + assert sub.tier_id == billing_tier_essential.id + assert sub.trial_ends_at is not None + assert sub.is_annual is False + + def test_create_subscription_period_end_matches_trial( + self, db, test_merchant, test_platform, billing_tier_essential + ): + """Period end is set to trial_days from now.""" + sub = self.service.create_merchant_subscription( + db, test_merchant.id, test_platform.id, trial_days=7 + ) + + expected_end = sub.period_start + timedelta(days=7) + # Allow 1 second tolerance + assert abs((sub.period_end - expected_end).total_seconds()) < 1 + + def test_create_annual_subscription( + self, db, test_merchant, test_platform, billing_tier_essential + ): + """Annual subscription with no trial gets 365-day period.""" + sub = self.service.create_merchant_subscription( + db, test_merchant.id, test_platform.id, trial_days=0, is_annual=True + ) + + assert sub.status == SubscriptionStatus.ACTIVE.value + assert sub.is_annual is True + assert sub.trial_ends_at is None + expected_end = sub.period_start + timedelta(days=365) + assert abs((sub.period_end - expected_end).total_seconds()) < 1 + + def test_create_monthly_subscription( + self, db, test_merchant, test_platform, billing_tier_essential + ): + """Monthly subscription with no trial gets 30-day period.""" + sub = self.service.create_merchant_subscription( + db, test_merchant.id, test_platform.id, trial_days=0, is_annual=False + ) + + assert sub.status == SubscriptionStatus.ACTIVE.value + assert sub.is_annual is False + assert sub.trial_ends_at is None + expected_end = sub.period_start + timedelta(days=30) + assert abs((sub.period_end - expected_end).total_seconds()) < 1 + + def test_create_subscription_custom_tier( + self, db, test_merchant, test_platform, billing_tiers + ): + """Can specify a custom tier code.""" + sub = self.service.create_merchant_subscription( + db, test_merchant.id, test_platform.id, tier_code="professional" + ) + assert sub.tier_id == billing_tiers[1].id + + def test_create_duplicate_raises( + self, db, billing_subscription + ): + """Creating a second subscription for same merchant+platform raises ValueError.""" + with pytest.raises(ValueError, match="already has a subscription"): + self.service.create_merchant_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceGetOrCreate: + """Tests for get_or_create_subscription.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_get_or_create_returns_existing(self, db, billing_subscription): + """Returns existing subscription without creating a new one.""" + sub = self.service.get_or_create_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub.id == billing_subscription.id + + def test_get_or_create_creates_new( + self, db, test_merchant, test_platform, billing_tier_essential + ): + """Creates new subscription when none exists.""" + sub = self.service.get_or_create_subscription( + db, test_merchant.id, test_platform.id + ) + assert sub is not None + assert sub.merchant_id == test_merchant.id + assert sub.status == SubscriptionStatus.TRIAL.value + + +# ============================================================================ +# Upgrade / Cancel / Reactivate +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceUpgradeTier: + """Tests for upgrade_tier.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_upgrade_tier_changes_tier_id(self, db, billing_subscription, billing_tiers): + """Tier is updated to the new tier's ID.""" + sub = self.service.upgrade_tier( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + "professional", + ) + assert sub.tier_id == billing_tiers[1].id + + def test_upgrade_from_trial_activates(self, db, billing_subscription, billing_tiers): + """Upgrading from trial status sets status to active.""" + billing_subscription.status = SubscriptionStatus.TRIAL.value + db.flush() + + sub = self.service.upgrade_tier( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + "professional", + ) + assert sub.status == SubscriptionStatus.ACTIVE.value + + def test_upgrade_active_stays_active(self, db, billing_subscription, billing_tiers): + """Upgrading an already-active subscription keeps status active.""" + billing_subscription.status = SubscriptionStatus.ACTIVE.value + db.flush() + + sub = self.service.upgrade_tier( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + "professional", + ) + assert sub.status == SubscriptionStatus.ACTIVE.value + + def test_upgrade_nonexistent_tier_raises(self, db, billing_subscription): + """Upgrading to a nonexistent tier raises ValueError.""" + with pytest.raises(ValueError, match="not found"): + self.service.upgrade_tier( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + "nonexistent_tier", + ) + + def test_upgrade_no_subscription_raises(self, db): + """Upgrading when no subscription exists raises SubscriptionNotFoundException.""" + with pytest.raises(SubscriptionNotFoundException): + self.service.upgrade_tier(db, 99999, 99999, "professional") + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceCancel: + """Tests for cancel_subscription.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_cancel_sets_status(self, db, billing_subscription): + """Cancellation sets status to cancelled.""" + sub = self.service.cancel_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub.status == SubscriptionStatus.CANCELLED.value + + def test_cancel_records_timestamp(self, db, billing_subscription): + """Cancellation records cancelled_at.""" + sub = self.service.cancel_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub.cancelled_at is not None + + def test_cancel_with_reason(self, db, billing_subscription): + """Cancellation stores the reason.""" + sub = self.service.cancel_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + reason="Too expensive", + ) + assert sub.cancellation_reason == "Too expensive" + + def test_cancel_no_subscription_raises(self, db): + """Cancelling when no subscription exists raises.""" + with pytest.raises(SubscriptionNotFoundException): + self.service.cancel_subscription(db, 99999, 99999) + + +@pytest.mark.unit +@pytest.mark.billing +class TestSubscriptionServiceReactivate: + """Tests for reactivate_subscription.""" + + def setup_method(self): + self.service = SubscriptionService() + + def test_reactivate_sets_active(self, db, billing_subscription): + """Reactivation sets status back to active.""" + # First cancel + self.service.cancel_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + + sub = self.service.reactivate_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub.status == SubscriptionStatus.ACTIVE.value + + def test_reactivate_clears_cancellation(self, db, billing_subscription): + """Reactivation clears cancelled_at and cancellation_reason.""" + self.service.cancel_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + reason="Testing", + ) + + sub = self.service.reactivate_subscription( + db, + billing_subscription.merchant_id, + billing_subscription.platform_id, + ) + assert sub.cancelled_at is None + assert sub.cancellation_reason is None + + def test_reactivate_no_subscription_raises(self, db): + """Reactivating when no subscription exists raises.""" + with pytest.raises(SubscriptionNotFoundException): + self.service.reactivate_subscription(db, 99999, 99999) + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def billing_tier_essential(db): + """Create essential subscription tier.""" + tier = SubscriptionTier( + code="essential", + name="Essential", + description="Essential plan", + price_monthly_cents=4900, + price_annual_cents=49000, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def billing_tiers(db): + """Create essential, professional, and business tiers.""" + tiers = [ + SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + price_annual_cents=49000, + display_order=1, + is_active=True, + is_public=True, + ), + SubscriptionTier( + code="professional", + name="Professional", + price_monthly_cents=9900, + price_annual_cents=99000, + display_order=2, + is_active=True, + is_public=True, + ), + SubscriptionTier( + code="business", + name="Business", + price_monthly_cents=19900, + price_annual_cents=199000, + display_order=3, + is_active=True, + is_public=True, + ), + ] + db.add_all(tiers) + db.commit() + for t in tiers: + db.refresh(t) + return tiers + + +@pytest.fixture +def billing_subscription(db, test_merchant, test_platform, billing_tier_essential): + """Create an active merchant subscription (essential tier).""" + now = datetime.now(UTC) + sub = MerchantSubscription( + merchant_id=test_merchant.id, + platform_id=test_platform.id, + tier_id=billing_tier_essential.id, + status=SubscriptionStatus.ACTIVE.value, + period_start=now, + period_end=now + timedelta(days=30), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return sub