test: add service tests and fix architecture violations
- Add comprehensive unit tests for FeatureService (24 tests) - Add comprehensive unit tests for UsageService (11 tests) - Fix API-002/API-003 architecture violations in feature/usage APIs - Move database queries from API layer to service layer - Create UsageService for usage and limits management - Create custom exceptions (FeatureNotFoundError, TierNotFoundError) - Fix ValidationException usage in content_pages.py - Refactor vendor features API to use proper response models - All 35 new tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.api.deps import get_current_admin_api, get_db
|
from app.api.deps import get_current_admin_api, get_db
|
||||||
|
from app.exceptions import ValidationException
|
||||||
from app.services.content_page_service import content_page_service
|
from app.services.content_page_service import content_page_service
|
||||||
from models.database.user import User
|
from models.database.user import User
|
||||||
|
|
||||||
@@ -178,11 +179,9 @@ def create_vendor_page(
|
|||||||
Vendor pages override platform defaults for a specific vendor.
|
Vendor pages override platform defaults for a specific vendor.
|
||||||
"""
|
"""
|
||||||
if not page_data.vendor_id:
|
if not page_data.vendor_id:
|
||||||
from fastapi import HTTPException
|
raise ValidationException(
|
||||||
|
message="vendor_id is required for vendor pages. Use /platform for platform defaults.",
|
||||||
raise HTTPException(
|
field="vendor_id",
|
||||||
status_code=400,
|
|
||||||
detail="vendor_id is required for vendor pages. Use /platform for platform defaults.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
page = content_page_service.create_page(
|
page = content_page_service.create_page(
|
||||||
|
|||||||
@@ -11,15 +11,13 @@ Provides endpoints for:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.api.deps import get_current_admin_api
|
from app.api.deps import get_current_admin_api
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.services.feature_service import feature_service
|
from app.services.feature_service import feature_service
|
||||||
from models.database.feature import Feature
|
|
||||||
from models.database.subscription import SubscriptionTier
|
|
||||||
from models.database.user import User
|
from models.database.user import User
|
||||||
|
|
||||||
router = APIRouter(prefix="/features")
|
router = APIRouter(prefix="/features")
|
||||||
@@ -103,6 +101,41 @@ class CategoryListResponse(BaseModel):
|
|||||||
categories: list[str]
|
categories: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TierFeatureDetailResponse(BaseModel):
|
||||||
|
"""Tier features with full details."""
|
||||||
|
|
||||||
|
tier_code: str
|
||||||
|
tier_name: str
|
||||||
|
features: list[dict]
|
||||||
|
feature_count: int
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _feature_to_response(feature) -> FeatureResponse:
|
||||||
|
"""Convert Feature model to response."""
|
||||||
|
return FeatureResponse(
|
||||||
|
id=feature.id,
|
||||||
|
code=feature.code,
|
||||||
|
name=feature.name,
|
||||||
|
description=feature.description,
|
||||||
|
category=feature.category,
|
||||||
|
ui_location=feature.ui_location,
|
||||||
|
ui_icon=feature.ui_icon,
|
||||||
|
ui_route=feature.ui_route,
|
||||||
|
ui_badge_text=feature.ui_badge_text,
|
||||||
|
minimum_tier_id=feature.minimum_tier_id,
|
||||||
|
minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None,
|
||||||
|
minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None,
|
||||||
|
is_active=feature.is_active,
|
||||||
|
is_visible=feature.is_visible,
|
||||||
|
display_order=feature.display_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Endpoints
|
# Endpoints
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -121,26 +154,7 @@ def list_features(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FeatureListResponse(
|
return FeatureListResponse(
|
||||||
features=[
|
features=[_feature_to_response(f) for f in features],
|
||||||
FeatureResponse(
|
|
||||||
id=f.id,
|
|
||||||
code=f.code,
|
|
||||||
name=f.name,
|
|
||||||
description=f.description,
|
|
||||||
category=f.category,
|
|
||||||
ui_location=f.ui_location,
|
|
||||||
ui_icon=f.ui_icon,
|
|
||||||
ui_route=f.ui_route,
|
|
||||||
ui_badge_text=f.ui_badge_text,
|
|
||||||
minimum_tier_id=f.minimum_tier_id,
|
|
||||||
minimum_tier_code=f.minimum_tier.code if f.minimum_tier else None,
|
|
||||||
minimum_tier_name=f.minimum_tier.name if f.minimum_tier else None,
|
|
||||||
is_active=f.is_active,
|
|
||||||
is_visible=f.is_visible,
|
|
||||||
display_order=f.display_order,
|
|
||||||
)
|
|
||||||
for f in features
|
|
||||||
],
|
|
||||||
total=len(features),
|
total=len(features),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -161,12 +175,7 @@ def list_tiers_with_features(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""List all tiers with their feature assignments."""
|
"""List all tiers with their feature assignments."""
|
||||||
tiers = (
|
tiers = feature_service.get_all_tiers_with_features(db)
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(SubscriptionTier.is_active == True) # noqa: E712
|
|
||||||
.order_by(SubscriptionTier.display_order)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return TierListWithFeaturesResponse(
|
return TierListWithFeaturesResponse(
|
||||||
tiers=[
|
tiers=[
|
||||||
@@ -189,29 +198,19 @@ def get_feature(
|
|||||||
current_user: User = Depends(get_current_admin_api),
|
current_user: User = Depends(get_current_admin_api),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Get a single feature by code."""
|
"""
|
||||||
|
Get a single feature by code.
|
||||||
|
|
||||||
|
Raises 404 if feature not found.
|
||||||
|
"""
|
||||||
feature = feature_service.get_feature_by_code(db, feature_code)
|
feature = feature_service.get_feature_by_code(db, feature_code)
|
||||||
|
|
||||||
if not feature:
|
if not feature:
|
||||||
raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found")
|
from app.exceptions import FeatureNotFoundError
|
||||||
|
|
||||||
return FeatureResponse(
|
raise FeatureNotFoundError(feature_code)
|
||||||
id=feature.id,
|
|
||||||
code=feature.code,
|
return _feature_to_response(feature)
|
||||||
name=feature.name,
|
|
||||||
description=feature.description,
|
|
||||||
category=feature.category,
|
|
||||||
ui_location=feature.ui_location,
|
|
||||||
ui_icon=feature.ui_icon,
|
|
||||||
ui_route=feature.ui_route,
|
|
||||||
ui_badge_text=feature.ui_badge_text,
|
|
||||||
minimum_tier_id=feature.minimum_tier_id,
|
|
||||||
minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None,
|
|
||||||
minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None,
|
|
||||||
is_active=feature.is_active,
|
|
||||||
is_visible=feature.is_visible,
|
|
||||||
display_order=feature.display_order,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{feature_code}", response_model=FeatureResponse)
|
@router.put("/{feature_code}", response_model=FeatureResponse)
|
||||||
@@ -221,73 +220,33 @@ def update_feature(
|
|||||||
current_user: User = Depends(get_current_admin_api),
|
current_user: User = Depends(get_current_admin_api),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Update feature metadata."""
|
"""
|
||||||
feature = db.query(Feature).filter(Feature.code == feature_code).first()
|
Update feature metadata.
|
||||||
|
|
||||||
if not feature:
|
Raises 404 if feature not found, 400 if tier code is invalid.
|
||||||
raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found")
|
"""
|
||||||
|
feature = feature_service.update_feature(
|
||||||
# Update fields if provided
|
db,
|
||||||
if request.name is not None:
|
feature_code,
|
||||||
feature.name = request.name
|
name=request.name,
|
||||||
if request.description is not None:
|
description=request.description,
|
||||||
feature.description = request.description
|
category=request.category,
|
||||||
if request.category is not None:
|
ui_location=request.ui_location,
|
||||||
feature.category = request.category
|
ui_icon=request.ui_icon,
|
||||||
if request.ui_location is not None:
|
ui_route=request.ui_route,
|
||||||
feature.ui_location = request.ui_location
|
ui_badge_text=request.ui_badge_text,
|
||||||
if request.ui_icon is not None:
|
minimum_tier_code=request.minimum_tier_code,
|
||||||
feature.ui_icon = request.ui_icon
|
is_active=request.is_active,
|
||||||
if request.ui_route is not None:
|
is_visible=request.is_visible,
|
||||||
feature.ui_route = request.ui_route
|
display_order=request.display_order,
|
||||||
if request.ui_badge_text is not None:
|
)
|
||||||
feature.ui_badge_text = request.ui_badge_text
|
|
||||||
if request.is_active is not None:
|
|
||||||
feature.is_active = request.is_active
|
|
||||||
if request.is_visible is not None:
|
|
||||||
feature.is_visible = request.is_visible
|
|
||||||
if request.display_order is not None:
|
|
||||||
feature.display_order = request.display_order
|
|
||||||
|
|
||||||
# Update minimum tier if provided
|
|
||||||
if request.minimum_tier_code is not None:
|
|
||||||
if request.minimum_tier_code == "":
|
|
||||||
feature.minimum_tier_id = None
|
|
||||||
else:
|
|
||||||
tier = (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(SubscriptionTier.code == request.minimum_tier_code)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not tier:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Tier '{request.minimum_tier_code}' not found",
|
|
||||||
)
|
|
||||||
feature.minimum_tier_id = tier.id
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(feature)
|
db.refresh(feature)
|
||||||
|
|
||||||
logger.info(f"Updated feature {feature_code} by admin {current_user.id}")
|
logger.info(f"Updated feature {feature_code} by admin {current_user.id}")
|
||||||
|
|
||||||
return FeatureResponse(
|
return _feature_to_response(feature)
|
||||||
id=feature.id,
|
|
||||||
code=feature.code,
|
|
||||||
name=feature.name,
|
|
||||||
description=feature.description,
|
|
||||||
category=feature.category,
|
|
||||||
ui_location=feature.ui_location,
|
|
||||||
ui_icon=feature.ui_icon,
|
|
||||||
ui_route=feature.ui_route,
|
|
||||||
ui_badge_text=feature.ui_badge_text,
|
|
||||||
minimum_tier_id=feature.minimum_tier_id,
|
|
||||||
minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None,
|
|
||||||
minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None,
|
|
||||||
is_active=feature.is_active,
|
|
||||||
is_visible=feature.is_visible,
|
|
||||||
display_order=feature.display_order,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/tiers/{tier_code}/features", response_model=TierFeaturesResponse)
|
@router.put("/tiers/{tier_code}/features", response_model=TierFeaturesResponse)
|
||||||
@@ -297,54 +256,46 @@ def update_tier_features(
|
|||||||
current_user: User = Depends(get_current_admin_api),
|
current_user: User = Depends(get_current_admin_api),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Update features assigned to a tier."""
|
"""
|
||||||
try:
|
Update features assigned to a tier.
|
||||||
tier = feature_service.update_tier_features(db, tier_code, request.feature_codes)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(
|
Raises 404 if tier not found, 422 if any feature codes are invalid.
|
||||||
f"Updated tier {tier_code} features to {len(request.feature_codes)} features "
|
"""
|
||||||
f"by admin {current_user.id}"
|
tier = feature_service.update_tier_features(db, tier_code, request.feature_codes)
|
||||||
)
|
db.commit()
|
||||||
|
|
||||||
return TierFeaturesResponse(
|
logger.info(
|
||||||
id=tier.id,
|
f"Updated tier {tier_code} features to {len(request.feature_codes)} features "
|
||||||
code=tier.code,
|
f"by admin {current_user.id}"
|
||||||
name=tier.name,
|
)
|
||||||
description=tier.description,
|
|
||||||
features=tier.features or [],
|
|
||||||
feature_count=len(tier.features or []),
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
return TierFeaturesResponse(
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
id=tier.id,
|
||||||
|
code=tier.code,
|
||||||
|
name=tier.name,
|
||||||
|
description=tier.description,
|
||||||
|
features=tier.features or [],
|
||||||
|
feature_count=len(tier.features or []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tiers/{tier_code}/features")
|
@router.get("/tiers/{tier_code}/features", response_model=TierFeatureDetailResponse)
|
||||||
def get_tier_features(
|
def get_tier_features(
|
||||||
tier_code: str,
|
tier_code: str,
|
||||||
current_user: User = Depends(get_current_admin_api),
|
current_user: User = Depends(get_current_admin_api),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Get features assigned to a specific tier."""
|
"""
|
||||||
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
|
Get features assigned to a specific tier with full details.
|
||||||
|
|
||||||
if not tier:
|
Raises 404 if tier not found.
|
||||||
raise HTTPException(status_code=404, detail=f"Tier '{tier_code}' not found")
|
"""
|
||||||
|
tier, features = feature_service.get_tier_features_with_details(db, tier_code)
|
||||||
|
|
||||||
# Get full feature details for the tier's features
|
return TierFeatureDetailResponse(
|
||||||
feature_codes = tier.features or []
|
tier_code=tier.code,
|
||||||
features = (
|
tier_name=tier.name,
|
||||||
db.query(Feature)
|
features=[
|
||||||
.filter(Feature.code.in_(feature_codes))
|
|
||||||
.order_by(Feature.category, Feature.display_order)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"tier_code": tier.code,
|
|
||||||
"tier_name": tier.name,
|
|
||||||
"features": [
|
|
||||||
{
|
{
|
||||||
"code": f.code,
|
"code": f.code,
|
||||||
"name": f.name,
|
"name": f.name,
|
||||||
@@ -353,5 +304,5 @@ def get_tier_features(
|
|||||||
}
|
}
|
||||||
for f in features
|
for f in features
|
||||||
],
|
],
|
||||||
"feature_count": len(features),
|
feature_count=len(features),
|
||||||
}
|
)
|
||||||
|
|||||||
18
app/api/v1/vendor/features.py
vendored
18
app/api/v1/vendor/features.py
vendored
@@ -16,12 +16,13 @@ Endpoints:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.api.deps import get_current_vendor_api
|
from app.api.deps import get_current_vendor_api
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
|
from app.exceptions import FeatureNotFoundError
|
||||||
from app.services.feature_service import feature_service
|
from app.services.feature_service import feature_service
|
||||||
from models.database.user import User
|
from models.database.user import User
|
||||||
|
|
||||||
@@ -99,6 +100,13 @@ class FeatureGroupedResponse(BaseModel):
|
|||||||
total_count: int
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureCheckResponse(BaseModel):
|
||||||
|
"""Quick feature availability check response."""
|
||||||
|
|
||||||
|
has_feature: bool
|
||||||
|
feature_code: str
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Endpoints
|
# Endpoints
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -285,7 +293,7 @@ def get_feature_detail(
|
|||||||
# Get feature
|
# Get feature
|
||||||
feature = feature_service.get_feature_by_code(db, feature_code)
|
feature = feature_service.get_feature_by_code(db, feature_code)
|
||||||
if not feature:
|
if not feature:
|
||||||
raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found")
|
raise FeatureNotFoundError(feature_code)
|
||||||
|
|
||||||
# Check availability
|
# Check availability
|
||||||
is_available = feature_service.has_feature(db, vendor_id, feature_code)
|
is_available = feature_service.has_feature(db, vendor_id, feature_code)
|
||||||
@@ -317,7 +325,7 @@ def get_feature_detail(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/check/{feature_code}")
|
@router.get("/check/{feature_code}", response_model=FeatureCheckResponse)
|
||||||
def check_feature(
|
def check_feature(
|
||||||
feature_code: str,
|
feature_code: str,
|
||||||
current_user: User = Depends(get_current_vendor_api),
|
current_user: User = Depends(get_current_vendor_api),
|
||||||
@@ -332,9 +340,9 @@ def check_feature(
|
|||||||
feature_code: The feature code
|
feature_code: The feature code
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
{"has_feature": true/false}
|
has_feature and feature_code
|
||||||
"""
|
"""
|
||||||
vendor_id = current_user.token_vendor_id
|
vendor_id = current_user.token_vendor_id
|
||||||
has_feature = feature_service.has_feature(db, vendor_id, feature_code)
|
has_feature = feature_service.has_feature(db, vendor_id, feature_code)
|
||||||
|
|
||||||
return {"has_feature": has_feature, "feature_code": feature_code}
|
return FeatureCheckResponse(has_feature=has_feature, feature_code=feature_code)
|
||||||
|
|||||||
293
app/api/v1/vendor/usage.py
vendored
293
app/api/v1/vendor/usage.py
vendored
@@ -12,16 +12,12 @@ import logging
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import func
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.api.deps import get_current_vendor_api
|
from app.api.deps import get_current_vendor_api
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.services.subscription_service import subscription_service
|
from app.services.usage_service import usage_service
|
||||||
from models.database.product import Product
|
|
||||||
from models.database.subscription import SubscriptionTier
|
|
||||||
from models.database.user import User
|
from models.database.user import User
|
||||||
from models.database.vendor import VendorUser
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/usage")
|
router = APIRouter(prefix="/usage")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -106,181 +102,44 @@ def get_usage(
|
|||||||
"""
|
"""
|
||||||
vendor_id = current_user.token_vendor_id
|
vendor_id = current_user.token_vendor_id
|
||||||
|
|
||||||
# Get subscription
|
# Get usage data from service
|
||||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
usage_data = usage_service.get_vendor_usage(db, vendor_id)
|
||||||
|
|
||||||
# Get current tier
|
|
||||||
tier = subscription.tier_obj
|
|
||||||
if not tier:
|
|
||||||
tier = (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(SubscriptionTier.code == subscription.tier)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate usage metrics
|
|
||||||
usage_metrics = []
|
|
||||||
|
|
||||||
# Orders this period
|
|
||||||
orders_current = subscription.orders_this_period or 0
|
|
||||||
orders_limit = subscription.orders_limit
|
|
||||||
orders_unlimited = orders_limit is None or orders_limit < 0
|
|
||||||
orders_percentage = 0 if orders_unlimited else (orders_current / orders_limit * 100 if orders_limit > 0 else 100)
|
|
||||||
|
|
||||||
usage_metrics.append(
|
|
||||||
UsageMetric(
|
|
||||||
name="orders",
|
|
||||||
current=orders_current,
|
|
||||||
limit=None if orders_unlimited else orders_limit,
|
|
||||||
percentage=orders_percentage,
|
|
||||||
is_unlimited=orders_unlimited,
|
|
||||||
is_at_limit=not orders_unlimited and orders_current >= orders_limit,
|
|
||||||
is_approaching_limit=not orders_unlimited and orders_percentage >= 80,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Products
|
|
||||||
products_count = (
|
|
||||||
db.query(func.count(Product.id))
|
|
||||||
.filter(Product.vendor_id == vendor_id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
products_limit = subscription.products_limit
|
|
||||||
products_unlimited = products_limit is None or products_limit < 0
|
|
||||||
products_percentage = 0 if products_unlimited else (products_count / products_limit * 100 if products_limit > 0 else 100)
|
|
||||||
|
|
||||||
usage_metrics.append(
|
|
||||||
UsageMetric(
|
|
||||||
name="products",
|
|
||||||
current=products_count,
|
|
||||||
limit=None if products_unlimited else products_limit,
|
|
||||||
percentage=products_percentage,
|
|
||||||
is_unlimited=products_unlimited,
|
|
||||||
is_at_limit=not products_unlimited and products_count >= products_limit,
|
|
||||||
is_approaching_limit=not products_unlimited and products_percentage >= 80,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Team members
|
|
||||||
team_count = (
|
|
||||||
db.query(func.count(VendorUser.id))
|
|
||||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
team_limit = subscription.team_members_limit
|
|
||||||
team_unlimited = team_limit is None or team_limit < 0
|
|
||||||
team_percentage = 0 if team_unlimited else (team_count / team_limit * 100 if team_limit > 0 else 100)
|
|
||||||
|
|
||||||
usage_metrics.append(
|
|
||||||
UsageMetric(
|
|
||||||
name="team_members",
|
|
||||||
current=team_count,
|
|
||||||
limit=None if team_unlimited else team_limit,
|
|
||||||
percentage=team_percentage,
|
|
||||||
is_unlimited=team_unlimited,
|
|
||||||
is_at_limit=not team_unlimited and team_count >= team_limit,
|
|
||||||
is_approaching_limit=not team_unlimited and team_percentage >= 80,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for approaching/reached limits
|
|
||||||
has_limits_approaching = any(m.is_approaching_limit for m in usage_metrics)
|
|
||||||
has_limits_reached = any(m.is_at_limit for m in usage_metrics)
|
|
||||||
|
|
||||||
# Get next tier for upgrade
|
|
||||||
all_tiers = (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(SubscriptionTier.is_active == True) # noqa: E712
|
|
||||||
.order_by(SubscriptionTier.display_order)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
current_tier_order = tier.display_order if tier else 0
|
|
||||||
next_tier = None
|
|
||||||
for t in all_tiers:
|
|
||||||
if t.display_order > current_tier_order:
|
|
||||||
next_tier = t
|
|
||||||
break
|
|
||||||
|
|
||||||
is_highest_tier = next_tier is None
|
|
||||||
|
|
||||||
# Build upgrade info
|
|
||||||
upgrade_tier_info = None
|
|
||||||
upgrade_reasons = []
|
|
||||||
|
|
||||||
if next_tier:
|
|
||||||
# Calculate benefits
|
|
||||||
benefits = []
|
|
||||||
if next_tier.orders_per_month and (not tier or (tier.orders_per_month and next_tier.orders_per_month > tier.orders_per_month)):
|
|
||||||
if next_tier.orders_per_month < 0:
|
|
||||||
benefits.append("Unlimited orders per month")
|
|
||||||
else:
|
|
||||||
benefits.append(f"{next_tier.orders_per_month:,} orders/month")
|
|
||||||
|
|
||||||
if next_tier.products_limit and (not tier or (tier.products_limit and next_tier.products_limit > tier.products_limit)):
|
|
||||||
if next_tier.products_limit < 0:
|
|
||||||
benefits.append("Unlimited products")
|
|
||||||
else:
|
|
||||||
benefits.append(f"{next_tier.products_limit:,} products")
|
|
||||||
|
|
||||||
if next_tier.team_members and (not tier or (tier.team_members and next_tier.team_members > tier.team_members)):
|
|
||||||
if next_tier.team_members < 0:
|
|
||||||
benefits.append("Unlimited team members")
|
|
||||||
else:
|
|
||||||
benefits.append(f"{next_tier.team_members} team members")
|
|
||||||
|
|
||||||
# Add feature benefits
|
|
||||||
current_features = set(tier.features) if tier and tier.features else set()
|
|
||||||
next_features = set(next_tier.features) if next_tier.features else set()
|
|
||||||
new_features = next_features - current_features
|
|
||||||
|
|
||||||
feature_names = {
|
|
||||||
"analytics_dashboard": "Advanced Analytics",
|
|
||||||
"api_access": "API Access",
|
|
||||||
"automation_rules": "Automation Rules",
|
|
||||||
"team_roles": "Team Roles & Permissions",
|
|
||||||
"custom_domain": "Custom Domain",
|
|
||||||
"webhooks": "Webhooks",
|
|
||||||
"accounting_export": "Accounting Export",
|
|
||||||
}
|
|
||||||
for feature in list(new_features)[:3]: # Show top 3
|
|
||||||
if feature in feature_names:
|
|
||||||
benefits.append(feature_names[feature])
|
|
||||||
|
|
||||||
current_price = tier.price_monthly_cents if tier else 0
|
|
||||||
upgrade_tier_info = UpgradeTierInfo(
|
|
||||||
code=next_tier.code,
|
|
||||||
name=next_tier.name,
|
|
||||||
price_monthly_cents=next_tier.price_monthly_cents,
|
|
||||||
price_increase_cents=next_tier.price_monthly_cents - current_price,
|
|
||||||
benefits=benefits,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build upgrade reasons
|
|
||||||
if has_limits_reached:
|
|
||||||
for m in usage_metrics:
|
|
||||||
if m.is_at_limit:
|
|
||||||
upgrade_reasons.append(f"You've reached your {m.name.replace('_', ' ')} limit")
|
|
||||||
elif has_limits_approaching:
|
|
||||||
for m in usage_metrics:
|
|
||||||
if m.is_approaching_limit:
|
|
||||||
upgrade_reasons.append(f"You're approaching your {m.name.replace('_', ' ')} limit ({int(m.percentage)}%)")
|
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
return UsageResponse(
|
return UsageResponse(
|
||||||
tier=TierInfo(
|
tier=TierInfo(
|
||||||
code=tier.code if tier else subscription.tier,
|
code=usage_data.tier.code,
|
||||||
name=tier.name if tier else subscription.tier.title(),
|
name=usage_data.tier.name,
|
||||||
price_monthly_cents=tier.price_monthly_cents if tier else 0,
|
price_monthly_cents=usage_data.tier.price_monthly_cents,
|
||||||
is_highest_tier=is_highest_tier,
|
is_highest_tier=usage_data.tier.is_highest_tier,
|
||||||
),
|
),
|
||||||
usage=usage_metrics,
|
usage=[
|
||||||
has_limits_approaching=has_limits_approaching,
|
UsageMetric(
|
||||||
has_limits_reached=has_limits_reached,
|
name=m.name,
|
||||||
upgrade_available=not is_highest_tier,
|
current=m.current,
|
||||||
upgrade_tier=upgrade_tier_info,
|
limit=m.limit,
|
||||||
upgrade_reasons=upgrade_reasons,
|
percentage=m.percentage,
|
||||||
|
is_unlimited=m.is_unlimited,
|
||||||
|
is_at_limit=m.is_at_limit,
|
||||||
|
is_approaching_limit=m.is_approaching_limit,
|
||||||
|
)
|
||||||
|
for m in usage_data.usage
|
||||||
|
],
|
||||||
|
has_limits_approaching=usage_data.has_limits_approaching,
|
||||||
|
has_limits_reached=usage_data.has_limits_reached,
|
||||||
|
upgrade_available=usage_data.upgrade_available,
|
||||||
|
upgrade_tier=(
|
||||||
|
UpgradeTierInfo(
|
||||||
|
code=usage_data.upgrade_tier.code,
|
||||||
|
name=usage_data.upgrade_tier.name,
|
||||||
|
price_monthly_cents=usage_data.upgrade_tier.price_monthly_cents,
|
||||||
|
price_increase_cents=usage_data.upgrade_tier.price_increase_cents,
|
||||||
|
benefits=usage_data.upgrade_tier.benefits,
|
||||||
|
)
|
||||||
|
if usage_data.upgrade_tier
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
upgrade_reasons=usage_data.upgrade_reasons,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -303,78 +162,16 @@ def check_limit(
|
|||||||
"""
|
"""
|
||||||
vendor_id = current_user.token_vendor_id
|
vendor_id = current_user.token_vendor_id
|
||||||
|
|
||||||
if limit_type == "orders":
|
# Check limit using service
|
||||||
can_proceed, message = subscription_service.can_create_order(db, vendor_id)
|
check_data = usage_service.check_limit(db, vendor_id, limit_type)
|
||||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
|
||||||
current = subscription.orders_this_period if subscription else 0
|
|
||||||
limit = subscription.orders_limit if subscription else 0
|
|
||||||
|
|
||||||
elif limit_type == "products":
|
|
||||||
can_proceed, message = subscription_service.can_add_product(db, vendor_id)
|
|
||||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
|
||||||
current = (
|
|
||||||
db.query(func.count(Product.id))
|
|
||||||
.filter(Product.vendor_id == vendor_id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
limit = subscription.products_limit if subscription else 0
|
|
||||||
|
|
||||||
elif limit_type == "team_members":
|
|
||||||
can_proceed, message = subscription_service.can_add_team_member(db, vendor_id)
|
|
||||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
|
||||||
current = (
|
|
||||||
db.query(func.count(VendorUser.id))
|
|
||||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
limit = subscription.team_members_limit if subscription else 0
|
|
||||||
|
|
||||||
else:
|
|
||||||
return LimitCheckResponse(
|
|
||||||
limit_type=limit_type,
|
|
||||||
can_proceed=True,
|
|
||||||
current=0,
|
|
||||||
limit=None,
|
|
||||||
percentage=0,
|
|
||||||
message=f"Unknown limit type: {limit_type}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate percentage
|
|
||||||
is_unlimited = limit is None or limit < 0
|
|
||||||
percentage = 0 if is_unlimited else (current / limit * 100 if limit > 0 else 100)
|
|
||||||
|
|
||||||
# Get upgrade info if at limit
|
|
||||||
upgrade_tier_code = None
|
|
||||||
upgrade_tier_name = None
|
|
||||||
|
|
||||||
if not can_proceed:
|
|
||||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
|
||||||
current_tier = subscription.tier_obj if subscription else None
|
|
||||||
|
|
||||||
if current_tier:
|
|
||||||
next_tier = (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(
|
|
||||||
SubscriptionTier.is_active == True, # noqa: E712
|
|
||||||
SubscriptionTier.display_order > current_tier.display_order,
|
|
||||||
)
|
|
||||||
.order_by(SubscriptionTier.display_order)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if next_tier:
|
|
||||||
upgrade_tier_code = next_tier.code
|
|
||||||
upgrade_tier_name = next_tier.name
|
|
||||||
|
|
||||||
return LimitCheckResponse(
|
return LimitCheckResponse(
|
||||||
limit_type=limit_type,
|
limit_type=check_data.limit_type,
|
||||||
can_proceed=can_proceed,
|
can_proceed=check_data.can_proceed,
|
||||||
current=current,
|
current=check_data.current,
|
||||||
limit=None if is_unlimited else limit,
|
limit=check_data.limit,
|
||||||
percentage=percentage,
|
percentage=check_data.percentage,
|
||||||
message=message,
|
message=check_data.message,
|
||||||
upgrade_tier_code=upgrade_tier_code,
|
upgrade_tier_code=check_data.upgrade_tier_code,
|
||||||
upgrade_tier_name=upgrade_tier_name,
|
upgrade_tier_name=check_data.upgrade_tier_name,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -263,6 +263,13 @@ from .onboarding import (
|
|||||||
OnboardingSyncNotCompleteException,
|
OnboardingSyncNotCompleteException,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Feature management exceptions
|
||||||
|
from .feature import (
|
||||||
|
FeatureNotFoundError,
|
||||||
|
InvalidFeatureCodesError,
|
||||||
|
TierNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base exceptions
|
# Base exceptions
|
||||||
"WizamartException",
|
"WizamartException",
|
||||||
@@ -456,4 +463,8 @@ __all__ = [
|
|||||||
"OnboardingCsvUrlRequiredException",
|
"OnboardingCsvUrlRequiredException",
|
||||||
"OnboardingSyncJobNotFoundException",
|
"OnboardingSyncJobNotFoundException",
|
||||||
"OnboardingSyncNotCompleteException",
|
"OnboardingSyncNotCompleteException",
|
||||||
|
# Feature exceptions
|
||||||
|
"FeatureNotFoundError",
|
||||||
|
"TierNotFoundError",
|
||||||
|
"InvalidFeatureCodesError",
|
||||||
]
|
]
|
||||||
|
|||||||
42
app/exceptions/feature.py
Normal file
42
app/exceptions/feature.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# app/exceptions/feature.py
|
||||||
|
"""
|
||||||
|
Feature management exceptions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.exceptions.base import ResourceNotFoundException, ValidationException
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureNotFoundError(ResourceNotFoundException):
|
||||||
|
"""Feature not found."""
|
||||||
|
|
||||||
|
def __init__(self, feature_code: str):
|
||||||
|
super().__init__(
|
||||||
|
resource_type="Feature",
|
||||||
|
identifier=feature_code,
|
||||||
|
message=f"Feature '{feature_code}' not found",
|
||||||
|
)
|
||||||
|
self.feature_code = feature_code
|
||||||
|
|
||||||
|
|
||||||
|
class TierNotFoundError(ResourceNotFoundException):
|
||||||
|
"""Subscription tier not found."""
|
||||||
|
|
||||||
|
def __init__(self, tier_code: str):
|
||||||
|
super().__init__(
|
||||||
|
resource_type="SubscriptionTier",
|
||||||
|
identifier=tier_code,
|
||||||
|
message=f"Tier '{tier_code}' not found",
|
||||||
|
)
|
||||||
|
self.tier_code = tier_code
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFeatureCodesError(ValidationException):
|
||||||
|
"""Invalid feature codes provided."""
|
||||||
|
|
||||||
|
def __init__(self, invalid_codes: set[str]):
|
||||||
|
codes_str = ", ".join(sorted(invalid_codes))
|
||||||
|
super().__init__(
|
||||||
|
message=f"Invalid feature codes: {codes_str}",
|
||||||
|
details={"invalid_codes": list(invalid_codes)},
|
||||||
|
)
|
||||||
|
self.invalid_codes = invalid_codes
|
||||||
@@ -29,6 +29,11 @@ from functools import lru_cache
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from app.exceptions.feature import (
|
||||||
|
FeatureNotFoundError,
|
||||||
|
InvalidFeatureCodesError,
|
||||||
|
TierNotFoundError,
|
||||||
|
)
|
||||||
from models.database.feature import Feature, FeatureCode
|
from models.database.feature import Feature, FeatureCode
|
||||||
from models.database.subscription import SubscriptionTier, VendorSubscription
|
from models.database.subscription import SubscriptionTier, VendorSubscription
|
||||||
|
|
||||||
@@ -370,6 +375,51 @@ class FeatureService:
|
|||||||
# Admin Operations
|
# Admin Operations
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_all_tiers_with_features(self, db: Session) -> list[SubscriptionTier]:
|
||||||
|
"""Get all active tiers with their features for admin."""
|
||||||
|
return (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(SubscriptionTier.is_active == True) # noqa: E712
|
||||||
|
.order_by(SubscriptionTier.display_order)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||||
|
"""
|
||||||
|
Get tier by code, raising exception if not found.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TierNotFoundError: If tier not found
|
||||||
|
"""
|
||||||
|
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
|
||||||
|
if not tier:
|
||||||
|
raise TierNotFoundError(tier_code)
|
||||||
|
return tier
|
||||||
|
|
||||||
|
def get_tier_features_with_details(
|
||||||
|
self, db: Session, tier_code: str
|
||||||
|
) -> tuple[SubscriptionTier, list[Feature]]:
|
||||||
|
"""
|
||||||
|
Get tier with full feature details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (tier, list of Feature objects)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TierNotFoundError: If tier not found
|
||||||
|
"""
|
||||||
|
tier = self.get_tier_by_code(db, tier_code)
|
||||||
|
feature_codes = tier.features or []
|
||||||
|
|
||||||
|
features = (
|
||||||
|
db.query(Feature)
|
||||||
|
.filter(Feature.code.in_(feature_codes))
|
||||||
|
.order_by(Feature.category, Feature.display_order)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return tier, features
|
||||||
|
|
||||||
def update_tier_features(
|
def update_tier_features(
|
||||||
self, db: Session, tier_code: str, feature_codes: list[str]
|
self, db: Session, tier_code: str, feature_codes: list[str]
|
||||||
) -> SubscriptionTier:
|
) -> SubscriptionTier:
|
||||||
@@ -383,11 +433,15 @@ class FeatureService:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated tier
|
Updated tier
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TierNotFoundError: If tier not found
|
||||||
|
InvalidFeatureCodesError: If any feature codes are invalid
|
||||||
"""
|
"""
|
||||||
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
|
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
|
||||||
|
|
||||||
if not tier:
|
if not tier:
|
||||||
raise ValueError(f"Tier '{tier_code}' not found")
|
raise TierNotFoundError(tier_code)
|
||||||
|
|
||||||
# Validate feature codes exist
|
# Validate feature codes exist
|
||||||
valid_codes = {
|
valid_codes = {
|
||||||
@@ -395,7 +449,7 @@ class FeatureService:
|
|||||||
}
|
}
|
||||||
invalid = set(feature_codes) - valid_codes
|
invalid = set(feature_codes) - valid_codes
|
||||||
if invalid:
|
if invalid:
|
||||||
raise ValueError(f"Invalid feature codes: {invalid}")
|
raise InvalidFeatureCodesError(invalid)
|
||||||
|
|
||||||
tier.features = feature_codes
|
tier.features = feature_codes
|
||||||
|
|
||||||
@@ -405,6 +459,86 @@ class FeatureService:
|
|||||||
logger.info(f"Updated features for tier {tier_code}: {len(feature_codes)} features")
|
logger.info(f"Updated features for tier {tier_code}: {len(feature_codes)} features")
|
||||||
return tier
|
return tier
|
||||||
|
|
||||||
|
def update_feature(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
feature_code: str,
|
||||||
|
name: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
ui_location: str | None = None,
|
||||||
|
ui_icon: str | None = None,
|
||||||
|
ui_route: str | None = None,
|
||||||
|
ui_badge_text: str | None = None,
|
||||||
|
minimum_tier_code: str | None = None,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
is_visible: bool | None = None,
|
||||||
|
display_order: int | None = None,
|
||||||
|
) -> Feature:
|
||||||
|
"""
|
||||||
|
Update feature metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
feature_code: Feature code to update
|
||||||
|
... other optional fields to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated feature
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FeatureNotFoundError: If feature not found
|
||||||
|
TierNotFoundError: If minimum_tier_code provided but not found
|
||||||
|
"""
|
||||||
|
feature = (
|
||||||
|
db.query(Feature)
|
||||||
|
.options(joinedload(Feature.minimum_tier))
|
||||||
|
.filter(Feature.code == feature_code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not feature:
|
||||||
|
raise FeatureNotFoundError(feature_code)
|
||||||
|
|
||||||
|
# Update fields if provided
|
||||||
|
if name is not None:
|
||||||
|
feature.name = name
|
||||||
|
if description is not None:
|
||||||
|
feature.description = description
|
||||||
|
if category is not None:
|
||||||
|
feature.category = category
|
||||||
|
if ui_location is not None:
|
||||||
|
feature.ui_location = ui_location
|
||||||
|
if ui_icon is not None:
|
||||||
|
feature.ui_icon = ui_icon
|
||||||
|
if ui_route is not None:
|
||||||
|
feature.ui_route = ui_route
|
||||||
|
if ui_badge_text is not None:
|
||||||
|
feature.ui_badge_text = ui_badge_text
|
||||||
|
if is_active is not None:
|
||||||
|
feature.is_active = is_active
|
||||||
|
if is_visible is not None:
|
||||||
|
feature.is_visible = is_visible
|
||||||
|
if display_order is not None:
|
||||||
|
feature.display_order = display_order
|
||||||
|
|
||||||
|
# Update minimum tier if provided
|
||||||
|
if minimum_tier_code is not None:
|
||||||
|
if minimum_tier_code == "":
|
||||||
|
feature.minimum_tier_id = None
|
||||||
|
else:
|
||||||
|
tier = (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(SubscriptionTier.code == minimum_tier_code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not tier:
|
||||||
|
raise TierNotFoundError(minimum_tier_code)
|
||||||
|
feature.minimum_tier_id = tier.id
|
||||||
|
|
||||||
|
logger.info(f"Updated feature {feature_code}")
|
||||||
|
return feature
|
||||||
|
|
||||||
def update_feature_minimum_tier(
|
def update_feature_minimum_tier(
|
||||||
self, db: Session, feature_code: str, tier_code: str | None
|
self, db: Session, feature_code: str, tier_code: str | None
|
||||||
) -> Feature:
|
) -> Feature:
|
||||||
@@ -415,16 +549,20 @@ class FeatureService:
|
|||||||
db: Database session
|
db: Database session
|
||||||
feature_code: Feature code
|
feature_code: Feature code
|
||||||
tier_code: Tier code or None
|
tier_code: Tier code or None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FeatureNotFoundError: If feature not found
|
||||||
|
TierNotFoundError: If tier_code provided but not found
|
||||||
"""
|
"""
|
||||||
feature = db.query(Feature).filter(Feature.code == feature_code).first()
|
feature = db.query(Feature).filter(Feature.code == feature_code).first()
|
||||||
|
|
||||||
if not feature:
|
if not feature:
|
||||||
raise ValueError(f"Feature '{feature_code}' not found")
|
raise FeatureNotFoundError(feature_code)
|
||||||
|
|
||||||
if tier_code:
|
if tier_code:
|
||||||
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
|
tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first()
|
||||||
if not tier:
|
if not tier:
|
||||||
raise ValueError(f"Tier '{tier_code}' not found")
|
raise TierNotFoundError(tier_code)
|
||||||
feature.minimum_tier_id = tier.id
|
feature.minimum_tier_id = tier.id
|
||||||
else:
|
else:
|
||||||
feature.minimum_tier_id = None
|
feature.minimum_tier_id = None
|
||||||
|
|||||||
445
app/services/usage_service.py
Normal file
445
app/services/usage_service.py
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
# app/services/usage_service.py
|
||||||
|
"""
|
||||||
|
Usage and limits service.
|
||||||
|
|
||||||
|
Provides methods for:
|
||||||
|
- Getting current usage vs limits
|
||||||
|
- Calculating upgrade recommendations
|
||||||
|
- Checking limits before actions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from models.database.product import Product
|
||||||
|
from models.database.subscription import SubscriptionTier, VendorSubscription
|
||||||
|
from models.database.vendor import VendorUser
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageMetricData:
|
||||||
|
"""Usage metric data."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
current: int
|
||||||
|
limit: int | None
|
||||||
|
percentage: float
|
||||||
|
is_unlimited: bool
|
||||||
|
is_at_limit: bool
|
||||||
|
is_approaching_limit: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TierInfoData:
|
||||||
|
"""Tier information."""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
name: str
|
||||||
|
price_monthly_cents: int
|
||||||
|
is_highest_tier: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpgradeTierData:
|
||||||
|
"""Upgrade tier information."""
|
||||||
|
|
||||||
|
code: str
|
||||||
|
name: str
|
||||||
|
price_monthly_cents: int
|
||||||
|
price_increase_cents: int
|
||||||
|
benefits: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageData:
|
||||||
|
"""Full usage data."""
|
||||||
|
|
||||||
|
tier: TierInfoData
|
||||||
|
usage: list[UsageMetricData]
|
||||||
|
has_limits_approaching: bool
|
||||||
|
has_limits_reached: bool
|
||||||
|
upgrade_available: bool
|
||||||
|
upgrade_tier: UpgradeTierData | None
|
||||||
|
upgrade_reasons: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LimitCheckData:
|
||||||
|
"""Limit check result."""
|
||||||
|
|
||||||
|
limit_type: str
|
||||||
|
can_proceed: bool
|
||||||
|
current: int
|
||||||
|
limit: int | None
|
||||||
|
percentage: float
|
||||||
|
message: str | None
|
||||||
|
upgrade_tier_code: str | None
|
||||||
|
upgrade_tier_name: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class UsageService:
|
||||||
|
"""Service for usage and limits management."""
|
||||||
|
|
||||||
|
def get_vendor_usage(self, db: Session, vendor_id: int) -> UsageData:
|
||||||
|
"""
|
||||||
|
Get comprehensive usage data for a vendor.
|
||||||
|
|
||||||
|
Returns current usage, limits, and upgrade recommendations.
|
||||||
|
"""
|
||||||
|
from app.services.subscription_service import subscription_service
|
||||||
|
|
||||||
|
# Get subscription
|
||||||
|
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||||
|
|
||||||
|
# Get current tier
|
||||||
|
tier = self._get_tier(db, subscription)
|
||||||
|
|
||||||
|
# Calculate usage metrics
|
||||||
|
usage_metrics = self._calculate_usage_metrics(db, vendor_id, subscription)
|
||||||
|
|
||||||
|
# Check for approaching/reached limits
|
||||||
|
has_limits_approaching = any(m.is_approaching_limit for m in usage_metrics)
|
||||||
|
has_limits_reached = any(m.is_at_limit for m in usage_metrics)
|
||||||
|
|
||||||
|
# Get upgrade info
|
||||||
|
next_tier = self._get_next_tier(db, tier)
|
||||||
|
is_highest_tier = next_tier is None
|
||||||
|
|
||||||
|
# Build upgrade info
|
||||||
|
upgrade_tier_info = None
|
||||||
|
upgrade_reasons = []
|
||||||
|
|
||||||
|
if next_tier:
|
||||||
|
upgrade_tier_info = self._build_upgrade_tier_info(tier, next_tier)
|
||||||
|
upgrade_reasons = self._build_upgrade_reasons(
|
||||||
|
usage_metrics, has_limits_reached, has_limits_approaching
|
||||||
|
)
|
||||||
|
|
||||||
|
return UsageData(
|
||||||
|
tier=TierInfoData(
|
||||||
|
code=tier.code if tier else subscription.tier,
|
||||||
|
name=tier.name if tier else subscription.tier.title(),
|
||||||
|
price_monthly_cents=tier.price_monthly_cents if tier else 0,
|
||||||
|
is_highest_tier=is_highest_tier,
|
||||||
|
),
|
||||||
|
usage=usage_metrics,
|
||||||
|
has_limits_approaching=has_limits_approaching,
|
||||||
|
has_limits_reached=has_limits_reached,
|
||||||
|
upgrade_available=not is_highest_tier,
|
||||||
|
upgrade_tier=upgrade_tier_info,
|
||||||
|
upgrade_reasons=upgrade_reasons,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_limit(
|
||||||
|
self, db: Session, vendor_id: int, limit_type: str
|
||||||
|
) -> LimitCheckData:
|
||||||
|
"""
|
||||||
|
Check a specific limit before performing an action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
vendor_id: Vendor ID
|
||||||
|
limit_type: One of "orders", "products", "team_members"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LimitCheckData with proceed status and upgrade info
|
||||||
|
"""
|
||||||
|
from app.services.subscription_service import subscription_service
|
||||||
|
|
||||||
|
if limit_type == "orders":
|
||||||
|
can_proceed, message = subscription_service.can_create_order(db, vendor_id)
|
||||||
|
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||||
|
current = subscription.orders_this_period if subscription else 0
|
||||||
|
limit = subscription.orders_limit if subscription else 0
|
||||||
|
|
||||||
|
elif limit_type == "products":
|
||||||
|
can_proceed, message = subscription_service.can_add_product(db, vendor_id)
|
||||||
|
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||||
|
current = self._get_product_count(db, vendor_id)
|
||||||
|
limit = subscription.products_limit if subscription else 0
|
||||||
|
|
||||||
|
elif limit_type == "team_members":
|
||||||
|
can_proceed, message = subscription_service.can_add_team_member(db, vendor_id)
|
||||||
|
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||||
|
current = self._get_team_member_count(db, vendor_id)
|
||||||
|
limit = subscription.team_members_limit if subscription else 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
return LimitCheckData(
|
||||||
|
limit_type=limit_type,
|
||||||
|
can_proceed=True,
|
||||||
|
current=0,
|
||||||
|
limit=None,
|
||||||
|
percentage=0,
|
||||||
|
message=f"Unknown limit type: {limit_type}",
|
||||||
|
upgrade_tier_code=None,
|
||||||
|
upgrade_tier_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate percentage
|
||||||
|
is_unlimited = limit is None or limit < 0
|
||||||
|
percentage = 0 if is_unlimited else (current / limit * 100 if limit > 0 else 100)
|
||||||
|
|
||||||
|
# Get upgrade info if at limit
|
||||||
|
upgrade_tier_code = None
|
||||||
|
upgrade_tier_name = None
|
||||||
|
|
||||||
|
if not can_proceed:
|
||||||
|
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||||
|
current_tier = subscription.tier_obj if subscription else None
|
||||||
|
|
||||||
|
if current_tier:
|
||||||
|
next_tier = self._get_next_tier(db, current_tier)
|
||||||
|
if next_tier:
|
||||||
|
upgrade_tier_code = next_tier.code
|
||||||
|
upgrade_tier_name = next_tier.name
|
||||||
|
|
||||||
|
return LimitCheckData(
|
||||||
|
limit_type=limit_type,
|
||||||
|
can_proceed=can_proceed,
|
||||||
|
current=current,
|
||||||
|
limit=None if is_unlimited else limit,
|
||||||
|
percentage=percentage,
|
||||||
|
message=message,
|
||||||
|
upgrade_tier_code=upgrade_tier_code,
|
||||||
|
upgrade_tier_name=upgrade_tier_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Private Helper Methods
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def _get_tier(
|
||||||
|
self, db: Session, subscription: VendorSubscription
|
||||||
|
) -> SubscriptionTier | None:
|
||||||
|
"""Get tier from subscription or query by code."""
|
||||||
|
tier = subscription.tier_obj
|
||||||
|
if not tier:
|
||||||
|
tier = (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(SubscriptionTier.code == subscription.tier)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return tier
|
||||||
|
|
||||||
|
def _get_product_count(self, db: Session, vendor_id: int) -> int:
|
||||||
|
"""Get product count for vendor."""
|
||||||
|
return (
|
||||||
|
db.query(func.count(Product.id))
|
||||||
|
.filter(Product.vendor_id == vendor_id)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_team_member_count(self, db: Session, vendor_id: int) -> int:
|
||||||
|
"""Get active team member count for vendor."""
|
||||||
|
return (
|
||||||
|
db.query(func.count(VendorUser.id))
|
||||||
|
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calculate_usage_metrics(
|
||||||
|
self, db: Session, vendor_id: int, subscription: VendorSubscription
|
||||||
|
) -> list[UsageMetricData]:
|
||||||
|
"""Calculate all usage metrics for a vendor."""
|
||||||
|
metrics = []
|
||||||
|
|
||||||
|
# Orders this period
|
||||||
|
orders_current = subscription.orders_this_period or 0
|
||||||
|
orders_limit = subscription.orders_limit
|
||||||
|
orders_unlimited = orders_limit is None or orders_limit < 0
|
||||||
|
orders_percentage = (
|
||||||
|
0
|
||||||
|
if orders_unlimited
|
||||||
|
else (orders_current / orders_limit * 100 if orders_limit > 0 else 100)
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics.append(
|
||||||
|
UsageMetricData(
|
||||||
|
name="orders",
|
||||||
|
current=orders_current,
|
||||||
|
limit=None if orders_unlimited else orders_limit,
|
||||||
|
percentage=orders_percentage,
|
||||||
|
is_unlimited=orders_unlimited,
|
||||||
|
is_at_limit=not orders_unlimited and orders_current >= orders_limit,
|
||||||
|
is_approaching_limit=not orders_unlimited and orders_percentage >= 80,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Products
|
||||||
|
products_count = self._get_product_count(db, vendor_id)
|
||||||
|
products_limit = subscription.products_limit
|
||||||
|
products_unlimited = products_limit is None or products_limit < 0
|
||||||
|
products_percentage = (
|
||||||
|
0
|
||||||
|
if products_unlimited
|
||||||
|
else (products_count / products_limit * 100 if products_limit > 0 else 100)
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics.append(
|
||||||
|
UsageMetricData(
|
||||||
|
name="products",
|
||||||
|
current=products_count,
|
||||||
|
limit=None if products_unlimited else products_limit,
|
||||||
|
percentage=products_percentage,
|
||||||
|
is_unlimited=products_unlimited,
|
||||||
|
is_at_limit=not products_unlimited and products_count >= products_limit,
|
||||||
|
is_approaching_limit=not products_unlimited and products_percentage >= 80,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Team members
|
||||||
|
team_count = self._get_team_member_count(db, vendor_id)
|
||||||
|
team_limit = subscription.team_members_limit
|
||||||
|
team_unlimited = team_limit is None or team_limit < 0
|
||||||
|
team_percentage = (
|
||||||
|
0
|
||||||
|
if team_unlimited
|
||||||
|
else (team_count / team_limit * 100 if team_limit > 0 else 100)
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics.append(
|
||||||
|
UsageMetricData(
|
||||||
|
name="team_members",
|
||||||
|
current=team_count,
|
||||||
|
limit=None if team_unlimited else team_limit,
|
||||||
|
percentage=team_percentage,
|
||||||
|
is_unlimited=team_unlimited,
|
||||||
|
is_at_limit=not team_unlimited and team_count >= team_limit,
|
||||||
|
is_approaching_limit=not team_unlimited and team_percentage >= 80,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _get_next_tier(
|
||||||
|
self, db: Session, current_tier: SubscriptionTier | None
|
||||||
|
) -> SubscriptionTier | None:
|
||||||
|
"""Get next tier for upgrade."""
|
||||||
|
current_tier_order = current_tier.display_order if current_tier else 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(
|
||||||
|
SubscriptionTier.is_active == True, # noqa: E712
|
||||||
|
SubscriptionTier.display_order > current_tier_order,
|
||||||
|
)
|
||||||
|
.order_by(SubscriptionTier.display_order)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_upgrade_tier_info(
|
||||||
|
self, current_tier: SubscriptionTier | None, next_tier: SubscriptionTier
|
||||||
|
) -> UpgradeTierData:
|
||||||
|
"""Build upgrade tier information with benefits."""
|
||||||
|
benefits = []
|
||||||
|
|
||||||
|
# Numeric limit benefits
|
||||||
|
if next_tier.orders_per_month and (
|
||||||
|
not current_tier
|
||||||
|
or (
|
||||||
|
current_tier.orders_per_month
|
||||||
|
and next_tier.orders_per_month > current_tier.orders_per_month
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if next_tier.orders_per_month < 0:
|
||||||
|
benefits.append("Unlimited orders per month")
|
||||||
|
else:
|
||||||
|
benefits.append(f"{next_tier.orders_per_month:,} orders/month")
|
||||||
|
|
||||||
|
if next_tier.products_limit and (
|
||||||
|
not current_tier
|
||||||
|
or (
|
||||||
|
current_tier.products_limit
|
||||||
|
and next_tier.products_limit > current_tier.products_limit
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if next_tier.products_limit < 0:
|
||||||
|
benefits.append("Unlimited products")
|
||||||
|
else:
|
||||||
|
benefits.append(f"{next_tier.products_limit:,} products")
|
||||||
|
|
||||||
|
if next_tier.team_members and (
|
||||||
|
not current_tier
|
||||||
|
or (
|
||||||
|
current_tier.team_members
|
||||||
|
and next_tier.team_members > current_tier.team_members
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if next_tier.team_members < 0:
|
||||||
|
benefits.append("Unlimited team members")
|
||||||
|
else:
|
||||||
|
benefits.append(f"{next_tier.team_members} team members")
|
||||||
|
|
||||||
|
# Feature benefits
|
||||||
|
current_features = (
|
||||||
|
set(current_tier.features) if current_tier and current_tier.features else set()
|
||||||
|
)
|
||||||
|
next_features = set(next_tier.features) if next_tier.features else set()
|
||||||
|
new_features = next_features - current_features
|
||||||
|
|
||||||
|
feature_names = {
|
||||||
|
"analytics_dashboard": "Advanced Analytics",
|
||||||
|
"api_access": "API Access",
|
||||||
|
"automation_rules": "Automation Rules",
|
||||||
|
"team_roles": "Team Roles & Permissions",
|
||||||
|
"custom_domain": "Custom Domain",
|
||||||
|
"webhooks": "Webhooks",
|
||||||
|
"accounting_export": "Accounting Export",
|
||||||
|
}
|
||||||
|
for feature in list(new_features)[:3]:
|
||||||
|
if feature in feature_names:
|
||||||
|
benefits.append(feature_names[feature])
|
||||||
|
|
||||||
|
current_price = current_tier.price_monthly_cents if current_tier else 0
|
||||||
|
|
||||||
|
return UpgradeTierData(
|
||||||
|
code=next_tier.code,
|
||||||
|
name=next_tier.name,
|
||||||
|
price_monthly_cents=next_tier.price_monthly_cents,
|
||||||
|
price_increase_cents=next_tier.price_monthly_cents - current_price,
|
||||||
|
benefits=benefits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_upgrade_reasons(
|
||||||
|
self,
|
||||||
|
usage_metrics: list[UsageMetricData],
|
||||||
|
has_limits_reached: bool,
|
||||||
|
has_limits_approaching: bool,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build upgrade reasons based on usage."""
|
||||||
|
reasons = []
|
||||||
|
|
||||||
|
if has_limits_reached:
|
||||||
|
for m in usage_metrics:
|
||||||
|
if m.is_at_limit:
|
||||||
|
reasons.append(f"You've reached your {m.name.replace('_', ' ')} limit")
|
||||||
|
elif has_limits_approaching:
|
||||||
|
for m in usage_metrics:
|
||||||
|
if m.is_approaching_limit:
|
||||||
|
reasons.append(
|
||||||
|
f"You're approaching your {m.name.replace('_', ' ')} limit ({int(m.percentage)}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return reasons
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
usage_service = UsageService()
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"usage_service",
|
||||||
|
"UsageService",
|
||||||
|
"UsageData",
|
||||||
|
"UsageMetricData",
|
||||||
|
"TierInfoData",
|
||||||
|
"UpgradeTierData",
|
||||||
|
"LimitCheckData",
|
||||||
|
]
|
||||||
367
tests/unit/services/test_feature_service.py
Normal file
367
tests/unit/services/test_feature_service.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
# tests/unit/services/test_feature_service.py
|
||||||
|
"""Unit tests for FeatureService."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.exceptions import FeatureNotFoundError, InvalidFeatureCodesError, TierNotFoundError
|
||||||
|
from app.services.feature_service import FeatureService, feature_service
|
||||||
|
from models.database.feature import Feature
|
||||||
|
from models.database.subscription import SubscriptionTier, VendorSubscription
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.features
|
||||||
|
class TestFeatureServiceAvailability:
|
||||||
|
"""Test suite for feature availability checking."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Initialize service instance before each test."""
|
||||||
|
self.service = FeatureService()
|
||||||
|
|
||||||
|
def test_has_feature_true(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test has_feature returns True for available feature."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
result = self.service.has_feature(db, vendor_id, "basic_reports")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_has_feature_false(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test has_feature returns False for unavailable feature."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
result = self.service.has_feature(db, vendor_id, "api_access")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_has_feature_no_subscription(self, db, test_vendor):
|
||||||
|
"""Test has_feature returns False for vendor without subscription."""
|
||||||
|
result = self.service.has_feature(db, test_vendor.id, "basic_reports")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_get_vendor_feature_codes(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test getting all feature codes for vendor."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
features = self.service.get_vendor_feature_codes(db, vendor_id)
|
||||||
|
|
||||||
|
assert isinstance(features, set)
|
||||||
|
assert "basic_reports" in features
|
||||||
|
assert "api_access" not in features
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.features
|
||||||
|
class TestFeatureServiceListing:
|
||||||
|
"""Test suite for feature listing operations."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Initialize service instance before each test."""
|
||||||
|
self.service = FeatureService()
|
||||||
|
|
||||||
|
def test_get_vendor_features(self, db, test_vendor_with_subscription, test_features):
|
||||||
|
"""Test getting all features with availability."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
features = self.service.get_vendor_features(db, vendor_id)
|
||||||
|
|
||||||
|
assert len(features) > 0
|
||||||
|
basic_reports = next((f for f in features if f.code == "basic_reports"), None)
|
||||||
|
assert basic_reports is not None
|
||||||
|
assert basic_reports.is_available is True
|
||||||
|
|
||||||
|
api_access = next((f for f in features if f.code == "api_access"), None)
|
||||||
|
assert api_access is not None
|
||||||
|
assert api_access.is_available is False
|
||||||
|
|
||||||
|
def test_get_vendor_features_by_category(
|
||||||
|
self, db, test_vendor_with_subscription, test_features
|
||||||
|
):
|
||||||
|
"""Test filtering features by category."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
features = self.service.get_vendor_features(db, vendor_id, category="analytics")
|
||||||
|
|
||||||
|
assert all(f.category == "analytics" for f in features)
|
||||||
|
|
||||||
|
def test_get_vendor_features_available_only(
|
||||||
|
self, db, test_vendor_with_subscription, test_features
|
||||||
|
):
|
||||||
|
"""Test getting only available features."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
features = self.service.get_vendor_features(
|
||||||
|
db, vendor_id, include_unavailable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all(f.is_available for f in features)
|
||||||
|
|
||||||
|
def test_get_available_feature_codes(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test getting simple list of available codes."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
codes = self.service.get_available_feature_codes(db, vendor_id)
|
||||||
|
|
||||||
|
assert isinstance(codes, list)
|
||||||
|
assert "basic_reports" in codes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.features
|
||||||
|
class TestFeatureServiceMetadata:
|
||||||
|
"""Test suite for feature metadata operations."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Initialize service instance before each test."""
|
||||||
|
self.service = FeatureService()
|
||||||
|
|
||||||
|
def test_get_feature_by_code(self, db, test_features):
|
||||||
|
"""Test getting feature by code."""
|
||||||
|
feature = self.service.get_feature_by_code(db, "basic_reports")
|
||||||
|
|
||||||
|
assert feature is not None
|
||||||
|
assert feature.code == "basic_reports"
|
||||||
|
assert feature.name == "Basic Reports"
|
||||||
|
|
||||||
|
def test_get_feature_by_code_not_found(self, db, test_features):
|
||||||
|
"""Test getting non-existent feature returns None."""
|
||||||
|
feature = self.service.get_feature_by_code(db, "nonexistent")
|
||||||
|
assert feature is None
|
||||||
|
|
||||||
|
def test_get_feature_upgrade_info(self, db, test_features, test_subscription_tiers):
|
||||||
|
"""Test getting upgrade info for locked feature."""
|
||||||
|
info = self.service.get_feature_upgrade_info(db, "api_access")
|
||||||
|
|
||||||
|
assert info is not None
|
||||||
|
assert info.feature_code == "api_access"
|
||||||
|
assert info.required_tier_code == "professional"
|
||||||
|
|
||||||
|
def test_get_feature_upgrade_info_no_minimum_tier(self, db, test_features):
|
||||||
|
"""Test upgrade info for feature without minimum tier."""
|
||||||
|
# basic_reports has no minimum tier in fixtures
|
||||||
|
info = self.service.get_feature_upgrade_info(db, "basic_reports")
|
||||||
|
assert info is None
|
||||||
|
|
||||||
|
def test_get_all_features(self, db, test_features):
|
||||||
|
"""Test getting all features for admin."""
|
||||||
|
features = self.service.get_all_features(db)
|
||||||
|
assert len(features) >= 3
|
||||||
|
|
||||||
|
def test_get_all_features_by_category(self, db, test_features):
|
||||||
|
"""Test filtering features by category."""
|
||||||
|
features = self.service.get_all_features(db, category="analytics")
|
||||||
|
assert all(f.category == "analytics" for f in features)
|
||||||
|
|
||||||
|
def test_get_categories(self, db, test_features):
|
||||||
|
"""Test getting unique categories."""
|
||||||
|
categories = self.service.get_categories(db)
|
||||||
|
assert "analytics" in categories
|
||||||
|
assert "integrations" in categories
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.features
|
||||||
|
class TestFeatureServiceCache:
|
||||||
|
"""Test suite for cache operations."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Initialize service instance before each test."""
|
||||||
|
self.service = FeatureService()
|
||||||
|
|
||||||
|
def test_cache_invalidation(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test cache invalidation for vendor."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
|
||||||
|
# Prime the cache
|
||||||
|
self.service.get_vendor_feature_codes(db, vendor_id)
|
||||||
|
assert self.service._cache.get(vendor_id) is not None
|
||||||
|
|
||||||
|
# Invalidate
|
||||||
|
self.service.invalidate_vendor_cache(vendor_id)
|
||||||
|
assert self.service._cache.get(vendor_id) is None
|
||||||
|
|
||||||
|
def test_cache_invalidate_all(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test invalidating entire cache."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
|
||||||
|
# Prime the cache
|
||||||
|
self.service.get_vendor_feature_codes(db, vendor_id)
|
||||||
|
|
||||||
|
# Invalidate all
|
||||||
|
self.service.invalidate_all_cache()
|
||||||
|
assert self.service._cache.get(vendor_id) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.features
|
||||||
|
class TestFeatureServiceAdmin:
|
||||||
|
"""Test suite for admin operations."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Initialize service instance before each test."""
|
||||||
|
self.service = FeatureService()
|
||||||
|
|
||||||
|
def test_get_all_tiers_with_features(self, db, test_subscription_tiers):
|
||||||
|
"""Test getting all tiers."""
|
||||||
|
tiers = self.service.get_all_tiers_with_features(db)
|
||||||
|
|
||||||
|
assert len(tiers) == 2
|
||||||
|
assert tiers[0].code == "essential"
|
||||||
|
assert tiers[1].code == "professional"
|
||||||
|
|
||||||
|
def test_update_tier_features(self, db, test_subscription_tiers, test_features):
|
||||||
|
"""Test updating tier features."""
|
||||||
|
tier = self.service.update_tier_features(
|
||||||
|
db, "essential", ["basic_reports", "api_access"]
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
assert "api_access" in tier.features
|
||||||
|
|
||||||
|
def test_update_tier_features_invalid_codes(
|
||||||
|
self, db, test_subscription_tiers, test_features
|
||||||
|
):
|
||||||
|
"""Test updating tier with invalid feature codes."""
|
||||||
|
with pytest.raises(InvalidFeatureCodesError) as exc_info:
|
||||||
|
self.service.update_tier_features(
|
||||||
|
db, "essential", ["basic_reports", "nonexistent_feature"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "nonexistent_feature" in exc_info.value.invalid_codes
|
||||||
|
|
||||||
|
def test_update_tier_features_tier_not_found(self, db, test_features):
|
||||||
|
"""Test updating non-existent tier."""
|
||||||
|
with pytest.raises(TierNotFoundError) as exc_info:
|
||||||
|
self.service.update_tier_features(db, "nonexistent", ["basic_reports"])
|
||||||
|
|
||||||
|
assert exc_info.value.tier_code == "nonexistent"
|
||||||
|
|
||||||
|
def test_update_feature(self, db, test_features):
|
||||||
|
"""Test updating feature metadata."""
|
||||||
|
feature = self.service.update_feature(
|
||||||
|
db, "basic_reports", name="Updated Reports", description="New description"
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
assert feature.name == "Updated Reports"
|
||||||
|
assert feature.description == "New description"
|
||||||
|
|
||||||
|
def test_update_feature_not_found(self, db, test_features):
|
||||||
|
"""Test updating non-existent feature."""
|
||||||
|
with pytest.raises(FeatureNotFoundError) as exc_info:
|
||||||
|
self.service.update_feature(db, "nonexistent", name="Test")
|
||||||
|
|
||||||
|
assert exc_info.value.feature_code == "nonexistent"
|
||||||
|
|
||||||
|
def test_update_feature_minimum_tier(
|
||||||
|
self, db, test_features, test_subscription_tiers
|
||||||
|
):
|
||||||
|
"""Test updating feature minimum tier."""
|
||||||
|
feature = self.service.update_feature(
|
||||||
|
db, "basic_reports", minimum_tier_code="professional"
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(feature)
|
||||||
|
|
||||||
|
assert feature.minimum_tier is not None
|
||||||
|
assert feature.minimum_tier.code == "professional"
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_subscription_tiers(db):
|
||||||
|
"""Create multiple subscription tiers."""
|
||||||
|
tiers = [
|
||||||
|
SubscriptionTier(
|
||||||
|
code="essential",
|
||||||
|
name="Essential",
|
||||||
|
description="Essential plan",
|
||||||
|
price_monthly_cents=4900,
|
||||||
|
price_annual_cents=49000,
|
||||||
|
orders_per_month=100,
|
||||||
|
products_limit=500,
|
||||||
|
team_members=2,
|
||||||
|
features=["basic_reports"],
|
||||||
|
is_active=True,
|
||||||
|
display_order=1,
|
||||||
|
),
|
||||||
|
SubscriptionTier(
|
||||||
|
code="professional",
|
||||||
|
name="Professional",
|
||||||
|
description="Professional plan",
|
||||||
|
price_monthly_cents=9900,
|
||||||
|
price_annual_cents=99000,
|
||||||
|
orders_per_month=500,
|
||||||
|
products_limit=2000,
|
||||||
|
team_members=5,
|
||||||
|
features=["basic_reports", "api_access", "analytics_dashboard"],
|
||||||
|
is_active=True,
|
||||||
|
display_order=2,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for tier in tiers:
|
||||||
|
db.add(tier)
|
||||||
|
db.commit()
|
||||||
|
for tier in tiers:
|
||||||
|
db.refresh(tier)
|
||||||
|
return tiers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vendor_with_subscription(db, test_vendor, test_subscription_tiers):
|
||||||
|
"""Create a vendor with an active subscription."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
essential_tier = test_subscription_tiers[0] # Use the essential tier from tiers list
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
subscription = VendorSubscription(
|
||||||
|
vendor_id=test_vendor.id,
|
||||||
|
tier="essential",
|
||||||
|
tier_id=essential_tier.id,
|
||||||
|
status="active",
|
||||||
|
period_start=now,
|
||||||
|
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
|
||||||
|
orders_this_period=10,
|
||||||
|
)
|
||||||
|
db.add(subscription)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(test_vendor)
|
||||||
|
return test_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_features(db, test_subscription_tiers):
|
||||||
|
"""Create test features."""
|
||||||
|
features = [
|
||||||
|
Feature(
|
||||||
|
code="basic_reports",
|
||||||
|
name="Basic Reports",
|
||||||
|
description="View basic analytics reports",
|
||||||
|
category="analytics",
|
||||||
|
ui_location="sidebar",
|
||||||
|
ui_icon="chart-bar",
|
||||||
|
is_active=True,
|
||||||
|
display_order=1,
|
||||||
|
),
|
||||||
|
Feature(
|
||||||
|
code="api_access",
|
||||||
|
name="API Access",
|
||||||
|
description="Access the REST API",
|
||||||
|
category="integrations",
|
||||||
|
ui_location="settings",
|
||||||
|
ui_icon="code",
|
||||||
|
minimum_tier_id=test_subscription_tiers[1].id, # Professional
|
||||||
|
is_active=True,
|
||||||
|
display_order=2,
|
||||||
|
),
|
||||||
|
Feature(
|
||||||
|
code="analytics_dashboard",
|
||||||
|
name="Analytics Dashboard",
|
||||||
|
description="Advanced analytics dashboard",
|
||||||
|
category="analytics",
|
||||||
|
ui_location="sidebar",
|
||||||
|
ui_icon="presentation-chart-line",
|
||||||
|
minimum_tier_id=test_subscription_tiers[1].id,
|
||||||
|
is_active=True,
|
||||||
|
display_order=3,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for feature in features:
|
||||||
|
db.add(feature)
|
||||||
|
db.commit()
|
||||||
|
for feature in features:
|
||||||
|
db.refresh(feature)
|
||||||
|
return features
|
||||||
308
tests/unit/services/test_usage_service.py
Normal file
308
tests/unit/services/test_usage_service.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
# tests/unit/services/test_usage_service.py
|
||||||
|
"""Unit tests for UsageService."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.usage_service import UsageService, usage_service
|
||||||
|
from models.database.product import Product
|
||||||
|
from models.database.subscription import SubscriptionTier, VendorSubscription
|
||||||
|
from models.database.vendor import VendorUser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.usage
|
||||||
|
class TestUsageServiceGetUsage:
|
||||||
|
"""Test suite for get_vendor_usage operation."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Initialize service instance before each test."""
|
||||||
|
self.service = UsageService()
|
||||||
|
|
||||||
|
def test_get_vendor_usage_basic(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test getting basic usage data."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
usage = self.service.get_vendor_usage(db, vendor_id)
|
||||||
|
|
||||||
|
assert usage.tier.code == "essential"
|
||||||
|
assert usage.tier.name == "Essential"
|
||||||
|
assert len(usage.usage) == 3
|
||||||
|
|
||||||
|
def test_get_vendor_usage_metrics(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test usage metrics are calculated correctly."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
usage = self.service.get_vendor_usage(db, vendor_id)
|
||||||
|
|
||||||
|
orders_metric = next((m for m in usage.usage if m.name == "orders"), None)
|
||||||
|
assert orders_metric is not None
|
||||||
|
assert orders_metric.current == 10
|
||||||
|
assert orders_metric.limit == 100
|
||||||
|
assert orders_metric.percentage == 10.0
|
||||||
|
assert orders_metric.is_unlimited is False
|
||||||
|
|
||||||
|
def test_get_vendor_usage_at_limit(self, db, test_vendor_at_limit):
|
||||||
|
"""Test usage shows at limit correctly."""
|
||||||
|
vendor_id = test_vendor_at_limit.id
|
||||||
|
usage = self.service.get_vendor_usage(db, vendor_id)
|
||||||
|
|
||||||
|
orders_metric = next((m for m in usage.usage if m.name == "orders"), None)
|
||||||
|
assert orders_metric.is_at_limit is True
|
||||||
|
assert usage.has_limits_reached is True
|
||||||
|
|
||||||
|
def test_get_vendor_usage_approaching_limit(self, db, test_vendor_approaching_limit):
|
||||||
|
"""Test usage shows approaching limit correctly."""
|
||||||
|
vendor_id = test_vendor_approaching_limit.id
|
||||||
|
usage = self.service.get_vendor_usage(db, vendor_id)
|
||||||
|
|
||||||
|
orders_metric = next((m for m in usage.usage if m.name == "orders"), None)
|
||||||
|
assert orders_metric.is_approaching_limit is True
|
||||||
|
assert usage.has_limits_approaching is True
|
||||||
|
|
||||||
|
def test_get_vendor_usage_upgrade_available(
|
||||||
|
self, db, test_vendor_with_subscription, test_professional_tier
|
||||||
|
):
|
||||||
|
"""Test upgrade info when not on highest tier."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
usage = self.service.get_vendor_usage(db, vendor_id)
|
||||||
|
|
||||||
|
assert usage.upgrade_available is True
|
||||||
|
assert usage.upgrade_tier is not None
|
||||||
|
assert usage.upgrade_tier.code == "professional"
|
||||||
|
|
||||||
|
def test_get_vendor_usage_highest_tier(self, db, test_vendor_on_professional):
|
||||||
|
"""Test no upgrade when on highest tier."""
|
||||||
|
vendor_id = test_vendor_on_professional.id
|
||||||
|
usage = self.service.get_vendor_usage(db, vendor_id)
|
||||||
|
|
||||||
|
assert usage.tier.is_highest_tier is True
|
||||||
|
assert usage.upgrade_available is False
|
||||||
|
assert usage.upgrade_tier is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.usage
|
||||||
|
class TestUsageServiceCheckLimit:
|
||||||
|
"""Test suite for check_limit operation."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Initialize service instance before each test."""
|
||||||
|
self.service = UsageService()
|
||||||
|
|
||||||
|
def test_check_orders_limit_can_proceed(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test checking orders limit when under limit."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
result = self.service.check_limit(db, vendor_id, "orders")
|
||||||
|
|
||||||
|
assert result.can_proceed is True
|
||||||
|
assert result.current == 10
|
||||||
|
assert result.limit == 100
|
||||||
|
|
||||||
|
def test_check_products_limit(self, db, test_vendor_with_products):
|
||||||
|
"""Test checking products limit."""
|
||||||
|
vendor_id = test_vendor_with_products.id
|
||||||
|
result = self.service.check_limit(db, vendor_id, "products")
|
||||||
|
|
||||||
|
assert result.can_proceed is True
|
||||||
|
assert result.current == 5
|
||||||
|
assert result.limit == 500
|
||||||
|
|
||||||
|
def test_check_team_members_limit(self, db, test_vendor_with_team):
|
||||||
|
"""Test checking team members limit when at limit."""
|
||||||
|
vendor_id = test_vendor_with_team.id
|
||||||
|
result = self.service.check_limit(db, vendor_id, "team_members")
|
||||||
|
|
||||||
|
# At limit (2/2) - can_proceed should be False
|
||||||
|
assert result.can_proceed is False
|
||||||
|
assert result.current == 2
|
||||||
|
assert result.limit == 2
|
||||||
|
assert result.percentage == 100.0
|
||||||
|
|
||||||
|
def test_check_unknown_limit_type(self, db, test_vendor_with_subscription):
|
||||||
|
"""Test checking unknown limit type."""
|
||||||
|
vendor_id = test_vendor_with_subscription.id
|
||||||
|
result = self.service.check_limit(db, vendor_id, "unknown")
|
||||||
|
|
||||||
|
assert result.can_proceed is True
|
||||||
|
assert "Unknown limit type" in result.message
|
||||||
|
|
||||||
|
def test_check_limit_upgrade_info_when_blocked(self, db, test_vendor_at_limit):
|
||||||
|
"""Test upgrade info is provided when at limit."""
|
||||||
|
vendor_id = test_vendor_at_limit.id
|
||||||
|
result = self.service.check_limit(db, vendor_id, "orders")
|
||||||
|
|
||||||
|
assert result.can_proceed is False
|
||||||
|
assert result.upgrade_tier_code == "professional"
|
||||||
|
assert result.upgrade_tier_name == "Professional"
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_essential_tier(db):
|
||||||
|
"""Create essential subscription tier."""
|
||||||
|
tier = SubscriptionTier(
|
||||||
|
code="essential",
|
||||||
|
name="Essential",
|
||||||
|
description="Essential plan",
|
||||||
|
price_monthly_cents=4900,
|
||||||
|
price_annual_cents=49000,
|
||||||
|
orders_per_month=100,
|
||||||
|
products_limit=500,
|
||||||
|
team_members=2,
|
||||||
|
features=["basic_reports"],
|
||||||
|
is_active=True,
|
||||||
|
display_order=1,
|
||||||
|
)
|
||||||
|
db.add(tier)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(tier)
|
||||||
|
return tier
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_professional_tier(db, test_essential_tier):
|
||||||
|
"""Create professional subscription tier."""
|
||||||
|
tier = SubscriptionTier(
|
||||||
|
code="professional",
|
||||||
|
name="Professional",
|
||||||
|
description="Professional plan",
|
||||||
|
price_monthly_cents=9900,
|
||||||
|
price_annual_cents=99000,
|
||||||
|
orders_per_month=500,
|
||||||
|
products_limit=2000,
|
||||||
|
team_members=10,
|
||||||
|
features=["basic_reports", "api_access", "analytics_dashboard"],
|
||||||
|
is_active=True,
|
||||||
|
display_order=2,
|
||||||
|
)
|
||||||
|
db.add(tier)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(tier)
|
||||||
|
return tier
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vendor_with_subscription(db, test_vendor, test_essential_tier):
|
||||||
|
"""Create vendor with active subscription."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
subscription = VendorSubscription(
|
||||||
|
vendor_id=test_vendor.id,
|
||||||
|
tier="essential",
|
||||||
|
tier_id=test_essential_tier.id,
|
||||||
|
status="active",
|
||||||
|
period_start=now,
|
||||||
|
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
|
||||||
|
orders_this_period=10,
|
||||||
|
)
|
||||||
|
db.add(subscription)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(test_vendor)
|
||||||
|
return test_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vendor_at_limit(db, test_vendor, test_essential_tier, test_professional_tier):
|
||||||
|
"""Create vendor at order limit."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
subscription = VendorSubscription(
|
||||||
|
vendor_id=test_vendor.id,
|
||||||
|
tier="essential",
|
||||||
|
tier_id=test_essential_tier.id,
|
||||||
|
status="active",
|
||||||
|
period_start=now,
|
||||||
|
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
|
||||||
|
orders_this_period=100, # At limit
|
||||||
|
)
|
||||||
|
db.add(subscription)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(test_vendor)
|
||||||
|
return test_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vendor_approaching_limit(db, test_vendor, test_essential_tier):
|
||||||
|
"""Create vendor approaching order limit (>=80%)."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
subscription = VendorSubscription(
|
||||||
|
vendor_id=test_vendor.id,
|
||||||
|
tier="essential",
|
||||||
|
tier_id=test_essential_tier.id,
|
||||||
|
status="active",
|
||||||
|
period_start=now,
|
||||||
|
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
|
||||||
|
orders_this_period=85, # 85% of 100
|
||||||
|
)
|
||||||
|
db.add(subscription)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(test_vendor)
|
||||||
|
return test_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vendor_on_professional(db, test_vendor, test_professional_tier):
|
||||||
|
"""Create vendor on highest tier."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
subscription = VendorSubscription(
|
||||||
|
vendor_id=test_vendor.id,
|
||||||
|
tier="professional",
|
||||||
|
tier_id=test_professional_tier.id,
|
||||||
|
status="active",
|
||||||
|
period_start=now,
|
||||||
|
period_end=now.replace(month=now.month + 1 if now.month < 12 else 1),
|
||||||
|
orders_this_period=50,
|
||||||
|
)
|
||||||
|
db.add(subscription)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(test_vendor)
|
||||||
|
return test_vendor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vendor_with_products(db, test_vendor_with_subscription, marketplace_product_factory):
|
||||||
|
"""Create vendor with products."""
|
||||||
|
for i in range(5):
|
||||||
|
# Create marketplace product first
|
||||||
|
mp = marketplace_product_factory(db, title=f"Test Product {i}")
|
||||||
|
product = Product(
|
||||||
|
vendor_id=test_vendor_with_subscription.id,
|
||||||
|
marketplace_product_id=mp.id,
|
||||||
|
price_cents=1000,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db.add(product)
|
||||||
|
db.commit()
|
||||||
|
return test_vendor_with_subscription
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vendor_with_team(db, test_vendor_with_subscription, test_user, other_user):
|
||||||
|
"""Create vendor with team members (owner + team member = 2)."""
|
||||||
|
from models.database.vendor import VendorUserType
|
||||||
|
|
||||||
|
# Add owner
|
||||||
|
owner = VendorUser(
|
||||||
|
vendor_id=test_vendor_with_subscription.id,
|
||||||
|
user_id=test_user.id,
|
||||||
|
user_type=VendorUserType.OWNER.value,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db.add(owner)
|
||||||
|
|
||||||
|
# Add team member
|
||||||
|
team_member = VendorUser(
|
||||||
|
vendor_id=test_vendor_with_subscription.id,
|
||||||
|
user_id=other_user.id,
|
||||||
|
user_type=VendorUserType.TEAM_MEMBER.value,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db.add(team_member)
|
||||||
|
db.commit()
|
||||||
|
return test_vendor_with_subscription
|
||||||
Reference in New Issue
Block a user