diff --git a/app/modules/billing/tests/integration/__init__.py b/app/modules/billing/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/modules/billing/tests/integration/test_admin_routes.py b/app/modules/billing/tests/integration/test_admin_routes.py new file mode 100644 index 00000000..86023941 --- /dev/null +++ b/app/modules/billing/tests/integration/test_admin_routes.py @@ -0,0 +1,510 @@ +# app/modules/billing/tests/integration/test_admin_routes.py +""" +Integration tests for billing admin API routes. + +Tests the admin subscription management endpoints at: + /api/v1/admin/subscriptions/* + +Uses super_admin_headers fixture which bypasses module access checks. +""" + +import uuid +from datetime import UTC, datetime, timedelta + +import pytest + +from app.modules.billing.models import ( + BillingHistory, + MerchantSubscription, + SubscriptionStatus, + SubscriptionTier, +) +from app.modules.tenancy.models import Merchant, Platform, User + + +# ============================================================================ +# Fixtures +# ============================================================================ + +BASE = "/api/v1/admin/subscriptions" + + +@pytest.fixture +def rt_platform(db): + """Create a platform for route tests.""" + platform = Platform( + code=f"test_{uuid.uuid4().hex[:8]}", + name="Test Platform", + is_active=True, + ) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def rt_tiers(db, rt_platform): + """Create subscription tiers for route tests.""" + tiers = [] + for i, (code, name, price) in enumerate([ + ("essential", "Essential", 0), + ("professional", "Professional", 2900), + ("business", "Business", 7900), + ]): + tier = SubscriptionTier( + code=code, + name=name, + description=f"{name} tier", + price_monthly_cents=price, + price_annual_cents=price * 10 if price > 0 else 0, + display_order=i, + is_active=True, + is_public=True, + platform_id=rt_platform.id, + ) + db.add(tier) + tiers.append(tier) + db.commit() + for t in tiers: + db.refresh(t) + return tiers + + +@pytest.fixture +def rt_merchant(db, rt_platform): + """Create a merchant with owner for route tests.""" + from middleware.auth import AuthManager + + auth = AuthManager() + owner = User( + email=f"merchant_{uuid.uuid4().hex[:8]}@test.com", + username=f"merchant_{uuid.uuid4().hex[:8]}", + hashed_password=auth.hash_password("pass123"), + role="store", + is_active=True, + ) + db.add(owner) + db.commit() + db.refresh(owner) + + merchant = Merchant( + name="Route Test Merchant", + owner_user_id=owner.id, + contact_email=owner.email, + is_active=True, + is_verified=True, + ) + db.add(merchant) + db.commit() + db.refresh(merchant) + return merchant + + +@pytest.fixture +def rt_subscription(db, rt_merchant, rt_platform, rt_tiers): + """Create a subscription for route tests.""" + sub = MerchantSubscription( + merchant_id=rt_merchant.id, + platform_id=rt_platform.id, + tier_id=rt_tiers[1].id, # professional + status=SubscriptionStatus.ACTIVE.value, + is_annual=False, + period_start=datetime.now(UTC), + period_end=datetime.now(UTC) + timedelta(days=30), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return sub + + +@pytest.fixture +def rt_billing_history(db, rt_merchant): + """Create billing history entries for route tests.""" + records = [] + for i in range(3): + record = BillingHistory( + merchant_id=rt_merchant.id, + invoice_number=f"INV-{1000 + i}", + invoice_date=datetime.now(UTC) - timedelta(days=30 * i), + subtotal_cents=2900, + tax_cents=493, + total_cents=3393, + amount_paid_cents=3393, + currency="EUR", + status="paid", + description=f"Invoice {i}", + ) + db.add(record) + records.append(record) + db.commit() + for r in records: + db.refresh(r) + return records + + +# ============================================================================ +# Tier Endpoints +# ============================================================================ + + +class TestAdminListTiers: + """Tests for GET /api/v1/admin/subscriptions/tiers.""" + + def test_list_tiers_success(self, client, super_admin_headers, rt_tiers): + response = client.get(f"{BASE}/tiers", headers=super_admin_headers) + assert response.status_code == 200 + data = response.json() + assert "tiers" in data + assert "total" in data + assert data["total"] >= 3 + + def test_list_tiers_active_only_by_default(self, client, super_admin_headers, rt_tiers, db): + # Deactivate one tier + rt_tiers[2].is_active = False + db.commit() + + response = client.get(f"{BASE}/tiers", headers=super_admin_headers) + assert response.status_code == 200 + codes = [t["code"] for t in response.json()["tiers"]] + assert "business" not in codes + + def test_list_tiers_include_inactive(self, client, super_admin_headers, rt_tiers, db): + rt_tiers[2].is_active = False + db.commit() + + response = client.get( + f"{BASE}/tiers", + params={"include_inactive": True}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + codes = [t["code"] for t in response.json()["tiers"]] + assert "business" in codes + + def test_list_tiers_filter_by_platform(self, client, super_admin_headers, rt_tiers, rt_platform): + response = client.get( + f"{BASE}/tiers", + params={"platform_id": rt_platform.id}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + assert response.json()["total"] == 3 + + def test_list_tiers_unauthorized(self, client): + response = client.get(f"{BASE}/tiers") + assert response.status_code in (401, 403) + + +class TestAdminGetTier: + """Tests for GET /api/v1/admin/subscriptions/tiers/{tier_code}.""" + + def test_get_tier_success(self, client, super_admin_headers, rt_tiers): + response = client.get( + f"{BASE}/tiers/professional", headers=super_admin_headers + ) + assert response.status_code == 200 + data = response.json() + assert data["code"] == "professional" + assert data["name"] == "Professional" + assert data["price_monthly_cents"] == 2900 + + def test_get_tier_not_found(self, client, super_admin_headers): + response = client.get( + f"{BASE}/tiers/nonexistent", headers=super_admin_headers + ) + assert response.status_code == 404 + + +class TestAdminCreateTier: + """Tests for POST /api/v1/admin/subscriptions/tiers.""" + + def test_create_tier_success(self, client, super_admin_headers, rt_platform): + response = client.post( + f"{BASE}/tiers", + json={ + "code": "starter", + "name": "Starter", + "description": "Starter plan", + "price_monthly_cents": 990, + "price_annual_cents": 9900, + "display_order": 0, + "is_active": True, + "is_public": True, + "platform_id": rt_platform.id, + }, + headers=super_admin_headers, + ) + assert response.status_code == 201 + data = response.json() + assert data["code"] == "starter" + assert data["price_monthly_cents"] == 990 + + def test_create_tier_duplicate_code(self, client, super_admin_headers, rt_tiers): + response = client.post( + f"{BASE}/tiers", + json={ + "code": "essential", + "name": "Essential Dup", + "price_monthly_cents": 0, + }, + headers=super_admin_headers, + ) + assert response.status_code in (400, 409, 422) + + +class TestAdminUpdateTier: + """Tests for PATCH /api/v1/admin/subscriptions/tiers/{tier_code}.""" + + def test_update_tier_success(self, client, super_admin_headers, rt_tiers): + response = client.patch( + f"{BASE}/tiers/professional", + json={"name": "Professional Plus", "price_monthly_cents": 3900}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Professional Plus" + assert data["price_monthly_cents"] == 3900 + + def test_update_tier_not_found(self, client, super_admin_headers): + response = client.patch( + f"{BASE}/tiers/nonexistent", + json={"name": "Updated"}, + headers=super_admin_headers, + ) + assert response.status_code == 404 + + +class TestAdminDeleteTier: + """Tests for DELETE /api/v1/admin/subscriptions/tiers/{tier_code}.""" + + def test_delete_tier_success(self, client, super_admin_headers, rt_tiers): + response = client.delete( + f"{BASE}/tiers/business", headers=super_admin_headers + ) + assert response.status_code == 204 + + def test_delete_tier_with_active_subs( + self, client, super_admin_headers, rt_subscription, rt_tiers + ): + # Try to delete the tier used by rt_subscription (professional) + response = client.delete( + f"{BASE}/tiers/professional", headers=super_admin_headers + ) + assert response.status_code in (400, 409, 422) + + def test_delete_tier_not_found(self, client, super_admin_headers): + response = client.delete( + f"{BASE}/tiers/nonexistent", headers=super_admin_headers + ) + assert response.status_code == 404 + + +# ============================================================================ +# Subscription Endpoints +# ============================================================================ + + +class TestAdminListSubscriptions: + """Tests for GET /api/v1/admin/subscriptions.""" + + def test_list_subscriptions_success( + self, client, super_admin_headers, rt_subscription + ): + response = client.get(f"{BASE}", headers=super_admin_headers) + assert response.status_code == 200 + data = response.json() + assert "subscriptions" in data + assert "total" in data + assert data["total"] >= 1 + + def test_list_subscriptions_pagination( + self, client, super_admin_headers, rt_subscription + ): + response = client.get( + f"{BASE}", + params={"page": 1, "per_page": 5}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert data["page"] == 1 + assert data["per_page"] == 5 + + def test_list_subscriptions_filter_by_status( + self, client, super_admin_headers, rt_subscription + ): + response = client.get( + f"{BASE}", + params={"status": "active"}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + for sub in response.json()["subscriptions"]: + assert sub["status"] == "active" + + +class TestAdminCreateSubscription: + """Tests for POST /api/v1/admin/subscriptions/merchants/{id}/platforms/{id}.""" + + def test_create_subscription_success( + self, client, super_admin_headers, rt_merchant, rt_platform, rt_tiers + ): + response = client.post( + f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}", + json={ + "merchant_id": rt_merchant.id, + "platform_id": rt_platform.id, + "tier_code": "essential", + "status": "trial", + "trial_days": 14, + "is_annual": False, + }, + headers=super_admin_headers, + ) + assert response.status_code == 201 + data = response.json() + assert data["merchant_id"] == rt_merchant.id + assert data["platform_id"] == rt_platform.id + + +class TestAdminGetSubscription: + """Tests for GET /api/v1/admin/subscriptions/merchants/{id}/platforms/{id}.""" + + def test_get_subscription_success( + self, client, super_admin_headers, rt_subscription, rt_merchant, rt_platform + ): + response = client.get( + f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}", + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert data["merchant_id"] == rt_merchant.id + assert data["status"] == "active" + + def test_get_subscription_not_found( + self, client, super_admin_headers, rt_platform + ): + response = client.get( + f"{BASE}/merchants/99999/platforms/{rt_platform.id}", + headers=super_admin_headers, + ) + assert response.status_code == 404 + + +class TestAdminUpdateSubscription: + """Tests for PATCH /api/v1/admin/subscriptions/merchants/{id}/platforms/{id}.""" + + def test_update_subscription_status( + self, client, super_admin_headers, rt_subscription, rt_merchant, rt_platform + ): + response = client.patch( + f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}", + json={"status": "past_due"}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + assert response.json()["status"] == "past_due" + + def test_update_subscription_tier( + self, + client, + super_admin_headers, + rt_subscription, + rt_merchant, + rt_platform, + rt_tiers, + ): + response = client.patch( + f"{BASE}/merchants/{rt_merchant.id}/platforms/{rt_platform.id}", + json={"tier_code": "business"}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + + +# ============================================================================ +# Stats Endpoint +# ============================================================================ + + +class TestAdminStats: + """Tests for GET /api/v1/admin/subscriptions/stats.""" + + def test_get_stats_success( + self, client, super_admin_headers, rt_subscription + ): + response = client.get( + f"{BASE}/stats", headers=super_admin_headers + ) + assert response.status_code == 200 + data = response.json() + assert "total_subscriptions" in data + assert "active_count" in data + assert "mrr_cents" in data + assert "arr_cents" in data + assert "tier_distribution" in data + assert data["active_count"] >= 1 + + def test_get_stats_empty(self, client, super_admin_headers): + response = client.get( + f"{BASE}/stats", headers=super_admin_headers + ) + assert response.status_code == 200 + data = response.json() + assert data["total_subscriptions"] == 0 + + +# ============================================================================ +# Billing History Endpoint +# ============================================================================ + + +class TestAdminBillingHistory: + """Tests for GET /api/v1/admin/subscriptions/billing/history.""" + + def test_list_billing_history_success( + self, client, super_admin_headers, rt_billing_history + ): + response = client.get( + f"{BASE}/billing/history", headers=super_admin_headers + ) + assert response.status_code == 200 + data = response.json() + assert "invoices" in data + assert "total" in data + assert data["total"] >= 3 + + def test_list_billing_history_pagination( + self, client, super_admin_headers, rt_billing_history + ): + response = client.get( + f"{BASE}/billing/history", + params={"page": 1, "per_page": 2}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert data["per_page"] == 2 + assert len(data["invoices"]) <= 2 + + def test_list_billing_history_filter_by_merchant( + self, client, super_admin_headers, rt_billing_history, rt_merchant + ): + response = client.get( + f"{BASE}/billing/history", + params={"merchant_id": rt_merchant.id}, + headers=super_admin_headers, + ) + assert response.status_code == 200 + assert response.json()["total"] == 3 + + def test_list_billing_history_empty(self, client, super_admin_headers): + response = client.get( + f"{BASE}/billing/history", headers=super_admin_headers + ) + assert response.status_code == 200 + assert response.json()["total"] == 0 diff --git a/app/modules/billing/tests/integration/test_merchant_routes.py b/app/modules/billing/tests/integration/test_merchant_routes.py new file mode 100644 index 00000000..434cf444 --- /dev/null +++ b/app/modules/billing/tests/integration/test_merchant_routes.py @@ -0,0 +1,433 @@ +# app/modules/billing/tests/integration/test_merchant_routes.py +""" +Integration tests for merchant billing API routes. + +Tests the merchant portal billing endpoints at: + /api/v1/merchants/billing/* + +Authentication: Overrides get_current_merchant_from_cookie_or_header with a +mock that returns a UserContext for the merchant owner. +""" + +import uuid +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import pytest + +from app.api.deps import get_current_merchant_from_cookie_or_header +from app.modules.billing.models import ( + BillingHistory, + MerchantSubscription, + SubscriptionStatus, + SubscriptionTier, +) +from app.modules.tenancy.models import Merchant, Platform, User +from main import app +from models.schema.auth import UserContext + + +# ============================================================================ +# Fixtures +# ============================================================================ + +BASE = "/api/v1/merchants/billing" + + +@pytest.fixture +def merch_owner(db): + """Create a merchant owner user.""" + from middleware.auth import AuthManager + + auth = AuthManager() + user = User( + email=f"merchowner_{uuid.uuid4().hex[:8]}@test.com", + username=f"merchowner_{uuid.uuid4().hex[:8]}", + hashed_password=auth.hash_password("merchpass123"), + role="store", + is_active=True, + ) + db.add(user) + db.commit() + db.refresh(user) + return user + + +@pytest.fixture +def merch_platform(db): + """Create a platform for merchant tests.""" + platform = Platform( + code=f"merch_{uuid.uuid4().hex[:8]}", + name="Merchant Test Platform", + is_active=True, + ) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def merch_merchant(db, merch_owner): + """Create a merchant owned by merch_owner.""" + merchant = Merchant( + name="Merchant Route Test", + owner_user_id=merch_owner.id, + contact_email=merch_owner.email, + is_active=True, + is_verified=True, + ) + db.add(merchant) + db.commit() + db.refresh(merchant) + return merchant + + +@pytest.fixture +def merch_tiers(db, merch_platform): + """Create tiers for merchant tests.""" + tiers = [] + for i, (code, name, price) in enumerate([ + ("essential", "Essential", 0), + ("professional", "Professional", 2900), + ("business", "Business", 7900), + ]): + tier = SubscriptionTier( + code=code, + name=name, + description=f"{name} tier", + price_monthly_cents=price, + price_annual_cents=price * 10 if price > 0 else 0, + display_order=i, + is_active=True, + is_public=True, + platform_id=merch_platform.id, + ) + db.add(tier) + tiers.append(tier) + db.commit() + for t in tiers: + db.refresh(t) + return tiers + + +@pytest.fixture +def merch_subscription(db, merch_merchant, merch_platform, merch_tiers): + """Create a subscription for the merchant.""" + sub = MerchantSubscription( + merchant_id=merch_merchant.id, + platform_id=merch_platform.id, + tier_id=merch_tiers[1].id, # professional + status=SubscriptionStatus.ACTIVE.value, + is_annual=False, + period_start=datetime.now(UTC), + period_end=datetime.now(UTC) + timedelta(days=30), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return sub + + +@pytest.fixture +def merch_invoices(db, merch_merchant): + """Create invoice records for the merchant.""" + records = [] + for i in range(3): + record = BillingHistory( + merchant_id=merch_merchant.id, + invoice_number=f"MINV-{2000 + i}", + invoice_date=datetime.now(UTC) - timedelta(days=30 * i), + subtotal_cents=2900, + tax_cents=493, + total_cents=3393, + amount_paid_cents=3393, + currency="EUR", + status="paid", + description=f"Merchant invoice {i}", + ) + db.add(record) + records.append(record) + db.commit() + for r in records: + db.refresh(r) + return records + + +@pytest.fixture +def merch_auth_headers(merch_owner, merch_merchant): + """Override auth dependency to return a UserContext for the merchant owner.""" + user_context = UserContext( + id=merch_owner.id, + email=merch_owner.email, + username=merch_owner.username, + role="store", + is_active=True, + ) + + def _override(): + return user_context + + app.dependency_overrides[get_current_merchant_from_cookie_or_header] = _override + yield {"Authorization": "Bearer fake-token"} + if get_current_merchant_from_cookie_or_header in app.dependency_overrides: + del app.dependency_overrides[get_current_merchant_from_cookie_or_header] + + +# ============================================================================ +# Subscription Endpoints +# ============================================================================ + + +class TestMerchantListSubscriptions: + """Tests for GET /api/v1/merchants/billing/subscriptions.""" + + def test_list_subscriptions_success( + self, client, merch_auth_headers, merch_subscription, merch_merchant + ): + response = client.get( + f"{BASE}/subscriptions", headers=merch_auth_headers + ) + assert response.status_code == 200 + data = response.json() + assert "subscriptions" in data + assert "total" in data + assert data["total"] >= 1 + + def test_list_subscriptions_includes_tier_info( + self, client, merch_auth_headers, merch_subscription + ): + response = client.get( + f"{BASE}/subscriptions", headers=merch_auth_headers + ) + assert response.status_code == 200 + sub = response.json()["subscriptions"][0] + assert "tier" in sub + assert "tier_name" in sub + assert sub["tier"] == "professional" + + def test_list_subscriptions_empty( + self, client, merch_auth_headers, merch_merchant + ): + response = client.get( + f"{BASE}/subscriptions", headers=merch_auth_headers + ) + assert response.status_code == 200 + assert response.json()["total"] == 0 + + +class TestMerchantGetSubscription: + """Tests for GET /api/v1/merchants/billing/subscriptions/{platform_id}.""" + + def test_get_subscription_success( + self, + client, + merch_auth_headers, + merch_subscription, + merch_platform, + ): + response = client.get( + f"{BASE}/subscriptions/{merch_platform.id}", + headers=merch_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert "subscription" in data + assert "tier" in data + assert data["subscription"]["status"] == "active" + + def test_get_subscription_with_tier_details( + self, + client, + merch_auth_headers, + merch_subscription, + merch_platform, + ): + response = client.get( + f"{BASE}/subscriptions/{merch_platform.id}", + headers=merch_auth_headers, + ) + assert response.status_code == 200 + tier = response.json()["tier"] + assert tier is not None + assert tier["code"] == "professional" + assert tier["price_monthly_cents"] == 2900 + + def test_get_subscription_not_found( + self, client, merch_auth_headers, merch_merchant + ): + response = client.get( + f"{BASE}/subscriptions/99999", + headers=merch_auth_headers, + ) + assert response.status_code == 404 + + +class TestMerchantGetAvailableTiers: + """Tests for GET /api/v1/merchants/billing/subscriptions/{platform_id}/tiers.""" + + def test_get_tiers_success( + self, + client, + merch_auth_headers, + merch_subscription, + merch_platform, + merch_tiers, + ): + response = client.get( + f"{BASE}/subscriptions/{merch_platform.id}/tiers", + headers=merch_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert "tiers" in data + assert "current_tier" in data + assert data["current_tier"] == "professional" + + def test_get_tiers_includes_upgrade_info( + self, + client, + merch_auth_headers, + merch_subscription, + merch_platform, + merch_tiers, + ): + response = client.get( + f"{BASE}/subscriptions/{merch_platform.id}/tiers", + headers=merch_auth_headers, + ) + assert response.status_code == 200 + tiers = response.json()["tiers"] + assert len(tiers) >= 3 + + +class TestMerchantChangeTier: + """Tests for POST /api/v1/merchants/billing/subscriptions/{platform_id}/change-tier.""" + + @patch("app.modules.billing.routes.api.merchant.billing_service") + def test_change_tier_success( + self, + mock_billing, + client, + merch_auth_headers, + merch_subscription, + merch_platform, + merch_tiers, + ): + mock_billing.change_tier.return_value = { + "message": "Tier changed to business", + "new_tier": "business", + "effective_immediately": True, + } + + response = client.post( + f"{BASE}/subscriptions/{merch_platform.id}/change-tier", + json={"tier_code": "business", "is_annual": False}, + headers=merch_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert "new_tier" in data or "message" in data + mock_billing.change_tier.assert_called_once() + + def test_change_tier_no_stripe_returns_error( + self, + client, + merch_auth_headers, + merch_subscription, + merch_platform, + ): + """Without Stripe subscription, change_tier returns 400.""" + response = client.post( + f"{BASE}/subscriptions/{merch_platform.id}/change-tier", + json={"tier_code": "business", "is_annual": False}, + headers=merch_auth_headers, + ) + assert response.status_code == 400 + + +class TestMerchantCheckout: + """Tests for POST /api/v1/merchants/billing/subscriptions/{platform_id}/checkout.""" + + @patch("app.modules.billing.routes.api.merchant.billing_service") + def test_create_checkout_with_stripe( + self, + mock_billing, + client, + merch_auth_headers, + merch_subscription, + merch_platform, + merch_tiers, + ): + mock_billing.create_checkout_session.return_value = { + "checkout_url": "https://checkout.stripe.com/test", + "session_id": "cs_test_123", + } + + response = client.post( + f"{BASE}/subscriptions/{merch_platform.id}/checkout", + json={"tier_code": "business", "is_annual": False}, + headers=merch_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert data["checkout_url"] == "https://checkout.stripe.com/test" + assert data["session_id"] == "cs_test_123" + mock_billing.create_checkout_session.assert_called_once() + + +# ============================================================================ +# Invoice Endpoints +# ============================================================================ + + +class TestMerchantInvoices: + """Tests for GET /api/v1/merchants/billing/invoices.""" + + def test_list_invoices_success( + self, client, merch_auth_headers, merch_invoices + ): + response = client.get( + f"{BASE}/invoices", headers=merch_auth_headers + ) + assert response.status_code == 200 + data = response.json() + assert "invoices" in data + assert "total" in data + assert data["total"] >= 3 + + def test_list_invoices_pagination( + self, client, merch_auth_headers, merch_invoices + ): + response = client.get( + f"{BASE}/invoices", + params={"skip": 0, "limit": 2}, + headers=merch_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["invoices"]) <= 2 + + def test_list_invoices_response_shape( + self, client, merch_auth_headers, merch_invoices + ): + response = client.get( + f"{BASE}/invoices", headers=merch_auth_headers + ) + assert response.status_code == 200 + inv = response.json()["invoices"][0] + assert "id" in inv + assert "invoice_number" in inv + assert "invoice_date" in inv + assert "total_cents" in inv + assert "currency" in inv + assert "status" in inv + + def test_list_invoices_empty( + self, client, merch_auth_headers, merch_merchant + ): + response = client.get( + f"{BASE}/invoices", headers=merch_auth_headers + ) + assert response.status_code == 200 + assert response.json()["total"] == 0 diff --git a/app/modules/billing/tests/integration/test_platform_routes.py b/app/modules/billing/tests/integration/test_platform_routes.py new file mode 100644 index 00000000..5d29af77 --- /dev/null +++ b/app/modules/billing/tests/integration/test_platform_routes.py @@ -0,0 +1,226 @@ +# app/modules/billing/tests/integration/test_platform_routes.py +""" +Integration tests for platform pricing API routes (public, no auth). + +Tests the public pricing endpoints at: + /api/v1/platform/pricing/* + +These are unauthenticated endpoints used by the marketing site and signup flow. +""" + +import uuid + +import pytest + +from app.modules.billing.models import ( + AddOnProduct, + SubscriptionTier, + TierCode, +) +from app.modules.tenancy.models import Platform + + +# ============================================================================ +# Fixtures +# ============================================================================ + +BASE = "/api/v1/platform/pricing" + + +@pytest.fixture +def pricing_platform(db): + """Create a platform for pricing tests.""" + platform = Platform( + code=f"pricing_{uuid.uuid4().hex[:8]}", + name="Pricing Test Platform", + is_active=True, + ) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def pricing_tiers(db, pricing_platform): + """Create public subscription tiers for pricing display.""" + tiers = [] + for i, (code, name, price_m, price_a) in enumerate([ + (TierCode.ESSENTIAL.value, "Essential", 0, 0), + (TierCode.PROFESSIONAL.value, "Professional", 2900, 29000), + (TierCode.BUSINESS.value, "Business", 7900, 79000), + (TierCode.ENTERPRISE.value, "Enterprise", 19900, 199000), + ]): + tier = SubscriptionTier( + code=code, + name=name, + description=f"{name} plan for growing businesses", + price_monthly_cents=price_m, + price_annual_cents=price_a, + display_order=i, + is_active=True, + is_public=True, + platform_id=pricing_platform.id, + ) + db.add(tier) + tiers.append(tier) + db.commit() + for t in tiers: + db.refresh(t) + return tiers + + +@pytest.fixture +def pricing_addons(db): + """Create add-on products for pricing display.""" + addons = [] + for code, name, cat, price in [ + ("custom_domain", "Custom Domain", "hosting", 499), + ("extra_products", "Extra Products Pack", "capacity", 999), + ]: + addon = AddOnProduct( + code=code, + name=name, + description=f"{name} add-on", + category=cat, + price_cents=price, + billing_period="monthly", + is_active=True, + ) + db.add(addon) + addons.append(addon) + db.commit() + for a in addons: + db.refresh(a) + return addons + + +# ============================================================================ +# GET /pricing/tiers +# ============================================================================ + + +class TestPlatformGetTiers: + """Tests for GET /api/v1/platform/pricing/tiers.""" + + def test_get_tiers_success(self, client, pricing_tiers): + response = client.get(f"{BASE}/tiers") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) >= 4 + + def test_get_tiers_response_shape(self, client, pricing_tiers): + response = client.get(f"{BASE}/tiers") + assert response.status_code == 200 + tier = response.json()[0] + assert "code" in tier + assert "name" in tier + assert "price_monthly" in tier + assert "price_monthly_cents" in tier + assert "feature_codes" in tier + + def test_get_tiers_excludes_inactive(self, client, pricing_tiers, db): + pricing_tiers[3].is_active = False + db.commit() + + response = client.get(f"{BASE}/tiers") + assert response.status_code == 200 + codes = [t["code"] for t in response.json()] + assert TierCode.ENTERPRISE.value not in codes + + def test_get_tiers_popular_flag(self, client, pricing_tiers): + response = client.get(f"{BASE}/tiers") + assert response.status_code == 200 + for tier in response.json(): + if tier["code"] == TierCode.PROFESSIONAL.value: + assert tier["is_popular"] is True + elif tier["code"] == TierCode.ENTERPRISE.value: + assert tier["is_enterprise"] is True + + def test_get_tiers_empty(self, client): + response = client.get(f"{BASE}/tiers") + assert response.status_code == 200 + assert response.json() == [] + + +# ============================================================================ +# GET /pricing/tiers/{tier_code} +# ============================================================================ + + +class TestPlatformGetTier: + """Tests for GET /api/v1/platform/pricing/tiers/{tier_code}.""" + + def test_get_tier_success(self, client, pricing_tiers): + response = client.get(f"{BASE}/tiers/{TierCode.PROFESSIONAL.value}") + assert response.status_code == 200 + data = response.json() + assert data["code"] == TierCode.PROFESSIONAL.value + assert data["price_monthly_cents"] == 2900 + assert data["price_annual_cents"] == 29000 + assert data["price_monthly"] == 29.0 + + def test_get_tier_not_found(self, client): + response = client.get(f"{BASE}/tiers/nonexistent") + assert response.status_code == 404 + + +# ============================================================================ +# GET /pricing/addons +# ============================================================================ + + +class TestPlatformGetAddons: + """Tests for GET /api/v1/platform/pricing/addons.""" + + def test_get_addons_success(self, client, pricing_addons): + response = client.get(f"{BASE}/addons") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) >= 2 + + def test_get_addons_response_shape(self, client, pricing_addons): + response = client.get(f"{BASE}/addons") + assert response.status_code == 200 + addon = response.json()[0] + assert "code" in addon + assert "name" in addon + assert "price" in addon + assert "price_cents" in addon + assert "category" in addon + + def test_get_addons_empty(self, client): + response = client.get(f"{BASE}/addons") + assert response.status_code == 200 + assert response.json() == [] + + +# ============================================================================ +# GET /pricing (full pricing page) +# ============================================================================ + + +class TestPlatformGetPricing: + """Tests for GET /api/v1/platform/pricing.""" + + def test_get_pricing_success(self, client, pricing_tiers, pricing_addons): + response = client.get(f"{BASE}") + assert response.status_code == 200 + data = response.json() + assert "tiers" in data + assert "addons" in data + assert "trial_days" in data + assert "annual_discount_months" in data + assert len(data["tiers"]) >= 4 + assert len(data["addons"]) >= 2 + assert data["annual_discount_months"] == 2 + + def test_get_pricing_empty_db(self, client): + response = client.get(f"{BASE}") + assert response.status_code == 200 + data = response.json() + assert data["tiers"] == [] + assert data["addons"] == [] + assert "trial_days" in data diff --git a/app/modules/billing/tests/integration/test_store_routes.py b/app/modules/billing/tests/integration/test_store_routes.py new file mode 100644 index 00000000..fb7479e7 --- /dev/null +++ b/app/modules/billing/tests/integration/test_store_routes.py @@ -0,0 +1,414 @@ +# app/modules/billing/tests/integration/test_store_routes.py +""" +Integration tests for store billing API routes. + +Tests the store billing endpoints at: + /api/v1/store/billing/* + +Authentication: Uses real JWT tokens via store login endpoint. +The store routes require `require_module_access("billing", FrontendType.STORE)` +as a router-level dependency, which calls auth functions directly — so +dependency overrides don't work. Real tokens are needed. +""" + +import uuid +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import pytest + +from app.modules.billing.models import ( + BillingHistory, + MerchantSubscription, + SubscriptionStatus, + SubscriptionTier, +) +from app.modules.tenancy.models import Merchant, Platform, Store, User +from app.modules.tenancy.models.store import StoreUser, StoreUserType +from app.modules.tenancy.models.store_platform import StorePlatform + + +# ============================================================================ +# Fixtures +# ============================================================================ + +BASE = "/api/v1/store/billing" + + +@pytest.fixture +def store_platform(db): + """Create a platform for store tests.""" + platform = Platform( + code=f"store_{uuid.uuid4().hex[:8]}", + name="Store Test Platform", + is_active=True, + ) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def store_full_setup(db, store_platform): + """ + Create the full store setup: user, merchant, store, StoreUser, StorePlatform. + + Returns a dict with all created objects for easy access. + """ + from middleware.auth import AuthManager + + auth = AuthManager() + uid = uuid.uuid4().hex[:8] + + # 1. Create store owner user + owner = User( + email=f"storeowner_{uid}@test.com", + username=f"storeowner_{uid}", + hashed_password=auth.hash_password("storepass123"), + role="store", + is_active=True, + ) + db.add(owner) + db.commit() + db.refresh(owner) + + # 2. Create merchant + merchant = Merchant( + name=f"Store Merchant {uid}", + owner_user_id=owner.id, + contact_email=owner.email, + is_active=True, + is_verified=True, + ) + db.add(merchant) + db.commit() + db.refresh(merchant) + + # 3. Create store + store = Store( + merchant_id=merchant.id, + store_code=f"STORETEST_{uid.upper()}", + subdomain=f"storetest{uid}", + name=f"Store Test {uid}", + is_active=True, + is_verified=True, + ) + db.add(store) + db.commit() + db.refresh(store) + + # 4. Create StoreUser (owner association) + store_user = StoreUser( + store_id=store.id, + user_id=owner.id, + user_type=StoreUserType.OWNER.value, + is_active=True, + ) + db.add(store_user) + db.commit() + + # 5. Link store to platform + sp = StorePlatform( + store_id=store.id, + platform_id=store_platform.id, + ) + db.add(sp) + db.commit() + + return { + "owner": owner, + "merchant": merchant, + "store": store, + "platform": store_platform, + } + + +@pytest.fixture +def store_auth_headers(client, store_full_setup): + """ + Get real JWT auth headers by logging in via store auth endpoint. + + Uses the store login flow which generates a JWT with store context + (token_store_id, token_store_code, token_store_role). + """ + owner = store_full_setup["owner"] + response = client.post( + "/api/v1/store/auth/login", + json={ + "email_or_username": owner.username, + "password": "storepass123", + }, + ) + assert response.status_code == 200, f"Store login failed: {response.text}" + token = response.json()["access_token"] + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture +def store_tiers(db, store_platform): + """Create tiers for store tests.""" + tiers = [] + for i, (code, name, price) in enumerate([ + ("essential", "Essential", 0), + ("professional", "Professional", 2900), + ("business", "Business", 7900), + ]): + tier = SubscriptionTier( + code=code, + name=name, + description=f"{name} tier", + price_monthly_cents=price, + price_annual_cents=price * 10 if price > 0 else 0, + display_order=i, + is_active=True, + is_public=True, + platform_id=store_platform.id, + ) + db.add(tier) + tiers.append(tier) + db.commit() + for t in tiers: + db.refresh(t) + return tiers + + +@pytest.fixture +def store_subscription(db, store_full_setup, store_tiers): + """Create a subscription for the store's merchant.""" + sub = MerchantSubscription( + merchant_id=store_full_setup["merchant"].id, + platform_id=store_full_setup["platform"].id, + tier_id=store_tiers[1].id, # professional + status=SubscriptionStatus.ACTIVE.value, + is_annual=False, + period_start=datetime.now(UTC), + period_end=datetime.now(UTC) + timedelta(days=30), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return sub + + +@pytest.fixture +def store_invoices(db, store_full_setup): + """Create invoice records for the store's merchant.""" + records = [] + for i in range(2): + record = BillingHistory( + merchant_id=store_full_setup["merchant"].id, + invoice_number=f"SINV-{3000 + i}", + invoice_date=datetime.now(UTC) - timedelta(days=30 * i), + subtotal_cents=2900, + tax_cents=493, + total_cents=3393, + amount_paid_cents=3393, + currency="EUR", + status="paid", + description=f"Store invoice {i}", + ) + db.add(record) + records.append(record) + db.commit() + for r in records: + db.refresh(r) + return records + + +# ============================================================================ +# Core Billing Endpoints +# ============================================================================ + + +class TestStoreGetSubscription: + """Tests for GET /api/v1/store/billing/subscription.""" + + def test_get_subscription_success( + self, client, store_auth_headers, store_subscription, store_tiers + ): + response = client.get( + f"{BASE}/subscription", headers=store_auth_headers + ) + assert response.status_code == 200 + data = response.json() + assert "tier_code" in data + assert "tier_name" in data + assert "status" in data + assert data["status"] == "active" + assert data["tier_code"] == "professional" + + +class TestStoreGetTiers: + """Tests for GET /api/v1/store/billing/tiers.""" + + def test_get_tiers_success( + self, client, store_auth_headers, store_subscription, store_tiers + ): + response = client.get(f"{BASE}/tiers", headers=store_auth_headers) + assert response.status_code == 200 + data = response.json() + assert "tiers" in data + assert "current_tier" in data + assert data["current_tier"] == "professional" + assert len(data["tiers"]) >= 3 + + +class TestStoreGetInvoices: + """Tests for GET /api/v1/store/billing/invoices.""" + + def test_get_invoices_success( + self, client, store_auth_headers, store_invoices + ): + response = client.get(f"{BASE}/invoices", headers=store_auth_headers) + assert response.status_code == 200 + data = response.json() + assert "invoices" in data + assert "total" in data + assert data["total"] >= 2 + + def test_get_invoices_pagination( + self, client, store_auth_headers, store_invoices + ): + response = client.get( + f"{BASE}/invoices", + params={"skip": 0, "limit": 1}, + headers=store_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["invoices"]) <= 1 + + def test_get_invoices_empty( + self, client, store_auth_headers, store_full_setup + ): + response = client.get(f"{BASE}/invoices", headers=store_auth_headers) + assert response.status_code == 200 + assert response.json()["total"] == 0 + + +# ============================================================================ +# Checkout / Management Endpoints +# ============================================================================ + + +class TestStoreChangeTier: + """Tests for POST /api/v1/store/billing/change-tier.""" + + @patch("app.modules.billing.routes.api.store_checkout.billing_service") + def test_change_tier_success( + self, mock_billing, client, store_auth_headers, store_subscription, store_tiers + ): + mock_billing.change_tier.return_value = { + "message": "Tier changed to business", + "new_tier": "business", + "effective_immediately": True, + } + response = client.post( + f"{BASE}/change-tier", + json={"tier_code": "business", "is_annual": False}, + headers=store_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert "message" in data + assert "new_tier" in data + + def test_change_tier_no_stripe_returns_error( + self, client, store_auth_headers, store_subscription + ): + """Without Stripe subscription, change_tier returns 400.""" + response = client.post( + f"{BASE}/change-tier", + json={"tier_code": "business", "is_annual": False}, + headers=store_auth_headers, + ) + assert response.status_code == 400 + + +class TestStoreCancelSubscription: + """Tests for POST /api/v1/store/billing/cancel.""" + + @patch("app.modules.billing.routes.api.store_checkout.billing_service") + def test_cancel_success( + self, mock_billing, client, store_auth_headers, store_subscription + ): + mock_billing.cancel_subscription.return_value = { + "message": "Subscription cancelled", + "effective_date": (datetime.now(UTC) + timedelta(days=30)).isoformat(), + } + response = client.post( + f"{BASE}/cancel", + json={"reason": "Too expensive", "immediately": False}, + headers=store_auth_headers, + ) + assert response.status_code == 200 + data = response.json() + assert "message" in data + assert "effective_date" in data + + +class TestStoreReactivateSubscription: + """Tests for POST /api/v1/store/billing/reactivate.""" + + @patch("app.modules.billing.routes.api.store_checkout.billing_service") + def test_reactivate_cancelled( + self, mock_billing, client, store_auth_headers, store_subscription, db + ): + store_subscription.status = SubscriptionStatus.CANCELLED.value + store_subscription.cancelled_at = datetime.now(UTC) + db.commit() + + mock_billing.reactivate_subscription.return_value = { + "message": "Subscription reactivated", + "status": "active", + } + + response = client.post( + f"{BASE}/reactivate", headers=store_auth_headers + ) + assert response.status_code == 200 + + +class TestStoreUpcomingInvoice: + """Tests for GET /api/v1/store/billing/upcoming-invoice.""" + + @patch("app.modules.billing.routes.api.store_checkout.billing_service") + def test_upcoming_invoice( + self, mock_billing, client, store_auth_headers, store_subscription + ): + mock_billing.get_upcoming_invoice.return_value = { + "amount_due_cents": 2900, + "currency": "EUR", + "next_payment_date": None, + "line_items": [], + } + + response = client.get( + f"{BASE}/upcoming-invoice", headers=store_auth_headers + ) + assert response.status_code == 200 + data = response.json() + assert "amount_due_cents" in data + assert "currency" in data + + +# ============================================================================ +# Unauthorized Access +# ============================================================================ + + +class TestStoreUnauthorized: + """Tests for unauthorized store access.""" + + def test_subscription_no_auth(self, client): + response = client.get(f"{BASE}/subscription") + assert response.status_code in (401, 403) + + def test_tiers_no_auth(self, client): + response = client.get(f"{BASE}/tiers") + assert response.status_code in (401, 403) + + def test_invoices_no_auth(self, client): + response = client.get(f"{BASE}/invoices") + assert response.status_code in (401, 403)