test(billing): add integration route tests for all billing API endpoints

68 route tests covering admin, merchant, store, and platform billing APIs.
Store tests use real JWT auth (router-level deps can't be overridden);
Stripe-dependent endpoints are mocked at the route module level.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-11 23:31:44 +01:00
parent bf5bb69409
commit b265d0db51
5 changed files with 1583 additions and 0 deletions

View File

@@ -0,0 +1,510 @@
# app/modules/billing/tests/integration/test_admin_routes.py
"""
Integration tests for billing admin API routes.
Tests the admin subscription management endpoints at:
/api/v1/admin/subscriptions/*
Uses super_admin_headers fixture which bypasses module access checks.
"""
import uuid
from datetime import UTC, datetime, timedelta
import pytest
from app.modules.billing.models import (
BillingHistory,
MerchantSubscription,
SubscriptionStatus,
SubscriptionTier,
)
from app.modules.tenancy.models import Merchant, Platform, User
# ============================================================================
# Fixtures
# ============================================================================
BASE = "/api/v1/admin/subscriptions"
@pytest.fixture
def rt_platform(db):
"""Create a platform for route tests."""
platform = Platform(
code=f"test_{uuid.uuid4().hex[:8]}",
name="Test Platform",
is_active=True,
)
db.add(platform)
db.commit()
db.refresh(platform)
return platform
@pytest.fixture
def rt_tiers(db, rt_platform):
"""Create subscription tiers for route tests."""
tiers = []
for i, (code, name, price) in enumerate([
("essential", "Essential", 0),
("professional", "Professional", 2900),
("business", "Business", 7900),
]):
tier = SubscriptionTier(
code=code,
name=name,
description=f"{name} tier",
price_monthly_cents=price,
price_annual_cents=price * 10 if price > 0 else 0,
display_order=i,
is_active=True,
is_public=True,
platform_id=rt_platform.id,
)
db.add(tier)
tiers.append(tier)
db.commit()
for t in tiers:
db.refresh(t)
return tiers
@pytest.fixture
def rt_merchant(db, rt_platform):
"""Create a merchant with owner for route tests."""
from middleware.auth import AuthManager
auth = AuthManager()
owner = User(
email=f"merchant_{uuid.uuid4().hex[:8]}@test.com",
username=f"merchant_{uuid.uuid4().hex[:8]}",
hashed_password=auth.hash_password("pass123"),
role="store",
is_active=True,
)
db.add(owner)
db.commit()
db.refresh(owner)
merchant = Merchant(
name="Route Test Merchant",
owner_user_id=owner.id,
contact_email=owner.email,
is_active=True,
is_verified=True,
)
db.add(merchant)
db.commit()
db.refresh(merchant)
return merchant
@pytest.fixture
def rt_subscription(db, rt_merchant, rt_platform, rt_tiers):
"""Create a subscription for route tests."""
sub = MerchantSubscription(
merchant_id=rt_merchant.id,
platform_id=rt_platform.id,
tier_id=rt_tiers[1].id, # professional
status=SubscriptionStatus.ACTIVE.value,
is_annual=False,
period_start=datetime.now(UTC),
period_end=datetime.now(UTC) + timedelta(days=30),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub
@pytest.fixture
def rt_billing_history(db, rt_merchant):
"""Create billing history entries for route tests."""
records = []
for i in range(3):
record = BillingHistory(
merchant_id=rt_merchant.id,
invoice_number=f"INV-{1000 + i}",
invoice_date=datetime.now(UTC) - timedelta(days=30 * i),
subtotal_cents=2900,
tax_cents=493,
total_cents=3393,
amount_paid_cents=3393,
currency="EUR",
status="paid",
description=f"Invoice {i}",
)
db.add(record)
records.append(record)
db.commit()
for r in records:
db.refresh(r)
return records
# ============================================================================
# Tier Endpoints
# ============================================================================
class TestAdminListTiers:
"""Tests for GET /api/v1/admin/subscriptions/tiers."""
def test_list_tiers_success(self, client, super_admin_headers, rt_tiers):
response = client.get(f"{BASE}/tiers", headers=super_admin_headers)
assert response.status_code == 200
data = response.json()
assert "tiers" in data
assert "total" in data
assert data["total"] >= 3
def test_list_tiers_active_only_by_default(self, client, super_admin_headers, rt_tiers, db):
# Deactivate one tier
rt_tiers[2].is_active = False
db.commit()
response = client.get(f"{BASE}/tiers", headers=super_admin_headers)
assert response.status_code == 200
codes = [t["code"] for t in response.json()["tiers"]]
assert "business" not in codes
def test_list_tiers_include_inactive(self, client, super_admin_headers, rt_tiers, db):
rt_tiers[2].is_active = False
db.commit()
response = client.get(
f"{BASE}/tiers",
params={"include_inactive": True},
headers=super_admin_headers,
)
assert response.status_code == 200
codes = [t["code"] for t in response.json()["tiers"]]
assert "business" in codes
def test_list_tiers_filter_by_platform(self, client, super_admin_headers, rt_tiers, rt_platform):
response = client.get(
f"{BASE}/tiers",
params={"platform_id": rt_platform.id},
headers=super_admin_headers,
)
assert response.status_code == 200
assert response.json()["total"] == 3
def test_list_tiers_unauthorized(self, client):
response = client.get(f"{BASE}/tiers")
assert response.status_code in (401, 403)
class TestAdminGetTier:
"""Tests for GET /api/v1/admin/subscriptions/tiers/{tier_code}."""
def test_get_tier_success(self, client, super_admin_headers, rt_tiers):
response = client.get(
f"{BASE}/tiers/professional", headers=super_admin_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == "professional"
assert data["name"] == "Professional"
assert data["price_monthly_cents"] == 2900
def test_get_tier_not_found(self, client, super_admin_headers):
response = client.get(
f"{BASE}/tiers/nonexistent", headers=super_admin_headers
)
assert response.status_code == 404
class TestAdminCreateTier:
"""Tests for POST /api/v1/admin/subscriptions/tiers."""
def test_create_tier_success(self, client, super_admin_headers, rt_platform):
response = client.post(
f"{BASE}/tiers",
json={
"code": "starter",
"name": "Starter",
"description": "Starter plan",
"price_monthly_cents": 990,
"price_annual_cents": 9900,
"display_order": 0,
"is_active": True,
"is_public": True,
"platform_id": rt_platform.id,
},
headers=super_admin_headers,
)
assert response.status_code == 201
data = response.json()
assert data["code"] == "starter"
assert data["price_monthly_cents"] == 990
def test_create_tier_duplicate_code(self, client, super_admin_headers, rt_tiers):
response = client.post(
f"{BASE}/tiers",
json={
"code": "essential",
"name": "Essential Dup",
"price_monthly_cents": 0,
},
headers=super_admin_headers,
)
assert response.status_code in (400, 409, 422)
class TestAdminUpdateTier:
"""Tests for PATCH /api/v1/admin/subscriptions/tiers/{tier_code}."""
def test_update_tier_success(self, client, super_admin_headers, rt_tiers):
response = client.patch(
f"{BASE}/tiers/professional",
json={"name": "Professional Plus", "price_monthly_cents": 3900},
headers=super_admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Professional Plus"
assert data["price_monthly_cents"] == 3900
def test_update_tier_not_found(self, client, super_admin_headers):
response = client.patch(
f"{BASE}/tiers/nonexistent",
json={"name": "Updated"},
headers=super_admin_headers,
)
assert response.status_code == 404
class TestAdminDeleteTier:
"""Tests for DELETE /api/v1/admin/subscriptions/tiers/{tier_code}."""
def test_delete_tier_success(self, client, super_admin_headers, rt_tiers):
response = client.delete(
f"{BASE}/tiers/business", headers=super_admin_headers
)
assert response.status_code == 204
def test_delete_tier_with_active_subs(
self, client, super_admin_headers, rt_subscription, rt_tiers
):
# Try to delete the tier used by rt_subscription (professional)
response = client.delete(
f"{BASE}/tiers/professional", headers=super_admin_headers
)
assert response.status_code in (400, 409, 422)
def test_delete_tier_not_found(self, client, super_admin_headers):
response = client.delete(
f"{BASE}/tiers/nonexistent", headers=super_admin_headers
)
assert response.status_code == 404
# ============================================================================
# Subscription Endpoints
# ============================================================================
class TestAdminListSubscriptions:
"""Tests for GET /api/v1/admin/subscriptions."""
def test_list_subscriptions_success(
self, client, super_admin_headers, rt_subscription
):
response = client.get(f"{BASE}", headers=super_admin_headers)
assert response.status_code == 200
data = response.json()
assert "subscriptions" in data
assert "total" in data
assert data["total"] >= 1
def test_list_subscriptions_pagination(
self, client, super_admin_headers, rt_subscription
):
response = client.get(
f"{BASE}",
params={"page": 1, "per_page": 5},
headers=super_admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["page"] == 1
assert data["per_page"] == 5
def test_list_subscriptions_filter_by_status(
self, client, super_admin_headers, rt_subscription
):
response = client.get(
f"{BASE}",
params={"status": "active"},
headers=super_admin_headers,
)
assert response.status_code == 200
for sub in response.json()["subscriptions"]:
assert sub["status"] == "active"
class TestAdminCreateSubscription:
"""Tests for POST /api/v1/admin/subscriptions/merchants/{id}/platforms/{id}."""
def test_create_subscription_success(
self, client, super_admin_headers, rt_merchant, rt_platform, rt_tiers
):
response = client.post(
f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}",
json={
"merchant_id": rt_merchant.id,
"platform_id": rt_platform.id,
"tier_code": "essential",
"status": "trial",
"trial_days": 14,
"is_annual": False,
},
headers=super_admin_headers,
)
assert response.status_code == 201
data = response.json()
assert data["merchant_id"] == rt_merchant.id
assert data["platform_id"] == rt_platform.id
class TestAdminGetSubscription:
"""Tests for GET /api/v1/admin/subscriptions/merchants/{id}/platforms/{id}."""
def test_get_subscription_success(
self, client, super_admin_headers, rt_subscription, rt_merchant, rt_platform
):
response = client.get(
f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}",
headers=super_admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["merchant_id"] == rt_merchant.id
assert data["status"] == "active"
def test_get_subscription_not_found(
self, client, super_admin_headers, rt_platform
):
response = client.get(
f"{BASE}/merchants/99999/platforms/{rt_platform.id}",
headers=super_admin_headers,
)
assert response.status_code == 404
class TestAdminUpdateSubscription:
"""Tests for PATCH /api/v1/admin/subscriptions/merchants/{id}/platforms/{id}."""
def test_update_subscription_status(
self, client, super_admin_headers, rt_subscription, rt_merchant, rt_platform
):
response = client.patch(
f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}",
json={"status": "past_due"},
headers=super_admin_headers,
)
assert response.status_code == 200
assert response.json()["status"] == "past_due"
def test_update_subscription_tier(
self,
client,
super_admin_headers,
rt_subscription,
rt_merchant,
rt_platform,
rt_tiers,
):
response = client.patch(
f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}",
json={"tier_code": "business"},
headers=super_admin_headers,
)
assert response.status_code == 200
# ============================================================================
# Stats Endpoint
# ============================================================================
class TestAdminStats:
"""Tests for GET /api/v1/admin/subscriptions/stats."""
def test_get_stats_success(
self, client, super_admin_headers, rt_subscription
):
response = client.get(
f"{BASE}/stats", headers=super_admin_headers
)
assert response.status_code == 200
data = response.json()
assert "total_subscriptions" in data
assert "active_count" in data
assert "mrr_cents" in data
assert "arr_cents" in data
assert "tier_distribution" in data
assert data["active_count"] >= 1
def test_get_stats_empty(self, client, super_admin_headers):
response = client.get(
f"{BASE}/stats", headers=super_admin_headers
)
assert response.status_code == 200
data = response.json()
assert data["total_subscriptions"] == 0
# ============================================================================
# Billing History Endpoint
# ============================================================================
class TestAdminBillingHistory:
"""Tests for GET /api/v1/admin/subscriptions/billing/history."""
def test_list_billing_history_success(
self, client, super_admin_headers, rt_billing_history
):
response = client.get(
f"{BASE}/billing/history", headers=super_admin_headers
)
assert response.status_code == 200
data = response.json()
assert "invoices" in data
assert "total" in data
assert data["total"] >= 3
def test_list_billing_history_pagination(
self, client, super_admin_headers, rt_billing_history
):
response = client.get(
f"{BASE}/billing/history",
params={"page": 1, "per_page": 2},
headers=super_admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["per_page"] == 2
assert len(data["invoices"]) <= 2
def test_list_billing_history_filter_by_merchant(
self, client, super_admin_headers, rt_billing_history, rt_merchant
):
response = client.get(
f"{BASE}/billing/history",
params={"merchant_id": rt_merchant.id},
headers=super_admin_headers,
)
assert response.status_code == 200
assert response.json()["total"] == 3
def test_list_billing_history_empty(self, client, super_admin_headers):
response = client.get(
f"{BASE}/billing/history", headers=super_admin_headers
)
assert response.status_code == 200
assert response.json()["total"] == 0

View File

@@ -0,0 +1,433 @@
# app/modules/billing/tests/integration/test_merchant_routes.py
"""
Integration tests for merchant billing API routes.
Tests the merchant portal billing endpoints at:
/api/v1/merchants/billing/*
Authentication: Overrides get_current_merchant_from_cookie_or_header with a
mock that returns a UserContext for the merchant owner.
"""
import uuid
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from app.api.deps import get_current_merchant_from_cookie_or_header
from app.modules.billing.models import (
BillingHistory,
MerchantSubscription,
SubscriptionStatus,
SubscriptionTier,
)
from app.modules.tenancy.models import Merchant, Platform, User
from main import app
from models.schema.auth import UserContext
# ============================================================================
# Fixtures
# ============================================================================
BASE = "/api/v1/merchants/billing"
@pytest.fixture
def merch_owner(db):
"""Create a merchant owner user."""
from middleware.auth import AuthManager
auth = AuthManager()
user = User(
email=f"merchowner_{uuid.uuid4().hex[:8]}@test.com",
username=f"merchowner_{uuid.uuid4().hex[:8]}",
hashed_password=auth.hash_password("merchpass123"),
role="store",
is_active=True,
)
db.add(user)
db.commit()
db.refresh(user)
return user
@pytest.fixture
def merch_platform(db):
"""Create a platform for merchant tests."""
platform = Platform(
code=f"merch_{uuid.uuid4().hex[:8]}",
name="Merchant Test Platform",
is_active=True,
)
db.add(platform)
db.commit()
db.refresh(platform)
return platform
@pytest.fixture
def merch_merchant(db, merch_owner):
"""Create a merchant owned by merch_owner."""
merchant = Merchant(
name="Merchant Route Test",
owner_user_id=merch_owner.id,
contact_email=merch_owner.email,
is_active=True,
is_verified=True,
)
db.add(merchant)
db.commit()
db.refresh(merchant)
return merchant
@pytest.fixture
def merch_tiers(db, merch_platform):
"""Create tiers for merchant tests."""
tiers = []
for i, (code, name, price) in enumerate([
("essential", "Essential", 0),
("professional", "Professional", 2900),
("business", "Business", 7900),
]):
tier = SubscriptionTier(
code=code,
name=name,
description=f"{name} tier",
price_monthly_cents=price,
price_annual_cents=price * 10 if price > 0 else 0,
display_order=i,
is_active=True,
is_public=True,
platform_id=merch_platform.id,
)
db.add(tier)
tiers.append(tier)
db.commit()
for t in tiers:
db.refresh(t)
return tiers
@pytest.fixture
def merch_subscription(db, merch_merchant, merch_platform, merch_tiers):
"""Create a subscription for the merchant."""
sub = MerchantSubscription(
merchant_id=merch_merchant.id,
platform_id=merch_platform.id,
tier_id=merch_tiers[1].id, # professional
status=SubscriptionStatus.ACTIVE.value,
is_annual=False,
period_start=datetime.now(UTC),
period_end=datetime.now(UTC) + timedelta(days=30),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub
@pytest.fixture
def merch_invoices(db, merch_merchant):
"""Create invoice records for the merchant."""
records = []
for i in range(3):
record = BillingHistory(
merchant_id=merch_merchant.id,
invoice_number=f"MINV-{2000 + i}",
invoice_date=datetime.now(UTC) - timedelta(days=30 * i),
subtotal_cents=2900,
tax_cents=493,
total_cents=3393,
amount_paid_cents=3393,
currency="EUR",
status="paid",
description=f"Merchant invoice {i}",
)
db.add(record)
records.append(record)
db.commit()
for r in records:
db.refresh(r)
return records
@pytest.fixture
def merch_auth_headers(merch_owner, merch_merchant):
"""Override auth dependency to return a UserContext for the merchant owner."""
user_context = UserContext(
id=merch_owner.id,
email=merch_owner.email,
username=merch_owner.username,
role="store",
is_active=True,
)
def _override():
return user_context
app.dependency_overrides[get_current_merchant_from_cookie_or_header] = _override
yield {"Authorization": "Bearer fake-token"}
if get_current_merchant_from_cookie_or_header in app.dependency_overrides:
del app.dependency_overrides[get_current_merchant_from_cookie_or_header]
# ============================================================================
# Subscription Endpoints
# ============================================================================
class TestMerchantListSubscriptions:
"""Tests for GET /api/v1/merchants/billing/subscriptions."""
def test_list_subscriptions_success(
self, client, merch_auth_headers, merch_subscription, merch_merchant
):
response = client.get(
f"{BASE}/subscriptions", headers=merch_auth_headers
)
assert response.status_code == 200
data = response.json()
assert "subscriptions" in data
assert "total" in data
assert data["total"] >= 1
def test_list_subscriptions_includes_tier_info(
self, client, merch_auth_headers, merch_subscription
):
response = client.get(
f"{BASE}/subscriptions", headers=merch_auth_headers
)
assert response.status_code == 200
sub = response.json()["subscriptions"][0]
assert "tier" in sub
assert "tier_name" in sub
assert sub["tier"] == "professional"
def test_list_subscriptions_empty(
self, client, merch_auth_headers, merch_merchant
):
response = client.get(
f"{BASE}/subscriptions", headers=merch_auth_headers
)
assert response.status_code == 200
assert response.json()["total"] == 0
class TestMerchantGetSubscription:
"""Tests for GET /api/v1/merchants/billing/subscriptions/{platform_id}."""
def test_get_subscription_success(
self,
client,
merch_auth_headers,
merch_subscription,
merch_platform,
):
response = client.get(
f"{BASE}/subscriptions/{merch_platform.id}",
headers=merch_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "subscription" in data
assert "tier" in data
assert data["subscription"]["status"] == "active"
def test_get_subscription_with_tier_details(
self,
client,
merch_auth_headers,
merch_subscription,
merch_platform,
):
response = client.get(
f"{BASE}/subscriptions/{merch_platform.id}",
headers=merch_auth_headers,
)
assert response.status_code == 200
tier = response.json()["tier"]
assert tier is not None
assert tier["code"] == "professional"
assert tier["price_monthly_cents"] == 2900
def test_get_subscription_not_found(
self, client, merch_auth_headers, merch_merchant
):
response = client.get(
f"{BASE}/subscriptions/99999",
headers=merch_auth_headers,
)
assert response.status_code == 404
class TestMerchantGetAvailableTiers:
"""Tests for GET /api/v1/merchants/billing/subscriptions/{platform_id}/tiers."""
def test_get_tiers_success(
self,
client,
merch_auth_headers,
merch_subscription,
merch_platform,
merch_tiers,
):
response = client.get(
f"{BASE}/subscriptions/{merch_platform.id}/tiers",
headers=merch_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "tiers" in data
assert "current_tier" in data
assert data["current_tier"] == "professional"
def test_get_tiers_includes_upgrade_info(
self,
client,
merch_auth_headers,
merch_subscription,
merch_platform,
merch_tiers,
):
response = client.get(
f"{BASE}/subscriptions/{merch_platform.id}/tiers",
headers=merch_auth_headers,
)
assert response.status_code == 200
tiers = response.json()["tiers"]
assert len(tiers) >= 3
class TestMerchantChangeTier:
"""Tests for POST /api/v1/merchants/billing/subscriptions/{platform_id}/change-tier."""
@patch("app.modules.billing.routes.api.merchant.billing_service")
def test_change_tier_success(
self,
mock_billing,
client,
merch_auth_headers,
merch_subscription,
merch_platform,
merch_tiers,
):
mock_billing.change_tier.return_value = {
"message": "Tier changed to business",
"new_tier": "business",
"effective_immediately": True,
}
response = client.post(
f"{BASE}/subscriptions/{merch_platform.id}/change-tier",
json={"tier_code": "business", "is_annual": False},
headers=merch_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "new_tier" in data or "message" in data
mock_billing.change_tier.assert_called_once()
def test_change_tier_no_stripe_returns_error(
self,
client,
merch_auth_headers,
merch_subscription,
merch_platform,
):
"""Without Stripe subscription, change_tier returns 400."""
response = client.post(
f"{BASE}/subscriptions/{merch_platform.id}/change-tier",
json={"tier_code": "business", "is_annual": False},
headers=merch_auth_headers,
)
assert response.status_code == 400
class TestMerchantCheckout:
"""Tests for POST /api/v1/merchants/billing/subscriptions/{platform_id}/checkout."""
@patch("app.modules.billing.routes.api.merchant.billing_service")
def test_create_checkout_with_stripe(
self,
mock_billing,
client,
merch_auth_headers,
merch_subscription,
merch_platform,
merch_tiers,
):
mock_billing.create_checkout_session.return_value = {
"checkout_url": "https://checkout.stripe.com/test",
"session_id": "cs_test_123",
}
response = client.post(
f"{BASE}/subscriptions/{merch_platform.id}/checkout",
json={"tier_code": "business", "is_annual": False},
headers=merch_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["checkout_url"] == "https://checkout.stripe.com/test"
assert data["session_id"] == "cs_test_123"
mock_billing.create_checkout_session.assert_called_once()
# ============================================================================
# Invoice Endpoints
# ============================================================================
class TestMerchantInvoices:
"""Tests for GET /api/v1/merchants/billing/invoices."""
def test_list_invoices_success(
self, client, merch_auth_headers, merch_invoices
):
response = client.get(
f"{BASE}/invoices", headers=merch_auth_headers
)
assert response.status_code == 200
data = response.json()
assert "invoices" in data
assert "total" in data
assert data["total"] >= 3
def test_list_invoices_pagination(
self, client, merch_auth_headers, merch_invoices
):
response = client.get(
f"{BASE}/invoices",
params={"skip": 0, "limit": 2},
headers=merch_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert len(data["invoices"]) <= 2
def test_list_invoices_response_shape(
self, client, merch_auth_headers, merch_invoices
):
response = client.get(
f"{BASE}/invoices", headers=merch_auth_headers
)
assert response.status_code == 200
inv = response.json()["invoices"][0]
assert "id" in inv
assert "invoice_number" in inv
assert "invoice_date" in inv
assert "total_cents" in inv
assert "currency" in inv
assert "status" in inv
def test_list_invoices_empty(
self, client, merch_auth_headers, merch_merchant
):
response = client.get(
f"{BASE}/invoices", headers=merch_auth_headers
)
assert response.status_code == 200
assert response.json()["total"] == 0

View File

@@ -0,0 +1,226 @@
# app/modules/billing/tests/integration/test_platform_routes.py
"""
Integration tests for platform pricing API routes (public, no auth).
Tests the public pricing endpoints at:
/api/v1/platform/pricing/*
These are unauthenticated endpoints used by the marketing site and signup flow.
"""
import uuid
import pytest
from app.modules.billing.models import (
AddOnProduct,
SubscriptionTier,
TierCode,
)
from app.modules.tenancy.models import Platform
# ============================================================================
# Fixtures
# ============================================================================
BASE = "/api/v1/platform/pricing"
@pytest.fixture
def pricing_platform(db):
"""Create a platform for pricing tests."""
platform = Platform(
code=f"pricing_{uuid.uuid4().hex[:8]}",
name="Pricing Test Platform",
is_active=True,
)
db.add(platform)
db.commit()
db.refresh(platform)
return platform
@pytest.fixture
def pricing_tiers(db, pricing_platform):
"""Create public subscription tiers for pricing display."""
tiers = []
for i, (code, name, price_m, price_a) in enumerate([
(TierCode.ESSENTIAL.value, "Essential", 0, 0),
(TierCode.PROFESSIONAL.value, "Professional", 2900, 29000),
(TierCode.BUSINESS.value, "Business", 7900, 79000),
(TierCode.ENTERPRISE.value, "Enterprise", 19900, 199000),
]):
tier = SubscriptionTier(
code=code,
name=name,
description=f"{name} plan for growing businesses",
price_monthly_cents=price_m,
price_annual_cents=price_a,
display_order=i,
is_active=True,
is_public=True,
platform_id=pricing_platform.id,
)
db.add(tier)
tiers.append(tier)
db.commit()
for t in tiers:
db.refresh(t)
return tiers
@pytest.fixture
def pricing_addons(db):
"""Create add-on products for pricing display."""
addons = []
for code, name, cat, price in [
("custom_domain", "Custom Domain", "hosting", 499),
("extra_products", "Extra Products Pack", "capacity", 999),
]:
addon = AddOnProduct(
code=code,
name=name,
description=f"{name} add-on",
category=cat,
price_cents=price,
billing_period="monthly",
is_active=True,
)
db.add(addon)
addons.append(addon)
db.commit()
for a in addons:
db.refresh(a)
return addons
# ============================================================================
# GET /pricing/tiers
# ============================================================================
class TestPlatformGetTiers:
"""Tests for GET /api/v1/platform/pricing/tiers."""
def test_get_tiers_success(self, client, pricing_tiers):
response = client.get(f"{BASE}/tiers")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 4
def test_get_tiers_response_shape(self, client, pricing_tiers):
response = client.get(f"{BASE}/tiers")
assert response.status_code == 200
tier = response.json()[0]
assert "code" in tier
assert "name" in tier
assert "price_monthly" in tier
assert "price_monthly_cents" in tier
assert "feature_codes" in tier
def test_get_tiers_excludes_inactive(self, client, pricing_tiers, db):
pricing_tiers[3].is_active = False
db.commit()
response = client.get(f"{BASE}/tiers")
assert response.status_code == 200
codes = [t["code"] for t in response.json()]
assert TierCode.ENTERPRISE.value not in codes
def test_get_tiers_popular_flag(self, client, pricing_tiers):
response = client.get(f"{BASE}/tiers")
assert response.status_code == 200
for tier in response.json():
if tier["code"] == TierCode.PROFESSIONAL.value:
assert tier["is_popular"] is True
elif tier["code"] == TierCode.ENTERPRISE.value:
assert tier["is_enterprise"] is True
def test_get_tiers_empty(self, client):
response = client.get(f"{BASE}/tiers")
assert response.status_code == 200
assert response.json() == []
# ============================================================================
# GET /pricing/tiers/{tier_code}
# ============================================================================
class TestPlatformGetTier:
"""Tests for GET /api/v1/platform/pricing/tiers/{tier_code}."""
def test_get_tier_success(self, client, pricing_tiers):
response = client.get(f"{BASE}/tiers/{TierCode.PROFESSIONAL.value}")
assert response.status_code == 200
data = response.json()
assert data["code"] == TierCode.PROFESSIONAL.value
assert data["price_monthly_cents"] == 2900
assert data["price_annual_cents"] == 29000
assert data["price_monthly"] == 29.0
def test_get_tier_not_found(self, client):
response = client.get(f"{BASE}/tiers/nonexistent")
assert response.status_code == 404
# ============================================================================
# GET /pricing/addons
# ============================================================================
class TestPlatformGetAddons:
"""Tests for GET /api/v1/platform/pricing/addons."""
def test_get_addons_success(self, client, pricing_addons):
response = client.get(f"{BASE}/addons")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) >= 2
def test_get_addons_response_shape(self, client, pricing_addons):
response = client.get(f"{BASE}/addons")
assert response.status_code == 200
addon = response.json()[0]
assert "code" in addon
assert "name" in addon
assert "price" in addon
assert "price_cents" in addon
assert "category" in addon
def test_get_addons_empty(self, client):
response = client.get(f"{BASE}/addons")
assert response.status_code == 200
assert response.json() == []
# ============================================================================
# GET /pricing (full pricing page)
# ============================================================================
class TestPlatformGetPricing:
"""Tests for GET /api/v1/platform/pricing."""
def test_get_pricing_success(self, client, pricing_tiers, pricing_addons):
response = client.get(f"{BASE}")
assert response.status_code == 200
data = response.json()
assert "tiers" in data
assert "addons" in data
assert "trial_days" in data
assert "annual_discount_months" in data
assert len(data["tiers"]) >= 4
assert len(data["addons"]) >= 2
assert data["annual_discount_months"] == 2
def test_get_pricing_empty_db(self, client):
response = client.get(f"{BASE}")
assert response.status_code == 200
data = response.json()
assert data["tiers"] == []
assert data["addons"] == []
assert "trial_days" in data

View File

@@ -0,0 +1,414 @@
# app/modules/billing/tests/integration/test_store_routes.py
"""
Integration tests for store billing API routes.
Tests the store billing endpoints at:
/api/v1/store/billing/*
Authentication: Uses real JWT tokens via store login endpoint.
The store routes require `require_module_access("billing", FrontendType.STORE)`
as a router-level dependency, which calls auth functions directly — so
dependency overrides don't work. Real tokens are needed.
"""
import uuid
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from app.modules.billing.models import (
BillingHistory,
MerchantSubscription,
SubscriptionStatus,
SubscriptionTier,
)
from app.modules.tenancy.models import Merchant, Platform, Store, User
from app.modules.tenancy.models.store import StoreUser, StoreUserType
from app.modules.tenancy.models.store_platform import StorePlatform
# ============================================================================
# Fixtures
# ============================================================================
BASE = "/api/v1/store/billing"
@pytest.fixture
def store_platform(db):
"""Create a platform for store tests."""
platform = Platform(
code=f"store_{uuid.uuid4().hex[:8]}",
name="Store Test Platform",
is_active=True,
)
db.add(platform)
db.commit()
db.refresh(platform)
return platform
@pytest.fixture
def store_full_setup(db, store_platform):
"""
Create the full store setup: user, merchant, store, StoreUser, StorePlatform.
Returns a dict with all created objects for easy access.
"""
from middleware.auth import AuthManager
auth = AuthManager()
uid = uuid.uuid4().hex[:8]
# 1. Create store owner user
owner = User(
email=f"storeowner_{uid}@test.com",
username=f"storeowner_{uid}",
hashed_password=auth.hash_password("storepass123"),
role="store",
is_active=True,
)
db.add(owner)
db.commit()
db.refresh(owner)
# 2. Create merchant
merchant = Merchant(
name=f"Store Merchant {uid}",
owner_user_id=owner.id,
contact_email=owner.email,
is_active=True,
is_verified=True,
)
db.add(merchant)
db.commit()
db.refresh(merchant)
# 3. Create store
store = Store(
merchant_id=merchant.id,
store_code=f"STORETEST_{uid.upper()}",
subdomain=f"storetest{uid}",
name=f"Store Test {uid}",
is_active=True,
is_verified=True,
)
db.add(store)
db.commit()
db.refresh(store)
# 4. Create StoreUser (owner association)
store_user = StoreUser(
store_id=store.id,
user_id=owner.id,
user_type=StoreUserType.OWNER.value,
is_active=True,
)
db.add(store_user)
db.commit()
# 5. Link store to platform
sp = StorePlatform(
store_id=store.id,
platform_id=store_platform.id,
)
db.add(sp)
db.commit()
return {
"owner": owner,
"merchant": merchant,
"store": store,
"platform": store_platform,
}
@pytest.fixture
def store_auth_headers(client, store_full_setup):
"""
Get real JWT auth headers by logging in via store auth endpoint.
Uses the store login flow which generates a JWT with store context
(token_store_id, token_store_code, token_store_role).
"""
owner = store_full_setup["owner"]
response = client.post(
"/api/v1/store/auth/login",
json={
"email_or_username": owner.username,
"password": "storepass123",
},
)
assert response.status_code == 200, f"Store login failed: {response.text}"
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def store_tiers(db, store_platform):
"""Create tiers for store tests."""
tiers = []
for i, (code, name, price) in enumerate([
("essential", "Essential", 0),
("professional", "Professional", 2900),
("business", "Business", 7900),
]):
tier = SubscriptionTier(
code=code,
name=name,
description=f"{name} tier",
price_monthly_cents=price,
price_annual_cents=price * 10 if price > 0 else 0,
display_order=i,
is_active=True,
is_public=True,
platform_id=store_platform.id,
)
db.add(tier)
tiers.append(tier)
db.commit()
for t in tiers:
db.refresh(t)
return tiers
@pytest.fixture
def store_subscription(db, store_full_setup, store_tiers):
"""Create a subscription for the store's merchant."""
sub = MerchantSubscription(
merchant_id=store_full_setup["merchant"].id,
platform_id=store_full_setup["platform"].id,
tier_id=store_tiers[1].id, # professional
status=SubscriptionStatus.ACTIVE.value,
is_annual=False,
period_start=datetime.now(UTC),
period_end=datetime.now(UTC) + timedelta(days=30),
)
db.add(sub)
db.commit()
db.refresh(sub)
return sub
@pytest.fixture
def store_invoices(db, store_full_setup):
"""Create invoice records for the store's merchant."""
records = []
for i in range(2):
record = BillingHistory(
merchant_id=store_full_setup["merchant"].id,
invoice_number=f"SINV-{3000 + i}",
invoice_date=datetime.now(UTC) - timedelta(days=30 * i),
subtotal_cents=2900,
tax_cents=493,
total_cents=3393,
amount_paid_cents=3393,
currency="EUR",
status="paid",
description=f"Store invoice {i}",
)
db.add(record)
records.append(record)
db.commit()
for r in records:
db.refresh(r)
return records
# ============================================================================
# Core Billing Endpoints
# ============================================================================
class TestStoreGetSubscription:
"""Tests for GET /api/v1/store/billing/subscription."""
def test_get_subscription_success(
self, client, store_auth_headers, store_subscription, store_tiers
):
response = client.get(
f"{BASE}/subscription", headers=store_auth_headers
)
assert response.status_code == 200
data = response.json()
assert "tier_code" in data
assert "tier_name" in data
assert "status" in data
assert data["status"] == "active"
assert data["tier_code"] == "professional"
class TestStoreGetTiers:
"""Tests for GET /api/v1/store/billing/tiers."""
def test_get_tiers_success(
self, client, store_auth_headers, store_subscription, store_tiers
):
response = client.get(f"{BASE}/tiers", headers=store_auth_headers)
assert response.status_code == 200
data = response.json()
assert "tiers" in data
assert "current_tier" in data
assert data["current_tier"] == "professional"
assert len(data["tiers"]) >= 3
class TestStoreGetInvoices:
"""Tests for GET /api/v1/store/billing/invoices."""
def test_get_invoices_success(
self, client, store_auth_headers, store_invoices
):
response = client.get(f"{BASE}/invoices", headers=store_auth_headers)
assert response.status_code == 200
data = response.json()
assert "invoices" in data
assert "total" in data
assert data["total"] >= 2
def test_get_invoices_pagination(
self, client, store_auth_headers, store_invoices
):
response = client.get(
f"{BASE}/invoices",
params={"skip": 0, "limit": 1},
headers=store_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert len(data["invoices"]) <= 1
def test_get_invoices_empty(
self, client, store_auth_headers, store_full_setup
):
response = client.get(f"{BASE}/invoices", headers=store_auth_headers)
assert response.status_code == 200
assert response.json()["total"] == 0
# ============================================================================
# Checkout / Management Endpoints
# ============================================================================
class TestStoreChangeTier:
"""Tests for POST /api/v1/store/billing/change-tier."""
@patch("app.modules.billing.routes.api.store_checkout.billing_service")
def test_change_tier_success(
self, mock_billing, client, store_auth_headers, store_subscription, store_tiers
):
mock_billing.change_tier.return_value = {
"message": "Tier changed to business",
"new_tier": "business",
"effective_immediately": True,
}
response = client.post(
f"{BASE}/change-tier",
json={"tier_code": "business", "is_annual": False},
headers=store_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "new_tier" in data
def test_change_tier_no_stripe_returns_error(
self, client, store_auth_headers, store_subscription
):
"""Without Stripe subscription, change_tier returns 400."""
response = client.post(
f"{BASE}/change-tier",
json={"tier_code": "business", "is_annual": False},
headers=store_auth_headers,
)
assert response.status_code == 400
class TestStoreCancelSubscription:
"""Tests for POST /api/v1/store/billing/cancel."""
@patch("app.modules.billing.routes.api.store_checkout.billing_service")
def test_cancel_success(
self, mock_billing, client, store_auth_headers, store_subscription
):
mock_billing.cancel_subscription.return_value = {
"message": "Subscription cancelled",
"effective_date": (datetime.now(UTC) + timedelta(days=30)).isoformat(),
}
response = client.post(
f"{BASE}/cancel",
json={"reason": "Too expensive", "immediately": False},
headers=store_auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "effective_date" in data
class TestStoreReactivateSubscription:
"""Tests for POST /api/v1/store/billing/reactivate."""
@patch("app.modules.billing.routes.api.store_checkout.billing_service")
def test_reactivate_cancelled(
self, mock_billing, client, store_auth_headers, store_subscription, db
):
store_subscription.status = SubscriptionStatus.CANCELLED.value
store_subscription.cancelled_at = datetime.now(UTC)
db.commit()
mock_billing.reactivate_subscription.return_value = {
"message": "Subscription reactivated",
"status": "active",
}
response = client.post(
f"{BASE}/reactivate", headers=store_auth_headers
)
assert response.status_code == 200
class TestStoreUpcomingInvoice:
"""Tests for GET /api/v1/store/billing/upcoming-invoice."""
@patch("app.modules.billing.routes.api.store_checkout.billing_service")
def test_upcoming_invoice(
self, mock_billing, client, store_auth_headers, store_subscription
):
mock_billing.get_upcoming_invoice.return_value = {
"amount_due_cents": 2900,
"currency": "EUR",
"next_payment_date": None,
"line_items": [],
}
response = client.get(
f"{BASE}/upcoming-invoice", headers=store_auth_headers
)
assert response.status_code == 200
data = response.json()
assert "amount_due_cents" in data
assert "currency" in data
# ============================================================================
# Unauthorized Access
# ============================================================================
class TestStoreUnauthorized:
"""Tests for unauthorized store access."""
def test_subscription_no_auth(self, client):
response = client.get(f"{BASE}/subscription")
assert response.status_code in (401, 403)
def test_tiers_no_auth(self, client):
response = client.get(f"{BASE}/tiers")
assert response.status_code in (401, 403)
def test_invoices_no_auth(self, client):
response = client.get(f"{BASE}/invoices")
assert response.status_code in (401, 403)