test(billing): add comprehensive service layer tests and fix deactivate_tier bug
Add 139 tests across 3 test files for the billing service layer: - test_subscription_service.py (37 tests): tier lookup, subscription CRUD, upgrades, cancellation - test_admin_subscription_service.py (39 tests): admin tier/subscription management, stats, billing history - test_billing_service.py (43 tests): rewritten with correct fixtures after store→merchant migration Fix production bug in deactivate_tier() — BusinessLogicException was missing required error_code argument, now uses TIER_HAS_ACTIVE_SUBSCRIPTIONS. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -119,7 +119,8 @@ class AdminSubscriptionService:
|
||||
|
||||
if active_subs > 0:
|
||||
raise BusinessLogicException(
|
||||
f"Cannot delete tier: {active_subs} active subscriptions are using it"
|
||||
f"Cannot delete tier: {active_subs} active subscriptions are using it",
|
||||
"TIER_HAS_ACTIVE_SUBSCRIPTIONS",
|
||||
)
|
||||
|
||||
tier.is_active = False
|
||||
|
||||
@@ -0,0 +1,642 @@
|
||||
# app/modules/billing/tests/unit/test_admin_subscription_service.py
|
||||
"""Unit tests for AdminSubscriptionService (admin tier & subscription management)."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (
|
||||
BusinessLogicException,
|
||||
ConflictException,
|
||||
ResourceNotFoundException,
|
||||
)
|
||||
from app.modules.billing.exceptions import TierNotFoundException
|
||||
from app.modules.billing.models import (
|
||||
BillingHistory,
|
||||
MerchantSubscription,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
)
|
||||
from app.modules.billing.services.admin_subscription_service import (
|
||||
AdminSubscriptionService,
|
||||
)
|
||||
from app.modules.tenancy.models import Merchant
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tier Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminGetTiers:
|
||||
"""Tests for get_tiers with filtering."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_get_tiers_active_only_by_default(self, db, admin_billing_tiers):
|
||||
"""Default call returns only active tiers."""
|
||||
tiers = self.service.get_tiers(db)
|
||||
|
||||
assert len(tiers) == 3
|
||||
assert all(t.is_active for t in tiers)
|
||||
|
||||
def test_get_tiers_include_inactive(self, db, admin_billing_tiers):
|
||||
"""include_inactive=True returns all tiers."""
|
||||
admin_billing_tiers[0].is_active = False
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_tiers(db, include_inactive=True)
|
||||
assert len(tiers) == 3
|
||||
|
||||
def test_get_tiers_active_excludes_inactive(self, db, admin_billing_tiers):
|
||||
"""Active filter excludes inactive tiers."""
|
||||
admin_billing_tiers[0].is_active = False
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_tiers(db, include_inactive=False)
|
||||
assert len(tiers) == 2
|
||||
assert all(t.code != "essential" for t in tiers)
|
||||
|
||||
def test_get_tiers_ordered_by_display_order(self, db, admin_billing_tiers):
|
||||
"""Tiers are returned in display_order."""
|
||||
tiers = self.service.get_tiers(db)
|
||||
orders = [t.display_order for t in tiers]
|
||||
assert orders == sorted(orders)
|
||||
|
||||
def test_get_tiers_platform_filter(self, db, test_platform, admin_billing_tiers):
|
||||
"""Platform filter returns platform-specific + global tiers."""
|
||||
admin_billing_tiers[0].platform_id = test_platform.id
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_tiers(db, platform_id=test_platform.id)
|
||||
codes = {t.code for t in tiers}
|
||||
assert "essential" in codes # Platform-specific
|
||||
assert "professional" in codes # Global
|
||||
assert "business" in codes # Global
|
||||
|
||||
def test_get_tiers_platform_filter_excludes_other(
|
||||
self, db, test_platform, another_platform, admin_billing_tiers
|
||||
):
|
||||
"""Tiers from another platform are excluded."""
|
||||
admin_billing_tiers[0].platform_id = another_platform.id
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_tiers(db, platform_id=test_platform.id)
|
||||
assert all(t.code != "essential" for t in tiers)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminGetTierByCode:
|
||||
"""Tests for get_tier_by_code."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_get_tier_by_code_found(self, db, admin_billing_tiers):
|
||||
"""Returns the tier when it exists."""
|
||||
tier = self.service.get_tier_by_code(db, "professional")
|
||||
assert tier.code == "professional"
|
||||
assert tier.name == "Professional"
|
||||
|
||||
def test_get_tier_by_code_not_found_raises(self, db):
|
||||
"""Raises TierNotFoundException for nonexistent code."""
|
||||
with pytest.raises(TierNotFoundException):
|
||||
self.service.get_tier_by_code(db, "nonexistent")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminCreateTier:
|
||||
"""Tests for create_tier."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_create_tier_success(self, db):
|
||||
"""Creates a new tier with provided data."""
|
||||
data = {
|
||||
"code": "starter",
|
||||
"name": "Starter",
|
||||
"description": "A starter plan",
|
||||
"price_monthly_cents": 1900,
|
||||
"price_annual_cents": 19000,
|
||||
"display_order": 0,
|
||||
"is_active": True,
|
||||
"is_public": True,
|
||||
}
|
||||
tier = self.service.create_tier(db, data)
|
||||
db.flush()
|
||||
|
||||
assert tier.code == "starter"
|
||||
assert tier.name == "Starter"
|
||||
assert tier.price_monthly_cents == 1900
|
||||
|
||||
def test_create_tier_duplicate_raises(self, db, admin_billing_tiers):
|
||||
"""Duplicate tier code raises ConflictException."""
|
||||
data = {
|
||||
"code": "essential",
|
||||
"name": "Essential Duplicate",
|
||||
"price_monthly_cents": 4900,
|
||||
"display_order": 1,
|
||||
"is_active": True,
|
||||
"is_public": True,
|
||||
}
|
||||
with pytest.raises(ConflictException, match="already exists"):
|
||||
self.service.create_tier(db, data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminUpdateTier:
|
||||
"""Tests for update_tier."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_update_tier_name(self, db, admin_billing_tiers):
|
||||
"""Updates tier fields."""
|
||||
tier = self.service.update_tier(db, "essential", {"name": "Essential Plus"})
|
||||
assert tier.name == "Essential Plus"
|
||||
|
||||
def test_update_tier_price(self, db, admin_billing_tiers):
|
||||
"""Updates tier pricing."""
|
||||
tier = self.service.update_tier(
|
||||
db, "essential", {"price_monthly_cents": 5900}
|
||||
)
|
||||
assert tier.price_monthly_cents == 5900
|
||||
|
||||
def test_update_tier_nonexistent_raises(self, db):
|
||||
"""Updating a nonexistent tier raises."""
|
||||
with pytest.raises(TierNotFoundException):
|
||||
self.service.update_tier(db, "nonexistent", {"name": "X"})
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminDeactivateTier:
|
||||
"""Tests for deactivate_tier (soft-delete)."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_deactivate_tier_sets_inactive(self, db, admin_billing_tiers):
|
||||
"""Deactivation sets is_active to False."""
|
||||
self.service.deactivate_tier(db, "business")
|
||||
db.flush()
|
||||
|
||||
tier = db.query(SubscriptionTier).filter(
|
||||
SubscriptionTier.code == "business"
|
||||
).first()
|
||||
assert tier.is_active is False
|
||||
|
||||
def test_deactivate_tier_with_active_subs_raises(
|
||||
self, db, admin_billing_tiers, admin_billing_subscription
|
||||
):
|
||||
"""Cannot deactivate a tier used by active subscriptions."""
|
||||
with pytest.raises(BusinessLogicException, match="active subscriptions"):
|
||||
self.service.deactivate_tier(db, "essential")
|
||||
|
||||
def test_deactivate_tier_with_cancelled_subs_ok(
|
||||
self, db, admin_billing_tiers, admin_billing_subscription
|
||||
):
|
||||
"""Can deactivate a tier used only by cancelled subscriptions."""
|
||||
admin_billing_subscription.status = SubscriptionStatus.CANCELLED.value
|
||||
db.flush()
|
||||
|
||||
# Should not raise
|
||||
self.service.deactivate_tier(db, "essential")
|
||||
|
||||
def test_deactivate_nonexistent_tier_raises(self, db):
|
||||
"""Deactivating a nonexistent tier raises."""
|
||||
with pytest.raises(TierNotFoundException):
|
||||
self.service.deactivate_tier(db, "nonexistent")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Subscription Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminListSubscriptions:
|
||||
"""Tests for list_subscriptions with pagination and filtering."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_list_subscriptions_empty(self, db):
|
||||
"""Returns empty results when no subscriptions exist."""
|
||||
result = self.service.list_subscriptions(db)
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
assert result["page"] == 1
|
||||
|
||||
def test_list_subscriptions_returns_results(
|
||||
self, db, admin_billing_subscription
|
||||
):
|
||||
"""Returns subscriptions joined with merchant."""
|
||||
result = self.service.list_subscriptions(db)
|
||||
assert result["total"] == 1
|
||||
sub, merchant = result["results"][0]
|
||||
assert isinstance(sub, MerchantSubscription)
|
||||
assert isinstance(merchant, Merchant)
|
||||
|
||||
def test_list_subscriptions_filter_by_status(
|
||||
self, db, admin_billing_subscription, admin_billing_subscription_trial
|
||||
):
|
||||
"""Status filter returns only matching subscriptions."""
|
||||
result = self.service.list_subscriptions(db, status="trial")
|
||||
assert result["total"] == 1
|
||||
sub, _ = result["results"][0]
|
||||
assert sub.status == "trial"
|
||||
|
||||
def test_list_subscriptions_filter_by_tier(
|
||||
self, db, admin_billing_tiers, admin_billing_subscription
|
||||
):
|
||||
"""Tier filter returns only matching subscriptions."""
|
||||
result = self.service.list_subscriptions(db, tier="essential")
|
||||
assert result["total"] == 1
|
||||
|
||||
result = self.service.list_subscriptions(db, tier="professional")
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_list_subscriptions_search_by_merchant_name(
|
||||
self, db, admin_billing_subscription, test_merchant
|
||||
):
|
||||
"""Search filters by merchant name (case-insensitive)."""
|
||||
# test_merchant has a name like "Test Merchant xxxxxxxx"
|
||||
result = self.service.list_subscriptions(db, search="Test Merchant")
|
||||
assert result["total"] == 1
|
||||
|
||||
result = self.service.list_subscriptions(db, search="NonExistent Corp")
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_list_subscriptions_pagination(
|
||||
self, db, admin_billing_multiple_subscriptions
|
||||
):
|
||||
"""Pagination returns correct page with correct count."""
|
||||
result = self.service.list_subscriptions(db, page=1, per_page=2)
|
||||
assert result["total"] == 3
|
||||
assert len(result["results"]) == 2
|
||||
assert result["pages"] == 2
|
||||
|
||||
result2 = self.service.list_subscriptions(db, page=2, per_page=2)
|
||||
assert len(result2["results"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminGetSubscription:
|
||||
"""Tests for get_subscription."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_get_subscription_found(self, db, admin_billing_subscription):
|
||||
"""Returns (subscription, merchant) tuple."""
|
||||
sub, merchant = self.service.get_subscription(
|
||||
db,
|
||||
admin_billing_subscription.merchant_id,
|
||||
admin_billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.id == admin_billing_subscription.id
|
||||
assert isinstance(merchant, Merchant)
|
||||
|
||||
def test_get_subscription_not_found_raises(self, db):
|
||||
"""Raises ResourceNotFoundException when not found."""
|
||||
with pytest.raises(ResourceNotFoundException):
|
||||
self.service.get_subscription(db, 99999, 99999)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminUpdateSubscription:
|
||||
"""Tests for update_subscription — the tier_code→tier_id resolution bug area."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_update_subscription_status(self, db, admin_billing_subscription):
|
||||
"""Can update subscription status."""
|
||||
sub, _ = self.service.update_subscription(
|
||||
db,
|
||||
admin_billing_subscription.merchant_id,
|
||||
admin_billing_subscription.platform_id,
|
||||
{"status": SubscriptionStatus.PAST_DUE.value},
|
||||
)
|
||||
assert sub.status == SubscriptionStatus.PAST_DUE.value
|
||||
|
||||
def test_update_subscription_tier_code_resolves_to_tier_id(
|
||||
self, db, admin_billing_tiers, admin_billing_subscription
|
||||
):
|
||||
"""
|
||||
tier_code in update_data is resolved to tier_id (FK).
|
||||
|
||||
This is the exact bug that was fixed in commit d1fe358 — the admin form
|
||||
sends tier_code (string) but the model uses tier_id (int FK).
|
||||
"""
|
||||
sub, _ = self.service.update_subscription(
|
||||
db,
|
||||
admin_billing_subscription.merchant_id,
|
||||
admin_billing_subscription.platform_id,
|
||||
{"tier_code": "professional"},
|
||||
)
|
||||
assert sub.tier_id == admin_billing_tiers[1].id
|
||||
|
||||
def test_update_subscription_tier_code_nonexistent_raises(
|
||||
self, db, admin_billing_subscription
|
||||
):
|
||||
"""Nonexistent tier_code raises TierNotFoundException."""
|
||||
with pytest.raises(TierNotFoundException):
|
||||
self.service.update_subscription(
|
||||
db,
|
||||
admin_billing_subscription.merchant_id,
|
||||
admin_billing_subscription.platform_id,
|
||||
{"tier_code": "nonexistent"},
|
||||
)
|
||||
|
||||
def test_update_subscription_not_found_raises(self, db):
|
||||
"""Updating a nonexistent subscription raises."""
|
||||
with pytest.raises(ResourceNotFoundException):
|
||||
self.service.update_subscription(db, 99999, 99999, {"status": "active"})
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Billing History
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminListBillingHistory:
|
||||
"""Tests for list_billing_history."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_list_billing_history_empty(self, db):
|
||||
"""Returns empty when no records exist."""
|
||||
result = self.service.list_billing_history(db)
|
||||
assert result["total"] == 0
|
||||
assert result["results"] == []
|
||||
|
||||
def test_list_billing_history_with_records(
|
||||
self, db, admin_billing_history_records
|
||||
):
|
||||
"""Returns billing history joined with merchant."""
|
||||
result = self.service.list_billing_history(db)
|
||||
assert result["total"] == 3
|
||||
bh, merchant = result["results"][0]
|
||||
assert isinstance(bh, BillingHistory)
|
||||
assert isinstance(merchant, Merchant)
|
||||
|
||||
def test_list_billing_history_filter_by_merchant(
|
||||
self, db, admin_billing_history_records, test_merchant
|
||||
):
|
||||
"""Merchant filter returns only matching records."""
|
||||
result = self.service.list_billing_history(
|
||||
db, merchant_id=test_merchant.id
|
||||
)
|
||||
assert result["total"] == 3
|
||||
|
||||
result = self.service.list_billing_history(db, merchant_id=99999)
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_list_billing_history_filter_by_status(
|
||||
self, db, admin_billing_history_records
|
||||
):
|
||||
"""Status filter returns matching records."""
|
||||
result = self.service.list_billing_history(db, status="paid")
|
||||
assert result["total"] == 3
|
||||
|
||||
result = self.service.list_billing_history(db, status="failed")
|
||||
assert result["total"] == 0
|
||||
|
||||
def test_list_billing_history_pagination(
|
||||
self, db, admin_billing_history_records
|
||||
):
|
||||
"""Pagination works correctly."""
|
||||
result = self.service.list_billing_history(db, page=1, per_page=2)
|
||||
assert result["total"] == 3
|
||||
assert len(result["results"]) == 2
|
||||
assert result["pages"] == 2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Statistics
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestAdminGetStats:
|
||||
"""Tests for get_stats (dashboard analytics)."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = AdminSubscriptionService()
|
||||
|
||||
def test_get_stats_empty_db(self, db):
|
||||
"""Stats with no subscriptions return zeroes."""
|
||||
stats = self.service.get_stats(db)
|
||||
|
||||
assert stats["total_subscriptions"] == 0
|
||||
assert stats["active_count"] == 0
|
||||
assert stats["trial_count"] == 0
|
||||
assert stats["mrr_cents"] == 0
|
||||
assert stats["arr_cents"] == 0
|
||||
assert stats["tier_distribution"] == {}
|
||||
|
||||
def test_get_stats_counts_by_status(
|
||||
self, db, admin_billing_subscription, admin_billing_subscription_trial
|
||||
):
|
||||
"""Status counts are accurate."""
|
||||
stats = self.service.get_stats(db)
|
||||
|
||||
assert stats["total_subscriptions"] == 2
|
||||
assert stats["active_count"] == 1
|
||||
assert stats["trial_count"] == 1
|
||||
|
||||
def test_get_stats_tier_distribution(
|
||||
self, db, admin_billing_tiers, admin_billing_subscription
|
||||
):
|
||||
"""Tier distribution counts active/trial subscriptions by tier name."""
|
||||
stats = self.service.get_stats(db)
|
||||
|
||||
assert "Essential" in stats["tier_distribution"]
|
||||
assert stats["tier_distribution"]["Essential"] == 1
|
||||
|
||||
def test_get_stats_mrr_calculation_monthly(
|
||||
self, db, admin_billing_tiers, admin_billing_subscription
|
||||
):
|
||||
"""MRR for a monthly subscription equals the monthly price."""
|
||||
stats = self.service.get_stats(db)
|
||||
# Essential tier is 4900 cents/month
|
||||
assert stats["mrr_cents"] == 4900
|
||||
assert stats["arr_cents"] == 4900 * 12
|
||||
|
||||
def test_get_stats_mrr_calculation_annual(
|
||||
self, db, admin_billing_tiers, admin_billing_subscription
|
||||
):
|
||||
"""MRR for an annual subscription is annual_price / 12."""
|
||||
admin_billing_subscription.is_annual = True
|
||||
db.flush()
|
||||
|
||||
stats = self.service.get_stats(db)
|
||||
# Essential annual = 49000 cents → MRR = 49000 / 12 = 4083
|
||||
assert stats["mrr_cents"] == 49000 // 12
|
||||
assert stats["arr_cents"] == 49000
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_billing_tiers(db):
|
||||
"""Create essential, professional, business tiers for admin tests."""
|
||||
tiers = [
|
||||
SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=4900,
|
||||
price_annual_cents=49000,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
),
|
||||
SubscriptionTier(
|
||||
code="professional",
|
||||
name="Professional",
|
||||
price_monthly_cents=9900,
|
||||
price_annual_cents=99000,
|
||||
display_order=2,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
),
|
||||
SubscriptionTier(
|
||||
code="business",
|
||||
name="Business",
|
||||
price_monthly_cents=19900,
|
||||
price_annual_cents=199000,
|
||||
display_order=3,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
),
|
||||
]
|
||||
db.add_all(tiers)
|
||||
db.commit()
|
||||
for t in tiers:
|
||||
db.refresh(t)
|
||||
return tiers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_billing_subscription(db, test_merchant, test_platform, admin_billing_tiers):
|
||||
"""Create an active subscription on the essential tier."""
|
||||
now = datetime.now(UTC)
|
||||
sub = MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=test_platform.id,
|
||||
tier_id=admin_billing_tiers[0].id,
|
||||
status=SubscriptionStatus.ACTIVE.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=30),
|
||||
)
|
||||
db.add(sub)
|
||||
db.commit()
|
||||
db.refresh(sub)
|
||||
return sub
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_billing_subscription_trial(
|
||||
db, other_merchant, another_platform, admin_billing_tiers
|
||||
):
|
||||
"""Create a trial subscription on another merchant/platform."""
|
||||
now = datetime.now(UTC)
|
||||
sub = MerchantSubscription(
|
||||
merchant_id=other_merchant.id,
|
||||
platform_id=another_platform.id,
|
||||
tier_id=admin_billing_tiers[0].id,
|
||||
status=SubscriptionStatus.TRIAL.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=14),
|
||||
trial_ends_at=now + timedelta(days=14),
|
||||
)
|
||||
db.add(sub)
|
||||
db.commit()
|
||||
db.refresh(sub)
|
||||
return sub
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_billing_multiple_subscriptions(
|
||||
db, test_merchant, other_merchant, test_platform, another_platform,
|
||||
admin_billing_tiers, platform_factory,
|
||||
):
|
||||
"""Create 3 subscriptions across different merchants/platforms."""
|
||||
now = datetime.now(UTC)
|
||||
third_platform = platform_factory()
|
||||
|
||||
subs = [
|
||||
MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=test_platform.id,
|
||||
tier_id=admin_billing_tiers[0].id,
|
||||
status=SubscriptionStatus.ACTIVE.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=30),
|
||||
),
|
||||
MerchantSubscription(
|
||||
merchant_id=other_merchant.id,
|
||||
platform_id=another_platform.id,
|
||||
tier_id=admin_billing_tiers[1].id,
|
||||
status=SubscriptionStatus.TRIAL.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=14),
|
||||
),
|
||||
MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=third_platform.id,
|
||||
tier_id=admin_billing_tiers[2].id,
|
||||
status=SubscriptionStatus.ACTIVE.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=30),
|
||||
),
|
||||
]
|
||||
db.add_all(subs)
|
||||
db.commit()
|
||||
for s in subs:
|
||||
db.refresh(s)
|
||||
return subs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_billing_history_records(db, test_merchant):
|
||||
"""Create 3 billing history records for test_merchant."""
|
||||
records = []
|
||||
for i in range(3):
|
||||
record = BillingHistory(
|
||||
merchant_id=test_merchant.id,
|
||||
stripe_invoice_id=f"in_admin_test_{i}",
|
||||
invoice_number=f"ADM-{i:03d}",
|
||||
invoice_date=datetime.now(UTC) - timedelta(days=i * 30),
|
||||
subtotal_cents=4900,
|
||||
tax_cents=0,
|
||||
total_cents=4900,
|
||||
amount_paid_cents=4900,
|
||||
currency="EUR",
|
||||
status="paid",
|
||||
)
|
||||
records.append(record)
|
||||
db.add_all(records)
|
||||
db.commit()
|
||||
for r in records:
|
||||
db.refresh(r)
|
||||
return records
|
||||
@@ -1,12 +1,11 @@
|
||||
# tests/unit/services/test_billing_service.py
|
||||
# app/modules/billing/tests/unit/test_billing_service.py
|
||||
"""Unit tests for BillingService."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.tenancy.exceptions import StoreNotFoundException
|
||||
from app.modules.billing.services.billing_service import (
|
||||
BillingService,
|
||||
NoActiveSubscriptionError,
|
||||
@@ -25,26 +24,107 @@ from app.modules.billing.models import (
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tier Lookup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceTiers:
|
||||
"""Test suite for BillingService tier operations."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Initialize service instance before each test."""
|
||||
self.service = BillingService()
|
||||
|
||||
def test_get_tier_by_code_found(self, db, bs_tier_essential):
|
||||
"""Returns the active tier."""
|
||||
tier = self.service.get_tier_by_code(db, "essential")
|
||||
assert tier.code == "essential"
|
||||
|
||||
def test_get_tier_by_code_not_found(self, db):
|
||||
"""Test getting non-existent tier raises error."""
|
||||
"""Nonexistent tier raises TierNotFoundError."""
|
||||
with pytest.raises(TierNotFoundError) as exc_info:
|
||||
self.service.get_tier_by_code(db, "nonexistent")
|
||||
|
||||
assert exc_info.value.tier_code == "nonexistent"
|
||||
|
||||
def test_get_tier_by_code_inactive_not_returned(self, db, bs_tier_essential):
|
||||
"""Inactive tier raises TierNotFoundError (only active tiers returned)."""
|
||||
bs_tier_essential.is_active = False
|
||||
db.flush()
|
||||
|
||||
with pytest.raises(TierNotFoundError):
|
||||
self.service.get_tier_by_code(db, "essential")
|
||||
|
||||
|
||||
# TestBillingServiceCheckout removed — depends on refactored store_id-based API
|
||||
# TestBillingServicePortal removed — depends on refactored store_id-based API
|
||||
# ============================================================================
|
||||
# Available Tiers with Upgrade/Downgrade Flags
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceAvailableTiers:
|
||||
"""Test suite for get_available_tiers (upgrade/downgrade detection)."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_get_available_tiers_returns_all(self, db, bs_tiers):
|
||||
"""Returns all active public tiers."""
|
||||
tier_list, tier_order = self.service.get_available_tiers(db, None)
|
||||
assert len(tier_list) == 3
|
||||
|
||||
def test_get_available_tiers_marks_current(self, db, bs_tiers):
|
||||
"""Current tier is marked with is_current=True."""
|
||||
tier_list, _ = self.service.get_available_tiers(db, bs_tiers[1].id)
|
||||
|
||||
current = [t for t in tier_list if t["is_current"]]
|
||||
assert len(current) == 1
|
||||
assert current[0]["code"] == "professional"
|
||||
|
||||
def test_get_available_tiers_upgrade_flags(self, db, bs_tiers):
|
||||
"""Tiers with higher display_order have can_upgrade=True."""
|
||||
tier_list, _ = self.service.get_available_tiers(db, bs_tiers[0].id)
|
||||
|
||||
essential = next(t for t in tier_list if t["code"] == "essential")
|
||||
professional = next(t for t in tier_list if t["code"] == "professional")
|
||||
business = next(t for t in tier_list if t["code"] == "business")
|
||||
|
||||
assert essential["is_current"] is True
|
||||
assert essential["can_upgrade"] is False
|
||||
assert professional["can_upgrade"] is True
|
||||
assert business["can_upgrade"] is True
|
||||
|
||||
def test_get_available_tiers_downgrade_flags(self, db, bs_tiers):
|
||||
"""Tiers with lower display_order have can_downgrade=True."""
|
||||
tier_list, _ = self.service.get_available_tiers(db, bs_tiers[2].id)
|
||||
|
||||
essential = next(t for t in tier_list if t["code"] == "essential")
|
||||
professional = next(t for t in tier_list if t["code"] == "professional")
|
||||
business = next(t for t in tier_list if t["code"] == "business")
|
||||
|
||||
assert essential["can_downgrade"] is True
|
||||
assert professional["can_downgrade"] is True
|
||||
assert business["is_current"] is True
|
||||
assert business["can_downgrade"] is False
|
||||
|
||||
def test_get_available_tiers_no_current_tier(self, db, bs_tiers):
|
||||
"""When current_tier_id is None, no tier is marked current."""
|
||||
tier_list, _ = self.service.get_available_tiers(db, None)
|
||||
assert all(t["is_current"] is False for t in tier_list)
|
||||
|
||||
def test_get_available_tiers_returns_tier_order_map(self, db, bs_tiers):
|
||||
"""Returns tier_order map of code → display_order."""
|
||||
_, tier_order = self.service.get_available_tiers(db, None)
|
||||
assert tier_order["essential"] == 1
|
||||
assert tier_order["professional"] == 2
|
||||
assert tier_order["business"] == 3
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Invoices
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -53,17 +133,42 @@ class TestBillingServiceInvoices:
|
||||
"""Test suite for BillingService invoice operations."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Initialize service instance before each test."""
|
||||
self.service = BillingService()
|
||||
|
||||
def test_get_invoices_empty(self, db, test_store):
|
||||
"""Test getting invoices when none exist."""
|
||||
invoices, total = self.service.get_invoices(db, test_store.id)
|
||||
|
||||
def test_get_invoices_empty(self, db, test_merchant):
|
||||
"""Returns empty list and zero total when no invoices exist."""
|
||||
invoices, total = self.service.get_invoices(db, test_merchant.id)
|
||||
assert invoices == []
|
||||
assert total == 0
|
||||
|
||||
# test_get_invoices_with_data and test_get_invoices_pagination removed — fixture model mismatch after migration
|
||||
def test_get_invoices_with_data(self, db, bs_billing_history):
|
||||
"""Returns invoices for the merchant."""
|
||||
merchant_id = bs_billing_history[0].merchant_id
|
||||
invoices, total = self.service.get_invoices(db, merchant_id)
|
||||
assert total == 3
|
||||
assert len(invoices) == 3
|
||||
|
||||
def test_get_invoices_pagination(self, db, bs_billing_history):
|
||||
"""Pagination limits and offsets results."""
|
||||
merchant_id = bs_billing_history[0].merchant_id
|
||||
invoices, total = self.service.get_invoices(db, merchant_id, skip=0, limit=2)
|
||||
assert total == 3
|
||||
assert len(invoices) == 2
|
||||
|
||||
invoices2, _ = self.service.get_invoices(db, merchant_id, skip=2, limit=2)
|
||||
assert len(invoices2) == 1
|
||||
|
||||
def test_get_invoices_ordered_by_date_desc(self, db, bs_billing_history):
|
||||
"""Invoices are returned newest first."""
|
||||
merchant_id = bs_billing_history[0].merchant_id
|
||||
invoices, _ = self.service.get_invoices(db, merchant_id)
|
||||
dates = [inv.invoice_date for inv in invoices]
|
||||
assert dates == sorted(dates, reverse=True)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Add-ons
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -72,55 +177,442 @@ class TestBillingServiceAddons:
|
||||
"""Test suite for BillingService addon operations."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Initialize service instance before each test."""
|
||||
self.service = BillingService()
|
||||
|
||||
def test_get_available_addons_empty(self, db):
|
||||
"""Test getting addons when none exist."""
|
||||
"""Returns empty when no addons exist."""
|
||||
addons = self.service.get_available_addons(db)
|
||||
assert addons == []
|
||||
|
||||
def test_get_available_addons_with_data(self, db, test_addon_products):
|
||||
"""Test getting all available addons."""
|
||||
"""Returns all active addons."""
|
||||
addons = self.service.get_available_addons(db)
|
||||
|
||||
assert len(addons) == 3
|
||||
assert all(addon.is_active for addon in addons)
|
||||
|
||||
def test_get_available_addons_by_category(self, db, test_addon_products):
|
||||
"""Test filtering addons by category."""
|
||||
"""Filters by category."""
|
||||
domain_addons = self.service.get_available_addons(db, category="domain")
|
||||
|
||||
assert len(domain_addons) == 1
|
||||
assert domain_addons[0].category == "domain"
|
||||
|
||||
def test_get_store_addons_empty(self, db, test_store):
|
||||
"""Test getting store addons when none purchased."""
|
||||
"""Returns empty when store has no purchased addons."""
|
||||
addons = self.service.get_store_addons(db, test_store.id)
|
||||
assert addons == []
|
||||
|
||||
|
||||
|
||||
# TestBillingServiceCancellation removed — depends on refactored store_id-based API
|
||||
# TestBillingServiceStore removed — get_store method was removed from BillingService
|
||||
# ============================================================================
|
||||
# Subscription with Tier
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ==================== Fixtures ====================
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceSubscriptionWithTier:
|
||||
"""Tests for get_subscription_with_tier."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_get_subscription_with_tier_existing(
|
||||
self, db, bs_subscription
|
||||
):
|
||||
"""Returns (subscription, tier) tuple for existing subscription."""
|
||||
sub, tier = self.service.get_subscription_with_tier(
|
||||
db, bs_subscription.merchant_id, bs_subscription.platform_id
|
||||
)
|
||||
assert sub.id == bs_subscription.id
|
||||
assert tier is not None
|
||||
assert tier.code == "essential"
|
||||
|
||||
def test_get_subscription_with_tier_creates_if_missing(
|
||||
self, db, test_merchant, test_platform, bs_tier_essential
|
||||
):
|
||||
"""Creates a trial subscription when none exists (via get_or_create)."""
|
||||
sub, tier = self.service.get_subscription_with_tier(
|
||||
db, test_merchant.id, test_platform.id
|
||||
)
|
||||
assert sub is not None
|
||||
assert sub.status == SubscriptionStatus.TRIAL.value
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Change Tier
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceChangeTier:
|
||||
"""Tests for change_tier (the tier upgrade/downgrade flow)."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_change_tier_no_subscription_raises(self, db, bs_tiers):
|
||||
"""Raises NoActiveSubscriptionError when no subscription exists."""
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.change_tier(db, 99999, 99999, "professional", False)
|
||||
|
||||
def test_change_tier_no_stripe_subscription_raises(
|
||||
self, db, bs_subscription, bs_tiers
|
||||
):
|
||||
"""Raises when subscription has no stripe_subscription_id."""
|
||||
# bs_subscription has no Stripe IDs
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.change_tier(
|
||||
db,
|
||||
bs_subscription.merchant_id,
|
||||
bs_subscription.platform_id,
|
||||
"professional",
|
||||
False,
|
||||
)
|
||||
|
||||
def test_change_tier_nonexistent_tier_raises(
|
||||
self, db, bs_stripe_subscription
|
||||
):
|
||||
"""Raises TierNotFoundError for nonexistent tier."""
|
||||
with pytest.raises(TierNotFoundError):
|
||||
self.service.change_tier(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
"nonexistent",
|
||||
False,
|
||||
)
|
||||
|
||||
def test_change_tier_no_price_id_raises(
|
||||
self, db, bs_stripe_subscription, bs_tiers
|
||||
):
|
||||
"""Raises StripePriceNotConfiguredError when tier has no Stripe price."""
|
||||
# bs_tiers have no stripe_price_* set
|
||||
with pytest.raises(StripePriceNotConfiguredError):
|
||||
self.service.change_tier(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
"professional",
|
||||
False,
|
||||
)
|
||||
|
||||
@patch("app.modules.billing.services.billing_service.stripe_service")
|
||||
def test_change_tier_success(
|
||||
self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe
|
||||
):
|
||||
"""Successful tier change updates local subscription and calls Stripe."""
|
||||
mock_stripe.is_configured = True
|
||||
mock_stripe.update_subscription = MagicMock()
|
||||
|
||||
result = self.service.change_tier(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
"professional",
|
||||
False,
|
||||
)
|
||||
|
||||
assert result["new_tier"] == "professional"
|
||||
assert result["effective_immediately"] is True
|
||||
assert bs_stripe_subscription.tier_id == bs_tiers_with_stripe[1].id
|
||||
mock_stripe.update_subscription.assert_called_once()
|
||||
|
||||
@patch("app.modules.billing.services.billing_service.stripe_service")
|
||||
def test_change_tier_annual_uses_annual_price(
|
||||
self, mock_stripe, db, bs_stripe_subscription, bs_tiers_with_stripe
|
||||
):
|
||||
"""Annual billing selects stripe_price_annual_id."""
|
||||
mock_stripe.is_configured = True
|
||||
mock_stripe.update_subscription = MagicMock()
|
||||
|
||||
self.service.change_tier(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
"professional",
|
||||
True,
|
||||
)
|
||||
|
||||
call_args = mock_stripe.update_subscription.call_args
|
||||
assert call_args.kwargs["new_price_id"] == "price_pro_annual"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _is_upgrade
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceIsUpgrade:
|
||||
"""Tests for _is_upgrade helper."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_is_upgrade_true(self, db, bs_tiers):
|
||||
"""Higher display_order is an upgrade."""
|
||||
assert self.service._is_upgrade(db, bs_tiers[0].id, bs_tiers[2].id) is True
|
||||
|
||||
def test_is_upgrade_false_downgrade(self, db, bs_tiers):
|
||||
"""Lower display_order is not an upgrade."""
|
||||
assert self.service._is_upgrade(db, bs_tiers[2].id, bs_tiers[0].id) is False
|
||||
|
||||
def test_is_upgrade_same_tier(self, db, bs_tiers):
|
||||
"""Same tier is not an upgrade."""
|
||||
assert self.service._is_upgrade(db, bs_tiers[1].id, bs_tiers[1].id) is False
|
||||
|
||||
def test_is_upgrade_none_ids(self, db):
|
||||
"""None tier IDs return False."""
|
||||
assert self.service._is_upgrade(db, None, None) is False
|
||||
assert self.service._is_upgrade(db, None, 1) is False
|
||||
assert self.service._is_upgrade(db, 1, None) is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cancel Subscription
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceCancel:
|
||||
"""Tests for cancel_subscription."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_cancel_no_subscription_raises(self, db):
|
||||
"""Raises when no subscription found."""
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.cancel_subscription(db, 99999, 99999, None, False)
|
||||
|
||||
def test_cancel_no_stripe_id_raises(self, db, bs_subscription):
|
||||
"""Raises when subscription has no stripe_subscription_id."""
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.cancel_subscription(
|
||||
db,
|
||||
bs_subscription.merchant_id,
|
||||
bs_subscription.platform_id,
|
||||
"reason",
|
||||
False,
|
||||
)
|
||||
|
||||
@patch("app.modules.billing.services.billing_service.stripe_service")
|
||||
def test_cancel_success(self, mock_stripe, db, bs_stripe_subscription):
|
||||
"""Cancellation records timestamp and reason."""
|
||||
mock_stripe.is_configured = True
|
||||
mock_stripe.cancel_subscription = MagicMock()
|
||||
|
||||
result = self.service.cancel_subscription(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
"Too expensive",
|
||||
False,
|
||||
)
|
||||
|
||||
assert result["message"] == "Subscription cancelled successfully"
|
||||
assert bs_stripe_subscription.cancelled_at is not None
|
||||
assert bs_stripe_subscription.cancellation_reason == "Too expensive"
|
||||
mock_stripe.cancel_subscription.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Reactivate Subscription
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceReactivate:
|
||||
"""Tests for reactivate_subscription."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_reactivate_no_subscription_raises(self, db):
|
||||
"""Raises when no subscription found."""
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.reactivate_subscription(db, 99999, 99999)
|
||||
|
||||
def test_reactivate_not_cancelled_raises(self, db, bs_stripe_subscription):
|
||||
"""Raises SubscriptionNotCancelledError when not cancelled."""
|
||||
with pytest.raises(SubscriptionNotCancelledError):
|
||||
self.service.reactivate_subscription(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
)
|
||||
|
||||
@patch("app.modules.billing.services.billing_service.stripe_service")
|
||||
def test_reactivate_success(self, mock_stripe, db, bs_stripe_subscription):
|
||||
"""Reactivation clears cancellation and calls Stripe."""
|
||||
mock_stripe.is_configured = True
|
||||
mock_stripe.reactivate_subscription = MagicMock()
|
||||
|
||||
# Cancel first
|
||||
bs_stripe_subscription.cancelled_at = datetime.now(UTC)
|
||||
bs_stripe_subscription.cancellation_reason = "Testing"
|
||||
db.flush()
|
||||
|
||||
result = self.service.reactivate_subscription(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
)
|
||||
|
||||
assert result["message"] == "Subscription reactivated successfully"
|
||||
assert bs_stripe_subscription.cancelled_at is None
|
||||
assert bs_stripe_subscription.cancellation_reason is None
|
||||
mock_stripe.reactivate_subscription.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Checkout Session
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceCheckout:
|
||||
"""Tests for create_checkout_session."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_checkout_stripe_not_configured_raises(self, db, bs_tiers_with_stripe):
|
||||
"""Raises PaymentSystemNotConfiguredError when Stripe is off."""
|
||||
with patch(
|
||||
"app.modules.billing.services.billing_service.stripe_service"
|
||||
) as mock_stripe:
|
||||
mock_stripe.is_configured = False
|
||||
|
||||
with pytest.raises(PaymentSystemNotConfiguredError):
|
||||
self.service.create_checkout_session(
|
||||
db, 1, 1, "essential", False, "http://ok", "http://cancel"
|
||||
)
|
||||
|
||||
def test_checkout_nonexistent_tier_raises(self, db):
|
||||
"""Raises TierNotFoundError for nonexistent tier."""
|
||||
with patch(
|
||||
"app.modules.billing.services.billing_service.stripe_service"
|
||||
) as mock_stripe:
|
||||
mock_stripe.is_configured = True
|
||||
|
||||
with pytest.raises(TierNotFoundError):
|
||||
self.service.create_checkout_session(
|
||||
db, 1, 1, "nonexistent", False, "http://ok", "http://cancel"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Portal Session
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServicePortal:
|
||||
"""Tests for create_portal_session."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_portal_stripe_not_configured_raises(self, db):
|
||||
"""Raises PaymentSystemNotConfiguredError when Stripe is off."""
|
||||
with patch(
|
||||
"app.modules.billing.services.billing_service.stripe_service"
|
||||
) as mock_stripe:
|
||||
mock_stripe.is_configured = False
|
||||
|
||||
with pytest.raises(PaymentSystemNotConfiguredError):
|
||||
self.service.create_portal_session(db, 1, 1, "http://return")
|
||||
|
||||
def test_portal_no_subscription_raises(self, db):
|
||||
"""Raises NoActiveSubscriptionError when no subscription found."""
|
||||
with patch(
|
||||
"app.modules.billing.services.billing_service.stripe_service"
|
||||
) as mock_stripe:
|
||||
mock_stripe.is_configured = True
|
||||
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.create_portal_session(db, 99999, 99999, "http://return")
|
||||
|
||||
def test_portal_no_customer_id_raises(self, db, bs_subscription):
|
||||
"""Raises when subscription has no stripe_customer_id."""
|
||||
with patch(
|
||||
"app.modules.billing.services.billing_service.stripe_service"
|
||||
) as mock_stripe:
|
||||
mock_stripe.is_configured = True
|
||||
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.create_portal_session(
|
||||
db,
|
||||
bs_subscription.merchant_id,
|
||||
bs_subscription.platform_id,
|
||||
"http://return",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Upcoming Invoice
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestBillingServiceUpcomingInvoice:
|
||||
"""Tests for get_upcoming_invoice."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = BillingService()
|
||||
|
||||
def test_upcoming_invoice_no_subscription_raises(self, db):
|
||||
"""Raises when no subscription exists."""
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.get_upcoming_invoice(db, 99999, 99999)
|
||||
|
||||
def test_upcoming_invoice_no_customer_id_raises(self, db, bs_subscription):
|
||||
"""Raises when subscription has no stripe_customer_id."""
|
||||
with pytest.raises(NoActiveSubscriptionError):
|
||||
self.service.get_upcoming_invoice(
|
||||
db, bs_subscription.merchant_id, bs_subscription.platform_id
|
||||
)
|
||||
|
||||
def test_upcoming_invoice_stripe_not_configured_returns_empty(
|
||||
self, db, bs_stripe_subscription
|
||||
):
|
||||
"""Returns empty invoice when Stripe is not configured."""
|
||||
with patch(
|
||||
"app.modules.billing.services.billing_service.stripe_service"
|
||||
) as mock_stripe:
|
||||
mock_stripe.is_configured = False
|
||||
|
||||
result = self.service.get_upcoming_invoice(
|
||||
db,
|
||||
bs_stripe_subscription.merchant_id,
|
||||
bs_stripe_subscription.platform_id,
|
||||
)
|
||||
|
||||
assert result["amount_due_cents"] == 0
|
||||
assert result["line_items"] == []
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_subscription_tier(db):
|
||||
"""Create a basic subscription tier."""
|
||||
def bs_tier_essential(db):
|
||||
"""Create essential subscription tier."""
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
description="Essential plan",
|
||||
price_monthly_cents=4900,
|
||||
price_annual_cents=49000,
|
||||
orders_per_month=100,
|
||||
products_limit=200,
|
||||
team_members=1,
|
||||
features=["basic_support"],
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
@@ -132,39 +624,14 @@ def test_subscription_tier(db):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_subscription_tier_with_stripe(db):
|
||||
"""Create a subscription tier with Stripe configuration."""
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
description="Essential plan",
|
||||
price_monthly_cents=4900,
|
||||
price_annual_cents=49000,
|
||||
orders_per_month=100,
|
||||
products_limit=200,
|
||||
team_members=1,
|
||||
features=["basic_support"],
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
stripe_product_id="prod_test123",
|
||||
stripe_price_monthly_id="price_test123",
|
||||
stripe_price_annual_id="price_test456",
|
||||
)
|
||||
db.add(tier)
|
||||
db.commit()
|
||||
db.refresh(tier)
|
||||
return tier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_subscription_tiers(db):
|
||||
"""Create multiple subscription tiers."""
|
||||
def bs_tiers(db):
|
||||
"""Create three tiers without Stripe config."""
|
||||
tiers = [
|
||||
SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=4900,
|
||||
price_annual_cents=49000,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
@@ -173,6 +640,7 @@ def test_subscription_tiers(db):
|
||||
code="professional",
|
||||
name="Professional",
|
||||
price_monthly_cents=9900,
|
||||
price_annual_cents=99000,
|
||||
display_order=2,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
@@ -181,6 +649,7 @@ def test_subscription_tiers(db):
|
||||
code="business",
|
||||
name="Business",
|
||||
price_monthly_cents=19900,
|
||||
price_annual_cents=199000,
|
||||
display_order=3,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
@@ -188,136 +657,107 @@ def test_subscription_tiers(db):
|
||||
]
|
||||
db.add_all(tiers)
|
||||
db.commit()
|
||||
for tier in tiers:
|
||||
db.refresh(tier)
|
||||
for t in tiers:
|
||||
db.refresh(t)
|
||||
return tiers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_subscription(db, test_store):
|
||||
"""Create a basic subscription for testing."""
|
||||
# Create tier first
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=4900,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
)
|
||||
db.add(tier)
|
||||
def bs_tiers_with_stripe(db):
|
||||
"""Create tiers with Stripe price IDs configured."""
|
||||
tiers = [
|
||||
SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=4900,
|
||||
price_annual_cents=49000,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
stripe_product_id="prod_essential",
|
||||
stripe_price_monthly_id="price_ess_monthly",
|
||||
stripe_price_annual_id="price_ess_annual",
|
||||
),
|
||||
SubscriptionTier(
|
||||
code="professional",
|
||||
name="Professional",
|
||||
price_monthly_cents=9900,
|
||||
price_annual_cents=99000,
|
||||
display_order=2,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
stripe_product_id="prod_professional",
|
||||
stripe_price_monthly_id="price_pro_monthly",
|
||||
stripe_price_annual_id="price_pro_annual",
|
||||
),
|
||||
SubscriptionTier(
|
||||
code="business",
|
||||
name="Business",
|
||||
price_monthly_cents=19900,
|
||||
price_annual_cents=199000,
|
||||
display_order=3,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
stripe_product_id="prod_business",
|
||||
stripe_price_monthly_id="price_biz_monthly",
|
||||
stripe_price_annual_id="price_biz_annual",
|
||||
),
|
||||
]
|
||||
db.add_all(tiers)
|
||||
db.commit()
|
||||
|
||||
subscription = MerchantSubscription(
|
||||
store_id=test_store.id,
|
||||
tier="essential",
|
||||
status=SubscriptionStatus.ACTIVE,
|
||||
period_start=datetime.now(timezone.utc),
|
||||
period_end=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(subscription)
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
return subscription
|
||||
for t in tiers:
|
||||
db.refresh(t)
|
||||
return tiers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_active_subscription(db, test_store):
|
||||
def bs_subscription(db, test_merchant, test_platform, bs_tier_essential):
|
||||
"""Create an active merchant subscription (no Stripe IDs)."""
|
||||
now = datetime.now(UTC)
|
||||
sub = MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=test_platform.id,
|
||||
tier_id=bs_tier_essential.id,
|
||||
status=SubscriptionStatus.ACTIVE.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=30),
|
||||
)
|
||||
db.add(sub)
|
||||
db.commit()
|
||||
db.refresh(sub)
|
||||
return sub
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bs_stripe_subscription(db, test_merchant, test_platform, bs_tier_essential):
|
||||
"""Create an active subscription with Stripe IDs."""
|
||||
# Create tier first if not exists
|
||||
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first()
|
||||
if not tier:
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=4900,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
)
|
||||
db.add(tier)
|
||||
db.commit()
|
||||
|
||||
subscription = MerchantSubscription(
|
||||
store_id=test_store.id,
|
||||
tier="essential",
|
||||
status=SubscriptionStatus.ACTIVE,
|
||||
now = datetime.now(UTC)
|
||||
sub = MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=test_platform.id,
|
||||
tier_id=bs_tier_essential.id,
|
||||
status=SubscriptionStatus.ACTIVE.value,
|
||||
stripe_customer_id="cus_test123",
|
||||
stripe_subscription_id="sub_test123",
|
||||
period_start=datetime.now(timezone.utc),
|
||||
period_end=datetime.now(timezone.utc),
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=30),
|
||||
)
|
||||
db.add(subscription)
|
||||
db.add(sub)
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
return subscription
|
||||
db.refresh(sub)
|
||||
return sub
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_cancelled_subscription(db, test_store):
|
||||
"""Create a cancelled subscription."""
|
||||
# Create tier first if not exists
|
||||
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first()
|
||||
if not tier:
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=4900,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
)
|
||||
db.add(tier)
|
||||
db.commit()
|
||||
|
||||
subscription = MerchantSubscription(
|
||||
store_id=test_store.id,
|
||||
tier="essential",
|
||||
status=SubscriptionStatus.ACTIVE,
|
||||
stripe_customer_id="cus_test123",
|
||||
stripe_subscription_id="sub_test123",
|
||||
period_start=datetime.now(timezone.utc),
|
||||
period_end=datetime.now(timezone.utc),
|
||||
cancelled_at=datetime.now(timezone.utc),
|
||||
cancellation_reason="Too expensive",
|
||||
)
|
||||
db.add(subscription)
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
return subscription
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_billing_history(db, test_store):
|
||||
"""Create a billing history record."""
|
||||
record = BillingHistory(
|
||||
store_id=test_store.id,
|
||||
stripe_invoice_id="in_test123",
|
||||
invoice_number="INV-001",
|
||||
invoice_date=datetime.now(timezone.utc),
|
||||
subtotal_cents=4900,
|
||||
tax_cents=0,
|
||||
total_cents=4900,
|
||||
amount_paid_cents=4900,
|
||||
currency="EUR",
|
||||
status="paid",
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
db.refresh(record)
|
||||
return record
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_multiple_invoices(db, test_store):
|
||||
"""Create multiple billing history records."""
|
||||
def bs_billing_history(db, test_merchant):
|
||||
"""Create billing history records for test_merchant."""
|
||||
records = []
|
||||
for i in range(5):
|
||||
for i in range(3):
|
||||
record = BillingHistory(
|
||||
store_id=test_store.id,
|
||||
stripe_invoice_id=f"in_test{i}",
|
||||
invoice_number=f"INV-{i:03d}",
|
||||
invoice_date=datetime.now(timezone.utc),
|
||||
merchant_id=test_merchant.id,
|
||||
stripe_invoice_id=f"in_bs_test_{i}",
|
||||
invoice_number=f"BS-{i:03d}",
|
||||
invoice_date=datetime.now(UTC) - timedelta(days=i * 30),
|
||||
subtotal_cents=4900,
|
||||
tax_cents=0,
|
||||
total_cents=4900,
|
||||
@@ -328,6 +768,8 @@ def test_multiple_invoices(db, test_store):
|
||||
records.append(record)
|
||||
db.add_all(records)
|
||||
db.commit()
|
||||
for r in records:
|
||||
db.refresh(r)
|
||||
return records
|
||||
|
||||
|
||||
|
||||
579
app/modules/billing/tests/unit/test_subscription_service.py
Normal file
579
app/modules/billing/tests/unit/test_subscription_service.py
Normal file
@@ -0,0 +1,579 @@
|
||||
# app/modules/billing/tests/unit/test_subscription_service.py
|
||||
"""Unit tests for SubscriptionService (merchant-level subscription management)."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.billing.exceptions import SubscriptionNotFoundException
|
||||
from app.modules.billing.models import (
|
||||
MerchantSubscription,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
TierCode,
|
||||
)
|
||||
from app.modules.billing.services.subscription_service import SubscriptionService
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tier Information
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceTierLookup:
|
||||
"""Tests for tier lookup methods."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_get_tier_by_code_returns_tier(self, db, billing_tier_essential):
|
||||
"""Existing tier code returns the tier object."""
|
||||
tier = self.service.get_tier_by_code(db, "essential")
|
||||
|
||||
assert tier is not None
|
||||
assert tier.code == "essential"
|
||||
assert tier.id == billing_tier_essential.id
|
||||
|
||||
def test_get_tier_by_code_nonexistent_returns_none(self, db):
|
||||
"""Nonexistent tier code returns None (does not raise)."""
|
||||
tier = self.service.get_tier_by_code(db, "nonexistent")
|
||||
assert tier is None
|
||||
|
||||
def test_get_tier_id_returns_id(self, db, billing_tier_essential):
|
||||
"""get_tier_id returns the integer ID for a valid code."""
|
||||
tier_id = self.service.get_tier_id(db, "essential")
|
||||
assert tier_id == billing_tier_essential.id
|
||||
|
||||
def test_get_tier_id_nonexistent_returns_none(self, db):
|
||||
"""get_tier_id returns None for nonexistent code."""
|
||||
tier_id = self.service.get_tier_id(db, "nonexistent")
|
||||
assert tier_id is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceGetAllTiers:
|
||||
"""Tests for get_all_tiers with platform filtering."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_get_all_tiers_returns_active_public_tiers(self, db, billing_tiers):
|
||||
"""Returns only active + public tiers, ordered by display_order."""
|
||||
tiers = self.service.get_all_tiers(db)
|
||||
|
||||
assert len(tiers) == 3
|
||||
assert tiers[0].code == "essential"
|
||||
assert tiers[1].code == "professional"
|
||||
assert tiers[2].code == "business"
|
||||
|
||||
def test_get_all_tiers_excludes_inactive(self, db, billing_tiers):
|
||||
"""Inactive tiers are excluded."""
|
||||
# Deactivate one
|
||||
billing_tiers[0].is_active = False
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_all_tiers(db)
|
||||
assert len(tiers) == 2
|
||||
assert all(t.code != "essential" for t in tiers)
|
||||
|
||||
def test_get_all_tiers_excludes_non_public(self, db, billing_tiers):
|
||||
"""Non-public tiers are excluded."""
|
||||
billing_tiers[2].is_public = False
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_all_tiers(db)
|
||||
assert len(tiers) == 2
|
||||
assert all(t.code != "business" for t in tiers)
|
||||
|
||||
def test_get_all_tiers_with_platform_filter(self, db, test_platform, billing_tiers):
|
||||
"""Platform filter returns platform-specific tiers + global tiers."""
|
||||
# Make one tier platform-specific
|
||||
billing_tiers[0].platform_id = test_platform.id
|
||||
# billing_tiers[1] and [2] remain global (platform_id=NULL)
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_all_tiers(db, platform_id=test_platform.id)
|
||||
|
||||
codes = {t.code for t in tiers}
|
||||
assert "essential" in codes # Platform-specific
|
||||
assert "professional" in codes # Global
|
||||
assert "business" in codes # Global
|
||||
|
||||
def test_get_all_tiers_platform_filter_excludes_other_platform(
|
||||
self, db, test_platform, another_platform, billing_tiers
|
||||
):
|
||||
"""Tiers belonging to another platform are excluded."""
|
||||
billing_tiers[0].platform_id = another_platform.id
|
||||
db.flush()
|
||||
|
||||
tiers = self.service.get_all_tiers(db, platform_id=test_platform.id)
|
||||
|
||||
codes = {t.code for t in tiers}
|
||||
assert "essential" not in codes
|
||||
assert len(tiers) == 2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Merchant Subscription CRUD
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceGetMerchantSubscription:
|
||||
"""Tests for subscription retrieval methods."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_get_merchant_subscription_found(
|
||||
self, db, billing_subscription
|
||||
):
|
||||
"""Returns subscription when it exists."""
|
||||
sub = self.service.get_merchant_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub is not None
|
||||
assert sub.id == billing_subscription.id
|
||||
|
||||
def test_get_merchant_subscription_not_found(self, db):
|
||||
"""Returns None when no subscription exists."""
|
||||
sub = self.service.get_merchant_subscription(db, 99999, 99999)
|
||||
assert sub is None
|
||||
|
||||
def test_get_merchant_subscription_eager_loads_tier(
|
||||
self, db, billing_subscription
|
||||
):
|
||||
"""Subscription comes with tier relationship loaded."""
|
||||
sub = self.service.get_merchant_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
# Accessing .tier should not trigger a lazy load error
|
||||
assert sub.tier is not None
|
||||
assert sub.tier.code == "essential"
|
||||
|
||||
def test_get_merchant_subscription_eager_loads_platform(
|
||||
self, db, billing_subscription
|
||||
):
|
||||
"""Subscription comes with platform relationship loaded."""
|
||||
sub = self.service.get_merchant_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.platform is not None
|
||||
|
||||
def test_get_merchant_subscriptions_all(
|
||||
self, db, test_merchant, test_platform, another_platform, billing_tier_essential
|
||||
):
|
||||
"""Returns all subscriptions for a merchant across platforms."""
|
||||
now = datetime.now(UTC)
|
||||
sub1 = MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=test_platform.id,
|
||||
tier_id=billing_tier_essential.id,
|
||||
status=SubscriptionStatus.ACTIVE.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=30),
|
||||
)
|
||||
sub2 = MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=another_platform.id,
|
||||
tier_id=billing_tier_essential.id,
|
||||
status=SubscriptionStatus.TRIAL.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=14),
|
||||
)
|
||||
db.add_all([sub1, sub2])
|
||||
db.flush()
|
||||
|
||||
subs = self.service.get_merchant_subscriptions(db, test_merchant.id)
|
||||
assert len(subs) == 2
|
||||
|
||||
def test_get_merchant_subscriptions_empty(self, db, test_merchant):
|
||||
"""Returns empty list when merchant has no subscriptions."""
|
||||
subs = self.service.get_merchant_subscriptions(db, test_merchant.id)
|
||||
assert subs == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceGetOrRaise:
|
||||
"""Tests for get_subscription_or_raise."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_get_subscription_or_raise_found(self, db, billing_subscription):
|
||||
"""Returns subscription when found."""
|
||||
sub = self.service.get_subscription_or_raise(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.id == billing_subscription.id
|
||||
|
||||
def test_get_subscription_or_raise_not_found(self, db):
|
||||
"""Raises SubscriptionNotFoundException when not found."""
|
||||
with pytest.raises(SubscriptionNotFoundException):
|
||||
self.service.get_subscription_or_raise(db, 99999, 99999)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Create Subscription
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceCreate:
|
||||
"""Tests for create_merchant_subscription."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_create_trial_subscription(
|
||||
self, db, test_merchant, test_platform, billing_tier_essential
|
||||
):
|
||||
"""Default creation starts a 14-day trial."""
|
||||
sub = self.service.create_merchant_subscription(
|
||||
db, test_merchant.id, test_platform.id
|
||||
)
|
||||
|
||||
assert sub.merchant_id == test_merchant.id
|
||||
assert sub.platform_id == test_platform.id
|
||||
assert sub.status == SubscriptionStatus.TRIAL.value
|
||||
assert sub.tier_id == billing_tier_essential.id
|
||||
assert sub.trial_ends_at is not None
|
||||
assert sub.is_annual is False
|
||||
|
||||
def test_create_subscription_period_end_matches_trial(
|
||||
self, db, test_merchant, test_platform, billing_tier_essential
|
||||
):
|
||||
"""Period end is set to trial_days from now."""
|
||||
sub = self.service.create_merchant_subscription(
|
||||
db, test_merchant.id, test_platform.id, trial_days=7
|
||||
)
|
||||
|
||||
expected_end = sub.period_start + timedelta(days=7)
|
||||
# Allow 1 second tolerance
|
||||
assert abs((sub.period_end - expected_end).total_seconds()) < 1
|
||||
|
||||
def test_create_annual_subscription(
|
||||
self, db, test_merchant, test_platform, billing_tier_essential
|
||||
):
|
||||
"""Annual subscription with no trial gets 365-day period."""
|
||||
sub = self.service.create_merchant_subscription(
|
||||
db, test_merchant.id, test_platform.id, trial_days=0, is_annual=True
|
||||
)
|
||||
|
||||
assert sub.status == SubscriptionStatus.ACTIVE.value
|
||||
assert sub.is_annual is True
|
||||
assert sub.trial_ends_at is None
|
||||
expected_end = sub.period_start + timedelta(days=365)
|
||||
assert abs((sub.period_end - expected_end).total_seconds()) < 1
|
||||
|
||||
def test_create_monthly_subscription(
|
||||
self, db, test_merchant, test_platform, billing_tier_essential
|
||||
):
|
||||
"""Monthly subscription with no trial gets 30-day period."""
|
||||
sub = self.service.create_merchant_subscription(
|
||||
db, test_merchant.id, test_platform.id, trial_days=0, is_annual=False
|
||||
)
|
||||
|
||||
assert sub.status == SubscriptionStatus.ACTIVE.value
|
||||
assert sub.is_annual is False
|
||||
assert sub.trial_ends_at is None
|
||||
expected_end = sub.period_start + timedelta(days=30)
|
||||
assert abs((sub.period_end - expected_end).total_seconds()) < 1
|
||||
|
||||
def test_create_subscription_custom_tier(
|
||||
self, db, test_merchant, test_platform, billing_tiers
|
||||
):
|
||||
"""Can specify a custom tier code."""
|
||||
sub = self.service.create_merchant_subscription(
|
||||
db, test_merchant.id, test_platform.id, tier_code="professional"
|
||||
)
|
||||
assert sub.tier_id == billing_tiers[1].id
|
||||
|
||||
def test_create_duplicate_raises(
|
||||
self, db, billing_subscription
|
||||
):
|
||||
"""Creating a second subscription for same merchant+platform raises ValueError."""
|
||||
with pytest.raises(ValueError, match="already has a subscription"):
|
||||
self.service.create_merchant_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceGetOrCreate:
|
||||
"""Tests for get_or_create_subscription."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_get_or_create_returns_existing(self, db, billing_subscription):
|
||||
"""Returns existing subscription without creating a new one."""
|
||||
sub = self.service.get_or_create_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.id == billing_subscription.id
|
||||
|
||||
def test_get_or_create_creates_new(
|
||||
self, db, test_merchant, test_platform, billing_tier_essential
|
||||
):
|
||||
"""Creates new subscription when none exists."""
|
||||
sub = self.service.get_or_create_subscription(
|
||||
db, test_merchant.id, test_platform.id
|
||||
)
|
||||
assert sub is not None
|
||||
assert sub.merchant_id == test_merchant.id
|
||||
assert sub.status == SubscriptionStatus.TRIAL.value
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Upgrade / Cancel / Reactivate
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceUpgradeTier:
|
||||
"""Tests for upgrade_tier."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_upgrade_tier_changes_tier_id(self, db, billing_subscription, billing_tiers):
|
||||
"""Tier is updated to the new tier's ID."""
|
||||
sub = self.service.upgrade_tier(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
"professional",
|
||||
)
|
||||
assert sub.tier_id == billing_tiers[1].id
|
||||
|
||||
def test_upgrade_from_trial_activates(self, db, billing_subscription, billing_tiers):
|
||||
"""Upgrading from trial status sets status to active."""
|
||||
billing_subscription.status = SubscriptionStatus.TRIAL.value
|
||||
db.flush()
|
||||
|
||||
sub = self.service.upgrade_tier(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
"professional",
|
||||
)
|
||||
assert sub.status == SubscriptionStatus.ACTIVE.value
|
||||
|
||||
def test_upgrade_active_stays_active(self, db, billing_subscription, billing_tiers):
|
||||
"""Upgrading an already-active subscription keeps status active."""
|
||||
billing_subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
db.flush()
|
||||
|
||||
sub = self.service.upgrade_tier(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
"professional",
|
||||
)
|
||||
assert sub.status == SubscriptionStatus.ACTIVE.value
|
||||
|
||||
def test_upgrade_nonexistent_tier_raises(self, db, billing_subscription):
|
||||
"""Upgrading to a nonexistent tier raises ValueError."""
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
self.service.upgrade_tier(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
"nonexistent_tier",
|
||||
)
|
||||
|
||||
def test_upgrade_no_subscription_raises(self, db):
|
||||
"""Upgrading when no subscription exists raises SubscriptionNotFoundException."""
|
||||
with pytest.raises(SubscriptionNotFoundException):
|
||||
self.service.upgrade_tier(db, 99999, 99999, "professional")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceCancel:
|
||||
"""Tests for cancel_subscription."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_cancel_sets_status(self, db, billing_subscription):
|
||||
"""Cancellation sets status to cancelled."""
|
||||
sub = self.service.cancel_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.status == SubscriptionStatus.CANCELLED.value
|
||||
|
||||
def test_cancel_records_timestamp(self, db, billing_subscription):
|
||||
"""Cancellation records cancelled_at."""
|
||||
sub = self.service.cancel_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.cancelled_at is not None
|
||||
|
||||
def test_cancel_with_reason(self, db, billing_subscription):
|
||||
"""Cancellation stores the reason."""
|
||||
sub = self.service.cancel_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
reason="Too expensive",
|
||||
)
|
||||
assert sub.cancellation_reason == "Too expensive"
|
||||
|
||||
def test_cancel_no_subscription_raises(self, db):
|
||||
"""Cancelling when no subscription exists raises."""
|
||||
with pytest.raises(SubscriptionNotFoundException):
|
||||
self.service.cancel_subscription(db, 99999, 99999)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestSubscriptionServiceReactivate:
|
||||
"""Tests for reactivate_subscription."""
|
||||
|
||||
def setup_method(self):
|
||||
self.service = SubscriptionService()
|
||||
|
||||
def test_reactivate_sets_active(self, db, billing_subscription):
|
||||
"""Reactivation sets status back to active."""
|
||||
# First cancel
|
||||
self.service.cancel_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
|
||||
sub = self.service.reactivate_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.status == SubscriptionStatus.ACTIVE.value
|
||||
|
||||
def test_reactivate_clears_cancellation(self, db, billing_subscription):
|
||||
"""Reactivation clears cancelled_at and cancellation_reason."""
|
||||
self.service.cancel_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
reason="Testing",
|
||||
)
|
||||
|
||||
sub = self.service.reactivate_subscription(
|
||||
db,
|
||||
billing_subscription.merchant_id,
|
||||
billing_subscription.platform_id,
|
||||
)
|
||||
assert sub.cancelled_at is None
|
||||
assert sub.cancellation_reason is None
|
||||
|
||||
def test_reactivate_no_subscription_raises(self, db):
|
||||
"""Reactivating when no subscription exists raises."""
|
||||
with pytest.raises(SubscriptionNotFoundException):
|
||||
self.service.reactivate_subscription(db, 99999, 99999)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def billing_tier_essential(db):
|
||||
"""Create essential subscription tier."""
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
description="Essential plan",
|
||||
price_monthly_cents=4900,
|
||||
price_annual_cents=49000,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
)
|
||||
db.add(tier)
|
||||
db.commit()
|
||||
db.refresh(tier)
|
||||
return tier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def billing_tiers(db):
|
||||
"""Create essential, professional, and business tiers."""
|
||||
tiers = [
|
||||
SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=4900,
|
||||
price_annual_cents=49000,
|
||||
display_order=1,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
),
|
||||
SubscriptionTier(
|
||||
code="professional",
|
||||
name="Professional",
|
||||
price_monthly_cents=9900,
|
||||
price_annual_cents=99000,
|
||||
display_order=2,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
),
|
||||
SubscriptionTier(
|
||||
code="business",
|
||||
name="Business",
|
||||
price_monthly_cents=19900,
|
||||
price_annual_cents=199000,
|
||||
display_order=3,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
),
|
||||
]
|
||||
db.add_all(tiers)
|
||||
db.commit()
|
||||
for t in tiers:
|
||||
db.refresh(t)
|
||||
return tiers
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def billing_subscription(db, test_merchant, test_platform, billing_tier_essential):
|
||||
"""Create an active merchant subscription (essential tier)."""
|
||||
now = datetime.now(UTC)
|
||||
sub = MerchantSubscription(
|
||||
merchant_id=test_merchant.id,
|
||||
platform_id=test_platform.id,
|
||||
tier_id=billing_tier_essential.id,
|
||||
status=SubscriptionStatus.ACTIVE.value,
|
||||
period_start=now,
|
||||
period_end=now + timedelta(days=30),
|
||||
)
|
||||
db.add(sub)
|
||||
db.commit()
|
||||
db.refresh(sub)
|
||||
return sub
|
||||
Reference in New Issue
Block a user