diff --git a/app/modules/billing/routes/api/admin_features.py b/app/modules/billing/routes/api/admin_features.py index 349bdceb..fcac5d8d 100644 --- a/app/modules/billing/routes/api/admin_features.py +++ b/app/modules/billing/routes/api/admin_features.py @@ -86,11 +86,11 @@ def get_feature_catalog( @admin_features_router.get( - "/tiers/{tier_code}/limits", + "/tiers/{tier_id}/limits", response_model=list[TierFeatureLimitEntry], ) def get_tier_feature_limits( - tier_code: str = Path(..., description="Tier code"), + tier_id: int = Path(..., description="Tier ID"), current_user: UserContext = Depends(get_current_admin_api), db: Session = Depends(get_db), ): @@ -100,7 +100,7 @@ def get_tier_feature_limits( Returns all TierFeatureLimit rows associated with the tier, each containing a feature_code and its optional limit_value. """ - rows = feature_service.get_tier_feature_limits(db, tier_code) + rows = feature_service.get_tier_feature_limits(db, tier_id) return [ TierFeatureLimitEntry( @@ -113,12 +113,12 @@ def get_tier_feature_limits( @admin_features_router.put( - "/tiers/{tier_code}/limits", + "/tiers/{tier_id}/limits", response_model=list[TierFeatureLimitEntry], ) def upsert_tier_feature_limits( entries: list[TierFeatureLimitEntry], - tier_code: str = Path(..., description="Tier code"), + tier_id: int = Path(..., description="Tier ID"), current_user: UserContext = Depends(get_current_admin_api), db: Session = Depends(get_db), ): @@ -136,15 +136,15 @@ def upsert_tier_feature_limits( raise InvalidFeatureCodesError(invalid_codes) new_rows = feature_service.upsert_tier_feature_limits( - db, tier_code, [e.model_dump() for e in entries] + db, tier_id, [e.model_dump() for e in entries] ) db.commit() logger.info( - "Admin %s replaced tier '%s' feature limits (%d entries)", + "Admin %s replaced tier %d feature limits (%d entries)", current_user.id, - tier_code, + tier_id, len(new_rows), ) diff --git a/app/modules/billing/services/feature_service.py b/app/modules/billing/services/feature_service.py index 73b0a5c7..c3e109e2 100644 --- a/app/modules/billing/services/feature_service.py +++ b/app/modules/billing/services/feature_service.py @@ -450,30 +450,24 @@ class FeatureService: # Tier Feature Limit Management # ========================================================================= - def get_tier_feature_limits(self, db: Session, tier_code: str) -> list: + def get_tier_feature_limits(self, db: Session, tier_id: int) -> list: """Get feature limits for a tier.""" - from app.modules.billing.services import admin_subscription_service - - tier = admin_subscription_service.get_tier_by_code(db, tier_code) return ( db.query(TierFeatureLimit) - .filter(TierFeatureLimit.tier_id == tier.id) + .filter(TierFeatureLimit.tier_id == tier_id) .order_by(TierFeatureLimit.feature_code) .all() ) - def upsert_tier_feature_limits(self, db: Session, tier_code: str, entries: list[dict]) -> list: + def upsert_tier_feature_limits(self, db: Session, tier_id: int, entries: list[dict]) -> list: """Replace feature limits for a tier. Returns list of new TierFeatureLimit objects.""" - from app.modules.billing.services import admin_subscription_service - - tier = admin_subscription_service.get_tier_by_code(db, tier_code) - db.query(TierFeatureLimit).filter(TierFeatureLimit.tier_id == tier.id).delete() + db.query(TierFeatureLimit).filter(TierFeatureLimit.tier_id == tier_id).delete() new_rows = [] for entry in entries: if not entry.get("enabled", True): continue row = TierFeatureLimit( - tier_id=tier.id, + tier_id=tier_id, feature_code=entry["feature_code"], limit_value=entry.get("limit_value"), ) diff --git a/app/modules/billing/static/admin/js/subscription-tiers.js b/app/modules/billing/static/admin/js/subscription-tiers.js index 66979f47..d6220d39 100644 --- a/app/modules/billing/static/admin/js/subscription-tiers.js +++ b/app/modules/billing/static/admin/js/subscription-tiers.js @@ -273,7 +273,7 @@ function adminSubscriptionTiers() { try { // Load tier's current feature limits - const data = await apiClient.get(`/admin/subscriptions/features/tiers/${tier.code}/limits`); + const data = await apiClient.get(`/admin/subscriptions/features/tiers/${tier.id}/limits`); // data is TierFeatureLimitEntry[]: [{feature_code, limit_value, enabled}] this.selectedFeatures = []; for (const entry of (data || [])) { @@ -327,7 +327,7 @@ function adminSubscriptionTiers() { })); await apiClient.put( - `/admin/subscriptions/features/tiers/${this.selectedTierForFeatures.code}/limits`, + `/admin/subscriptions/features/tiers/${this.selectedTierForFeatures.id}/limits`, entries ); diff --git a/app/modules/billing/tests/integration/test_admin_features_routes.py b/app/modules/billing/tests/integration/test_admin_features_routes.py new file mode 100644 index 00000000..ab75b0ec --- /dev/null +++ b/app/modules/billing/tests/integration/test_admin_features_routes.py @@ -0,0 +1,423 @@ +# app/modules/billing/tests/integration/test_admin_features_routes.py +""" +Integration tests for admin feature management API routes. + +Tests the feature limit endpoints at: + /api/v1/admin/subscriptions/features/* + +Covers: +- GET /features/catalog +- GET /features/tiers/{tier_id}/limits +- PUT /features/tiers/{tier_id}/limits +- Regression: tiers with duplicate codes across platforms are isolated by tier_id + +Uses super_admin_headers fixture which bypasses module access checks. +""" + +import uuid + +import pytest + +from app.modules.billing.models import SubscriptionTier +from app.modules.billing.models.tier_feature_limit import TierFeatureLimit +from app.modules.tenancy.models import Platform + +# ============================================================================ +# Fixtures +# ============================================================================ + +BASE = "/api/v1/admin/subscriptions/features" + + +@pytest.fixture +def ft_platform(db): + """Create a platform for feature route tests.""" + platform = Platform( + code=f"feat_{uuid.uuid4().hex[:8]}", + name="Feature Test Platform", + is_active=True, + ) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def ft_second_platform(db): + """Second platform for cross-platform isolation tests.""" + platform = Platform( + code=f"feat2_{uuid.uuid4().hex[:8]}", + name="Feature Test Platform 2", + is_active=True, + ) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def ft_tier(db, ft_platform): + """Create a tier for feature route tests.""" + tier = SubscriptionTier( + code=f"essential_{uuid.uuid4().hex[:6]}", + name="Essential", + price_monthly_cents=1000, + display_order=0, + is_active=True, + is_public=True, + platform_id=ft_platform.id, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def ft_duplicate_code_tiers(db, ft_platform, ft_second_platform): + """Create two tiers with the SAME code but different platforms. + + This is the exact scenario that caused the tier_code ambiguity bug. + """ + tier_a = SubscriptionTier( + code="essential", + name="Essential (Platform A)", + price_monthly_cents=1000, + display_order=0, + is_active=True, + is_public=True, + platform_id=ft_platform.id, + ) + tier_b = SubscriptionTier( + code="essential", + name="Essential (Platform B)", + price_monthly_cents=2000, + display_order=0, + is_active=True, + is_public=True, + platform_id=ft_second_platform.id, + ) + db.add(tier_a) + db.add(tier_b) + db.commit() + db.refresh(tier_a) + db.refresh(tier_b) + return tier_a, tier_b + + +@pytest.fixture +def ft_tier_with_features(db, ft_tier): + """Pre-populate a tier with feature limits.""" + features = [ + TierFeatureLimit(tier_id=ft_tier.id, feature_code="basic_shop", limit_value=None), + TierFeatureLimit(tier_id=ft_tier.id, feature_code="team_members", limit_value=5), + ] + for f in features: + db.add(f) + db.commit() + # Refresh so the tier's selectin-loaded feature_limits relationship is up to date + db.refresh(ft_tier) + return features + + +# ============================================================================ +# Feature Catalog +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.billing +class TestFeatureCatalog: + """Tests for GET /features/catalog.""" + + def test_get_catalog(self, client, super_admin_headers): + """Returns the feature catalog grouped by category.""" + response = client.get(f"{BASE}/catalog", headers=super_admin_headers) + assert response.status_code == 200 + data = response.json() + assert "features" in data + assert isinstance(data["features"], dict) + + def test_catalog_requires_auth(self, client): + """Catalog endpoint requires authentication.""" + response = client.get(f"{BASE}/catalog") + assert response.status_code == 401 + + +# ============================================================================ +# GET Tier Feature Limits +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.billing +class TestGetTierFeatureLimits: + """Tests for GET /features/tiers/{tier_id}/limits.""" + + def test_get_limits_empty(self, client, super_admin_headers, ft_tier): + """Returns empty list for tier with no features.""" + response = client.get( + f"{BASE}/tiers/{ft_tier.id}/limits", + headers=super_admin_headers, + ) + assert response.status_code == 200 + assert response.json() == [] + + def test_get_limits_with_features( + self, client, super_admin_headers, ft_tier, ft_tier_with_features + ): + """Returns feature limit entries for a tier.""" + response = client.get( + f"{BASE}/tiers/{ft_tier.id}/limits", + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + codes = {e["feature_code"] for e in data} + assert codes == {"basic_shop", "team_members"} + # Check limit values + for entry in data: + assert entry["enabled"] is True + if entry["feature_code"] == "team_members": + assert entry["limit_value"] == 5 + else: + assert entry["limit_value"] is None + + def test_get_limits_requires_auth(self, client, ft_tier): + """Endpoint requires authentication.""" + response = client.get(f"{BASE}/tiers/{ft_tier.id}/limits") + assert response.status_code == 401 + + +# ============================================================================ +# PUT Tier Feature Limits +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.billing +class TestUpsertTierFeatureLimits: + """Tests for PUT /features/tiers/{tier_id}/limits.""" + + def test_save_features(self, client, super_admin_headers, ft_tier): + """Saves feature limits and returns the saved entries.""" + # Get valid feature codes from catalog + catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json() + all_codes = [] + for features in catalog["features"].values(): + for f in features: + all_codes.append(f["code"]) + + # Use the first two valid codes + entries = [ + {"feature_code": all_codes[0], "limit_value": None, "enabled": True}, + {"feature_code": all_codes[1], "limit_value": 10, "enabled": True}, + ] + + response = client.put( + f"{BASE}/tiers/{ft_tier.id}/limits", + json=entries, + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + def test_save_replaces_existing( + self, client, super_admin_headers, ft_tier, ft_tier_with_features + ): + """Saving new features replaces the old ones entirely.""" + # Get a valid feature code + catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json() + valid_code = next( + f["code"] + for features in catalog["features"].values() + for f in features + ) + + entries = [ + {"feature_code": valid_code, "limit_value": None, "enabled": True}, + ] + + response = client.put( + f"{BASE}/tiers/{ft_tier.id}/limits", + json=entries, + headers=super_admin_headers, + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + + # Verify old features are gone + get_response = client.get( + f"{BASE}/tiers/{ft_tier.id}/limits", + headers=super_admin_headers, + ) + assert len(get_response.json()) == 1 + + def test_save_empty_clears_features( + self, client, super_admin_headers, ft_tier, ft_tier_with_features + ): + """Saving an empty list removes all features.""" + response = client.put( + f"{BASE}/tiers/{ft_tier.id}/limits", + json=[], + headers=super_admin_headers, + ) + assert response.status_code == 200 + assert response.json() == [] + + # Verify cleared + get_response = client.get( + f"{BASE}/tiers/{ft_tier.id}/limits", + headers=super_admin_headers, + ) + assert get_response.json() == [] + + def test_save_rejects_invalid_feature_codes(self, client, super_admin_headers, ft_tier): + """Returns error for unknown feature codes.""" + entries = [ + {"feature_code": "totally_fake_feature_xyz", "limit_value": None, "enabled": True}, + ] + + response = client.put( + f"{BASE}/tiers/{ft_tier.id}/limits", + json=entries, + headers=super_admin_headers, + ) + # Should fail validation + assert response.status_code in (400, 422) + + def test_save_requires_auth(self, client, ft_tier): + """Endpoint requires authentication.""" + response = client.put( + f"{BASE}/tiers/{ft_tier.id}/limits", + json=[], + ) + assert response.status_code == 401 + + +# ============================================================================ +# Cross-Platform Isolation (Regression Test) +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.billing +class TestCrossPlatformIsolation: + """ + Regression tests for the tier_code ambiguity bug. + + When multiple tiers share the same code (e.g., "essential" across platforms), + feature operations must use tier_id to avoid saving to the wrong tier. + """ + + def test_features_saved_to_correct_tier( + self, client, super_admin_headers, ft_duplicate_code_tiers + ): + """Features saved by tier_id go to the correct tier, not the first match.""" + tier_a, tier_b = ft_duplicate_code_tiers + + # Get valid feature codes + catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json() + codes = [ + f["code"] + for features in catalog["features"].values() + for f in features + ] + + # Save features to tier B (second platform) + entries_b = [ + {"feature_code": codes[0], "limit_value": None, "enabled": True}, + {"feature_code": codes[1], "limit_value": 50, "enabled": True}, + ] + response = client.put( + f"{BASE}/tiers/{tier_b.id}/limits", + json=entries_b, + headers=super_admin_headers, + ) + assert response.status_code == 200 + assert len(response.json()) == 2 + + # Tier A should still have 0 features + response_a = client.get( + f"{BASE}/tiers/{tier_a.id}/limits", + headers=super_admin_headers, + ) + assert response_a.status_code == 200 + assert len(response_a.json()) == 0 + + # Tier B should have 2 features + response_b = client.get( + f"{BASE}/tiers/{tier_b.id}/limits", + headers=super_admin_headers, + ) + assert response_b.status_code == 200 + assert len(response_b.json()) == 2 + + def test_features_do_not_leak_between_same_code_tiers( + self, client, super_admin_headers, ft_duplicate_code_tiers + ): + """Saving features to one tier doesn't affect another with the same code.""" + tier_a, tier_b = ft_duplicate_code_tiers + + # Get valid feature codes + catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json() + codes = [ + f["code"] + for features in catalog["features"].values() + for f in features + ] + + # Save different features to each tier + client.put( + f"{BASE}/tiers/{tier_a.id}/limits", + json=[{"feature_code": codes[0], "limit_value": None, "enabled": True}], + headers=super_admin_headers, + ) + client.put( + f"{BASE}/tiers/{tier_b.id}/limits", + json=[{"feature_code": codes[1], "limit_value": None, "enabled": True}], + headers=super_admin_headers, + ) + + # Each tier should have exactly its own feature + resp_a = client.get(f"{BASE}/tiers/{tier_a.id}/limits", headers=super_admin_headers) + resp_b = client.get(f"{BASE}/tiers/{tier_b.id}/limits", headers=super_admin_headers) + + assert len(resp_a.json()) == 1 + assert resp_a.json()[0]["feature_code"] == codes[0] + + assert len(resp_b.json()) == 1 + assert resp_b.json()[0]["feature_code"] == codes[1] + + +# ============================================================================ +# Feature Count in Tier List (End-to-End) +# ============================================================================ + + +@pytest.mark.integration +@pytest.mark.billing +class TestTierListFeatureCount: + """Tests that the tier list endpoint includes correct feature counts.""" + + def test_tier_list_includes_feature_codes( + self, client, super_admin_headers, ft_tier, ft_tier_with_features + ): + """GET /tiers returns feature_codes for each tier.""" + response = client.get( + "/api/v1/admin/subscriptions/tiers", + headers=super_admin_headers, + ) + assert response.status_code == 200 + + # Find our test tier in the response + tiers = response.json()["tiers"] + our_tier = next((t for t in tiers if t["id"] == ft_tier.id), None) + assert our_tier is not None + assert len(our_tier["feature_codes"]) == 2 + assert set(our_tier["feature_codes"]) == {"basic_shop", "team_members"} diff --git a/app/modules/billing/tests/unit/test_feature_service.py b/app/modules/billing/tests/unit/test_feature_service.py index 3c07ef15..cdb466de 100644 --- a/app/modules/billing/tests/unit/test_feature_service.py +++ b/app/modules/billing/tests/unit/test_feature_service.py @@ -2,7 +2,10 @@ import pytest +from app.modules.billing.models import SubscriptionTier +from app.modules.billing.models.tier_feature_limit import TierFeatureLimit from app.modules.billing.services.feature_service import FeatureService +from app.modules.tenancy.models import Platform @pytest.mark.unit @@ -16,3 +19,223 @@ class TestFeatureService: def test_service_instantiation(self): """Service can be instantiated.""" assert self.service is not None + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def fs_platform(db): + """Create a platform for feature service tests.""" + platform = Platform(code="fs_test", name="FS Test Platform", is_active=True) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def fs_second_platform(db): + """Create a second platform to test cross-platform isolation.""" + platform = Platform(code="fs_test2", name="FS Test Platform 2", is_active=True) + db.add(platform) + db.commit() + db.refresh(platform) + return platform + + +@pytest.fixture +def fs_tier(db, fs_platform): + """Create a tier for feature service tests.""" + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=1000, + display_order=0, + is_active=True, + is_public=True, + platform_id=fs_platform.id, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def fs_same_code_tier(db, fs_second_platform): + """Create a tier with the SAME code but different platform.""" + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=2000, + display_order=0, + is_active=True, + is_public=True, + platform_id=fs_second_platform.id, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def fs_tier_with_features(db, fs_tier): + """Create a tier with pre-existing feature limits.""" + features = [ + TierFeatureLimit(tier_id=fs_tier.id, feature_code="feature_a", limit_value=None), + TierFeatureLimit(tier_id=fs_tier.id, feature_code="feature_b", limit_value=100), + TierFeatureLimit(tier_id=fs_tier.id, feature_code="feature_c", limit_value=50), + ] + for f in features: + db.add(f) + db.commit() + return features + + +# ============================================================================ +# get_tier_feature_limits +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestGetTierFeatureLimits: + """Tests for FeatureService.get_tier_feature_limits.""" + + def test_returns_limits_for_tier(self, db, fs_tier_with_features, fs_tier): + """Returns all feature limit rows for the given tier ID.""" + service = FeatureService() + rows = service.get_tier_feature_limits(db, fs_tier.id) + assert len(rows) == 3 + codes = {r.feature_code for r in rows} + assert codes == {"feature_a", "feature_b", "feature_c"} + + def test_returns_empty_for_tier_without_features(self, db, fs_tier): + """Returns empty list for a tier with no feature limits.""" + service = FeatureService() + rows = service.get_tier_feature_limits(db, fs_tier.id) + assert rows == [] + + def test_returns_empty_for_nonexistent_tier(self, db): + """Returns empty list for a tier ID that doesn't exist.""" + service = FeatureService() + rows = service.get_tier_feature_limits(db, 999999) + assert rows == [] + + def test_isolates_by_tier_id(self, db, fs_tier, fs_same_code_tier, fs_tier_with_features): + """Features for one tier don't leak to another with the same code.""" + service = FeatureService() + + # fs_tier has 3 features + rows_tier1 = service.get_tier_feature_limits(db, fs_tier.id) + assert len(rows_tier1) == 3 + + # fs_same_code_tier (same code, different platform) has 0 + rows_tier2 = service.get_tier_feature_limits(db, fs_same_code_tier.id) + assert len(rows_tier2) == 0 + + +# ============================================================================ +# upsert_tier_feature_limits +# ============================================================================ + + +@pytest.mark.unit +@pytest.mark.billing +class TestUpsertTierFeatureLimits: + """Tests for FeatureService.upsert_tier_feature_limits.""" + + def test_inserts_new_features(self, db, fs_tier): + """Creates feature limit rows for a tier.""" + service = FeatureService() + entries = [ + {"feature_code": "feat_x", "limit_value": None, "enabled": True}, + {"feature_code": "feat_y", "limit_value": 200, "enabled": True}, + ] + rows = service.upsert_tier_feature_limits(db, fs_tier.id, entries) + db.commit() + + assert len(rows) == 2 + assert {r.feature_code for r in rows} == {"feat_x", "feat_y"} + + def test_replaces_existing_features(self, db, fs_tier, fs_tier_with_features): + """Upsert deletes old features and inserts new ones.""" + service = FeatureService() + entries = [ + {"feature_code": "new_feature", "limit_value": None, "enabled": True}, + ] + rows = service.upsert_tier_feature_limits(db, fs_tier.id, entries) + db.commit() + + assert len(rows) == 1 + assert rows[0].feature_code == "new_feature" + + # Old features should be gone + remaining = service.get_tier_feature_limits(db, fs_tier.id) + assert len(remaining) == 1 + assert remaining[0].feature_code == "new_feature" + + def test_skips_disabled_entries(self, db, fs_tier): + """Entries with enabled=False are not persisted.""" + service = FeatureService() + entries = [ + {"feature_code": "enabled_feat", "limit_value": None, "enabled": True}, + {"feature_code": "disabled_feat", "limit_value": None, "enabled": False}, + ] + rows = service.upsert_tier_feature_limits(db, fs_tier.id, entries) + db.commit() + + assert len(rows) == 1 + assert rows[0].feature_code == "enabled_feat" + + def test_saves_to_correct_tier_by_id(self, db, fs_tier, fs_same_code_tier): + """ + Regression test: saving by tier_id targets the exact tier, + not another tier that happens to share the same code. + """ + service = FeatureService() + entries = [ + {"feature_code": "platform_specific", "limit_value": None, "enabled": True}, + ] + + # Save to the second tier (same code "essential", different platform) + service.upsert_tier_feature_limits(db, fs_same_code_tier.id, entries) + db.commit() + + # First tier should have 0 features + rows_tier1 = service.get_tier_feature_limits(db, fs_tier.id) + assert len(rows_tier1) == 0 + + # Second tier should have 1 feature + rows_tier2 = service.get_tier_feature_limits(db, fs_same_code_tier.id) + assert len(rows_tier2) == 1 + assert rows_tier2[0].feature_code == "platform_specific" + + def test_clears_all_features_with_empty_list(self, db, fs_tier, fs_tier_with_features): + """Passing an empty list removes all features.""" + service = FeatureService() + rows = service.upsert_tier_feature_limits(db, fs_tier.id, []) + db.commit() + + assert len(rows) == 0 + remaining = service.get_tier_feature_limits(db, fs_tier.id) + assert len(remaining) == 0 + + def test_preserves_limit_values(self, db, fs_tier): + """Limit values (including None for unlimited) are stored correctly.""" + service = FeatureService() + entries = [ + {"feature_code": "unlimited", "limit_value": None, "enabled": True}, + {"feature_code": "limited", "limit_value": 42, "enabled": True}, + ] + service.upsert_tier_feature_limits(db, fs_tier.id, entries) + db.commit() + + rows = service.get_tier_feature_limits(db, fs_tier.id) + limits = {r.feature_code: r.limit_value for r in rows} + assert limits["unlimited"] is None + assert limits["limited"] == 42