fix(billing): use tier_id instead of tier_code for feature limit endpoints
Some checks failed
Some checks failed
Tier codes are not unique across platforms (e.g., "essential" exists for OMS, marketplace, and loyalty). Using tier_code caused feature limits to be saved to the wrong tier. Switched to tier_id (unique PK) in routes, service, and frontend JS. Added comprehensive unit and integration tests including cross-platform isolation regression tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -86,11 +86,11 @@ def get_feature_catalog(
|
||||
|
||||
|
||||
@admin_features_router.get(
|
||||
"/tiers/{tier_code}/limits",
|
||||
"/tiers/{tier_id}/limits",
|
||||
response_model=list[TierFeatureLimitEntry],
|
||||
)
|
||||
def get_tier_feature_limits(
|
||||
tier_code: str = Path(..., description="Tier code"),
|
||||
tier_id: int = Path(..., description="Tier ID"),
|
||||
current_user: UserContext = Depends(get_current_admin_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
@@ -100,7 +100,7 @@ def get_tier_feature_limits(
|
||||
Returns all TierFeatureLimit rows associated with the tier,
|
||||
each containing a feature_code and its optional limit_value.
|
||||
"""
|
||||
rows = feature_service.get_tier_feature_limits(db, tier_code)
|
||||
rows = feature_service.get_tier_feature_limits(db, tier_id)
|
||||
|
||||
return [
|
||||
TierFeatureLimitEntry(
|
||||
@@ -113,12 +113,12 @@ def get_tier_feature_limits(
|
||||
|
||||
|
||||
@admin_features_router.put(
|
||||
"/tiers/{tier_code}/limits",
|
||||
"/tiers/{tier_id}/limits",
|
||||
response_model=list[TierFeatureLimitEntry],
|
||||
)
|
||||
def upsert_tier_feature_limits(
|
||||
entries: list[TierFeatureLimitEntry],
|
||||
tier_code: str = Path(..., description="Tier code"),
|
||||
tier_id: int = Path(..., description="Tier ID"),
|
||||
current_user: UserContext = Depends(get_current_admin_api),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
@@ -136,15 +136,15 @@ def upsert_tier_feature_limits(
|
||||
raise InvalidFeatureCodesError(invalid_codes)
|
||||
|
||||
new_rows = feature_service.upsert_tier_feature_limits(
|
||||
db, tier_code, [e.model_dump() for e in entries]
|
||||
db, tier_id, [e.model_dump() for e in entries]
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"Admin %s replaced tier '%s' feature limits (%d entries)",
|
||||
"Admin %s replaced tier %d feature limits (%d entries)",
|
||||
current_user.id,
|
||||
tier_code,
|
||||
tier_id,
|
||||
len(new_rows),
|
||||
)
|
||||
|
||||
|
||||
@@ -450,30 +450,24 @@ class FeatureService:
|
||||
# Tier Feature Limit Management
|
||||
# =========================================================================
|
||||
|
||||
def get_tier_feature_limits(self, db: Session, tier_code: str) -> list:
|
||||
def get_tier_feature_limits(self, db: Session, tier_id: int) -> list:
|
||||
"""Get feature limits for a tier."""
|
||||
from app.modules.billing.services import admin_subscription_service
|
||||
|
||||
tier = admin_subscription_service.get_tier_by_code(db, tier_code)
|
||||
return (
|
||||
db.query(TierFeatureLimit)
|
||||
.filter(TierFeatureLimit.tier_id == tier.id)
|
||||
.filter(TierFeatureLimit.tier_id == tier_id)
|
||||
.order_by(TierFeatureLimit.feature_code)
|
||||
.all()
|
||||
)
|
||||
|
||||
def upsert_tier_feature_limits(self, db: Session, tier_code: str, entries: list[dict]) -> list:
|
||||
def upsert_tier_feature_limits(self, db: Session, tier_id: int, entries: list[dict]) -> list:
|
||||
"""Replace feature limits for a tier. Returns list of new TierFeatureLimit objects."""
|
||||
from app.modules.billing.services import admin_subscription_service
|
||||
|
||||
tier = admin_subscription_service.get_tier_by_code(db, tier_code)
|
||||
db.query(TierFeatureLimit).filter(TierFeatureLimit.tier_id == tier.id).delete()
|
||||
db.query(TierFeatureLimit).filter(TierFeatureLimit.tier_id == tier_id).delete()
|
||||
new_rows = []
|
||||
for entry in entries:
|
||||
if not entry.get("enabled", True):
|
||||
continue
|
||||
row = TierFeatureLimit(
|
||||
tier_id=tier.id,
|
||||
tier_id=tier_id,
|
||||
feature_code=entry["feature_code"],
|
||||
limit_value=entry.get("limit_value"),
|
||||
)
|
||||
|
||||
@@ -273,7 +273,7 @@ function adminSubscriptionTiers() {
|
||||
|
||||
try {
|
||||
// Load tier's current feature limits
|
||||
const data = await apiClient.get(`/admin/subscriptions/features/tiers/${tier.code}/limits`);
|
||||
const data = await apiClient.get(`/admin/subscriptions/features/tiers/${tier.id}/limits`);
|
||||
// data is TierFeatureLimitEntry[]: [{feature_code, limit_value, enabled}]
|
||||
this.selectedFeatures = [];
|
||||
for (const entry of (data || [])) {
|
||||
@@ -327,7 +327,7 @@ function adminSubscriptionTiers() {
|
||||
}));
|
||||
|
||||
await apiClient.put(
|
||||
`/admin/subscriptions/features/tiers/${this.selectedTierForFeatures.code}/limits`,
|
||||
`/admin/subscriptions/features/tiers/${this.selectedTierForFeatures.id}/limits`,
|
||||
entries
|
||||
);
|
||||
|
||||
|
||||
@@ -0,0 +1,423 @@
|
||||
# app/modules/billing/tests/integration/test_admin_features_routes.py
|
||||
"""
|
||||
Integration tests for admin feature management API routes.
|
||||
|
||||
Tests the feature limit endpoints at:
|
||||
/api/v1/admin/subscriptions/features/*
|
||||
|
||||
Covers:
|
||||
- GET /features/catalog
|
||||
- GET /features/tiers/{tier_id}/limits
|
||||
- PUT /features/tiers/{tier_id}/limits
|
||||
- Regression: tiers with duplicate codes across platforms are isolated by tier_id
|
||||
|
||||
Uses super_admin_headers fixture which bypasses module access checks.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.billing.models import SubscriptionTier
|
||||
from app.modules.billing.models.tier_feature_limit import TierFeatureLimit
|
||||
from app.modules.tenancy.models import Platform
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
BASE = "/api/v1/admin/subscriptions/features"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ft_platform(db):
|
||||
"""Create a platform for feature route tests."""
|
||||
platform = Platform(
|
||||
code=f"feat_{uuid.uuid4().hex[:8]}",
|
||||
name="Feature Test Platform",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(platform)
|
||||
db.commit()
|
||||
db.refresh(platform)
|
||||
return platform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ft_second_platform(db):
|
||||
"""Second platform for cross-platform isolation tests."""
|
||||
platform = Platform(
|
||||
code=f"feat2_{uuid.uuid4().hex[:8]}",
|
||||
name="Feature Test Platform 2",
|
||||
is_active=True,
|
||||
)
|
||||
db.add(platform)
|
||||
db.commit()
|
||||
db.refresh(platform)
|
||||
return platform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ft_tier(db, ft_platform):
|
||||
"""Create a tier for feature route tests."""
|
||||
tier = SubscriptionTier(
|
||||
code=f"essential_{uuid.uuid4().hex[:6]}",
|
||||
name="Essential",
|
||||
price_monthly_cents=1000,
|
||||
display_order=0,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
platform_id=ft_platform.id,
|
||||
)
|
||||
db.add(tier)
|
||||
db.commit()
|
||||
db.refresh(tier)
|
||||
return tier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ft_duplicate_code_tiers(db, ft_platform, ft_second_platform):
|
||||
"""Create two tiers with the SAME code but different platforms.
|
||||
|
||||
This is the exact scenario that caused the tier_code ambiguity bug.
|
||||
"""
|
||||
tier_a = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential (Platform A)",
|
||||
price_monthly_cents=1000,
|
||||
display_order=0,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
platform_id=ft_platform.id,
|
||||
)
|
||||
tier_b = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential (Platform B)",
|
||||
price_monthly_cents=2000,
|
||||
display_order=0,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
platform_id=ft_second_platform.id,
|
||||
)
|
||||
db.add(tier_a)
|
||||
db.add(tier_b)
|
||||
db.commit()
|
||||
db.refresh(tier_a)
|
||||
db.refresh(tier_b)
|
||||
return tier_a, tier_b
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ft_tier_with_features(db, ft_tier):
|
||||
"""Pre-populate a tier with feature limits."""
|
||||
features = [
|
||||
TierFeatureLimit(tier_id=ft_tier.id, feature_code="basic_shop", limit_value=None),
|
||||
TierFeatureLimit(tier_id=ft_tier.id, feature_code="team_members", limit_value=5),
|
||||
]
|
||||
for f in features:
|
||||
db.add(f)
|
||||
db.commit()
|
||||
# Refresh so the tier's selectin-loaded feature_limits relationship is up to date
|
||||
db.refresh(ft_tier)
|
||||
return features
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Feature Catalog
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.billing
|
||||
class TestFeatureCatalog:
|
||||
"""Tests for GET /features/catalog."""
|
||||
|
||||
def test_get_catalog(self, client, super_admin_headers):
|
||||
"""Returns the feature catalog grouped by category."""
|
||||
response = client.get(f"{BASE}/catalog", headers=super_admin_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "features" in data
|
||||
assert isinstance(data["features"], dict)
|
||||
|
||||
def test_catalog_requires_auth(self, client):
|
||||
"""Catalog endpoint requires authentication."""
|
||||
response = client.get(f"{BASE}/catalog")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GET Tier Feature Limits
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.billing
|
||||
class TestGetTierFeatureLimits:
|
||||
"""Tests for GET /features/tiers/{tier_id}/limits."""
|
||||
|
||||
def test_get_limits_empty(self, client, super_admin_headers, ft_tier):
|
||||
"""Returns empty list for tier with no features."""
|
||||
response = client.get(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_get_limits_with_features(
|
||||
self, client, super_admin_headers, ft_tier, ft_tier_with_features
|
||||
):
|
||||
"""Returns feature limit entries for a tier."""
|
||||
response = client.get(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
codes = {e["feature_code"] for e in data}
|
||||
assert codes == {"basic_shop", "team_members"}
|
||||
# Check limit values
|
||||
for entry in data:
|
||||
assert entry["enabled"] is True
|
||||
if entry["feature_code"] == "team_members":
|
||||
assert entry["limit_value"] == 5
|
||||
else:
|
||||
assert entry["limit_value"] is None
|
||||
|
||||
def test_get_limits_requires_auth(self, client, ft_tier):
|
||||
"""Endpoint requires authentication."""
|
||||
response = client.get(f"{BASE}/tiers/{ft_tier.id}/limits")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PUT Tier Feature Limits
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.billing
|
||||
class TestUpsertTierFeatureLimits:
|
||||
"""Tests for PUT /features/tiers/{tier_id}/limits."""
|
||||
|
||||
def test_save_features(self, client, super_admin_headers, ft_tier):
|
||||
"""Saves feature limits and returns the saved entries."""
|
||||
# Get valid feature codes from catalog
|
||||
catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json()
|
||||
all_codes = []
|
||||
for features in catalog["features"].values():
|
||||
for f in features:
|
||||
all_codes.append(f["code"])
|
||||
|
||||
# Use the first two valid codes
|
||||
entries = [
|
||||
{"feature_code": all_codes[0], "limit_value": None, "enabled": True},
|
||||
{"feature_code": all_codes[1], "limit_value": 10, "enabled": True},
|
||||
]
|
||||
|
||||
response = client.put(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
json=entries,
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
||||
def test_save_replaces_existing(
|
||||
self, client, super_admin_headers, ft_tier, ft_tier_with_features
|
||||
):
|
||||
"""Saving new features replaces the old ones entirely."""
|
||||
# Get a valid feature code
|
||||
catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json()
|
||||
valid_code = next(
|
||||
f["code"]
|
||||
for features in catalog["features"].values()
|
||||
for f in features
|
||||
)
|
||||
|
||||
entries = [
|
||||
{"feature_code": valid_code, "limit_value": None, "enabled": True},
|
||||
]
|
||||
|
||||
response = client.put(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
json=entries,
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
|
||||
# Verify old features are gone
|
||||
get_response = client.get(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert len(get_response.json()) == 1
|
||||
|
||||
def test_save_empty_clears_features(
|
||||
self, client, super_admin_headers, ft_tier, ft_tier_with_features
|
||||
):
|
||||
"""Saving an empty list removes all features."""
|
||||
response = client.put(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
json=[],
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
# Verify cleared
|
||||
get_response = client.get(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert get_response.json() == []
|
||||
|
||||
def test_save_rejects_invalid_feature_codes(self, client, super_admin_headers, ft_tier):
|
||||
"""Returns error for unknown feature codes."""
|
||||
entries = [
|
||||
{"feature_code": "totally_fake_feature_xyz", "limit_value": None, "enabled": True},
|
||||
]
|
||||
|
||||
response = client.put(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
json=entries,
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
# Should fail validation
|
||||
assert response.status_code in (400, 422)
|
||||
|
||||
def test_save_requires_auth(self, client, ft_tier):
|
||||
"""Endpoint requires authentication."""
|
||||
response = client.put(
|
||||
f"{BASE}/tiers/{ft_tier.id}/limits",
|
||||
json=[],
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cross-Platform Isolation (Regression Test)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.billing
|
||||
class TestCrossPlatformIsolation:
|
||||
"""
|
||||
Regression tests for the tier_code ambiguity bug.
|
||||
|
||||
When multiple tiers share the same code (e.g., "essential" across platforms),
|
||||
feature operations must use tier_id to avoid saving to the wrong tier.
|
||||
"""
|
||||
|
||||
def test_features_saved_to_correct_tier(
|
||||
self, client, super_admin_headers, ft_duplicate_code_tiers
|
||||
):
|
||||
"""Features saved by tier_id go to the correct tier, not the first match."""
|
||||
tier_a, tier_b = ft_duplicate_code_tiers
|
||||
|
||||
# Get valid feature codes
|
||||
catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json()
|
||||
codes = [
|
||||
f["code"]
|
||||
for features in catalog["features"].values()
|
||||
for f in features
|
||||
]
|
||||
|
||||
# Save features to tier B (second platform)
|
||||
entries_b = [
|
||||
{"feature_code": codes[0], "limit_value": None, "enabled": True},
|
||||
{"feature_code": codes[1], "limit_value": 50, "enabled": True},
|
||||
]
|
||||
response = client.put(
|
||||
f"{BASE}/tiers/{tier_b.id}/limits",
|
||||
json=entries_b,
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
|
||||
# Tier A should still have 0 features
|
||||
response_a = client.get(
|
||||
f"{BASE}/tiers/{tier_a.id}/limits",
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response_a.status_code == 200
|
||||
assert len(response_a.json()) == 0
|
||||
|
||||
# Tier B should have 2 features
|
||||
response_b = client.get(
|
||||
f"{BASE}/tiers/{tier_b.id}/limits",
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response_b.status_code == 200
|
||||
assert len(response_b.json()) == 2
|
||||
|
||||
def test_features_do_not_leak_between_same_code_tiers(
|
||||
self, client, super_admin_headers, ft_duplicate_code_tiers
|
||||
):
|
||||
"""Saving features to one tier doesn't affect another with the same code."""
|
||||
tier_a, tier_b = ft_duplicate_code_tiers
|
||||
|
||||
# Get valid feature codes
|
||||
catalog = client.get(f"{BASE}/catalog", headers=super_admin_headers).json()
|
||||
codes = [
|
||||
f["code"]
|
||||
for features in catalog["features"].values()
|
||||
for f in features
|
||||
]
|
||||
|
||||
# Save different features to each tier
|
||||
client.put(
|
||||
f"{BASE}/tiers/{tier_a.id}/limits",
|
||||
json=[{"feature_code": codes[0], "limit_value": None, "enabled": True}],
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
client.put(
|
||||
f"{BASE}/tiers/{tier_b.id}/limits",
|
||||
json=[{"feature_code": codes[1], "limit_value": None, "enabled": True}],
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
|
||||
# Each tier should have exactly its own feature
|
||||
resp_a = client.get(f"{BASE}/tiers/{tier_a.id}/limits", headers=super_admin_headers)
|
||||
resp_b = client.get(f"{BASE}/tiers/{tier_b.id}/limits", headers=super_admin_headers)
|
||||
|
||||
assert len(resp_a.json()) == 1
|
||||
assert resp_a.json()[0]["feature_code"] == codes[0]
|
||||
|
||||
assert len(resp_b.json()) == 1
|
||||
assert resp_b.json()[0]["feature_code"] == codes[1]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Feature Count in Tier List (End-to-End)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.billing
|
||||
class TestTierListFeatureCount:
|
||||
"""Tests that the tier list endpoint includes correct feature counts."""
|
||||
|
||||
def test_tier_list_includes_feature_codes(
|
||||
self, client, super_admin_headers, ft_tier, ft_tier_with_features
|
||||
):
|
||||
"""GET /tiers returns feature_codes for each tier."""
|
||||
response = client.get(
|
||||
"/api/v1/admin/subscriptions/tiers",
|
||||
headers=super_admin_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Find our test tier in the response
|
||||
tiers = response.json()["tiers"]
|
||||
our_tier = next((t for t in tiers if t["id"] == ft_tier.id), None)
|
||||
assert our_tier is not None
|
||||
assert len(our_tier["feature_codes"]) == 2
|
||||
assert set(our_tier["feature_codes"]) == {"basic_shop", "team_members"}
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from app.modules.billing.models import SubscriptionTier
|
||||
from app.modules.billing.models.tier_feature_limit import TierFeatureLimit
|
||||
from app.modules.billing.services.feature_service import FeatureService
|
||||
from app.modules.tenancy.models import Platform
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -16,3 +19,223 @@ class TestFeatureService:
|
||||
def test_service_instantiation(self):
|
||||
"""Service can be instantiated."""
|
||||
assert self.service is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fs_platform(db):
|
||||
"""Create a platform for feature service tests."""
|
||||
platform = Platform(code="fs_test", name="FS Test Platform", is_active=True)
|
||||
db.add(platform)
|
||||
db.commit()
|
||||
db.refresh(platform)
|
||||
return platform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fs_second_platform(db):
|
||||
"""Create a second platform to test cross-platform isolation."""
|
||||
platform = Platform(code="fs_test2", name="FS Test Platform 2", is_active=True)
|
||||
db.add(platform)
|
||||
db.commit()
|
||||
db.refresh(platform)
|
||||
return platform
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fs_tier(db, fs_platform):
|
||||
"""Create a tier for feature service tests."""
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=1000,
|
||||
display_order=0,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
platform_id=fs_platform.id,
|
||||
)
|
||||
db.add(tier)
|
||||
db.commit()
|
||||
db.refresh(tier)
|
||||
return tier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fs_same_code_tier(db, fs_second_platform):
|
||||
"""Create a tier with the SAME code but different platform."""
|
||||
tier = SubscriptionTier(
|
||||
code="essential",
|
||||
name="Essential",
|
||||
price_monthly_cents=2000,
|
||||
display_order=0,
|
||||
is_active=True,
|
||||
is_public=True,
|
||||
platform_id=fs_second_platform.id,
|
||||
)
|
||||
db.add(tier)
|
||||
db.commit()
|
||||
db.refresh(tier)
|
||||
return tier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fs_tier_with_features(db, fs_tier):
|
||||
"""Create a tier with pre-existing feature limits."""
|
||||
features = [
|
||||
TierFeatureLimit(tier_id=fs_tier.id, feature_code="feature_a", limit_value=None),
|
||||
TierFeatureLimit(tier_id=fs_tier.id, feature_code="feature_b", limit_value=100),
|
||||
TierFeatureLimit(tier_id=fs_tier.id, feature_code="feature_c", limit_value=50),
|
||||
]
|
||||
for f in features:
|
||||
db.add(f)
|
||||
db.commit()
|
||||
return features
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# get_tier_feature_limits
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestGetTierFeatureLimits:
|
||||
"""Tests for FeatureService.get_tier_feature_limits."""
|
||||
|
||||
def test_returns_limits_for_tier(self, db, fs_tier_with_features, fs_tier):
|
||||
"""Returns all feature limit rows for the given tier ID."""
|
||||
service = FeatureService()
|
||||
rows = service.get_tier_feature_limits(db, fs_tier.id)
|
||||
assert len(rows) == 3
|
||||
codes = {r.feature_code for r in rows}
|
||||
assert codes == {"feature_a", "feature_b", "feature_c"}
|
||||
|
||||
def test_returns_empty_for_tier_without_features(self, db, fs_tier):
|
||||
"""Returns empty list for a tier with no feature limits."""
|
||||
service = FeatureService()
|
||||
rows = service.get_tier_feature_limits(db, fs_tier.id)
|
||||
assert rows == []
|
||||
|
||||
def test_returns_empty_for_nonexistent_tier(self, db):
|
||||
"""Returns empty list for a tier ID that doesn't exist."""
|
||||
service = FeatureService()
|
||||
rows = service.get_tier_feature_limits(db, 999999)
|
||||
assert rows == []
|
||||
|
||||
def test_isolates_by_tier_id(self, db, fs_tier, fs_same_code_tier, fs_tier_with_features):
|
||||
"""Features for one tier don't leak to another with the same code."""
|
||||
service = FeatureService()
|
||||
|
||||
# fs_tier has 3 features
|
||||
rows_tier1 = service.get_tier_feature_limits(db, fs_tier.id)
|
||||
assert len(rows_tier1) == 3
|
||||
|
||||
# fs_same_code_tier (same code, different platform) has 0
|
||||
rows_tier2 = service.get_tier_feature_limits(db, fs_same_code_tier.id)
|
||||
assert len(rows_tier2) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# upsert_tier_feature_limits
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.billing
|
||||
class TestUpsertTierFeatureLimits:
|
||||
"""Tests for FeatureService.upsert_tier_feature_limits."""
|
||||
|
||||
def test_inserts_new_features(self, db, fs_tier):
|
||||
"""Creates feature limit rows for a tier."""
|
||||
service = FeatureService()
|
||||
entries = [
|
||||
{"feature_code": "feat_x", "limit_value": None, "enabled": True},
|
||||
{"feature_code": "feat_y", "limit_value": 200, "enabled": True},
|
||||
]
|
||||
rows = service.upsert_tier_feature_limits(db, fs_tier.id, entries)
|
||||
db.commit()
|
||||
|
||||
assert len(rows) == 2
|
||||
assert {r.feature_code for r in rows} == {"feat_x", "feat_y"}
|
||||
|
||||
def test_replaces_existing_features(self, db, fs_tier, fs_tier_with_features):
|
||||
"""Upsert deletes old features and inserts new ones."""
|
||||
service = FeatureService()
|
||||
entries = [
|
||||
{"feature_code": "new_feature", "limit_value": None, "enabled": True},
|
||||
]
|
||||
rows = service.upsert_tier_feature_limits(db, fs_tier.id, entries)
|
||||
db.commit()
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0].feature_code == "new_feature"
|
||||
|
||||
# Old features should be gone
|
||||
remaining = service.get_tier_feature_limits(db, fs_tier.id)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].feature_code == "new_feature"
|
||||
|
||||
def test_skips_disabled_entries(self, db, fs_tier):
|
||||
"""Entries with enabled=False are not persisted."""
|
||||
service = FeatureService()
|
||||
entries = [
|
||||
{"feature_code": "enabled_feat", "limit_value": None, "enabled": True},
|
||||
{"feature_code": "disabled_feat", "limit_value": None, "enabled": False},
|
||||
]
|
||||
rows = service.upsert_tier_feature_limits(db, fs_tier.id, entries)
|
||||
db.commit()
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0].feature_code == "enabled_feat"
|
||||
|
||||
def test_saves_to_correct_tier_by_id(self, db, fs_tier, fs_same_code_tier):
|
||||
"""
|
||||
Regression test: saving by tier_id targets the exact tier,
|
||||
not another tier that happens to share the same code.
|
||||
"""
|
||||
service = FeatureService()
|
||||
entries = [
|
||||
{"feature_code": "platform_specific", "limit_value": None, "enabled": True},
|
||||
]
|
||||
|
||||
# Save to the second tier (same code "essential", different platform)
|
||||
service.upsert_tier_feature_limits(db, fs_same_code_tier.id, entries)
|
||||
db.commit()
|
||||
|
||||
# First tier should have 0 features
|
||||
rows_tier1 = service.get_tier_feature_limits(db, fs_tier.id)
|
||||
assert len(rows_tier1) == 0
|
||||
|
||||
# Second tier should have 1 feature
|
||||
rows_tier2 = service.get_tier_feature_limits(db, fs_same_code_tier.id)
|
||||
assert len(rows_tier2) == 1
|
||||
assert rows_tier2[0].feature_code == "platform_specific"
|
||||
|
||||
def test_clears_all_features_with_empty_list(self, db, fs_tier, fs_tier_with_features):
|
||||
"""Passing an empty list removes all features."""
|
||||
service = FeatureService()
|
||||
rows = service.upsert_tier_feature_limits(db, fs_tier.id, [])
|
||||
db.commit()
|
||||
|
||||
assert len(rows) == 0
|
||||
remaining = service.get_tier_feature_limits(db, fs_tier.id)
|
||||
assert len(remaining) == 0
|
||||
|
||||
def test_preserves_limit_values(self, db, fs_tier):
|
||||
"""Limit values (including None for unlimited) are stored correctly."""
|
||||
service = FeatureService()
|
||||
entries = [
|
||||
{"feature_code": "unlimited", "limit_value": None, "enabled": True},
|
||||
{"feature_code": "limited", "limit_value": 42, "enabled": True},
|
||||
]
|
||||
service.upsert_tier_feature_limits(db, fs_tier.id, entries)
|
||||
db.commit()
|
||||
|
||||
rows = service.get_tier_feature_limits(db, fs_tier.id)
|
||||
limits = {r.feature_code: r.limit_value for r in rows}
|
||||
assert limits["unlimited"] is None
|
||||
assert limits["limited"] == 42
|
||||
|
||||
Reference in New Issue
Block a user