test(billing): add comprehensive service layer tests and fix deactivate_tier bug

Add 139 tests across 3 test files for the billing service layer:
- test_subscription_service.py (37 tests): tier lookup, subscription CRUD, upgrades, cancellation
- test_admin_subscription_service.py (39 tests): admin tier/subscription management, stats, billing history
- test_billing_service.py (43 tests): rewritten with correct fixtures after store→merchant migration

Fix production bug in deactivate_tier() — BusinessLogicException was missing
required error_code argument, now uses TIER_HAS_ACTIVE_SUBSCRIPTIONS.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-11 22:55:04 +01:00
parent bfb9b3c119
commit ad8f1c9008
4 changed files with 1836 additions and 172 deletions

View File

@@ -119,7 +119,8 @@ class AdminSubscriptionService:
if active_subs > 0: if active_subs > 0:
raise BusinessLogicException( raise BusinessLogicException(
f"Cannot delete tier: {active_subs} active subscriptions are using it" f"Cannot delete tier: {active_subs} active subscriptions are using it",
"TIER_HAS_ACTIVE_SUBSCRIPTIONS",
) )
tier.is_active = False tier.is_active = False

View File

@@ -0,0 +1,642 @@
# app/modules/billing/tests/unit/test_admin_subscription_service.py
"""Unit tests for AdminSubscriptionService (admin tier & subscription management)."""
from datetime import UTC, datetime, timedelta
import pytest
from app.exceptions import (
BusinessLogicException,
ConflictException,
ResourceNotFoundException,
)
from app.modules.billing.exceptions import TierNotFoundException
from app.modules.billing.models import (
BillingHistory,
MerchantSubscription,
SubscriptionStatus,
SubscriptionTier,
)
from app.modules.billing.services.admin_subscription_service import (
AdminSubscriptionService,
)
from app.modules.tenancy.models import Merchant
# ============================================================================
# Tier Management
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestAdminGetTiers:
"""Tests for get_tiers with filtering."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_get_tiers_active_only_by_default(self, db, admin_billing_tiers):
"""Default call returns only active tiers."""
tiers = self.service.get_tiers(db)
assert len(tiers) == 3
assert all(t.is_active for t in tiers)
def test_get_tiers_include_inactive(self, db, admin_billing_tiers):
"""include_inactive=True returns all tiers."""
admin_billing_tiers[0].is_active = False
db.flush()
tiers = self.service.get_tiers(db, include_inactive=True)
assert len(tiers) == 3
def test_get_tiers_active_excludes_inactive(self, db, admin_billing_tiers):
"""Active filter excludes inactive tiers."""
admin_billing_tiers[0].is_active = False
db.flush()
tiers = self.service.get_tiers(db, include_inactive=False)
assert len(tiers) == 2
assert all(t.code != "essential" for t in tiers)
def test_get_tiers_ordered_by_display_order(self, db, admin_billing_tiers):
"""Tiers are returned in display_order."""
tiers = self.service.get_tiers(db)
orders = [t.display_order for t in tiers]
assert orders == sorted(orders)
def test_get_tiers_platform_filter(self, db, test_platform, admin_billing_tiers):
"""Platform filter returns platform-specific + global tiers."""
admin_billing_tiers[0].platform_id = test_platform.id
db.flush()
tiers = self.service.get_tiers(db, platform_id=test_platform.id)
codes = {t.code for t in tiers}
assert "essential" in codes # Platform-specific
assert "professional" in codes # Global
assert "business" in codes # Global
def test_get_tiers_platform_filter_excludes_other(
self, db, test_platform, another_platform, admin_billing_tiers
):
"""Tiers from another platform are excluded."""
admin_billing_tiers[0].platform_id = another_platform.id
db.flush()
tiers = self.service.get_tiers(db, platform_id=test_platform.id)
assert all(t.code != "essential" for t in tiers)
@pytest.mark.unit
@pytest.mark.billing
class TestAdminGetTierByCode:
"""Tests for get_tier_by_code."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_get_tier_by_code_found(self, db, admin_billing_tiers):
"""Returns the tier when it exists."""
tier = self.service.get_tier_by_code(db, "professional")
assert tier.code == "professional"
assert tier.name == "Professional"
def test_get_tier_by_code_not_found_raises(self, db):
"""Raises TierNotFoundException for nonexistent code."""
with pytest.raises(TierNotFoundException):
self.service.get_tier_by_code(db, "nonexistent")
@pytest.mark.unit
@pytest.mark.billing
class TestAdminCreateTier:
"""Tests for create_tier."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_create_tier_success(self, db):
"""Creates a new tier with provided data."""
data = {
"code": "starter",
"name": "Starter",
"description": "A starter plan",
"price_monthly_cents": 1900,
"price_annual_cents": 19000,
"display_order": 0,
"is_active": True,
"is_public": True,
}
tier = self.service.create_tier(db, data)
db.flush()
assert tier.code == "starter"
assert tier.name == "Starter"
assert tier.price_monthly_cents == 1900
def test_create_tier_duplicate_raises(self, db, admin_billing_tiers):
"""Duplicate tier code raises ConflictException."""
data = {
"code": "essential",
"name": "Essential Duplicate",
"price_monthly_cents": 4900,
"display_order": 1,
"is_active": True,
"is_public": True,
}
with pytest.raises(ConflictException, match="already exists"):
self.service.create_tier(db, data)
@pytest.mark.unit
@pytest.mark.billing
class TestAdminUpdateTier:
"""Tests for update_tier."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_update_tier_name(self, db, admin_billing_tiers):
"""Updates tier fields."""
tier = self.service.update_tier(db, "essential", {"name": "Essential Plus"})
assert tier.name == "Essential Plus"
def test_update_tier_price(self, db, admin_billing_tiers):
"""Updates tier pricing."""
tier = self.service.update_tier(
db, "essential", {"price_monthly_cents": 5900}
)
assert tier.price_monthly_cents == 5900
def test_update_tier_nonexistent_raises(self, db):
"""Updating a nonexistent tier raises."""
with pytest.raises(TierNotFoundException):
self.service.update_tier(db, "nonexistent", {"name": "X"})
@pytest.mark.unit
@pytest.mark.billing
class TestAdminDeactivateTier:
"""Tests for deactivate_tier (soft-delete)."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_deactivate_tier_sets_inactive(self, db, admin_billing_tiers):
"""Deactivation sets is_active to False."""
self.service.deactivate_tier(db, "business")
db.flush()
tier = db.query(SubscriptionTier).filter(
SubscriptionTier.code == "business"
).first()
assert tier.is_active is False
def test_deactivate_tier_with_active_subs_raises(
self, db, admin_billing_tiers, admin_billing_subscription
):
"""Cannot deactivate a tier used by active subscriptions."""
with pytest.raises(BusinessLogicException, match="active subscriptions"):
self.service.deactivate_tier(db, "essential")
def test_deactivate_tier_with_cancelled_subs_ok(
self, db, admin_billing_tiers, admin_billing_subscription
):
"""Can deactivate a tier used only by cancelled subscriptions."""
admin_billing_subscription.status = SubscriptionStatus.CANCELLED.value
db.flush()
# Should not raise
self.service.deactivate_tier(db, "essential")
def test_deactivate_nonexistent_tier_raises(self, db):
"""Deactivating a nonexistent tier raises."""
with pytest.raises(TierNotFoundException):
self.service.deactivate_tier(db, "nonexistent")
# ============================================================================
# Subscription Management
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestAdminListSubscriptions:
"""Tests for list_subscriptions with pagination and filtering."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_list_subscriptions_empty(self, db):
"""Returns empty results when no subscriptions exist."""
result = self.service.list_subscriptions(db)
assert result["total"] == 0
assert result["results"] == []
assert result["page"] == 1
def test_list_subscriptions_returns_results(
self, db, admin_billing_subscription
):
"""Returns subscriptions joined with merchant."""
result = self.service.list_subscriptions(db)
assert result["total"] == 1
sub, merchant = result["results"][0]
assert isinstance(sub, MerchantSubscription)
assert isinstance(merchant, Merchant)
def test_list_subscriptions_filter_by_status(
self, db, admin_billing_subscription, admin_billing_subscription_trial
):
"""Status filter returns only matching subscriptions."""
result = self.service.list_subscriptions(db, status="trial")
assert result["total"] == 1
sub, _ = result["results"][0]
assert sub.status == "trial"
def test_list_subscriptions_filter_by_tier(
self, db, admin_billing_tiers, admin_billing_subscription
):
"""Tier filter returns only matching subscriptions."""
result = self.service.list_subscriptions(db, tier="essential")
assert result["total"] == 1
result = self.service.list_subscriptions(db, tier="professional")
assert result["total"] == 0
def test_list_subscriptions_search_by_merchant_name(
self, db, admin_billing_subscription, test_merchant
):
"""Search filters by merchant name (case-insensitive)."""
# test_merchant has a name like "Test Merchant xxxxxxxx"
result = self.service.list_subscriptions(db, search="Test Merchant")
assert result["total"] == 1
result = self.service.list_subscriptions(db, search="NonExistent Corp")
assert result["total"] == 0
def test_list_subscriptions_pagination(
self, db, admin_billing_multiple_subscriptions
):
"""Pagination returns correct page with correct count."""
result = self.service.list_subscriptions(db, page=1, per_page=2)
assert result["total"] == 3
assert len(result["results"]) == 2
assert result["pages"] == 2
result2 = self.service.list_subscriptions(db, page=2, per_page=2)
assert len(result2["results"]) == 1
@pytest.mark.unit
@pytest.mark.billing
class TestAdminGetSubscription:
"""Tests for get_subscription."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_get_subscription_found(self, db, admin_billing_subscription):
"""Returns (subscription, merchant) tuple."""
sub, merchant = self.service.get_subscription(
db,
admin_billing_subscription.merchant_id,
admin_billing_subscription.platform_id,
)
assert sub.id == admin_billing_subscription.id
assert isinstance(merchant, Merchant)
def test_get_subscription_not_found_raises(self, db):
"""Raises ResourceNotFoundException when not found."""
with pytest.raises(ResourceNotFoundException):
self.service.get_subscription(db, 99999, 99999)
@pytest.mark.unit
@pytest.mark.billing
class TestAdminUpdateSubscription:
"""Tests for update_subscription — the tier_code→tier_id resolution bug area."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_update_subscription_status(self, db, admin_billing_subscription):
"""Can update subscription status."""
sub, _ = self.service.update_subscription(
db,
admin_billing_subscription.merchant_id,
admin_billing_subscription.platform_id,
{"status": SubscriptionStatus.PAST_DUE.value},
)
assert sub.status == SubscriptionStatus.PAST_DUE.value
def test_update_subscription_tier_code_resolves_to_tier_id(
self, db, admin_billing_tiers, admin_billing_subscription
):
"""
tier_code in update_data is resolved to tier_id (FK).
This is the exact bug that was fixed in commit d1fe358 — the admin form
sends tier_code (string) but the model uses tier_id (int FK).
"""
sub, _ = self.service.update_subscription(
db,
admin_billing_subscription.merchant_id,
admin_billing_subscription.platform_id,
{"tier_code": "professional"},
)
assert sub.tier_id == admin_billing_tiers[1].id
def test_update_subscription_tier_code_nonexistent_raises(
self, db, admin_billing_subscription
):
"""Nonexistent tier_code raises TierNotFoundException."""
with pytest.raises(TierNotFoundException):
self.service.update_subscription(
db,
admin_billing_subscription.merchant_id,
admin_billing_subscription.platform_id,
{"tier_code": "nonexistent"},
)
def test_update_subscription_not_found_raises(self, db):
"""Updating a nonexistent subscription raises."""
with pytest.raises(ResourceNotFoundException):
self.service.update_subscription(db, 99999, 99999, {"status": "active"})
# ============================================================================
# Billing History
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestAdminListBillingHistory:
"""Tests for list_billing_history."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_list_billing_history_empty(self, db):
"""Returns empty when no records exist."""
result = self.service.list_billing_history(db)
assert result["total"] == 0
assert result["results"] == []
def test_list_billing_history_with_records(
self, db, admin_billing_history_records
):
"""Returns billing history joined with merchant."""
result = self.service.list_billing_history(db)
assert result["total"] == 3
bh, merchant = result["results"][0]
assert isinstance(bh, BillingHistory)
assert isinstance(merchant, Merchant)
def test_list_billing_history_filter_by_merchant(
self, db, admin_billing_history_records, test_merchant
):
"""Merchant filter returns only matching records."""
result = self.service.list_billing_history(
db, merchant_id=test_merchant.id
)
assert result["total"] == 3
result = self.service.list_billing_history(db, merchant_id=99999)
assert result["total"] == 0
def test_list_billing_history_filter_by_status(
self, db, admin_billing_history_records
):
"""Status filter returns matching records."""
result = self.service.list_billing_history(db, status="paid")
assert result["total"] == 3
result = self.service.list_billing_history(db, status="failed")
assert result["total"] == 0
def test_list_billing_history_pagination(
self, db, admin_billing_history_records
):
"""Pagination works correctly."""
result = self.service.list_billing_history(db, page=1, per_page=2)
assert result["total"] == 3
assert len(result["results"]) == 2
assert result["pages"] == 2
# ============================================================================
# Statistics
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestAdminGetStats:
"""Tests for get_stats (dashboard analytics)."""
def setup_method(self):
self.service = AdminSubscriptionService()
def test_get_stats_empty_db(self, db):
"""Stats with no subscriptions return zeroes."""
stats = self.service.get_stats(db)
assert stats["total_subscriptions"] == 0
assert stats["active_count"] == 0
assert stats["trial_count"] == 0
assert stats["mrr_cents"] == 0
assert stats["arr_cents"] == 0
assert stats["tier_distribution"] == {}
def test_get_stats_counts_by_status(
self, db, admin_billing_subscription, admin_billing_subscription_trial
):
"""Status counts are accurate."""
stats = self.service.get_stats(db)
assert stats["total_subscriptions"] == 2
assert stats["active_count"] == 1
assert stats["trial_count"] == 1
def test_get_stats_tier_distribution(
self, db, admin_billing_tiers, admin_billing_subscription
):
"""Tier distribution counts active/trial subscriptions by tier name."""
stats = self.service.get_stats(db)
assert "Essential" in stats["tier_distribution"]
assert stats["tier_distribution"]["Essential"] == 1
def test_get_stats_mrr_calculation_monthly(
self, db, admin_billing_tiers, admin_billing_subscription
):
"""MRR for a monthly subscription equals the monthly price."""
stats = self.service.get_stats(db)
# Essential tier is 4900 cents/month
assert stats["mrr_cents"] == 4900
assert stats["arr_cents"] == 4900 * 12
def test_get_stats_mrr_calculation_annual(
self, db, admin_billing_tiers, admin_billing_subscription
):
"""MRR for an annual subscription is annual_price / 12."""
admin_billing_subscription.is_annual = True
db.flush()
stats = self.service.get_stats(db)
# Essential annual = 49000 cents → MRR = 49000 / 12 = 4083
assert stats["mrr_cents"] == 49000 // 12
assert stats["arr_cents"] == 49000
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def admin_billing_tiers(db):
"""Create essential, professional, business tiers for admin tests."""
tiers = [
SubscriptionTier(
code="essential",
name="Essential",
price_monthly_cents=4900,
price_annual_cents=49000,
display_order=1,
is_active=True,
is_public=True,
),
SubscriptionTier(
code="professional",
name="Professional",
price_monthly_cents=9900,
price_annual_cents=99000,
display_order=2,
is_active=True,
is_public=True,
),
SubscriptionTier(
code="business",
name="Business",
price_monthly_cents=19900,
price_annual_cents=199000,
display_order=3,
is_active=True,
is_public=True,
),
]
db.add_all(tiers)
db.commit()
for t in tiers:
db.refresh(t)
return tiers
@pytest.fixture
def admin_billing_subscription(db, test_merchant, test_platform, admin_billing_tiers):
"""Create an active subscription on the essential tier."""
now = datetime.now(UTC)
sub = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=test_platform.id,
tier_id=admin_billing_tiers[0].id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub
@pytest.fixture
def admin_billing_subscription_trial(
db, other_merchant, another_platform, admin_billing_tiers
):
"""Create a trial subscription on another merchant/platform."""
now = datetime.now(UTC)
sub = MerchantSubscription(
merchant_id=other_merchant.id,
platform_id=another_platform.id,
tier_id=admin_billing_tiers[0].id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=now + timedelta(days=14),
trial_ends_at=now + timedelta(days=14),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub
@pytest.fixture
def admin_billing_multiple_subscriptions(
db, test_merchant, other_merchant, test_platform, another_platform,
admin_billing_tiers, platform_factory,
):
"""Create 3 subscriptions across different merchants/platforms."""
now = datetime.now(UTC)
third_platform = platform_factory()
subs = [
MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=test_platform.id,
tier_id=admin_billing_tiers[0].id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
),
MerchantSubscription(
merchant_id=other_merchant.id,
platform_id=another_platform.id,
tier_id=admin_billing_tiers[1].id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=now + timedelta(days=14),
),
MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=third_platform.id,
tier_id=admin_billing_tiers[2].id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
),
]
db.add_all(subs)
db.commit()
for s in subs:
db.refresh(s)
return subs
@pytest.fixture
def admin_billing_history_records(db, test_merchant):
"""Create 3 billing history records for test_merchant."""
records = []
for i in range(3):
record = BillingHistory(
merchant_id=test_merchant.id,
stripe_invoice_id=f"in_admin_test_{i}",
invoice_number=f"ADM-{i:03d}",
invoice_date=datetime.now(UTC) - timedelta(days=i * 30),
subtotal_cents=4900,
tax_cents=0,
total_cents=4900,
amount_paid_cents=4900,
currency="EUR",
status="paid",
)
records.append(record)
db.add_all(records)
db.commit()
for r in records:
db.refresh(r)
return records

View File

@@ -1,12 +1,11 @@
# tests/unit/services/test_billing_service.py # app/modules/billing/tests/unit/test_billing_service.py
"""Unit tests for BillingService.""" """Unit tests for BillingService."""
from datetime import datetime, timezone from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from app.modules.tenancy.exceptions import StoreNotFoundException
from app.modules.billing.services.billing_service import ( from app.modules.billing.services.billing_service import (
BillingService, BillingService,
NoActiveSubscriptionError, NoActiveSubscriptionError,
@@ -25,26 +24,107 @@ from app.modules.billing.models import (
) )
# ============================================================================
# Tier Lookup
# ============================================================================
@pytest.mark.unit @pytest.mark.unit
@pytest.mark.billing @pytest.mark.billing
class TestBillingServiceTiers: class TestBillingServiceTiers:
"""Test suite for BillingService tier operations.""" """Test suite for BillingService tier operations."""
def setup_method(self): def setup_method(self):
"""Initialize service instance before each test."""
self.service = BillingService() self.service = BillingService()
def test_get_tier_by_code_found(self, db, bs_tier_essential):
"""Returns the active tier."""
tier = self.service.get_tier_by_code(db, "essential")
assert tier.code == "essential"
def test_get_tier_by_code_not_found(self, db): def test_get_tier_by_code_not_found(self, db):
"""Test getting non-existent tier raises error.""" """Nonexistent tier raises TierNotFoundError."""
with pytest.raises(TierNotFoundError) as exc_info: with pytest.raises(TierNotFoundError) as exc_info:
self.service.get_tier_by_code(db, "nonexistent") self.service.get_tier_by_code(db, "nonexistent")
assert exc_info.value.tier_code == "nonexistent" assert exc_info.value.tier_code == "nonexistent"
def test_get_tier_by_code_inactive_not_returned(self, db, bs_tier_essential):
"""Inactive tier raises TierNotFoundError (only active tiers returned)."""
bs_tier_essential.is_active = False
db.flush()
with pytest.raises(TierNotFoundError):
self.service.get_tier_by_code(db, "essential")
# TestBillingServiceCheckout removed — depends on refactored store_id-based API # ============================================================================
# TestBillingServicePortal removed — depends on refactored store_id-based API # Available Tiers with Upgrade/Downgrade Flags
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceAvailableTiers:
"""Test suite for get_available_tiers (upgrade/downgrade detection)."""
def setup_method(self):
self.service = BillingService()
def test_get_available_tiers_returns_all(self, db, bs_tiers):
"""Returns all active public tiers."""
tier_list, tier_order = self.service.get_available_tiers(db, None)
assert len(tier_list) == 3
def test_get_available_tiers_marks_current(self, db, bs_tiers):
"""Current tier is marked with is_current=True."""
tier_list, _ = self.service.get_available_tiers(db, bs_tiers[1].id)
current = [t for t in tier_list if t["is_current"]]
assert len(current) == 1
assert current[0]["code"] == "professional"
def test_get_available_tiers_upgrade_flags(self, db, bs_tiers):
"""Tiers with higher display_order have can_upgrade=True."""
tier_list, _ = self.service.get_available_tiers(db, bs_tiers[0].id)
essential = next(t for t in tier_list if t["code"] == "essential")
professional = next(t for t in tier_list if t["code"] == "professional")
business = next(t for t in tier_list if t["code"] == "business")
assert essential["is_current"] is True
assert essential["can_upgrade"] is False
assert professional["can_upgrade"] is True
assert business["can_upgrade"] is True
def test_get_available_tiers_downgrade_flags(self, db, bs_tiers):
"""Tiers with lower display_order have can_downgrade=True."""
tier_list, _ = self.service.get_available_tiers(db, bs_tiers[2].id)
essential = next(t for t in tier_list if t["code"] == "essential")
professional = next(t for t in tier_list if t["code"] == "professional")
business = next(t for t in tier_list if t["code"] == "business")
assert essential["can_downgrade"] is True
assert professional["can_downgrade"] is True
assert business["is_current"] is True
assert business["can_downgrade"] is False
def test_get_available_tiers_no_current_tier(self, db, bs_tiers):
"""When current_tier_id is None, no tier is marked current."""
tier_list, _ = self.service.get_available_tiers(db, None)
assert all(t["is_current"] is False for t in tier_list)
def test_get_available_tiers_returns_tier_order_map(self, db, bs_tiers):
"""Returns tier_order map of code → display_order."""
_, tier_order = self.service.get_available_tiers(db, None)
assert tier_order["essential"] == 1
assert tier_order["professional"] == 2
assert tier_order["business"] == 3
# ============================================================================
# Invoices
# ============================================================================
@pytest.mark.unit @pytest.mark.unit
@@ -53,17 +133,42 @@ class TestBillingServiceInvoices:
"""Test suite for BillingService invoice operations.""" """Test suite for BillingService invoice operations."""
def setup_method(self): def setup_method(self):
"""Initialize service instance before each test."""
self.service = BillingService() self.service = BillingService()
def test_get_invoices_empty(self, db, test_store): def test_get_invoices_empty(self, db, test_merchant):
"""Test getting invoices when none exist.""" """Returns empty list and zero total when no invoices exist."""
invoices, total = self.service.get_invoices(db, test_store.id) invoices, total = self.service.get_invoices(db, test_merchant.id)
assert invoices == [] assert invoices == []
assert total == 0 assert total == 0
# test_get_invoices_with_data and test_get_invoices_pagination removed — fixture model mismatch after migration def test_get_invoices_with_data(self, db, bs_billing_history):
"""Returns invoices for the merchant."""
merchant_id = bs_billing_history[0].merchant_id
invoices, total = self.service.get_invoices(db, merchant_id)
assert total == 3
assert len(invoices) == 3
def test_get_invoices_pagination(self, db, bs_billing_history):
"""Pagination limits and offsets results."""
merchant_id = bs_billing_history[0].merchant_id
invoices, total = self.service.get_invoices(db, merchant_id, skip=0, limit=2)
assert total == 3
assert len(invoices) == 2
invoices2, _ = self.service.get_invoices(db, merchant_id, skip=2, limit=2)
assert len(invoices2) == 1
def test_get_invoices_ordered_by_date_desc(self, db, bs_billing_history):
"""Invoices are returned newest first."""
merchant_id = bs_billing_history[0].merchant_id
invoices, _ = self.service.get_invoices(db, merchant_id)
dates = [inv.invoice_date for inv in invoices]
assert dates == sorted(dates, reverse=True)
# ============================================================================
# Add-ons
# ============================================================================
@pytest.mark.unit @pytest.mark.unit
@@ -72,55 +177,442 @@ class TestBillingServiceAddons:
"""Test suite for BillingService addon operations.""" """Test suite for BillingService addon operations."""
def setup_method(self): def setup_method(self):
"""Initialize service instance before each test."""
self.service = BillingService() self.service = BillingService()
def test_get_available_addons_empty(self, db): def test_get_available_addons_empty(self, db):
"""Test getting addons when none exist.""" """Returns empty when no addons exist."""
addons = self.service.get_available_addons(db) addons = self.service.get_available_addons(db)
assert addons == [] assert addons == []
def test_get_available_addons_with_data(self, db, test_addon_products): def test_get_available_addons_with_data(self, db, test_addon_products):
"""Test getting all available addons.""" """Returns all active addons."""
addons = self.service.get_available_addons(db) addons = self.service.get_available_addons(db)
assert len(addons) == 3 assert len(addons) == 3
assert all(addon.is_active for addon in addons) assert all(addon.is_active for addon in addons)
def test_get_available_addons_by_category(self, db, test_addon_products): def test_get_available_addons_by_category(self, db, test_addon_products):
"""Test filtering addons by category.""" """Filters by category."""
domain_addons = self.service.get_available_addons(db, category="domain") domain_addons = self.service.get_available_addons(db, category="domain")
assert len(domain_addons) == 1 assert len(domain_addons) == 1
assert domain_addons[0].category == "domain" assert domain_addons[0].category == "domain"
def test_get_store_addons_empty(self, db, test_store): def test_get_store_addons_empty(self, db, test_store):
"""Test getting store addons when none purchased.""" """Returns empty when store has no purchased addons."""
addons = self.service.get_store_addons(db, test_store.id) addons = self.service.get_store_addons(db, test_store.id)
assert addons == [] assert addons == []
# ============================================================================
# TestBillingServiceCancellation removed — depends on refactored store_id-based API # Subscription with Tier
# TestBillingServiceStore removed — get_store method was removed from BillingService # ============================================================================
# ==================== Fixtures ==================== @pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceSubscriptionWithTier:
"""Tests for get_subscription_with_tier."""
def setup_method(self):
self.service = BillingService()
def test_get_subscription_with_tier_existing(
self, db, bs_subscription
):
"""Returns (subscription, tier) tuple for existing subscription."""
sub, tier = self.service.get_subscription_with_tier(
db, bs_subscription.merchant_id, bs_subscription.platform_id
)
assert sub.id == bs_subscription.id
assert tier is not None
assert tier.code == "essential"
def test_get_subscription_with_tier_creates_if_missing(
self, db, test_merchant, test_platform, bs_tier_essential
):
"""Creates a trial subscription when none exists (via get_or_create)."""
sub, tier = self.service.get_subscription_with_tier(
db, test_merchant.id, test_platform.id
)
assert sub is not None
assert sub.status == SubscriptionStatus.TRIAL.value
# ============================================================================
# Change Tier
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceChangeTier:
"""Tests for change_tier (the tier upgrade/downgrade flow)."""
def setup_method(self):
self.service = BillingService()
def test_change_tier_no_subscription_raises(self, db, bs_tiers):
"""Raises NoActiveSubscriptionError when no subscription exists."""
with pytest.raises(NoActiveSubscriptionError):
self.service.change_tier(db, 99999, 99999, "professional", False)
def test_change_tier_no_stripe_subscription_raises(
self, db, bs_subscription, bs_tiers
):
"""Raises when subscription has no stripe_subscription_id."""
# bs_subscription has no Stripe IDs
with pytest.raises(NoActiveSubscriptionError):
self.service.change_tier(
db,
bs_subscription.merchant_id,
bs_subscription.platform_id,
"professional",
False,
)
def test_change_tier_nonexistent_tier_raises(
self, db, bs_stripe_subscription
):
"""Raises TierNotFoundError for nonexistent tier."""
with pytest.raises(TierNotFoundError):
self.service.change_tier(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
"nonexistent",
False,
)
def test_change_tier_no_price_id_raises(
self, db, bs_stripe_subscription, bs_tiers
):
"""Raises StripePriceNotConfiguredError when tier has no Stripe price."""
# bs_tiers have no stripe_price_* set
with pytest.raises(StripePriceNotConfiguredError):
self.service.change_tier(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
"professional",
False,
)
@patch("app.modules.billing.services.billing_service.stripe_service")
def test_change_tier_success(
self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe
):
"""Successful tier change updates local subscription and calls Stripe."""
mock_stripe.is_configured = True
mock_stripe.update_subscription = MagicMock()
result = self.service.change_tier(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
"professional",
False,
)
assert result["new_tier"] == "professional"
assert result["effective_immediately"] is True
assert bs_stripe_subscription.tier_id == bs_tiers_with_stripe[1].id
mock_stripe.update_subscription.assert_called_once()
@patch("app.modules.billing.services.billing_service.stripe_service")
def test_change_tier_annual_uses_annual_price(
self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe
):
"""Annual billing selects stripe_price_annual_id."""
mock_stripe.is_configured = True
mock_stripe.update_subscription = MagicMock()
self.service.change_tier(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
"professional",
True,
)
call_args = mock_stripe.update_subscription.call_args
assert call_args.kwargs["new_price_id"] == "price_pro_annual"
# ============================================================================
# _is_upgrade
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceIsUpgrade:
"""Tests for _is_upgrade helper."""
def setup_method(self):
self.service = BillingService()
def test_is_upgrade_true(self, db, bs_tiers):
"""Higher display_order is an upgrade."""
assert self.service._is_upgrade(db, bs_tiers[0].id, bs_tiers[2].id) is True
def test_is_upgrade_false_downgrade(self, db, bs_tiers):
"""Lower display_order is not an upgrade."""
assert self.service._is_upgrade(db, bs_tiers[2].id, bs_tiers[0].id) is False
def test_is_upgrade_same_tier(self, db, bs_tiers):
"""Same tier is not an upgrade."""
assert self.service._is_upgrade(db, bs_tiers[1].id, bs_tiers[1].id) is False
def test_is_upgrade_none_ids(self, db):
"""None tier IDs return False."""
assert self.service._is_upgrade(db, None, None) is False
assert self.service._is_upgrade(db, None, 1) is False
assert self.service._is_upgrade(db, 1, None) is False
# ============================================================================
# Cancel Subscription
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceCancel:
"""Tests for cancel_subscription."""
def setup_method(self):
self.service = BillingService()
def test_cancel_no_subscription_raises(self, db):
"""Raises when no subscription found."""
with pytest.raises(NoActiveSubscriptionError):
self.service.cancel_subscription(db, 99999, 99999, None, False)
def test_cancel_no_stripe_id_raises(self, db, bs_subscription):
"""Raises when subscription has no stripe_subscription_id."""
with pytest.raises(NoActiveSubscriptionError):
self.service.cancel_subscription(
db,
bs_subscription.merchant_id,
bs_subscription.platform_id,
"reason",
False,
)
@patch("app.modules.billing.services.billing_service.stripe_service")
def test_cancel_success(self, mock_stripe, db, bs_stripe_subscription):
"""Cancellation records timestamp and reason."""
mock_stripe.is_configured = True
mock_stripe.cancel_subscription = MagicMock()
result = self.service.cancel_subscription(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
"Too expensive",
False,
)
assert result["message"] == "Subscription cancelled successfully"
assert bs_stripe_subscription.cancelled_at is not None
assert bs_stripe_subscription.cancellation_reason == "Too expensive"
mock_stripe.cancel_subscription.assert_called_once()
# ============================================================================
# Reactivate Subscription
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceReactivate:
"""Tests for reactivate_subscription."""
def setup_method(self):
self.service = BillingService()
def test_reactivate_no_subscription_raises(self, db):
"""Raises when no subscription found."""
with pytest.raises(NoActiveSubscriptionError):
self.service.reactivate_subscription(db, 99999, 99999)
def test_reactivate_not_cancelled_raises(self, db, bs_stripe_subscription):
"""Raises SubscriptionNotCancelledError when not cancelled."""
with pytest.raises(SubscriptionNotCancelledError):
self.service.reactivate_subscription(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
)
@patch("app.modules.billing.services.billing_service.stripe_service")
def test_reactivate_success(self, mock_stripe, db, bs_stripe_subscription):
"""Reactivation clears cancellation and calls Stripe."""
mock_stripe.is_configured = True
mock_stripe.reactivate_subscription = MagicMock()
# Cancel first
bs_stripe_subscription.cancelled_at = datetime.now(UTC)
bs_stripe_subscription.cancellation_reason = "Testing"
db.flush()
result = self.service.reactivate_subscription(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
)
assert result["message"] == "Subscription reactivated successfully"
assert bs_stripe_subscription.cancelled_at is None
assert bs_stripe_subscription.cancellation_reason is None
mock_stripe.reactivate_subscription.assert_called_once()
# ============================================================================
# Checkout Session
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceCheckout:
"""Tests for create_checkout_session."""
def setup_method(self):
self.service = BillingService()
def test_checkout_stripe_not_configured_raises(self, db, bs_tiers_with_stripe):
"""Raises PaymentSystemNotConfiguredError when Stripe is off."""
with patch(
"app.modules.billing.services.billing_service.stripe_service"
) as mock_stripe:
mock_stripe.is_configured = False
with pytest.raises(PaymentSystemNotConfiguredError):
self.service.create_checkout_session(
db, 1, 1, "essential", False, "http://ok", "http://cancel"
)
def test_checkout_nonexistent_tier_raises(self, db):
"""Raises TierNotFoundError for nonexistent tier."""
with patch(
"app.modules.billing.services.billing_service.stripe_service"
) as mock_stripe:
mock_stripe.is_configured = True
with pytest.raises(TierNotFoundError):
self.service.create_checkout_session(
db, 1, 1, "nonexistent", False, "http://ok", "http://cancel"
)
# ============================================================================
# Portal Session
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServicePortal:
"""Tests for create_portal_session."""
def setup_method(self):
self.service = BillingService()
def test_portal_stripe_not_configured_raises(self, db):
"""Raises PaymentSystemNotConfiguredError when Stripe is off."""
with patch(
"app.modules.billing.services.billing_service.stripe_service"
) as mock_stripe:
mock_stripe.is_configured = False
with pytest.raises(PaymentSystemNotConfiguredError):
self.service.create_portal_session(db, 1, 1, "http://return")
def test_portal_no_subscription_raises(self, db):
"""Raises NoActiveSubscriptionError when no subscription found."""
with patch(
"app.modules.billing.services.billing_service.stripe_service"
) as mock_stripe:
mock_stripe.is_configured = True
with pytest.raises(NoActiveSubscriptionError):
self.service.create_portal_session(db, 99999, 99999, "http://return")
def test_portal_no_customer_id_raises(self, db, bs_subscription):
"""Raises when subscription has no stripe_customer_id."""
with patch(
"app.modules.billing.services.billing_service.stripe_service"
) as mock_stripe:
mock_stripe.is_configured = True
with pytest.raises(NoActiveSubscriptionError):
self.service.create_portal_session(
db,
bs_subscription.merchant_id,
bs_subscription.platform_id,
"http://return",
)
# ============================================================================
# Upcoming Invoice
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestBillingServiceUpcomingInvoice:
"""Tests for get_upcoming_invoice."""
def setup_method(self):
self.service = BillingService()
def test_upcoming_invoice_no_subscription_raises(self, db):
"""Raises when no subscription exists."""
with pytest.raises(NoActiveSubscriptionError):
self.service.get_upcoming_invoice(db, 99999, 99999)
def test_upcoming_invoice_no_customer_id_raises(self, db, bs_subscription):
"""Raises when subscription has no stripe_customer_id."""
with pytest.raises(NoActiveSubscriptionError):
self.service.get_upcoming_invoice(
db, bs_subscription.merchant_id, bs_subscription.platform_id
)
def test_upcoming_invoice_stripe_not_configured_returns_empty(
self, db, bs_stripe_subscription
):
"""Returns empty invoice when Stripe is not configured."""
with patch(
"app.modules.billing.services.billing_service.stripe_service"
) as mock_stripe:
mock_stripe.is_configured = False
result = self.service.get_upcoming_invoice(
db,
bs_stripe_subscription.merchant_id,
bs_stripe_subscription.platform_id,
)
assert result["amount_due_cents"] == 0
assert result["line_items"] == []
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture @pytest.fixture
def test_subscription_tier(db): def bs_tier_essential(db):
"""Create a basic subscription tier.""" """Create essential subscription tier."""
tier = SubscriptionTier( tier = SubscriptionTier(
code="essential", code="essential",
name="Essential", name="Essential",
description="Essential plan", description="Essential plan",
price_monthly_cents=4900, price_monthly_cents=4900,
price_annual_cents=49000, price_annual_cents=49000,
orders_per_month=100,
products_limit=200,
team_members=1,
features=["basic_support"],
display_order=1, display_order=1,
is_active=True, is_active=True,
is_public=True, is_public=True,
@@ -132,39 +624,14 @@ def test_subscription_tier(db):
@pytest.fixture @pytest.fixture
def test_subscription_tier_with_stripe(db): def bs_tiers(db):
"""Create a subscription tier with Stripe configuration.""" """Create three tiers without Stripe config."""
tier = SubscriptionTier(
code="essential",
name="Essential",
description="Essential plan",
price_monthly_cents=4900,
price_annual_cents=49000,
orders_per_month=100,
products_limit=200,
team_members=1,
features=["basic_support"],
display_order=1,
is_active=True,
is_public=True,
stripe_product_id="prod_test123",
stripe_price_monthly_id="price_test123",
stripe_price_annual_id="price_test456",
)
db.add(tier)
db.commit()
db.refresh(tier)
return tier
@pytest.fixture
def test_subscription_tiers(db):
"""Create multiple subscription tiers."""
tiers = [ tiers = [
SubscriptionTier( SubscriptionTier(
code="essential", code="essential",
name="Essential", name="Essential",
price_monthly_cents=4900, price_monthly_cents=4900,
price_annual_cents=49000,
display_order=1, display_order=1,
is_active=True, is_active=True,
is_public=True, is_public=True,
@@ -173,6 +640,7 @@ def test_subscription_tiers(db):
code="professional", code="professional",
name="Professional", name="Professional",
price_monthly_cents=9900, price_monthly_cents=9900,
price_annual_cents=99000,
display_order=2, display_order=2,
is_active=True, is_active=True,
is_public=True, is_public=True,
@@ -181,6 +649,7 @@ def test_subscription_tiers(db):
code="business", code="business",
name="Business", name="Business",
price_monthly_cents=19900, price_monthly_cents=19900,
price_annual_cents=199000,
display_order=3, display_order=3,
is_active=True, is_active=True,
is_public=True, is_public=True,
@@ -188,136 +657,107 @@ def test_subscription_tiers(db):
] ]
db.add_all(tiers) db.add_all(tiers)
db.commit() db.commit()
for tier in tiers: for t in tiers:
db.refresh(tier) db.refresh(t)
return tiers return tiers
@pytest.fixture @pytest.fixture
def test_subscription(db, test_store): def bs_tiers_with_stripe(db):
"""Create a basic subscription for testing.""" """Create tiers with Stripe price IDs configured."""
# Create tier first tiers = [
tier = SubscriptionTier( SubscriptionTier(
code="essential", code="essential",
name="Essential", name="Essential",
price_monthly_cents=4900, price_monthly_cents=4900,
display_order=1, price_annual_cents=49000,
is_active=True, display_order=1,
is_public=True, is_active=True,
) is_public=True,
db.add(tier) stripe_product_id="prod_essential",
stripe_price_monthly_id="price_ess_monthly",
stripe_price_annual_id="price_ess_annual",
),
SubscriptionTier(
code="professional",
name="Professional",
price_monthly_cents=9900,
price_annual_cents=99000,
display_order=2,
is_active=True,
is_public=True,
stripe_product_id="prod_professional",
stripe_price_monthly_id="price_pro_monthly",
stripe_price_annual_id="price_pro_annual",
),
SubscriptionTier(
code="business",
name="Business",
price_monthly_cents=19900,
price_annual_cents=199000,
display_order=3,
is_active=True,
is_public=True,
stripe_product_id="prod_business",
stripe_price_monthly_id="price_biz_monthly",
stripe_price_annual_id="price_biz_annual",
),
]
db.add_all(tiers)
db.commit() db.commit()
for t in tiers:
subscription = MerchantSubscription( db.refresh(t)
store_id=test_store.id, return tiers
tier="essential",
status=SubscriptionStatus.ACTIVE,
period_start=datetime.now(timezone.utc),
period_end=datetime.now(timezone.utc),
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return subscription
@pytest.fixture @pytest.fixture
def test_active_subscription(db, test_store): def bs_subscription(db, test_merchant, test_platform, bs_tier_essential):
"""Create an active merchant subscription (no Stripe IDs)."""
now = datetime.now(UTC)
sub = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=test_platform.id,
tier_id=bs_tier_essential.id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub
@pytest.fixture
def bs_stripe_subscription(db, test_merchant, test_platform, bs_tier_essential):
"""Create an active subscription with Stripe IDs.""" """Create an active subscription with Stripe IDs."""
# Create tier first if not exists now = datetime.now(UTC)
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first() sub = MerchantSubscription(
if not tier: merchant_id=test_merchant.id,
tier = SubscriptionTier( platform_id=test_platform.id,
code="essential", tier_id=bs_tier_essential.id,
name="Essential", status=SubscriptionStatus.ACTIVE.value,
price_monthly_cents=4900,
display_order=1,
is_active=True,
is_public=True,
)
db.add(tier)
db.commit()
subscription = MerchantSubscription(
store_id=test_store.id,
tier="essential",
status=SubscriptionStatus.ACTIVE,
stripe_customer_id="cus_test123", stripe_customer_id="cus_test123",
stripe_subscription_id="sub_test123", stripe_subscription_id="sub_test123",
period_start=datetime.now(timezone.utc), period_start=now,
period_end=datetime.now(timezone.utc), period_end=now + timedelta(days=30),
) )
db.add(subscription) db.add(sub)
db.commit() db.commit()
db.refresh(subscription) db.refresh(sub)
return subscription return sub
@pytest.fixture @pytest.fixture
def test_cancelled_subscription(db, test_store): def bs_billing_history(db, test_merchant):
"""Create a cancelled subscription.""" """Create billing history records for test_merchant."""
# Create tier first if not exists
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first()
if not tier:
tier = SubscriptionTier(
code="essential",
name="Essential",
price_monthly_cents=4900,
display_order=1,
is_active=True,
is_public=True,
)
db.add(tier)
db.commit()
subscription = MerchantSubscription(
store_id=test_store.id,
tier="essential",
status=SubscriptionStatus.ACTIVE,
stripe_customer_id="cus_test123",
stripe_subscription_id="sub_test123",
period_start=datetime.now(timezone.utc),
period_end=datetime.now(timezone.utc),
cancelled_at=datetime.now(timezone.utc),
cancellation_reason="Too expensive",
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return subscription
@pytest.fixture
def test_billing_history(db, test_store):
"""Create a billing history record."""
record = BillingHistory(
store_id=test_store.id,
stripe_invoice_id="in_test123",
invoice_number="INV-001",
invoice_date=datetime.now(timezone.utc),
subtotal_cents=4900,
tax_cents=0,
total_cents=4900,
amount_paid_cents=4900,
currency="EUR",
status="paid",
)
db.add(record)
db.commit()
db.refresh(record)
return record
@pytest.fixture
def test_multiple_invoices(db, test_store):
"""Create multiple billing history records."""
records = [] records = []
for i in range(5): for i in range(3):
record = BillingHistory( record = BillingHistory(
store_id=test_store.id, merchant_id=test_merchant.id,
stripe_invoice_id=f"in_test{i}", stripe_invoice_id=f"in_bs_test_{i}",
invoice_number=f"INV-{i:03d}", invoice_number=f"BS-{i:03d}",
invoice_date=datetime.now(timezone.utc), invoice_date=datetime.now(UTC) - timedelta(days=i * 30),
subtotal_cents=4900, subtotal_cents=4900,
tax_cents=0, tax_cents=0,
total_cents=4900, total_cents=4900,
@@ -328,6 +768,8 @@ def test_multiple_invoices(db, test_store):
records.append(record) records.append(record)
db.add_all(records) db.add_all(records)
db.commit() db.commit()
for r in records:
db.refresh(r)
return records return records

View File

@@ -0,0 +1,579 @@
# app/modules/billing/tests/unit/test_subscription_service.py
"""Unit tests for SubscriptionService (merchant-level subscription management)."""
from datetime import UTC, datetime, timedelta
import pytest
from app.modules.billing.exceptions import SubscriptionNotFoundException
from app.modules.billing.models import (
MerchantSubscription,
SubscriptionStatus,
SubscriptionTier,
TierCode,
)
from app.modules.billing.services.subscription_service import SubscriptionService
# ============================================================================
# Tier Information
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceTierLookup:
"""Tests for tier lookup methods."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_tier_by_code_returns_tier(self, db, billing_tier_essential):
"""Existing tier code returns the tier object."""
tier = self.service.get_tier_by_code(db, "essential")
assert tier is not None
assert tier.code == "essential"
assert tier.id == billing_tier_essential.id
def test_get_tier_by_code_nonexistent_returns_none(self, db):
"""Nonexistent tier code returns None (does not raise)."""
tier = self.service.get_tier_by_code(db, "nonexistent")
assert tier is None
def test_get_tier_id_returns_id(self, db, billing_tier_essential):
"""get_tier_id returns the integer ID for a valid code."""
tier_id = self.service.get_tier_id(db, "essential")
assert tier_id == billing_tier_essential.id
def test_get_tier_id_nonexistent_returns_none(self, db):
"""get_tier_id returns None for nonexistent code."""
tier_id = self.service.get_tier_id(db, "nonexistent")
assert tier_id is None
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetAllTiers:
"""Tests for get_all_tiers with platform filtering."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_all_tiers_returns_active_public_tiers(self, db, billing_tiers):
"""Returns only active + public tiers, ordered by display_order."""
tiers = self.service.get_all_tiers(db)
assert len(tiers) == 3
assert tiers[0].code == "essential"
assert tiers[1].code == "professional"
assert tiers[2].code == "business"
def test_get_all_tiers_excludes_inactive(self, db, billing_tiers):
"""Inactive tiers are excluded."""
# Deactivate one
billing_tiers[0].is_active = False
db.flush()
tiers = self.service.get_all_tiers(db)
assert len(tiers) == 2
assert all(t.code != "essential" for t in tiers)
def test_get_all_tiers_excludes_non_public(self, db, billing_tiers):
"""Non-public tiers are excluded."""
billing_tiers[2].is_public = False
db.flush()
tiers = self.service.get_all_tiers(db)
assert len(tiers) == 2
assert all(t.code != "business" for t in tiers)
def test_get_all_tiers_with_platform_filter(self, db, test_platform, billing_tiers):
"""Platform filter returns platform-specific tiers + global tiers."""
# Make one tier platform-specific
billing_tiers[0].platform_id = test_platform.id
# billing_tiers[1] and [2] remain global (platform_id=NULL)
db.flush()
tiers = self.service.get_all_tiers(db, platform_id=test_platform.id)
codes = {t.code for t in tiers}
assert "essential" in codes # Platform-specific
assert "professional" in codes # Global
assert "business" in codes # Global
def test_get_all_tiers_platform_filter_excludes_other_platform(
self, db, test_platform, another_platform, billing_tiers
):
"""Tiers belonging to another platform are excluded."""
billing_tiers[0].platform_id = another_platform.id
db.flush()
tiers = self.service.get_all_tiers(db, platform_id=test_platform.id)
codes = {t.code for t in tiers}
assert "essential" not in codes
assert len(tiers) == 2
# ============================================================================
# Merchant Subscription CRUD
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetMerchantSubscription:
"""Tests for subscription retrieval methods."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_merchant_subscription_found(
self, db, billing_subscription
):
"""Returns subscription when it exists."""
sub = self.service.get_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub is not None
assert sub.id == billing_subscription.id
def test_get_merchant_subscription_not_found(self, db):
"""Returns None when no subscription exists."""
sub = self.service.get_merchant_subscription(db, 99999, 99999)
assert sub is None
def test_get_merchant_subscription_eager_loads_tier(
self, db, billing_subscription
):
"""Subscription comes with tier relationship loaded."""
sub = self.service.get_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
# Accessing .tier should not trigger a lazy load error
assert sub.tier is not None
assert sub.tier.code == "essential"
def test_get_merchant_subscription_eager_loads_platform(
self, db, billing_subscription
):
"""Subscription comes with platform relationship loaded."""
sub = self.service.get_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.platform is not None
def test_get_merchant_subscriptions_all(
self, db, test_merchant, test_platform, another_platform, billing_tier_essential
):
"""Returns all subscriptions for a merchant across platforms."""
now = datetime.now(UTC)
sub1 = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=test_platform.id,
tier_id=billing_tier_essential.id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
)
sub2 = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=another_platform.id,
tier_id=billing_tier_essential.id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=now + timedelta(days=14),
)
db.add_all([sub1, sub2])
db.flush()
subs = self.service.get_merchant_subscriptions(db, test_merchant.id)
assert len(subs) == 2
def test_get_merchant_subscriptions_empty(self, db, test_merchant):
"""Returns empty list when merchant has no subscriptions."""
subs = self.service.get_merchant_subscriptions(db, test_merchant.id)
assert subs == []
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetOrRaise:
"""Tests for get_subscription_or_raise."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_subscription_or_raise_found(self, db, billing_subscription):
"""Returns subscription when found."""
sub = self.service.get_subscription_or_raise(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.id == billing_subscription.id
def test_get_subscription_or_raise_not_found(self, db):
"""Raises SubscriptionNotFoundException when not found."""
with pytest.raises(SubscriptionNotFoundException):
self.service.get_subscription_or_raise(db, 99999, 99999)
# ============================================================================
# Create Subscription
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceCreate:
"""Tests for create_merchant_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_create_trial_subscription(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Default creation starts a 14-day trial."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id
)
assert sub.merchant_id == test_merchant.id
assert sub.platform_id == test_platform.id
assert sub.status == SubscriptionStatus.TRIAL.value
assert sub.tier_id == billing_tier_essential.id
assert sub.trial_ends_at is not None
assert sub.is_annual is False
def test_create_subscription_period_end_matches_trial(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Period end is set to trial_days from now."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, trial_days=7
)
expected_end = sub.period_start + timedelta(days=7)
# Allow 1 second tolerance
assert abs((sub.period_end - expected_end).total_seconds()) < 1
def test_create_annual_subscription(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Annual subscription with no trial gets 365-day period."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, trial_days=0, is_annual=True
)
assert sub.status == SubscriptionStatus.ACTIVE.value
assert sub.is_annual is True
assert sub.trial_ends_at is None
expected_end = sub.period_start + timedelta(days=365)
assert abs((sub.period_end - expected_end).total_seconds()) < 1
def test_create_monthly_subscription(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Monthly subscription with no trial gets 30-day period."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, trial_days=0, is_annual=False
)
assert sub.status == SubscriptionStatus.ACTIVE.value
assert sub.is_annual is False
assert sub.trial_ends_at is None
expected_end = sub.period_start + timedelta(days=30)
assert abs((sub.period_end - expected_end).total_seconds()) < 1
def test_create_subscription_custom_tier(
self, db, test_merchant, test_platform, billing_tiers
):
"""Can specify a custom tier code."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, tier_code="professional"
)
assert sub.tier_id == billing_tiers[1].id
def test_create_duplicate_raises(
self, db, billing_subscription
):
"""Creating a second subscription for same merchant+platform raises ValueError."""
with pytest.raises(ValueError, match="already has a subscription"):
self.service.create_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetOrCreate:
"""Tests for get_or_create_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_or_create_returns_existing(self, db, billing_subscription):
"""Returns existing subscription without creating a new one."""
sub = self.service.get_or_create_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.id == billing_subscription.id
def test_get_or_create_creates_new(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Creates new subscription when none exists."""
sub = self.service.get_or_create_subscription(
db, test_merchant.id, test_platform.id
)
assert sub is not None
assert sub.merchant_id == test_merchant.id
assert sub.status == SubscriptionStatus.TRIAL.value
# ============================================================================
# Upgrade / Cancel / Reactivate
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceUpgradeTier:
"""Tests for upgrade_tier."""
def setup_method(self):
self.service = SubscriptionService()
def test_upgrade_tier_changes_tier_id(self, db, billing_subscription, billing_tiers):
"""Tier is updated to the new tier's ID."""
sub = self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"professional",
)
assert sub.tier_id == billing_tiers[1].id
def test_upgrade_from_trial_activates(self, db, billing_subscription, billing_tiers):
"""Upgrading from trial status sets status to active."""
billing_subscription.status = SubscriptionStatus.TRIAL.value
db.flush()
sub = self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"professional",
)
assert sub.status == SubscriptionStatus.ACTIVE.value
def test_upgrade_active_stays_active(self, db, billing_subscription, billing_tiers):
"""Upgrading an already-active subscription keeps status active."""
billing_subscription.status = SubscriptionStatus.ACTIVE.value
db.flush()
sub = self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"professional",
)
assert sub.status == SubscriptionStatus.ACTIVE.value
def test_upgrade_nonexistent_tier_raises(self, db, billing_subscription):
"""Upgrading to a nonexistent tier raises ValueError."""
with pytest.raises(ValueError, match="not found"):
self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"nonexistent_tier",
)
def test_upgrade_no_subscription_raises(self, db):
"""Upgrading when no subscription exists raises SubscriptionNotFoundException."""
with pytest.raises(SubscriptionNotFoundException):
self.service.upgrade_tier(db, 99999, 99999, "professional")
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceCancel:
"""Tests for cancel_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_cancel_sets_status(self, db, billing_subscription):
"""Cancellation sets status to cancelled."""
sub = self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.status == SubscriptionStatus.CANCELLED.value
def test_cancel_records_timestamp(self, db, billing_subscription):
"""Cancellation records cancelled_at."""
sub = self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.cancelled_at is not None
def test_cancel_with_reason(self, db, billing_subscription):
"""Cancellation stores the reason."""
sub = self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
reason="Too expensive",
)
assert sub.cancellation_reason == "Too expensive"
def test_cancel_no_subscription_raises(self, db):
"""Cancelling when no subscription exists raises."""
with pytest.raises(SubscriptionNotFoundException):
self.service.cancel_subscription(db, 99999, 99999)
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceReactivate:
"""Tests for reactivate_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_reactivate_sets_active(self, db, billing_subscription):
"""Reactivation sets status back to active."""
# First cancel
self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
sub = self.service.reactivate_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.status == SubscriptionStatus.ACTIVE.value
def test_reactivate_clears_cancellation(self, db, billing_subscription):
"""Reactivation clears cancelled_at and cancellation_reason."""
self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
reason="Testing",
)
sub = self.service.reactivate_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.cancelled_at is None
assert sub.cancellation_reason is None
def test_reactivate_no_subscription_raises(self, db):
"""Reactivating when no subscription exists raises."""
with pytest.raises(SubscriptionNotFoundException):
self.service.reactivate_subscription(db, 99999, 99999)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def billing_tier_essential(db):
"""Create essential subscription tier."""
tier = SubscriptionTier(
code="essential",
name="Essential",
description="Essential plan",
price_monthly_cents=4900,
price_annual_cents=49000,
display_order=1,
is_active=True,
is_public=True,
)
db.add(tier)
db.commit()
db.refresh(tier)
return tier
@pytest.fixture
def billing_tiers(db):
"""Create essential, professional, and business tiers."""
tiers = [
SubscriptionTier(
code="essential",
name="Essential",
price_monthly_cents=4900,
price_annual_cents=49000,
display_order=1,
is_active=True,
is_public=True,
),
SubscriptionTier(
code="professional",
name="Professional",
price_monthly_cents=9900,
price_annual_cents=99000,
display_order=2,
is_active=True,
is_public=True,
),
SubscriptionTier(
code="business",
name="Business",
price_monthly_cents=19900,
price_annual_cents=199000,
display_order=3,
is_active=True,
is_public=True,
),
]
db.add_all(tiers)
db.commit()
for t in tiers:
db.refresh(t)
return tiers
@pytest.fixture
def billing_subscription(db, test_merchant, test_platform, billing_tier_essential):
"""Create an active merchant subscription (essential tier)."""
now = datetime.now(UTC)
sub = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=test_platform.id,
tier_id=billing_tier_essential.id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub