test: add service tests and fix architecture violations

- Add comprehensive unit tests for FeatureService (24 tests)
- Add comprehensive unit tests for UsageService (11 tests)
- Fix API-002/API-003 architecture violations in feature/usage APIs
- Move database queries from API layer to service layer
- Create UsageService for usage and limits management
- Create custom exceptions (FeatureNotFoundError, TierNotFoundError)
- Fix ValidationException usage in content_pages.py
- Refactor vendor features API to use proper response models
- All 35 new tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-31 18:48:59 +01:00
parent 7d1a421826
commit aa4b5a4c63
10 changed files with 1474 additions and 408 deletions

View File

@@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api, get_db
from app.exceptions import ValidationException
from app.services.content_page_service import content_page_service
from models.database.user import User
@@ -178,11 +179,9 @@ def create_vendor_page(
Vendor pages override platform defaults for a specific vendor.
"""
if not page_data.vendor_id:
from fastapi import HTTPException
raise HTTPException(
status_code=400,
detail="vendor_id is required for vendor pages. Use /platform for platform defaults.",
raise ValidationException(
message="vendor_id is required for vendor pages. Use /platform for platform defaults.",
field="vendor_id",
)
page = content_page_service.create_page(

View File

@@ -11,15 +11,13 @@ Provides endpoints for:
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api
from app.core.database import get_db
from app.services.feature_service import feature_service
from models.database.feature import Feature
from models.database.subscription import SubscriptionTier
from models.database.user import User
router = APIRouter(prefix="/features")
@@ -103,6 +101,41 @@ class CategoryListResponse(BaseModel):
categories: list[str]
class TierFeatureDetailResponse(BaseModel):
"""Tier features with full details."""
tier_code: str
tier_name: str
features: list[dict]
feature_count: int
# ============================================================================
# Helper Functions
# ============================================================================
def _feature_to_response(feature) -> FeatureResponse:
"""Convert Feature model to response."""
return FeatureResponse(
id=feature.id,
code=feature.code,
name=feature.name,
description=feature.description,
category=feature.category,
ui_location=feature.ui_location,
ui_icon=feature.ui_icon,
ui_route=feature.ui_route,
ui_badge_text=feature.ui_badge_text,
minimum_tier_id=feature.minimum_tier_id,
minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None,
minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None,
is_active=feature.is_active,
is_visible=feature.is_visible,
display_order=feature.display_order,
)
# ============================================================================
# Endpoints
# ============================================================================
@@ -121,26 +154,7 @@ def list_features(
)
return FeatureListResponse(
features=[
FeatureResponse(
id=f.id,
code=f.code,
name=f.name,
description=f.description,
category=f.category,
ui_location=f.ui_location,
ui_icon=f.ui_icon,
ui_route=f.ui_route,
ui_badge_text=f.ui_badge_text,
minimum_tier_id=f.minimum_tier_id,
minimum_tier_code=f.minimum_tier.code if f.minimum_tier else None,
minimum_tier_name=f.minimum_tier.name if f.minimum_tier else None,
is_active=f.is_active,
is_visible=f.is_visible,
display_order=f.display_order,
)
for f in features
],
features=[_feature_to_response(f) for f in features],
total=len(features),
)
@@ -161,12 +175,7 @@ def list_tiers_with_features(
db: Session = Depends(get_db),
):
"""List all tiers with their feature assignments."""
tiers = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.is_active == True) # noqa: E712
.order_by(SubscriptionTier.display_order)
.all()
)
tiers = feature_service.get_all_tiers_with_features(db)
return TierListWithFeaturesResponse(
tiers=[
@@ -189,29 +198,19 @@ def get_feature(
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db),
):
"""Get a single feature by code."""
"""
Get a single feature by code.
Raises 404 if feature not found.
"""
feature = feature_service.get_feature_by_code(db, feature_code)
if not feature:
raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found")
from app.exceptions import FeatureNotFoundError
return FeatureResponse(
id=feature.id,
code=feature.code,
name=feature.name,
description=feature.description,
category=feature.category,
ui_location=feature.ui_location,
ui_icon=feature.ui_icon,
ui_route=feature.ui_route,
ui_badge_text=feature.ui_badge_text,
minimum_tier_id=feature.minimum_tier_id,
minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None,
minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None,
is_active=feature.is_active,
is_visible=feature.is_visible,
display_order=feature.display_order,
)
raise FeatureNotFoundError(feature_code)
return _feature_to_response(feature)
@router.put("/{feature_code}", response_model=FeatureResponse)
@@ -221,73 +220,33 @@ def update_feature(
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db),
):
"""Update feature metadata."""
feature = db.query(Feature).filter(Feature.code == feature_code).first()
"""
Update feature metadata.
if not feature:
raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found")
# Update fields if provided
if request.name is not None:
feature.name = request.name
if request.description is not None:
feature.description = request.description
if request.category is not None:
feature.category = request.category
if request.ui_location is not None:
feature.ui_location = request.ui_location
if request.ui_icon is not None:
feature.ui_icon = request.ui_icon
if request.ui_route is not None:
feature.ui_route = request.ui_route
if request.ui_badge_text is not None:
feature.ui_badge_text = request.ui_badge_text
if request.is_active is not None:
feature.is_active = request.is_active
if request.is_visible is not None:
feature.is_visible = request.is_visible
if request.display_order is not None:
feature.display_order = request.display_order
# Update minimum tier if provided
if request.minimum_tier_code is not None:
if request.minimum_tier_code == "":
feature.minimum_tier_id = None
else:
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == request.minimum_tier_code)
.first()
)
if not tier:
raise HTTPException(
status_code=400,
detail=f"Tier '{request.minimum_tier_code}' not found",
)
feature.minimum_tier_id = tier.id
Raises 404 if feature not found, 400 if tier code is invalid.
"""
feature = feature_service.update_feature(
db,
feature_code,
name=request.name,
description=request.description,
category=request.category,
ui_location=request.ui_location,
ui_icon=request.ui_icon,
ui_route=request.ui_route,
ui_badge_text=request.ui_badge_text,
minimum_tier_code=request.minimum_tier_code,
is_active=request.is_active,
is_visible=request.is_visible,
display_order=request.display_order,
)
db.commit()
db.refresh(feature)
logger.info(f"Updated feature {feature_code} by admin {current_user.id}")
return FeatureResponse(
id=feature.id,
code=feature.code,
name=feature.name,
description=feature.description,
category=feature.category,
ui_location=feature.ui_location,
ui_icon=feature.ui_icon,
ui_route=feature.ui_route,
ui_badge_text=feature.ui_badge_text,
minimum_tier_id=feature.minimum_tier_id,
minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None,
minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None,
is_active=feature.is_active,
is_visible=feature.is_visible,
display_order=feature.display_order,
)
return _feature_to_response(feature)
@router.put("/tiers/{tier_code}/features", response_model=TierFeaturesResponse)
@@ -297,54 +256,46 @@ def update_tier_features(
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db),
):
"""Update features assigned to a tier."""
try:
tier = feature_service.update_tier_features(db, tier_code, request.feature_codes)
db.commit()
"""
Update features assigned to a tier.
logger.info(
f"Updated tier {tier_code} features to {len(request.feature_codes)} features "
f"by admin {current_user.id}"
)
Raises 404 if tier not found, 422 if any feature codes are invalid.
"""
tier = feature_service.update_tier_features(db, tier_code, request.feature_codes)
db.commit()
return TierFeaturesResponse(
id=tier.id,
code=tier.code,
name=tier.name,
description=tier.description,
features=tier.features or [],
feature_count=len(tier.features or []),
)
logger.info(
f"Updated tier {tier_code} features to {len(request.feature_codes)} features "
f"by admin {current_user.id}"
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return TierFeaturesResponse(
id=tier.id,
code=tier.code,
name=tier.name,
description=tier.description,
features=tier.features or [],
feature_count=len(tier.features or []),
)
@router.get("/tiers/{tier_code}/features")
@router.get("/tiers/{tier_code}/features", response_model=TierFeatureDetailResponse)
def get_tier_features(
tier_code: str,
current_user: User = Depends(get_current_admin_api),
db: Session = Depends(get_db),
):
"""Get features assigned to a specific tier."""
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
"""
Get features assigned to a specific tier with full details.
if not tier:
raise HTTPException(status_code=404, detail=f"Tier '{tier_code}' not found")
Raises 404 if tier not found.
"""
tier, features = feature_service.get_tier_features_with_details(db, tier_code)
# Get full feature details for the tier's features
feature_codes = tier.features or []
features = (
db.query(Feature)
.filter(Feature.code.in_(feature_codes))
.order_by(Feature.category, Feature.display_order)
.all()
)
return {
"tier_code": tier.code,
"tier_name": tier.name,
"features": [
return TierFeatureDetailResponse(
tier_code=tier.code,
tier_name=tier.name,
features=[
{
"code": f.code,
"name": f.name,
@@ -353,5 +304,5 @@ def get_tier_features(
}
for f in features
],
"feature_count": len(features),
}
feature_count=len(features),
)

View File

@@ -16,12 +16,13 @@ Endpoints:
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from app.exceptions import FeatureNotFoundError
from app.services.feature_service import feature_service
from models.database.user import User
@@ -99,6 +100,13 @@ class FeatureGroupedResponse(BaseModel):
total_count: int
class FeatureCheckResponse(BaseModel):
"""Quick feature availability check response."""
has_feature: bool
feature_code: str
# ============================================================================
# Endpoints
# ============================================================================
@@ -285,7 +293,7 @@ def get_feature_detail(
# Get feature
feature = feature_service.get_feature_by_code(db, feature_code)
if not feature:
raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found")
raise FeatureNotFoundError(feature_code)
# Check availability
is_available = feature_service.has_feature(db, vendor_id, feature_code)
@@ -317,7 +325,7 @@ def get_feature_detail(
)
@router.get("/check/{feature_code}")
@router.get("/check/{feature_code}", response_model=FeatureCheckResponse)
def check_feature(
feature_code: str,
current_user: User = Depends(get_current_vendor_api),
@@ -332,9 +340,9 @@ def check_feature(
feature_code: The feature code
Returns:
{"has_feature": true/false}
has_feature and feature_code
"""
vendor_id = current_user.token_vendor_id
has_feature = feature_service.has_feature(db, vendor_id, feature_code)
return {"has_feature": has_feature, "feature_code": feature_code}
return FeatureCheckResponse(has_feature=has_feature, feature_code=feature_code)

View File

@@ -12,16 +12,12 @@ import logging
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api
from app.core.database import get_db
from app.services.subscription_service import subscription_service
from models.database.product import Product
from models.database.subscription import SubscriptionTier
from app.services.usage_service import usage_service
from models.database.user import User
from models.database.vendor import VendorUser
router = APIRouter(prefix="/usage")
logger = logging.getLogger(__name__)
@@ -106,181 +102,44 @@ def get_usage(
"""
vendor_id = current_user.token_vendor_id
# Get subscription
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
# Get current tier
tier = subscription.tier_obj
if not tier:
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == subscription.tier)
.first()
)
# Calculate usage metrics
usage_metrics = []
# Orders this period
orders_current = subscription.orders_this_period or 0
orders_limit = subscription.orders_limit
orders_unlimited = orders_limit is None or orders_limit < 0
orders_percentage = 0 if orders_unlimited else (orders_current / orders_limit * 100 if orders_limit > 0 else 100)
usage_metrics.append(
UsageMetric(
name="orders",
current=orders_current,
limit=None if orders_unlimited else orders_limit,
percentage=orders_percentage,
is_unlimited=orders_unlimited,
is_at_limit=not orders_unlimited and orders_current >= orders_limit,
is_approaching_limit=not orders_unlimited and orders_percentage >= 80,
)
)
# Products
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
products_limit = subscription.products_limit
products_unlimited = products_limit is None or products_limit < 0
products_percentage = 0 if products_unlimited else (products_count / products_limit * 100 if products_limit > 0 else 100)
usage_metrics.append(
UsageMetric(
name="products",
current=products_count,
limit=None if products_unlimited else products_limit,
percentage=products_percentage,
is_unlimited=products_unlimited,
is_at_limit=not products_unlimited and products_count >= products_limit,
is_approaching_limit=not products_unlimited and products_percentage >= 80,
)
)
# Team members
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712
.scalar()
or 0
)
team_limit = subscription.team_members_limit
team_unlimited = team_limit is None or team_limit < 0
team_percentage = 0 if team_unlimited else (team_count / team_limit * 100 if team_limit > 0 else 100)
usage_metrics.append(
UsageMetric(
name="team_members",
current=team_count,
limit=None if team_unlimited else team_limit,
percentage=team_percentage,
is_unlimited=team_unlimited,
is_at_limit=not team_unlimited and team_count >= team_limit,
is_approaching_limit=not team_unlimited and team_percentage >= 80,
)
)
# Check for approaching/reached limits
has_limits_approaching = any(m.is_approaching_limit for m in usage_metrics)
has_limits_reached = any(m.is_at_limit for m in usage_metrics)
# Get next tier for upgrade
all_tiers = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.is_active == True) # noqa: E712
.order_by(SubscriptionTier.display_order)
.all()
)
current_tier_order = tier.display_order if tier else 0
next_tier = None
for t in all_tiers:
if t.display_order > current_tier_order:
next_tier = t
break
is_highest_tier = next_tier is None
# Build upgrade info
upgrade_tier_info = None
upgrade_reasons = []
if next_tier:
# Calculate benefits
benefits = []
if next_tier.orders_per_month and (not tier or (tier.orders_per_month and next_tier.orders_per_month > tier.orders_per_month)):
if next_tier.orders_per_month < 0:
benefits.append("Unlimited orders per month")
else:
benefits.append(f"{next_tier.orders_per_month:,} orders/month")
if next_tier.products_limit and (not tier or (tier.products_limit and next_tier.products_limit > tier.products_limit)):
if next_tier.products_limit < 0:
benefits.append("Unlimited products")
else:
benefits.append(f"{next_tier.products_limit:,} products")
if next_tier.team_members and (not tier or (tier.team_members and next_tier.team_members > tier.team_members)):
if next_tier.team_members < 0:
benefits.append("Unlimited team members")
else:
benefits.append(f"{next_tier.team_members} team members")
# Add feature benefits
current_features = set(tier.features) if tier and tier.features else set()
next_features = set(next_tier.features) if next_tier.features else set()
new_features = next_features - current_features
feature_names = {
"analytics_dashboard": "Advanced Analytics",
"api_access": "API Access",
"automation_rules": "Automation Rules",
"team_roles": "Team Roles & Permissions",
"custom_domain": "Custom Domain",
"webhooks": "Webhooks",
"accounting_export": "Accounting Export",
}
for feature in list(new_features)[:3]: # Show top 3
if feature in feature_names:
benefits.append(feature_names[feature])
current_price = tier.price_monthly_cents if tier else 0
upgrade_tier_info = UpgradeTierInfo(
code=next_tier.code,
name=next_tier.name,
price_monthly_cents=next_tier.price_monthly_cents,
price_increase_cents=next_tier.price_monthly_cents - current_price,
benefits=benefits,
)
# Build upgrade reasons
if has_limits_reached:
for m in usage_metrics:
if m.is_at_limit:
upgrade_reasons.append(f"You've reached your {m.name.replace('_', ' ')} limit")
elif has_limits_approaching:
for m in usage_metrics:
if m.is_approaching_limit:
upgrade_reasons.append(f"You're approaching your {m.name.replace('_', ' ')} limit ({int(m.percentage)}%)")
# Get usage data from service
usage_data = usage_service.get_vendor_usage(db, vendor_id)
# Convert to response
return UsageResponse(
tier=TierInfo(
code=tier.code if tier else subscription.tier,
name=tier.name if tier else subscription.tier.title(),
price_monthly_cents=tier.price_monthly_cents if tier else 0,
is_highest_tier=is_highest_tier,
code=usage_data.tier.code,
name=usage_data.tier.name,
price_monthly_cents=usage_data.tier.price_monthly_cents,
is_highest_tier=usage_data.tier.is_highest_tier,
),
usage=usage_metrics,
has_limits_approaching=has_limits_approaching,
has_limits_reached=has_limits_reached,
upgrade_available=not is_highest_tier,
upgrade_tier=upgrade_tier_info,
upgrade_reasons=upgrade_reasons,
usage=[
UsageMetric(
name=m.name,
current=m.current,
limit=m.limit,
percentage=m.percentage,
is_unlimited=m.is_unlimited,
is_at_limit=m.is_at_limit,
is_approaching_limit=m.is_approaching_limit,
)
for m in usage_data.usage
],
has_limits_approaching=usage_data.has_limits_approaching,
has_limits_reached=usage_data.has_limits_reached,
upgrade_available=usage_data.upgrade_available,
upgrade_tier=(
UpgradeTierInfo(
code=usage_data.upgrade_tier.code,
name=usage_data.upgrade_tier.name,
price_monthly_cents=usage_data.upgrade_tier.price_monthly_cents,
price_increase_cents=usage_data.upgrade_tier.price_increase_cents,
benefits=usage_data.upgrade_tier.benefits,
)
if usage_data.upgrade_tier
else None
),
upgrade_reasons=usage_data.upgrade_reasons,
)
@@ -303,78 +162,16 @@ def check_limit(
"""
vendor_id = current_user.token_vendor_id
if limit_type == "orders":
can_proceed, message = subscription_service.can_create_order(db, vendor_id)
subscription = subscription_service.get_subscription(db, vendor_id)
current = subscription.orders_this_period if subscription else 0
limit = subscription.orders_limit if subscription else 0
elif limit_type == "products":
can_proceed, message = subscription_service.can_add_product(db, vendor_id)
subscription = subscription_service.get_subscription(db, vendor_id)
current = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
limit = subscription.products_limit if subscription else 0
elif limit_type == "team_members":
can_proceed, message = subscription_service.can_add_team_member(db, vendor_id)
subscription = subscription_service.get_subscription(db, vendor_id)
current = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712
.scalar()
or 0
)
limit = subscription.team_members_limit if subscription else 0
else:
return LimitCheckResponse(
limit_type=limit_type,
can_proceed=True,
current=0,
limit=None,
percentage=0,
message=f"Unknown limit type: {limit_type}",
)
# Calculate percentage
is_unlimited = limit is None or limit < 0
percentage = 0 if is_unlimited else (current / limit * 100 if limit > 0 else 100)
# Get upgrade info if at limit
upgrade_tier_code = None
upgrade_tier_name = None
if not can_proceed:
subscription = subscription_service.get_subscription(db, vendor_id)
current_tier = subscription.tier_obj if subscription else None
if current_tier:
next_tier = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.display_order > current_tier.display_order,
)
.order_by(SubscriptionTier.display_order)
.first()
)
if next_tier:
upgrade_tier_code = next_tier.code
upgrade_tier_name = next_tier.name
# Check limit using service
check_data = usage_service.check_limit(db, vendor_id, limit_type)
return LimitCheckResponse(
limit_type=limit_type,
can_proceed=can_proceed,
current=current,
limit=None if is_unlimited else limit,
percentage=percentage,
message=message,
upgrade_tier_code=upgrade_tier_code,
upgrade_tier_name=upgrade_tier_name,
limit_type=check_data.limit_type,
can_proceed=check_data.can_proceed,
current=check_data.current,
limit=check_data.limit,
percentage=check_data.percentage,
message=check_data.message,
upgrade_tier_code=check_data.upgrade_tier_code,
upgrade_tier_name=check_data.upgrade_tier_name,
)

View File

@@ -263,6 +263,13 @@ from .onboarding import (
OnboardingSyncNotCompleteException,
)
# Feature management exceptions
from .feature import (
FeatureNotFoundError,
InvalidFeatureCodesError,
TierNotFoundError,
)
__all__ = [
# Base exceptions
"WizamartException",
@@ -456,4 +463,8 @@ __all__ = [
"OnboardingCsvUrlRequiredException",
"OnboardingSyncJobNotFoundException",
"OnboardingSyncNotCompleteException",
# Feature exceptions
"FeatureNotFoundError",
"TierNotFoundError",
"InvalidFeatureCodesError",
]

42
app/exceptions/feature.py Normal file
View File

@@ -0,0 +1,42 @@
# app/exceptions/feature.py
"""
Feature management exceptions.
"""
from app.exceptions.base import ResourceNotFoundException, ValidationException
class FeatureNotFoundError(ResourceNotFoundException):
"""Feature not found."""
def __init__(self, feature_code: str):
super().__init__(
resource_type="Feature",
identifier=feature_code,
message=f"Feature '{feature_code}' not found",
)
self.feature_code = feature_code
class TierNotFoundError(ResourceNotFoundException):
"""Subscription tier not found."""
def __init__(self, tier_code: str):
super().__init__(
resource_type="SubscriptionTier",
identifier=tier_code,
message=f"Tier '{tier_code}' not found",
)
self.tier_code = tier_code
class InvalidFeatureCodesError(ValidationException):
"""Invalid feature codes provided."""
def __init__(self, invalid_codes: set[str]):
codes_str = ", ".join(sorted(invalid_codes))
super().__init__(
message=f"Invalid feature codes: {codes_str}",
details={"invalid_codes": list(invalid_codes)},
)
self.invalid_codes = invalid_codes

View File

@@ -29,6 +29,11 @@ from functools import lru_cache
from sqlalchemy.orm import Session, joinedload
from app.exceptions.feature import (
FeatureNotFoundError,
InvalidFeatureCodesError,
TierNotFoundError,
)
from models.database.feature import Feature, FeatureCode
from models.database.subscription import SubscriptionTier, VendorSubscription
@@ -370,6 +375,51 @@ class FeatureService:
# Admin Operations
# =========================================================================
def get_all_tiers_with_features(self, db: Session) -> list[SubscriptionTier]:
"""Get all active tiers with their features for admin."""
return (
db.query(SubscriptionTier)
.filter(SubscriptionTier.is_active == True) # noqa: E712
.order_by(SubscriptionTier.display_order)
.all()
)
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
"""
Get tier by code, raising exception if not found.
Raises:
TierNotFoundError: If tier not found
"""
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
if not tier:
raise TierNotFoundError(tier_code)
return tier
def get_tier_features_with_details(
self, db: Session, tier_code: str
) -> tuple[SubscriptionTier, list[Feature]]:
"""
Get tier with full feature details.
Returns:
Tuple of (tier, list of Feature objects)
Raises:
TierNotFoundError: If tier not found
"""
tier = self.get_tier_by_code(db, tier_code)
feature_codes = tier.features or []
features = (
db.query(Feature)
.filter(Feature.code.in_(feature_codes))
.order_by(Feature.category, Feature.display_order)
.all()
)
return tier, features
def update_tier_features(
self, db: Session, tier_code: str, feature_codes: list[str]
) -> SubscriptionTier:
@@ -383,11 +433,15 @@ class FeatureService:
Returns:
Updated tier
Raises:
TierNotFoundError: If tier not found
InvalidFeatureCodesError: If any feature codes are invalid
"""
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
if not tier:
raise ValueError(f"Tier '{tier_code}' not found")
raise TierNotFoundError(tier_code)
# Validate feature codes exist
valid_codes = {
@@ -395,7 +449,7 @@ class FeatureService:
}
invalid = set(feature_codes) - valid_codes
if invalid:
raise ValueError(f"Invalid feature codes: {invalid}")
raise InvalidFeatureCodesError(invalid)
tier.features = feature_codes
@@ -405,6 +459,86 @@ class FeatureService:
logger.info(f"Updated features for tier {tier_code}: {len(feature_codes)} features")
return tier
def update_feature(
self,
db: Session,
feature_code: str,
name: str | None = None,
description: str | None = None,
category: str | None = None,
ui_location: str | None = None,
ui_icon: str | None = None,
ui_route: str | None = None,
ui_badge_text: str | None = None,
minimum_tier_code: str | None = None,
is_active: bool | None = None,
is_visible: bool | None = None,
display_order: int | None = None,
) -> Feature:
"""
Update feature metadata.
Args:
db: Database session
feature_code: Feature code to update
... other optional fields to update
Returns:
Updated feature
Raises:
FeatureNotFoundError: If feature not found
TierNotFoundError: If minimum_tier_code provided but not found
"""
feature = (
db.query(Feature)
.options(joinedload(Feature.minimum_tier))
.filter(Feature.code == feature_code)
.first()
)
if not feature:
raise FeatureNotFoundError(feature_code)
# Update fields if provided
if name is not None:
feature.name = name
if description is not None:
feature.description = description
if category is not None:
feature.category = category
if ui_location is not None:
feature.ui_location = ui_location
if ui_icon is not None:
feature.ui_icon = ui_icon
if ui_route is not None:
feature.ui_route = ui_route
if ui_badge_text is not None:
feature.ui_badge_text = ui_badge_text
if is_active is not None:
feature.is_active = is_active
if is_visible is not None:
feature.is_visible = is_visible
if display_order is not None:
feature.display_order = display_order
# Update minimum tier if provided
if minimum_tier_code is not None:
if minimum_tier_code == "":
feature.minimum_tier_id = None
else:
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == minimum_tier_code)
.first()
)
if not tier:
raise TierNotFoundError(minimum_tier_code)
feature.minimum_tier_id = tier.id
logger.info(f"Updated feature {feature_code}")
return feature
def update_feature_minimum_tier(
self, db: Session, feature_code: str, tier_code: str | None
) -> Feature:
@@ -415,16 +549,20 @@ class FeatureService:
db: Database session
feature_code: Feature code
tier_code: Tier code or None
Raises:
FeatureNotFoundError: If feature not found
TierNotFoundError: If tier_code provided but not found
"""
feature = db.query(Feature).filter(Feature.code == feature_code).first()
if not feature:
raise ValueError(f"Feature '{feature_code}' not found")
raise FeatureNotFoundError(feature_code)
if tier_code:
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
if not tier:
raise ValueError(f"Tier '{tier_code}' not found")
raise TierNotFoundError(tier_code)
feature.minimum_tier_id = tier.id
else:
feature.minimum_tier_id = None

View File

@@ -0,0 +1,445 @@
# app/services/usage_service.py
"""
Usage and limits service.
Provides methods for:
- Getting current usage vs limits
- Calculating upgrade recommendations
- Checking limits before actions
"""
import logging
from dataclasses import dataclass
from sqlalchemy import func
from sqlalchemy.orm import Session
from models.database.product import Product
from models.database.subscription import SubscriptionTier, VendorSubscription
from models.database.vendor import VendorUser
logger = logging.getLogger(__name__)
@dataclass
class UsageMetricData:
"""Usage metric data."""
name: str
current: int
limit: int | None
percentage: float
is_unlimited: bool
is_at_limit: bool
is_approaching_limit: bool
@dataclass
class TierInfoData:
"""Tier information."""
code: str
name: str
price_monthly_cents: int
is_highest_tier: bool
@dataclass
class UpgradeTierData:
"""Upgrade tier information."""
code: str
name: str
price_monthly_cents: int
price_increase_cents: int
benefits: list[str]
@dataclass
class UsageData:
"""Full usage data."""
tier: TierInfoData
usage: list[UsageMetricData]
has_limits_approaching: bool
has_limits_reached: bool
upgrade_available: bool
upgrade_tier: UpgradeTierData | None
upgrade_reasons: list[str]
@dataclass
class LimitCheckData:
"""Limit check result."""
limit_type: str
can_proceed: bool
current: int
limit: int | None
percentage: float
message: str | None
upgrade_tier_code: str | None
upgrade_tier_name: str | None
class UsageService:
"""Service for usage and limits management."""
def get_vendor_usage(self, db: Session, vendor_id: int) -> UsageData:
"""
Get comprehensive usage data for a vendor.
Returns current usage, limits, and upgrade recommendations.
"""
from app.services.subscription_service import subscription_service
# Get subscription
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
# Get current tier
tier = self._get_tier(db, subscription)
# Calculate usage metrics
usage_metrics = self._calculate_usage_metrics(db, vendor_id, subscription)
# Check for approaching/reached limits
has_limits_approaching = any(m.is_approaching_limit for m in usage_metrics)
has_limits_reached = any(m.is_at_limit for m in usage_metrics)
# Get upgrade info
next_tier = self._get_next_tier(db, tier)
is_highest_tier = next_tier is None
# Build upgrade info
upgrade_tier_info = None
upgrade_reasons = []
if next_tier:
upgrade_tier_info = self._build_upgrade_tier_info(tier, next_tier)
upgrade_reasons = self._build_upgrade_reasons(
usage_metrics, has_limits_reached, has_limits_approaching
)
return UsageData(
tier=TierInfoData(
code=tier.code if tier else subscription.tier,
name=tier.name if tier else subscription.tier.title(),
price_monthly_cents=tier.price_monthly_cents if tier else 0,
is_highest_tier=is_highest_tier,
),
usage=usage_metrics,
has_limits_approaching=has_limits_approaching,
has_limits_reached=has_limits_reached,
upgrade_available=not is_highest_tier,
upgrade_tier=upgrade_tier_info,
upgrade_reasons=upgrade_reasons,
)
def check_limit(
self, db: Session, vendor_id: int, limit_type: str
) -> LimitCheckData:
"""
Check a specific limit before performing an action.
Args:
db: Database session
vendor_id: Vendor ID
limit_type: One of "orders", "products", "team_members"
Returns:
LimitCheckData with proceed status and upgrade info
"""
from app.services.subscription_service import subscription_service
if limit_type == "orders":
can_proceed, message = subscription_service.can_create_order(db, vendor_id)
subscription = subscription_service.get_subscription(db, vendor_id)
current = subscription.orders_this_period if subscription else 0
limit = subscription.orders_limit if subscription else 0
elif limit_type == "products":
can_proceed, message = subscription_service.can_add_product(db, vendor_id)
subscription = subscription_service.get_subscription(db, vendor_id)
current = self._get_product_count(db, vendor_id)
limit = subscription.products_limit if subscription else 0
elif limit_type == "team_members":
can_proceed, message = subscription_service.can_add_team_member(db, vendor_id)
subscription = subscription_service.get_subscription(db, vendor_id)
current = self._get_team_member_count(db, vendor_id)
limit = subscription.team_members_limit if subscription else 0
else:
return LimitCheckData(
limit_type=limit_type,
can_proceed=True,
current=0,
limit=None,
percentage=0,
message=f"Unknown limit type: {limit_type}",
upgrade_tier_code=None,
upgrade_tier_name=None,
)
# Calculate percentage
is_unlimited = limit is None or limit < 0
percentage = 0 if is_unlimited else (current / limit * 100 if limit > 0 else 100)
# Get upgrade info if at limit
upgrade_tier_code = None
upgrade_tier_name = None
if not can_proceed:
subscription = subscription_service.get_subscription(db, vendor_id)
current_tier = subscription.tier_obj if subscription else None
if current_tier:
next_tier = self._get_next_tier(db, current_tier)
if next_tier:
upgrade_tier_code = next_tier.code
upgrade_tier_name = next_tier.name
return LimitCheckData(
limit_type=limit_type,
can_proceed=can_proceed,
current=current,
limit=None if is_unlimited else limit,
percentage=percentage,
message=message,
upgrade_tier_code=upgrade_tier_code,
upgrade_tier_name=upgrade_tier_name,
)
# =========================================================================
# Private Helper Methods
# =========================================================================
def _get_tier(
self, db: Session, subscription: VendorSubscription
) -> SubscriptionTier | None:
"""Get tier from subscription or query by code."""
tier = subscription.tier_obj
if not tier:
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == subscription.tier)
.first()
)
return tier
def _get_product_count(self, db: Session, vendor_id: int) -> int:
"""Get product count for vendor."""
return (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
def _get_team_member_count(self, db: Session, vendor_id: int) -> int:
"""Get active team member count for vendor."""
return (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712
.scalar()
or 0
)
def _calculate_usage_metrics(
self, db: Session, vendor_id: int, subscription: VendorSubscription
) -> list[UsageMetricData]:
"""Calculate all usage metrics for a vendor."""
metrics = []
# Orders this period
orders_current = subscription.orders_this_period or 0
orders_limit = subscription.orders_limit
orders_unlimited = orders_limit is None or orders_limit < 0
orders_percentage = (
0
if orders_unlimited
else (orders_current / orders_limit * 100 if orders_limit > 0 else 100)
)
metrics.append(
UsageMetricData(
name="orders",
current=orders_current,
limit=None if orders_unlimited else orders_limit,
percentage=orders_percentage,
is_unlimited=orders_unlimited,
is_at_limit=not orders_unlimited and orders_current >= orders_limit,
is_approaching_limit=not orders_unlimited and orders_percentage >= 80,
)
)
# Products
products_count = self._get_product_count(db, vendor_id)
products_limit = subscription.products_limit
products_unlimited = products_limit is None or products_limit < 0
products_percentage = (
0
if products_unlimited
else (products_count / products_limit * 100 if products_limit > 0 else 100)
)
metrics.append(
UsageMetricData(
name="products",
current=products_count,
limit=None if products_unlimited else products_limit,
percentage=products_percentage,
is_unlimited=products_unlimited,
is_at_limit=not products_unlimited and products_count >= products_limit,
is_approaching_limit=not products_unlimited and products_percentage >= 80,
)
)
# Team members
team_count = self._get_team_member_count(db, vendor_id)
team_limit = subscription.team_members_limit
team_unlimited = team_limit is None or team_limit < 0
team_percentage = (
0
if team_unlimited
else (team_count / team_limit * 100 if team_limit > 0 else 100)
)
metrics.append(
UsageMetricData(
name="team_members",
current=team_count,
limit=None if team_unlimited else team_limit,
percentage=team_percentage,
is_unlimited=team_unlimited,
is_at_limit=not team_unlimited and team_count >= team_limit,
is_approaching_limit=not team_unlimited and team_percentage >= 80,
)
)
return metrics
def _get_next_tier(
self, db: Session, current_tier: SubscriptionTier | None
) -> SubscriptionTier | None:
"""Get next tier for upgrade."""
current_tier_order = current_tier.display_order if current_tier else 0
return (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.display_order > current_tier_order,
)
.order_by(SubscriptionTier.display_order)
.first()
)
def _build_upgrade_tier_info(
self, current_tier: SubscriptionTier | None, next_tier: SubscriptionTier
) -> UpgradeTierData:
"""Build upgrade tier information with benefits."""
benefits = []
# Numeric limit benefits
if next_tier.orders_per_month and (
not current_tier
or (
current_tier.orders_per_month
and next_tier.orders_per_month > current_tier.orders_per_month
)
):
if next_tier.orders_per_month < 0:
benefits.append("Unlimited orders per month")
else:
benefits.append(f"{next_tier.orders_per_month:,} orders/month")
if next_tier.products_limit and (
not current_tier
or (
current_tier.products_limit
and next_tier.products_limit > current_tier.products_limit
)
):
if next_tier.products_limit < 0:
benefits.append("Unlimited products")
else:
benefits.append(f"{next_tier.products_limit:,} products")
if next_tier.team_members and (
not current_tier
or (
current_tier.team_members
and next_tier.team_members > current_tier.team_members
)
):
if next_tier.team_members < 0:
benefits.append("Unlimited team members")
else:
benefits.append(f"{next_tier.team_members} team members")
# Feature benefits
current_features = (
set(current_tier.features) if current_tier and current_tier.features else set()
)
next_features = set(next_tier.features) if next_tier.features else set()
new_features = next_features - current_features
feature_names = {
"analytics_dashboard": "Advanced Analytics",
"api_access": "API Access",
"automation_rules": "Automation Rules",
"team_roles": "Team Roles & Permissions",
"custom_domain": "Custom Domain",
"webhooks": "Webhooks",
"accounting_export": "Accounting Export",
}
for feature in list(new_features)[:3]:
if feature in feature_names:
benefits.append(feature_names[feature])
current_price = current_tier.price_monthly_cents if current_tier else 0
return UpgradeTierData(
code=next_tier.code,
name=next_tier.name,
price_monthly_cents=next_tier.price_monthly_cents,
price_increase_cents=next_tier.price_monthly_cents - current_price,
benefits=benefits,
)
def _build_upgrade_reasons(
self,
usage_metrics: list[UsageMetricData],
has_limits_reached: bool,
has_limits_approaching: bool,
) -> list[str]:
"""Build upgrade reasons based on usage."""
reasons = []
if has_limits_reached:
for m in usage_metrics:
if m.is_at_limit:
reasons.append(f"You've reached your {m.name.replace('_', ' ')} limit")
elif has_limits_approaching:
for m in usage_metrics:
if m.is_approaching_limit:
reasons.append(
f"You're approaching your {m.name.replace('_', ' ')} limit ({int(m.percentage)}%)"
)
return reasons
# Singleton instance
usage_service = UsageService()
__all__ = [
"usage_service",
"UsageService",
"UsageData",
"UsageMetricData",
"TierInfoData",
"UpgradeTierData",
"LimitCheckData",
]

View File

@@ -0,0 +1,367 @@
# tests/unit/services/test_feature_service.py
"""Unit tests for FeatureService."""
import pytest
from app.exceptions import FeatureNotFoundError, InvalidFeatureCodesError, TierNotFoundError
from app.services.feature_service import FeatureService, feature_service
from models.database.feature import Feature
from models.database.subscription import SubscriptionTier, VendorSubscription
@pytest.mark.unit
@pytest.mark.features
class TestFeatureServiceAvailability:
"""Test suite for feature availability checking."""
def setup_method(self):
"""Initialize service instance before each test."""
self.service = FeatureService()
def test_has_feature_true(self, db, test_vendor_with_subscription):
"""Test has_feature returns True for available feature."""
vendor_id = test_vendor_with_subscription.id
result = self.service.has_feature(db, vendor_id, "basic_reports")
assert result is True
def test_has_feature_false(self, db, test_vendor_with_subscription):
"""Test has_feature returns False for unavailable feature."""
vendor_id = test_vendor_with_subscription.id
result = self.service.has_feature(db, vendor_id, "api_access")
assert result is False
def test_has_feature_no_subscription(self, db, test_vendor):
"""Test has_feature returns False for vendor without subscription."""
result = self.service.has_feature(db, test_vendor.id, "basic_reports")
assert result is False
def test_get_vendor_feature_codes(self, db, test_vendor_with_subscription):
"""Test getting all feature codes for vendor."""
vendor_id = test_vendor_with_subscription.id
features = self.service.get_vendor_feature_codes(db, vendor_id)
assert isinstance(features, set)
assert "basic_reports" in features
assert "api_access" not in features
@pytest.mark.unit
@pytest.mark.features
class TestFeatureServiceListing:
"""Test suite for feature listing operations."""
def setup_method(self):
"""Initialize service instance before each test."""
self.service = FeatureService()
def test_get_vendor_features(self, db, test_vendor_with_subscription, test_features):
"""Test getting all features with availability."""
vendor_id = test_vendor_with_subscription.id
features = self.service.get_vendor_features(db, vendor_id)
assert len(features) > 0
basic_reports = next((f for f in features if f.code == "basic_reports"), None)
assert basic_reports is not None
assert basic_reports.is_available is True
api_access = next((f for f in features if f.code == "api_access"), None)
assert api_access is not None
assert api_access.is_available is False
def test_get_vendor_features_by_category(
self, db, test_vendor_with_subscription, test_features
):
"""Test filtering features by category."""
vendor_id = test_vendor_with_subscription.id
features = self.service.get_vendor_features(db, vendor_id, category="analytics")
assert all(f.category == "analytics" for f in features)
def test_get_vendor_features_available_only(
self, db, test_vendor_with_subscription, test_features
):
"""Test getting only available features."""
vendor_id = test_vendor_with_subscription.id
features = self.service.get_vendor_features(
db, vendor_id, include_unavailable=False
)
assert all(f.is_available for f in features)
def test_get_available_feature_codes(self, db, test_vendor_with_subscription):
"""Test getting simple list of available codes."""
vendor_id = test_vendor_with_subscription.id
codes = self.service.get_available_feature_codes(db, vendor_id)
assert isinstance(codes, list)
assert "basic_reports" in codes
@pytest.mark.unit
@pytest.mark.features
class TestFeatureServiceMetadata:
"""Test suite for feature metadata operations."""
def setup_method(self):
"""Initialize service instance before each test."""
self.service = FeatureService()
def test_get_feature_by_code(self, db, test_features):
"""Test getting feature by code."""
feature = self.service.get_feature_by_code(db, "basic_reports")
assert feature is not None
assert feature.code == "basic_reports"
assert feature.name == "Basic Reports"
def test_get_feature_by_code_not_found(self, db, test_features):
"""Test getting non-existent feature returns None."""
feature = self.service.get_feature_by_code(db, "nonexistent")
assert feature is None
def test_get_feature_upgrade_info(self, db, test_features, test_subscription_tiers):
"""Test getting upgrade info for locked feature."""
info = self.service.get_feature_upgrade_info(db, "api_access")
assert info is not None
assert info.feature_code == "api_access"
assert info.required_tier_code == "professional"
def test_get_feature_upgrade_info_no_minimum_tier(self, db, test_features):
"""Test upgrade info for feature without minimum tier."""
# basic_reports has no minimum tier in fixtures
info = self.service.get_feature_upgrade_info(db, "basic_reports")
assert info is None
def test_get_all_features(self, db, test_features):
"""Test getting all features for admin."""
features = self.service.get_all_features(db)
assert len(features) >= 3
def test_get_all_features_by_category(self, db, test_features):
"""Test filtering features by category."""
features = self.service.get_all_features(db, category="analytics")
assert all(f.category == "analytics" for f in features)
def test_get_categories(self, db, test_features):
"""Test getting unique categories."""
categories = self.service.get_categories(db)
assert "analytics" in categories
assert "integrations" in categories
@pytest.mark.unit
@pytest.mark.features
class TestFeatureServiceCache:
"""Test suite for cache operations."""
def setup_method(self):
"""Initialize service instance before each test."""
self.service = FeatureService()
def test_cache_invalidation(self, db, test_vendor_with_subscription):
"""Test cache invalidation for vendor."""
vendor_id = test_vendor_with_subscription.id
# Prime the cache
self.service.get_vendor_feature_codes(db, vendor_id)
assert self.service._cache.get(vendor_id) is not None
# Invalidate
self.service.invalidate_vendor_cache(vendor_id)
assert self.service._cache.get(vendor_id) is None
def test_cache_invalidate_all(self, db, test_vendor_with_subscription):
"""Test invalidating entire cache."""
vendor_id = test_vendor_with_subscription.id
# Prime the cache
self.service.get_vendor_feature_codes(db, vendor_id)
# Invalidate all
self.service.invalidate_all_cache()
assert self.service._cache.get(vendor_id) is None
@pytest.mark.unit
@pytest.mark.features
class TestFeatureServiceAdmin:
"""Test suite for admin operations."""
def setup_method(self):
"""Initialize service instance before each test."""
self.service = FeatureService()
def test_get_all_tiers_with_features(self, db, test_subscription_tiers):
"""Test getting all tiers."""
tiers = self.service.get_all_tiers_with_features(db)
assert len(tiers) == 2
assert tiers[0].code == "essential"
assert tiers[1].code == "professional"
def test_update_tier_features(self, db, test_subscription_tiers, test_features):
"""Test updating tier features."""
tier = self.service.update_tier_features(
db, "essential", ["basic_reports", "api_access"]
)
db.commit()
assert "api_access" in tier.features
def test_update_tier_features_invalid_codes(
self, db, test_subscription_tiers, test_features
):
"""Test updating tier with invalid feature codes."""
with pytest.raises(InvalidFeatureCodesError) as exc_info:
self.service.update_tier_features(
db, "essential", ["basic_reports", "nonexistent_feature"]
)
assert "nonexistent_feature" in exc_info.value.invalid_codes
def test_update_tier_features_tier_not_found(self, db, test_features):
"""Test updating non-existent tier."""
with pytest.raises(TierNotFoundError) as exc_info:
self.service.update_tier_features(db, "nonexistent", ["basic_reports"])
assert exc_info.value.tier_code == "nonexistent"
def test_update_feature(self, db, test_features):
"""Test updating feature metadata."""
feature = self.service.update_feature(
db, "basic_reports", name="Updated Reports", description="New description"
)
db.commit()
assert feature.name == "Updated Reports"
assert feature.description == "New description"
def test_update_feature_not_found(self, db, test_features):
"""Test updating non-existent feature."""
with pytest.raises(FeatureNotFoundError) as exc_info:
self.service.update_feature(db, "nonexistent", name="Test")
assert exc_info.value.feature_code == "nonexistent"
def test_update_feature_minimum_tier(
self, db, test_features, test_subscription_tiers
):
"""Test updating feature minimum tier."""
feature = self.service.update_feature(
db, "basic_reports", minimum_tier_code="professional"
)
db.commit()
db.refresh(feature)
assert feature.minimum_tier is not None
assert feature.minimum_tier.code == "professional"
# ==================== Fixtures ====================
@pytest.fixture
def test_subscription_tiers(db):
"""Create multiple subscription tiers."""
tiers = [
SubscriptionTier(
code="essential",
name="Essential",
description="Essential plan",
price_monthly_cents=4900,
price_annual_cents=49000,
orders_per_month=100,
products_limit=500,
team_members=2,
features=["basic_reports"],
is_active=True,
display_order=1,
),
SubscriptionTier(
code="professional",
name="Professional",
description="Professional plan",
price_monthly_cents=9900,
price_annual_cents=99000,
orders_per_month=500,
products_limit=2000,
team_members=5,
features=["basic_reports", "api_access", "analytics_dashboard"],
is_active=True,
display_order=2,
),
]
for tier in tiers:
db.add(tier)
db.commit()
for tier in tiers:
db.refresh(tier)
return tiers
@pytest.fixture
def test_vendor_with_subscription(db, test_vendor, test_subscription_tiers):
"""Create a vendor with an active subscription."""
from datetime import datetime, timezone
essential_tier = test_subscription_tiers[0] # Use the essential tier from tiers list
now = datetime.now(timezone.utc)
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
tier_id=essential_tier.id,
status="active",
period_start=now,
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
orders_this_period=10,
)
db.add(subscription)
db.commit()
db.refresh(test_vendor)
return test_vendor
@pytest.fixture
def test_features(db, test_subscription_tiers):
"""Create test features."""
features = [
Feature(
code="basic_reports",
name="Basic Reports",
description="View basic analytics reports",
category="analytics",
ui_location="sidebar",
ui_icon="chart-bar",
is_active=True,
display_order=1,
),
Feature(
code="api_access",
name="API Access",
description="Access the REST API",
category="integrations",
ui_location="settings",
ui_icon="code",
minimum_tier_id=test_subscription_tiers[1].id, # Professional
is_active=True,
display_order=2,
),
Feature(
code="analytics_dashboard",
name="Analytics Dashboard",
description="Advanced analytics dashboard",
category="analytics",
ui_location="sidebar",
ui_icon="presentation-chart-line",
minimum_tier_id=test_subscription_tiers[1].id,
is_active=True,
display_order=3,
),
]
for feature in features:
db.add(feature)
db.commit()
for feature in features:
db.refresh(feature)
return features

View File

@@ -0,0 +1,308 @@
# tests/unit/services/test_usage_service.py
"""Unit tests for UsageService."""
import pytest
from app.services.usage_service import UsageService, usage_service
from models.database.product import Product
from models.database.subscription import SubscriptionTier, VendorSubscription
from models.database.vendor import VendorUser
@pytest.mark.unit
@pytest.mark.usage
class TestUsageServiceGetUsage:
"""Test suite for get_vendor_usage operation."""
def setup_method(self):
"""Initialize service instance before each test."""
self.service = UsageService()
def test_get_vendor_usage_basic(self, db, test_vendor_with_subscription):
"""Test getting basic usage data."""
vendor_id = test_vendor_with_subscription.id
usage = self.service.get_vendor_usage(db, vendor_id)
assert usage.tier.code == "essential"
assert usage.tier.name == "Essential"
assert len(usage.usage) == 3
def test_get_vendor_usage_metrics(self, db, test_vendor_with_subscription):
"""Test usage metrics are calculated correctly."""
vendor_id = test_vendor_with_subscription.id
usage = self.service.get_vendor_usage(db, vendor_id)
orders_metric = next((m for m in usage.usage if m.name == "orders"), None)
assert orders_metric is not None
assert orders_metric.current == 10
assert orders_metric.limit == 100
assert orders_metric.percentage == 10.0
assert orders_metric.is_unlimited is False
def test_get_vendor_usage_at_limit(self, db, test_vendor_at_limit):
"""Test usage shows at limit correctly."""
vendor_id = test_vendor_at_limit.id
usage = self.service.get_vendor_usage(db, vendor_id)
orders_metric = next((m for m in usage.usage if m.name == "orders"), None)
assert orders_metric.is_at_limit is True
assert usage.has_limits_reached is True
def test_get_vendor_usage_approaching_limit(self, db, test_vendor_approaching_limit):
"""Test usage shows approaching limit correctly."""
vendor_id = test_vendor_approaching_limit.id
usage = self.service.get_vendor_usage(db, vendor_id)
orders_metric = next((m for m in usage.usage if m.name == "orders"), None)
assert orders_metric.is_approaching_limit is True
assert usage.has_limits_approaching is True
def test_get_vendor_usage_upgrade_available(
self, db, test_vendor_with_subscription, test_professional_tier
):
"""Test upgrade info when not on highest tier."""
vendor_id = test_vendor_with_subscription.id
usage = self.service.get_vendor_usage(db, vendor_id)
assert usage.upgrade_available is True
assert usage.upgrade_tier is not None
assert usage.upgrade_tier.code == "professional"
def test_get_vendor_usage_highest_tier(self, db, test_vendor_on_professional):
"""Test no upgrade when on highest tier."""
vendor_id = test_vendor_on_professional.id
usage = self.service.get_vendor_usage(db, vendor_id)
assert usage.tier.is_highest_tier is True
assert usage.upgrade_available is False
assert usage.upgrade_tier is None
@pytest.mark.unit
@pytest.mark.usage
class TestUsageServiceCheckLimit:
"""Test suite for check_limit operation."""
def setup_method(self):
"""Initialize service instance before each test."""
self.service = UsageService()
def test_check_orders_limit_can_proceed(self, db, test_vendor_with_subscription):
"""Test checking orders limit when under limit."""
vendor_id = test_vendor_with_subscription.id
result = self.service.check_limit(db, vendor_id, "orders")
assert result.can_proceed is True
assert result.current == 10
assert result.limit == 100
def test_check_products_limit(self, db, test_vendor_with_products):
"""Test checking products limit."""
vendor_id = test_vendor_with_products.id
result = self.service.check_limit(db, vendor_id, "products")
assert result.can_proceed is True
assert result.current == 5
assert result.limit == 500
def test_check_team_members_limit(self, db, test_vendor_with_team):
"""Test checking team members limit when at limit."""
vendor_id = test_vendor_with_team.id
result = self.service.check_limit(db, vendor_id, "team_members")
# At limit (2/2) - can_proceed should be False
assert result.can_proceed is False
assert result.current == 2
assert result.limit == 2
assert result.percentage == 100.0
def test_check_unknown_limit_type(self, db, test_vendor_with_subscription):
"""Test checking unknown limit type."""
vendor_id = test_vendor_with_subscription.id
result = self.service.check_limit(db, vendor_id, "unknown")
assert result.can_proceed is True
assert "Unknown limit type" in result.message
def test_check_limit_upgrade_info_when_blocked(self, db, test_vendor_at_limit):
"""Test upgrade info is provided when at limit."""
vendor_id = test_vendor_at_limit.id
result = self.service.check_limit(db, vendor_id, "orders")
assert result.can_proceed is False
assert result.upgrade_tier_code == "professional"
assert result.upgrade_tier_name == "Professional"
# ==================== Fixtures ====================
@pytest.fixture
def test_essential_tier(db):
"""Create essential subscription tier."""
tier = SubscriptionTier(
code="essential",
name="Essential",
description="Essential plan",
price_monthly_cents=4900,
price_annual_cents=49000,
orders_per_month=100,
products_limit=500,
team_members=2,
features=["basic_reports"],
is_active=True,
display_order=1,
)
db.add(tier)
db.commit()
db.refresh(tier)
return tier
@pytest.fixture
def test_professional_tier(db, test_essential_tier):
"""Create professional subscription tier."""
tier = SubscriptionTier(
code="professional",
name="Professional",
description="Professional plan",
price_monthly_cents=9900,
price_annual_cents=99000,
orders_per_month=500,
products_limit=2000,
team_members=10,
features=["basic_reports", "api_access", "analytics_dashboard"],
is_active=True,
display_order=2,
)
db.add(tier)
db.commit()
db.refresh(tier)
return tier
@pytest.fixture
def test_vendor_with_subscription(db, test_vendor, test_essential_tier):
"""Create vendor with active subscription."""
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
tier_id=test_essential_tier.id,
status="active",
period_start=now,
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
orders_this_period=10,
)
db.add(subscription)
db.commit()
db.refresh(test_vendor)
return test_vendor
@pytest.fixture
def test_vendor_at_limit(db, test_vendor, test_essential_tier, test_professional_tier):
"""Create vendor at order limit."""
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
tier_id=test_essential_tier.id,
status="active",
period_start=now,
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
orders_this_period=100, # At limit
)
db.add(subscription)
db.commit()
db.refresh(test_vendor)
return test_vendor
@pytest.fixture
def test_vendor_approaching_limit(db, test_vendor, test_essential_tier):
"""Create vendor approaching order limit (>=80%)."""
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
tier_id=test_essential_tier.id,
status="active",
period_start=now,
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
orders_this_period=85, # 85% of 100
)
db.add(subscription)
db.commit()
db.refresh(test_vendor)
return test_vendor
@pytest.fixture
def test_vendor_on_professional(db, test_vendor, test_professional_tier):
"""Create vendor on highest tier."""
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="professional",
tier_id=test_professional_tier.id,
status="active",
period_start=now,
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
orders_this_period=50,
)
db.add(subscription)
db.commit()
db.refresh(test_vendor)
return test_vendor
@pytest.fixture
def test_vendor_with_products(db, test_vendor_with_subscription, marketplace_product_factory):
"""Create vendor with products."""
for i in range(5):
# Create marketplace product first
mp = marketplace_product_factory(db, title=f"Test Product {i}")
product = Product(
vendor_id=test_vendor_with_subscription.id,
marketplace_product_id=mp.id,
price_cents=1000,
is_active=True,
)
db.add(product)
db.commit()
return test_vendor_with_subscription
@pytest.fixture
def test_vendor_with_team(db, test_vendor_with_subscription, test_user, other_user):
"""Create vendor with team members (owner + team member = 2)."""
from models.database.vendor import VendorUserType
# Add owner
owner = VendorUser(
vendor_id=test_vendor_with_subscription.id,
user_id=test_user.id,
user_type=VendorUserType.OWNER.value,
is_active=True,
)
db.add(owner)
# Add team member
team_member = VendorUser(
vendor_id=test_vendor_with_subscription.id,
user_id=other_user.id,
user_type=VendorUserType.TEAM_MEMBER.value,
is_active=True,
)
db.add(team_member)
db.commit()
return test_vendor_with_subscription