Files
orion/app/modules/billing/tests/unit/test_subscription_service.py
Samir Boulahtit 32acc76b49 feat: platform-aware storefront routing and billing improvements
Overhaul storefront URL routing to be platform-aware:
- Dev: /platforms/{code}/storefront/{store_code}/
- Prod: subdomain.platform.lu/ (internally rewritten to /storefront/)
- Add subdomain detection in PlatformContextMiddleware
- Add /storefront/ path rewrite for prod mode (subdomain/custom domain)
- Remove all silent platform fallbacks (platform_id=1)
- Add require_platform dependency for clean endpoint validation
- Update route registration, templates, module definitions, base_url calc
- Update StoreContextMiddleware for /storefront/ path detection
- Remove /stores/ from FrontendDetector STOREFRONT_PATH_PREFIXES

Billing service improvements:
- Add store_platform_sync_service to keep store_platforms in sync
- Make tier lookups platform-aware across billing services
- Add tiers for all platforms in seed data
- Add demo subscriptions to seed

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 23:42:41 +01:00

582 lines
20 KiB
Python

# app/modules/billing/tests/unit/test_subscription_service.py
"""Unit tests for SubscriptionService (merchant-level subscription management)."""
from datetime import UTC, datetime, timedelta
import pytest
from app.modules.billing.exceptions import SubscriptionNotFoundException
from app.modules.billing.models import (
MerchantSubscription,
SubscriptionStatus,
SubscriptionTier,
)
from app.modules.billing.services.subscription_service import SubscriptionService
# ============================================================================
# Tier Information
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceTierLookup:
"""Tests for tier lookup methods."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_tier_by_code_returns_tier(self, db, billing_tier_essential):
"""Existing tier code returns the tier object."""
tier = self.service.get_tier_by_code(db, "essential")
assert tier is not None
assert tier.code == "essential"
assert tier.id == billing_tier_essential.id
def test_get_tier_by_code_nonexistent_returns_none(self, db):
"""Nonexistent tier code returns None (does not raise)."""
tier = self.service.get_tier_by_code(db, "nonexistent")
assert tier is None
def test_get_tier_id_returns_id(self, db, billing_tier_essential):
"""get_tier_id returns the integer ID for a valid code."""
tier_id = self.service.get_tier_id(db, "essential")
assert tier_id == billing_tier_essential.id
def test_get_tier_id_nonexistent_returns_none(self, db):
"""get_tier_id returns None for nonexistent code."""
tier_id = self.service.get_tier_id(db, "nonexistent")
assert tier_id is None
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetAllTiers:
"""Tests for get_all_tiers with platform filtering."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_all_tiers_returns_active_public_tiers(self, db, billing_tiers):
"""Returns only active + public tiers, ordered by display_order."""
tiers = self.service.get_all_tiers(db)
assert len(tiers) == 3
assert tiers[0].code == "essential"
assert tiers[1].code == "professional"
assert tiers[2].code == "business"
def test_get_all_tiers_excludes_inactive(self, db, billing_tiers):
"""Inactive tiers are excluded."""
# Deactivate one
billing_tiers[0].is_active = False
db.flush()
tiers = self.service.get_all_tiers(db)
assert len(tiers) == 2
assert all(t.code != "essential" for t in tiers)
def test_get_all_tiers_excludes_non_public(self, db, billing_tiers):
"""Non-public tiers are excluded."""
billing_tiers[2].is_public = False
db.flush()
tiers = self.service.get_all_tiers(db)
assert len(tiers) == 2
assert all(t.code != "business" for t in tiers)
def test_get_all_tiers_with_platform_filter(self, db, test_platform, billing_tiers):
"""Platform filter returns platform-specific tiers + global tiers."""
# Make one tier platform-specific
billing_tiers[0].platform_id = test_platform.id
# billing_tiers[1] and [2] remain global (platform_id=NULL)
db.flush()
tiers = self.service.get_all_tiers(db, platform_id=test_platform.id)
codes = {t.code for t in tiers}
assert "essential" in codes # Platform-specific
assert "professional" in codes # Global
assert "business" in codes # Global
def test_get_all_tiers_platform_filter_excludes_other_platform(
self, db, test_platform, another_platform, billing_tiers
):
"""Tiers belonging to another platform are excluded."""
billing_tiers[0].platform_id = another_platform.id
db.flush()
tiers = self.service.get_all_tiers(db, platform_id=test_platform.id)
codes = {t.code for t in tiers}
assert "essential" not in codes
assert len(tiers) == 2
# ============================================================================
# Merchant Subscription CRUD
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetMerchantSubscription:
"""Tests for subscription retrieval methods."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_merchant_subscription_found(
self, db, billing_subscription
):
"""Returns subscription when it exists."""
sub = self.service.get_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub is not None
assert sub.id == billing_subscription.id
def test_get_merchant_subscription_not_found(self, db):
"""Returns None when no subscription exists."""
sub = self.service.get_merchant_subscription(db, 99999, 99999)
assert sub is None
def test_get_merchant_subscription_eager_loads_tier(
self, db, billing_subscription
):
"""Subscription comes with tier relationship loaded."""
sub = self.service.get_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
# Accessing .tier should not trigger a lazy load error
assert sub.tier is not None
assert sub.tier.code == "essential"
def test_get_merchant_subscription_eager_loads_platform(
self, db, billing_subscription
):
"""Subscription comes with platform relationship loaded."""
sub = self.service.get_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.platform is not None
def test_get_merchant_subscriptions_all(
self, db, test_merchant, test_platform, another_platform, billing_tier_essential
):
"""Returns all subscriptions for a merchant across platforms."""
now = datetime.now(UTC)
sub1 = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=test_platform.id,
tier_id=billing_tier_essential.id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
)
sub2 = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=another_platform.id,
tier_id=billing_tier_essential.id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=now + timedelta(days=14),
)
db.add_all([sub1, sub2])
db.flush()
subs = self.service.get_merchant_subscriptions(db, test_merchant.id)
assert len(subs) == 2
def test_get_merchant_subscriptions_empty(self, db, test_merchant):
"""Returns empty list when merchant has no subscriptions."""
subs = self.service.get_merchant_subscriptions(db, test_merchant.id)
assert subs == []
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetOrRaise:
"""Tests for get_subscription_or_raise."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_subscription_or_raise_found(self, db, billing_subscription):
"""Returns subscription when found."""
sub = self.service.get_subscription_or_raise(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.id == billing_subscription.id
def test_get_subscription_or_raise_not_found(self, db):
"""Raises SubscriptionNotFoundException when not found."""
with pytest.raises(SubscriptionNotFoundException):
self.service.get_subscription_or_raise(db, 99999, 99999)
# ============================================================================
# Create Subscription
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceCreate:
"""Tests for create_merchant_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_create_trial_subscription(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Default creation starts a 14-day trial."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id
)
assert sub.merchant_id == test_merchant.id
assert sub.platform_id == test_platform.id
assert sub.status == SubscriptionStatus.TRIAL.value
assert sub.tier_id == billing_tier_essential.id
assert sub.trial_ends_at is not None
assert sub.is_annual is False
def test_create_subscription_period_end_matches_trial(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Period end is set to trial_days from now."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, trial_days=7
)
expected_end = sub.period_start + timedelta(days=7)
# Allow 1 second tolerance
assert abs((sub.period_end - expected_end).total_seconds()) < 1
def test_create_annual_subscription(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Annual subscription with no trial gets 365-day period."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, trial_days=0, is_annual=True
)
assert sub.status == SubscriptionStatus.ACTIVE.value
assert sub.is_annual is True
assert sub.trial_ends_at is None
expected_end = sub.period_start + timedelta(days=365)
assert abs((sub.period_end - expected_end).total_seconds()) < 1
def test_create_monthly_subscription(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Monthly subscription with no trial gets 30-day period."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, trial_days=0, is_annual=False
)
assert sub.status == SubscriptionStatus.ACTIVE.value
assert sub.is_annual is False
assert sub.trial_ends_at is None
expected_end = sub.period_start + timedelta(days=30)
assert abs((sub.period_end - expected_end).total_seconds()) < 1
def test_create_subscription_custom_tier(
self, db, test_merchant, test_platform, billing_tiers
):
"""Can specify a custom tier code."""
sub = self.service.create_merchant_subscription(
db, test_merchant.id, test_platform.id, tier_code="professional"
)
assert sub.tier_id == billing_tiers[1].id
def test_create_duplicate_raises(
self, db, billing_subscription
):
"""Creating a second subscription for same merchant+platform raises ValueError."""
with pytest.raises(ValueError, match="already has a subscription"):
self.service.create_merchant_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceGetOrCreate:
"""Tests for get_or_create_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_get_or_create_returns_existing(self, db, billing_subscription):
"""Returns existing subscription without creating a new one."""
sub = self.service.get_or_create_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.id == billing_subscription.id
def test_get_or_create_creates_new(
self, db, test_merchant, test_platform, billing_tier_essential
):
"""Creates new subscription when none exists."""
sub = self.service.get_or_create_subscription(
db, test_merchant.id, test_platform.id
)
assert sub is not None
assert sub.merchant_id == test_merchant.id
assert sub.status == SubscriptionStatus.TRIAL.value
# ============================================================================
# Upgrade / Cancel / Reactivate
# ============================================================================
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceUpgradeTier:
"""Tests for upgrade_tier."""
def setup_method(self):
self.service = SubscriptionService()
def test_upgrade_tier_changes_tier_id(self, db, billing_subscription, billing_tiers):
"""Tier is updated to the new tier's ID."""
sub = self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"professional",
)
assert sub.tier_id == billing_tiers[1].id
def test_upgrade_from_trial_activates(self, db, billing_subscription, billing_tiers):
"""Upgrading from trial status sets status to active."""
billing_subscription.status = SubscriptionStatus.TRIAL.value
db.flush()
sub = self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"professional",
)
assert sub.status == SubscriptionStatus.ACTIVE.value
def test_upgrade_active_stays_active(self, db, billing_subscription, billing_tiers):
"""Upgrading an already-active subscription keeps status active."""
billing_subscription.status = SubscriptionStatus.ACTIVE.value
db.flush()
sub = self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"professional",
)
assert sub.status == SubscriptionStatus.ACTIVE.value
def test_upgrade_nonexistent_tier_raises(self, db, billing_subscription):
"""Upgrading to a nonexistent tier raises ValueError."""
with pytest.raises(ValueError, match="not found"):
self.service.upgrade_tier(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
"nonexistent_tier",
)
def test_upgrade_no_subscription_raises(self, db):
"""Upgrading when no subscription exists raises SubscriptionNotFoundException."""
with pytest.raises(SubscriptionNotFoundException):
self.service.upgrade_tier(db, 99999, 99999, "professional")
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceCancel:
"""Tests for cancel_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_cancel_sets_status(self, db, billing_subscription):
"""Cancellation sets status to cancelled."""
sub = self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.status == SubscriptionStatus.CANCELLED.value
def test_cancel_records_timestamp(self, db, billing_subscription):
"""Cancellation records cancelled_at."""
sub = self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.cancelled_at is not None
def test_cancel_with_reason(self, db, billing_subscription):
"""Cancellation stores the reason."""
sub = self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
reason="Too expensive",
)
assert sub.cancellation_reason == "Too expensive"
def test_cancel_no_subscription_raises(self, db):
"""Cancelling when no subscription exists raises."""
with pytest.raises(SubscriptionNotFoundException):
self.service.cancel_subscription(db, 99999, 99999)
@pytest.mark.unit
@pytest.mark.billing
class TestSubscriptionServiceReactivate:
"""Tests for reactivate_subscription."""
def setup_method(self):
self.service = SubscriptionService()
def test_reactivate_sets_active(self, db, billing_subscription):
"""Reactivation sets status back to active."""
# First cancel
self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
sub = self.service.reactivate_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.status == SubscriptionStatus.ACTIVE.value
def test_reactivate_clears_cancellation(self, db, billing_subscription):
"""Reactivation clears cancelled_at and cancellation_reason."""
self.service.cancel_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
reason="Testing",
)
sub = self.service.reactivate_subscription(
db,
billing_subscription.merchant_id,
billing_subscription.platform_id,
)
assert sub.cancelled_at is None
assert sub.cancellation_reason is None
def test_reactivate_no_subscription_raises(self, db):
"""Reactivating when no subscription exists raises."""
with pytest.raises(SubscriptionNotFoundException):
self.service.reactivate_subscription(db, 99999, 99999)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def billing_tier_essential(db, test_platform):
"""Create essential subscription tier."""
tier = SubscriptionTier(
code="essential",
name="Essential",
description="Essential plan",
price_monthly_cents=4900,
price_annual_cents=49000,
display_order=1,
is_active=True,
is_public=True,
platform_id=test_platform.id,
)
db.add(tier)
db.commit()
db.refresh(tier)
return tier
@pytest.fixture
def billing_tiers(db, test_platform):
"""Create essential, professional, and business tiers."""
tiers = [
SubscriptionTier(
code="essential",
name="Essential",
price_monthly_cents=4900,
price_annual_cents=49000,
display_order=1,
is_active=True,
is_public=True,
platform_id=test_platform.id,
),
SubscriptionTier(
code="professional",
name="Professional",
price_monthly_cents=9900,
price_annual_cents=99000,
display_order=2,
is_active=True,
is_public=True,
platform_id=test_platform.id,
),
SubscriptionTier(
code="business",
name="Business",
price_monthly_cents=19900,
price_annual_cents=199000,
display_order=3,
is_active=True,
is_public=True,
platform_id=test_platform.id,
),
]
db.add_all(tiers)
db.commit()
for t in tiers:
db.refresh(t)
return tiers
@pytest.fixture
def billing_subscription(db, test_merchant, test_platform, billing_tier_essential):
"""Create an active merchant subscription (essential tier)."""
now = datetime.now(UTC)
sub = MerchantSubscription(
merchant_id=test_merchant.id,
platform_id=test_platform.id,
tier_id=billing_tier_essential.id,
status=SubscriptionStatus.ACTIVE.value,
period_start=now,
period_end=now + timedelta(days=30),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub