From aa4b5a4c63d66e3a2cdd2dcc19fbbb1b73b91f4f Mon Sep 17 00:00:00 2001 From: Samir Boulahtit Date: Wed, 31 Dec 2025 18:48:59 +0100 Subject: [PATCH] test: add service tests and fix architecture violations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add comprehensive unit tests for FeatureService (24 tests) - Add comprehensive unit tests for UsageService (11 tests) - Fix API-002/API-003 architecture violations in feature/usage APIs - Move database queries from API layer to service layer - Create UsageService for usage and limits management - Create custom exceptions (FeatureNotFoundError, TierNotFoundError) - Fix ValidationException usage in content_pages.py - Refactor vendor features API to use proper response models - All 35 new tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- app/api/v1/admin/content_pages.py | 9 +- app/api/v1/admin/features.py | 243 +++++------ app/api/v1/vendor/features.py | 18 +- app/api/v1/vendor/usage.py | 293 ++----------- app/exceptions/__init__.py | 11 + app/exceptions/feature.py | 42 ++ app/services/feature_service.py | 146 ++++++- app/services/usage_service.py | 445 ++++++++++++++++++++ tests/unit/services/test_feature_service.py | 367 ++++++++++++++++ tests/unit/services/test_usage_service.py | 308 ++++++++++++++ 10 files changed, 1474 insertions(+), 408 deletions(-) create mode 100644 app/exceptions/feature.py create mode 100644 app/services/usage_service.py create mode 100644 tests/unit/services/test_feature_service.py create mode 100644 tests/unit/services/test_usage_service.py diff --git a/app/api/v1/admin/content_pages.py b/app/api/v1/admin/content_pages.py index 74a1fb3a..1dbfeab4 100644 --- a/app/api/v1/admin/content_pages.py +++ b/app/api/v1/admin/content_pages.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, Field from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api, get_db +from app.exceptions import ValidationException from app.services.content_page_service import content_page_service from models.database.user import User @@ -178,11 +179,9 @@ def create_vendor_page( Vendor pages override platform defaults for a specific vendor. """ if not page_data.vendor_id: - from fastapi import HTTPException - - raise HTTPException( - status_code=400, - detail="vendor_id is required for vendor pages. Use /platform for platform defaults.", + raise ValidationException( + message="vendor_id is required for vendor pages. Use /platform for platform defaults.", + field="vendor_id", ) page = content_page_service.create_page( diff --git a/app/api/v1/admin/features.py b/app/api/v1/admin/features.py index 7598edff..e61864ca 100644 --- a/app/api/v1/admin/features.py +++ b/app/api/v1/admin/features.py @@ -11,15 +11,13 @@ Provides endpoints for: import logging -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api from app.core.database import get_db from app.services.feature_service import feature_service -from models.database.feature import Feature -from models.database.subscription import SubscriptionTier from models.database.user import User router = APIRouter(prefix="/features") @@ -103,6 +101,41 @@ class CategoryListResponse(BaseModel): categories: list[str] +class TierFeatureDetailResponse(BaseModel): + """Tier features with full details.""" + + tier_code: str + tier_name: str + features: list[dict] + feature_count: int + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def _feature_to_response(feature) -> FeatureResponse: + """Convert Feature model to response.""" + return FeatureResponse( + id=feature.id, + code=feature.code, + name=feature.name, + description=feature.description, + category=feature.category, + ui_location=feature.ui_location, + ui_icon=feature.ui_icon, + ui_route=feature.ui_route, + ui_badge_text=feature.ui_badge_text, + minimum_tier_id=feature.minimum_tier_id, + minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None, + minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None, + is_active=feature.is_active, + is_visible=feature.is_visible, + display_order=feature.display_order, + ) + + # ============================================================================ # Endpoints # ============================================================================ @@ -121,26 +154,7 @@ def list_features( ) return FeatureListResponse( - features=[ - FeatureResponse( - id=f.id, - code=f.code, - name=f.name, - description=f.description, - category=f.category, - ui_location=f.ui_location, - ui_icon=f.ui_icon, - ui_route=f.ui_route, - ui_badge_text=f.ui_badge_text, - minimum_tier_id=f.minimum_tier_id, - minimum_tier_code=f.minimum_tier.code if f.minimum_tier else None, - minimum_tier_name=f.minimum_tier.name if f.minimum_tier else None, - is_active=f.is_active, - is_visible=f.is_visible, - display_order=f.display_order, - ) - for f in features - ], + features=[_feature_to_response(f) for f in features], total=len(features), ) @@ -161,12 +175,7 @@ def list_tiers_with_features( db: Session = Depends(get_db), ): """List all tiers with their feature assignments.""" - tiers = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.is_active == True) # noqa: E712 - .order_by(SubscriptionTier.display_order) - .all() - ) + tiers = feature_service.get_all_tiers_with_features(db) return TierListWithFeaturesResponse( tiers=[ @@ -189,29 +198,19 @@ def get_feature( current_user: User = Depends(get_current_admin_api), db: Session = Depends(get_db), ): - """Get a single feature by code.""" + """ + Get a single feature by code. + + Raises 404 if feature not found. + """ feature = feature_service.get_feature_by_code(db, feature_code) if not feature: - raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found") + from app.exceptions import FeatureNotFoundError - return FeatureResponse( - id=feature.id, - code=feature.code, - name=feature.name, - description=feature.description, - category=feature.category, - ui_location=feature.ui_location, - ui_icon=feature.ui_icon, - ui_route=feature.ui_route, - ui_badge_text=feature.ui_badge_text, - minimum_tier_id=feature.minimum_tier_id, - minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None, - minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None, - is_active=feature.is_active, - is_visible=feature.is_visible, - display_order=feature.display_order, - ) + raise FeatureNotFoundError(feature_code) + + return _feature_to_response(feature) @router.put("/{feature_code}", response_model=FeatureResponse) @@ -221,73 +220,33 @@ def update_feature( current_user: User = Depends(get_current_admin_api), db: Session = Depends(get_db), ): - """Update feature metadata.""" - feature = db.query(Feature).filter(Feature.code == feature_code).first() + """ + Update feature metadata. - if not feature: - raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found") - - # Update fields if provided - if request.name is not None: - feature.name = request.name - if request.description is not None: - feature.description = request.description - if request.category is not None: - feature.category = request.category - if request.ui_location is not None: - feature.ui_location = request.ui_location - if request.ui_icon is not None: - feature.ui_icon = request.ui_icon - if request.ui_route is not None: - feature.ui_route = request.ui_route - if request.ui_badge_text is not None: - feature.ui_badge_text = request.ui_badge_text - if request.is_active is not None: - feature.is_active = request.is_active - if request.is_visible is not None: - feature.is_visible = request.is_visible - if request.display_order is not None: - feature.display_order = request.display_order - - # Update minimum tier if provided - if request.minimum_tier_code is not None: - if request.minimum_tier_code == "": - feature.minimum_tier_id = None - else: - tier = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == request.minimum_tier_code) - .first() - ) - if not tier: - raise HTTPException( - status_code=400, - detail=f"Tier '{request.minimum_tier_code}' not found", - ) - feature.minimum_tier_id = tier.id + Raises 404 if feature not found, 400 if tier code is invalid. + """ + feature = feature_service.update_feature( + db, + feature_code, + name=request.name, + description=request.description, + category=request.category, + ui_location=request.ui_location, + ui_icon=request.ui_icon, + ui_route=request.ui_route, + ui_badge_text=request.ui_badge_text, + minimum_tier_code=request.minimum_tier_code, + is_active=request.is_active, + is_visible=request.is_visible, + display_order=request.display_order, + ) db.commit() db.refresh(feature) logger.info(f"Updated feature {feature_code} by admin {current_user.id}") - return FeatureResponse( - id=feature.id, - code=feature.code, - name=feature.name, - description=feature.description, - category=feature.category, - ui_location=feature.ui_location, - ui_icon=feature.ui_icon, - ui_route=feature.ui_route, - ui_badge_text=feature.ui_badge_text, - minimum_tier_id=feature.minimum_tier_id, - minimum_tier_code=feature.minimum_tier.code if feature.minimum_tier else None, - minimum_tier_name=feature.minimum_tier.name if feature.minimum_tier else None, - is_active=feature.is_active, - is_visible=feature.is_visible, - display_order=feature.display_order, - ) + return _feature_to_response(feature) @router.put("/tiers/{tier_code}/features", response_model=TierFeaturesResponse) @@ -297,54 +256,46 @@ def update_tier_features( current_user: User = Depends(get_current_admin_api), db: Session = Depends(get_db), ): - """Update features assigned to a tier.""" - try: - tier = feature_service.update_tier_features(db, tier_code, request.feature_codes) - db.commit() + """ + Update features assigned to a tier. - logger.info( - f"Updated tier {tier_code} features to {len(request.feature_codes)} features " - f"by admin {current_user.id}" - ) + Raises 404 if tier not found, 422 if any feature codes are invalid. + """ + tier = feature_service.update_tier_features(db, tier_code, request.feature_codes) + db.commit() - return TierFeaturesResponse( - id=tier.id, - code=tier.code, - name=tier.name, - description=tier.description, - features=tier.features or [], - feature_count=len(tier.features or []), - ) + logger.info( + f"Updated tier {tier_code} features to {len(request.feature_codes)} features " + f"by admin {current_user.id}" + ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + return TierFeaturesResponse( + id=tier.id, + code=tier.code, + name=tier.name, + description=tier.description, + features=tier.features or [], + feature_count=len(tier.features or []), + ) -@router.get("/tiers/{tier_code}/features") +@router.get("/tiers/{tier_code}/features", response_model=TierFeatureDetailResponse) def get_tier_features( tier_code: str, current_user: User = Depends(get_current_admin_api), db: Session = Depends(get_db), ): - """Get features assigned to a specific tier.""" - tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first() + """ + Get features assigned to a specific tier with full details. - if not tier: - raise HTTPException(status_code=404, detail=f"Tier '{tier_code}' not found") + Raises 404 if tier not found. + """ + tier, features = feature_service.get_tier_features_with_details(db, tier_code) - # Get full feature details for the tier's features - feature_codes = tier.features or [] - features = ( - db.query(Feature) - .filter(Feature.code.in_(feature_codes)) - .order_by(Feature.category, Feature.display_order) - .all() - ) - - return { - "tier_code": tier.code, - "tier_name": tier.name, - "features": [ + return TierFeatureDetailResponse( + tier_code=tier.code, + tier_name=tier.name, + features=[ { "code": f.code, "name": f.name, @@ -353,5 +304,5 @@ def get_tier_features( } for f in features ], - "feature_count": len(features), - } + feature_count=len(features), + ) diff --git a/app/api/v1/vendor/features.py b/app/api/v1/vendor/features.py index 8d1c9473..28cff02b 100644 --- a/app/api/v1/vendor/features.py +++ b/app/api/v1/vendor/features.py @@ -16,12 +16,13 @@ Endpoints: import logging -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db +from app.exceptions import FeatureNotFoundError from app.services.feature_service import feature_service from models.database.user import User @@ -99,6 +100,13 @@ class FeatureGroupedResponse(BaseModel): total_count: int +class FeatureCheckResponse(BaseModel): + """Quick feature availability check response.""" + + has_feature: bool + feature_code: str + + # ============================================================================ # Endpoints # ============================================================================ @@ -285,7 +293,7 @@ def get_feature_detail( # Get feature feature = feature_service.get_feature_by_code(db, feature_code) if not feature: - raise HTTPException(status_code=404, detail=f"Feature '{feature_code}' not found") + raise FeatureNotFoundError(feature_code) # Check availability is_available = feature_service.has_feature(db, vendor_id, feature_code) @@ -317,7 +325,7 @@ def get_feature_detail( ) -@router.get("/check/{feature_code}") +@router.get("/check/{feature_code}", response_model=FeatureCheckResponse) def check_feature( feature_code: str, current_user: User = Depends(get_current_vendor_api), @@ -332,9 +340,9 @@ def check_feature( feature_code: The feature code Returns: - {"has_feature": true/false} + has_feature and feature_code """ vendor_id = current_user.token_vendor_id has_feature = feature_service.has_feature(db, vendor_id, feature_code) - return {"has_feature": has_feature, "feature_code": feature_code} + return FeatureCheckResponse(has_feature=has_feature, feature_code=feature_code) diff --git a/app/api/v1/vendor/usage.py b/app/api/v1/vendor/usage.py index 21e4df4d..e0e6d16c 100644 --- a/app/api/v1/vendor/usage.py +++ b/app/api/v1/vendor/usage.py @@ -12,16 +12,12 @@ import logging from fastapi import APIRouter, Depends from pydantic import BaseModel -from sqlalchemy import func from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db -from app.services.subscription_service import subscription_service -from models.database.product import Product -from models.database.subscription import SubscriptionTier +from app.services.usage_service import usage_service from models.database.user import User -from models.database.vendor import VendorUser router = APIRouter(prefix="/usage") logger = logging.getLogger(__name__) @@ -106,181 +102,44 @@ def get_usage( """ vendor_id = current_user.token_vendor_id - # Get subscription - subscription = subscription_service.get_or_create_subscription(db, vendor_id) - - # Get current tier - tier = subscription.tier_obj - if not tier: - tier = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == subscription.tier) - .first() - ) - - # Calculate usage metrics - usage_metrics = [] - - # Orders this period - orders_current = subscription.orders_this_period or 0 - orders_limit = subscription.orders_limit - orders_unlimited = orders_limit is None or orders_limit < 0 - orders_percentage = 0 if orders_unlimited else (orders_current / orders_limit * 100 if orders_limit > 0 else 100) - - usage_metrics.append( - UsageMetric( - name="orders", - current=orders_current, - limit=None if orders_unlimited else orders_limit, - percentage=orders_percentage, - is_unlimited=orders_unlimited, - is_at_limit=not orders_unlimited and orders_current >= orders_limit, - is_approaching_limit=not orders_unlimited and orders_percentage >= 80, - ) - ) - - # Products - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - products_limit = subscription.products_limit - products_unlimited = products_limit is None or products_limit < 0 - products_percentage = 0 if products_unlimited else (products_count / products_limit * 100 if products_limit > 0 else 100) - - usage_metrics.append( - UsageMetric( - name="products", - current=products_count, - limit=None if products_unlimited else products_limit, - percentage=products_percentage, - is_unlimited=products_unlimited, - is_at_limit=not products_unlimited and products_count >= products_limit, - is_approaching_limit=not products_unlimited and products_percentage >= 80, - ) - ) - - # Team members - team_count = ( - db.query(func.count(VendorUser.id)) - .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712 - .scalar() - or 0 - ) - team_limit = subscription.team_members_limit - team_unlimited = team_limit is None or team_limit < 0 - team_percentage = 0 if team_unlimited else (team_count / team_limit * 100 if team_limit > 0 else 100) - - usage_metrics.append( - UsageMetric( - name="team_members", - current=team_count, - limit=None if team_unlimited else team_limit, - percentage=team_percentage, - is_unlimited=team_unlimited, - is_at_limit=not team_unlimited and team_count >= team_limit, - is_approaching_limit=not team_unlimited and team_percentage >= 80, - ) - ) - - # Check for approaching/reached limits - has_limits_approaching = any(m.is_approaching_limit for m in usage_metrics) - has_limits_reached = any(m.is_at_limit for m in usage_metrics) - - # Get next tier for upgrade - all_tiers = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.is_active == True) # noqa: E712 - .order_by(SubscriptionTier.display_order) - .all() - ) - - current_tier_order = tier.display_order if tier else 0 - next_tier = None - for t in all_tiers: - if t.display_order > current_tier_order: - next_tier = t - break - - is_highest_tier = next_tier is None - - # Build upgrade info - upgrade_tier_info = None - upgrade_reasons = [] - - if next_tier: - # Calculate benefits - benefits = [] - if next_tier.orders_per_month and (not tier or (tier.orders_per_month and next_tier.orders_per_month > tier.orders_per_month)): - if next_tier.orders_per_month < 0: - benefits.append("Unlimited orders per month") - else: - benefits.append(f"{next_tier.orders_per_month:,} orders/month") - - if next_tier.products_limit and (not tier or (tier.products_limit and next_tier.products_limit > tier.products_limit)): - if next_tier.products_limit < 0: - benefits.append("Unlimited products") - else: - benefits.append(f"{next_tier.products_limit:,} products") - - if next_tier.team_members and (not tier or (tier.team_members and next_tier.team_members > tier.team_members)): - if next_tier.team_members < 0: - benefits.append("Unlimited team members") - else: - benefits.append(f"{next_tier.team_members} team members") - - # Add feature benefits - current_features = set(tier.features) if tier and tier.features else set() - next_features = set(next_tier.features) if next_tier.features else set() - new_features = next_features - current_features - - feature_names = { - "analytics_dashboard": "Advanced Analytics", - "api_access": "API Access", - "automation_rules": "Automation Rules", - "team_roles": "Team Roles & Permissions", - "custom_domain": "Custom Domain", - "webhooks": "Webhooks", - "accounting_export": "Accounting Export", - } - for feature in list(new_features)[:3]: # Show top 3 - if feature in feature_names: - benefits.append(feature_names[feature]) - - current_price = tier.price_monthly_cents if tier else 0 - upgrade_tier_info = UpgradeTierInfo( - code=next_tier.code, - name=next_tier.name, - price_monthly_cents=next_tier.price_monthly_cents, - price_increase_cents=next_tier.price_monthly_cents - current_price, - benefits=benefits, - ) - - # Build upgrade reasons - if has_limits_reached: - for m in usage_metrics: - if m.is_at_limit: - upgrade_reasons.append(f"You've reached your {m.name.replace('_', ' ')} limit") - elif has_limits_approaching: - for m in usage_metrics: - if m.is_approaching_limit: - upgrade_reasons.append(f"You're approaching your {m.name.replace('_', ' ')} limit ({int(m.percentage)}%)") + # Get usage data from service + usage_data = usage_service.get_vendor_usage(db, vendor_id) + # Convert to response return UsageResponse( tier=TierInfo( - code=tier.code if tier else subscription.tier, - name=tier.name if tier else subscription.tier.title(), - price_monthly_cents=tier.price_monthly_cents if tier else 0, - is_highest_tier=is_highest_tier, + code=usage_data.tier.code, + name=usage_data.tier.name, + price_monthly_cents=usage_data.tier.price_monthly_cents, + is_highest_tier=usage_data.tier.is_highest_tier, ), - usage=usage_metrics, - has_limits_approaching=has_limits_approaching, - has_limits_reached=has_limits_reached, - upgrade_available=not is_highest_tier, - upgrade_tier=upgrade_tier_info, - upgrade_reasons=upgrade_reasons, + usage=[ + UsageMetric( + name=m.name, + current=m.current, + limit=m.limit, + percentage=m.percentage, + is_unlimited=m.is_unlimited, + is_at_limit=m.is_at_limit, + is_approaching_limit=m.is_approaching_limit, + ) + for m in usage_data.usage + ], + has_limits_approaching=usage_data.has_limits_approaching, + has_limits_reached=usage_data.has_limits_reached, + upgrade_available=usage_data.upgrade_available, + upgrade_tier=( + UpgradeTierInfo( + code=usage_data.upgrade_tier.code, + name=usage_data.upgrade_tier.name, + price_monthly_cents=usage_data.upgrade_tier.price_monthly_cents, + price_increase_cents=usage_data.upgrade_tier.price_increase_cents, + benefits=usage_data.upgrade_tier.benefits, + ) + if usage_data.upgrade_tier + else None + ), + upgrade_reasons=usage_data.upgrade_reasons, ) @@ -303,78 +162,16 @@ def check_limit( """ vendor_id = current_user.token_vendor_id - if limit_type == "orders": - can_proceed, message = subscription_service.can_create_order(db, vendor_id) - subscription = subscription_service.get_subscription(db, vendor_id) - current = subscription.orders_this_period if subscription else 0 - limit = subscription.orders_limit if subscription else 0 - - elif limit_type == "products": - can_proceed, message = subscription_service.can_add_product(db, vendor_id) - subscription = subscription_service.get_subscription(db, vendor_id) - current = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - limit = subscription.products_limit if subscription else 0 - - elif limit_type == "team_members": - can_proceed, message = subscription_service.can_add_team_member(db, vendor_id) - subscription = subscription_service.get_subscription(db, vendor_id) - current = ( - db.query(func.count(VendorUser.id)) - .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712 - .scalar() - or 0 - ) - limit = subscription.team_members_limit if subscription else 0 - - else: - return LimitCheckResponse( - limit_type=limit_type, - can_proceed=True, - current=0, - limit=None, - percentage=0, - message=f"Unknown limit type: {limit_type}", - ) - - # Calculate percentage - is_unlimited = limit is None or limit < 0 - percentage = 0 if is_unlimited else (current / limit * 100 if limit > 0 else 100) - - # Get upgrade info if at limit - upgrade_tier_code = None - upgrade_tier_name = None - - if not can_proceed: - subscription = subscription_service.get_subscription(db, vendor_id) - current_tier = subscription.tier_obj if subscription else None - - if current_tier: - next_tier = ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.is_active == True, # noqa: E712 - SubscriptionTier.display_order > current_tier.display_order, - ) - .order_by(SubscriptionTier.display_order) - .first() - ) - - if next_tier: - upgrade_tier_code = next_tier.code - upgrade_tier_name = next_tier.name + # Check limit using service + check_data = usage_service.check_limit(db, vendor_id, limit_type) return LimitCheckResponse( - limit_type=limit_type, - can_proceed=can_proceed, - current=current, - limit=None if is_unlimited else limit, - percentage=percentage, - message=message, - upgrade_tier_code=upgrade_tier_code, - upgrade_tier_name=upgrade_tier_name, + limit_type=check_data.limit_type, + can_proceed=check_data.can_proceed, + current=check_data.current, + limit=check_data.limit, + percentage=check_data.percentage, + message=check_data.message, + upgrade_tier_code=check_data.upgrade_tier_code, + upgrade_tier_name=check_data.upgrade_tier_name, ) diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py index 610b3828..ebb4e8b7 100644 --- a/app/exceptions/__init__.py +++ b/app/exceptions/__init__.py @@ -263,6 +263,13 @@ from .onboarding import ( OnboardingSyncNotCompleteException, ) +# Feature management exceptions +from .feature import ( + FeatureNotFoundError, + InvalidFeatureCodesError, + TierNotFoundError, +) + __all__ = [ # Base exceptions "WizamartException", @@ -456,4 +463,8 @@ __all__ = [ "OnboardingCsvUrlRequiredException", "OnboardingSyncJobNotFoundException", "OnboardingSyncNotCompleteException", + # Feature exceptions + "FeatureNotFoundError", + "TierNotFoundError", + "InvalidFeatureCodesError", ] diff --git a/app/exceptions/feature.py b/app/exceptions/feature.py new file mode 100644 index 00000000..3df4988e --- /dev/null +++ b/app/exceptions/feature.py @@ -0,0 +1,42 @@ +# app/exceptions/feature.py +""" +Feature management exceptions. +""" + +from app.exceptions.base import ResourceNotFoundException, ValidationException + + +class FeatureNotFoundError(ResourceNotFoundException): + """Feature not found.""" + + def __init__(self, feature_code: str): + super().__init__( + resource_type="Feature", + identifier=feature_code, + message=f"Feature '{feature_code}' not found", + ) + self.feature_code = feature_code + + +class TierNotFoundError(ResourceNotFoundException): + """Subscription tier not found.""" + + def __init__(self, tier_code: str): + super().__init__( + resource_type="SubscriptionTier", + identifier=tier_code, + message=f"Tier '{tier_code}' not found", + ) + self.tier_code = tier_code + + +class InvalidFeatureCodesError(ValidationException): + """Invalid feature codes provided.""" + + def __init__(self, invalid_codes: set[str]): + codes_str = ", ".join(sorted(invalid_codes)) + super().__init__( + message=f"Invalid feature codes: {codes_str}", + details={"invalid_codes": list(invalid_codes)}, + ) + self.invalid_codes = invalid_codes diff --git a/app/services/feature_service.py b/app/services/feature_service.py index fd72f354..51807872 100644 --- a/app/services/feature_service.py +++ b/app/services/feature_service.py @@ -29,6 +29,11 @@ from functools import lru_cache from sqlalchemy.orm import Session, joinedload +from app.exceptions.feature import ( + FeatureNotFoundError, + InvalidFeatureCodesError, + TierNotFoundError, +) from models.database.feature import Feature, FeatureCode from models.database.subscription import SubscriptionTier, VendorSubscription @@ -370,6 +375,51 @@ class FeatureService: # Admin Operations # ========================================================================= + def get_all_tiers_with_features(self, db: Session) -> list[SubscriptionTier]: + """Get all active tiers with their features for admin.""" + return ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.is_active == True) # noqa: E712 + .order_by(SubscriptionTier.display_order) + .all() + ) + + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: + """ + Get tier by code, raising exception if not found. + + Raises: + TierNotFoundError: If tier not found + """ + tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first() + if not tier: + raise TierNotFoundError(tier_code) + return tier + + def get_tier_features_with_details( + self, db: Session, tier_code: str + ) -> tuple[SubscriptionTier, list[Feature]]: + """ + Get tier with full feature details. + + Returns: + Tuple of (tier, list of Feature objects) + + Raises: + TierNotFoundError: If tier not found + """ + tier = self.get_tier_by_code(db, tier_code) + feature_codes = tier.features or [] + + features = ( + db.query(Feature) + .filter(Feature.code.in_(feature_codes)) + .order_by(Feature.category, Feature.display_order) + .all() + ) + + return tier, features + def update_tier_features( self, db: Session, tier_code: str, feature_codes: list[str] ) -> SubscriptionTier: @@ -383,11 +433,15 @@ class FeatureService: Returns: Updated tier + + Raises: + TierNotFoundError: If tier not found + InvalidFeatureCodesError: If any feature codes are invalid """ tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first() if not tier: - raise ValueError(f"Tier '{tier_code}' not found") + raise TierNotFoundError(tier_code) # Validate feature codes exist valid_codes = { @@ -395,7 +449,7 @@ class FeatureService: } invalid = set(feature_codes) - valid_codes if invalid: - raise ValueError(f"Invalid feature codes: {invalid}") + raise InvalidFeatureCodesError(invalid) tier.features = feature_codes @@ -405,6 +459,86 @@ class FeatureService: logger.info(f"Updated features for tier {tier_code}: {len(feature_codes)} features") return tier + def update_feature( + self, + db: Session, + feature_code: str, + name: str | None = None, + description: str | None = None, + category: str | None = None, + ui_location: str | None = None, + ui_icon: str | None = None, + ui_route: str | None = None, + ui_badge_text: str | None = None, + minimum_tier_code: str | None = None, + is_active: bool | None = None, + is_visible: bool | None = None, + display_order: int | None = None, + ) -> Feature: + """ + Update feature metadata. + + Args: + db: Database session + feature_code: Feature code to update + ... other optional fields to update + + Returns: + Updated feature + + Raises: + FeatureNotFoundError: If feature not found + TierNotFoundError: If minimum_tier_code provided but not found + """ + feature = ( + db.query(Feature) + .options(joinedload(Feature.minimum_tier)) + .filter(Feature.code == feature_code) + .first() + ) + + if not feature: + raise FeatureNotFoundError(feature_code) + + # Update fields if provided + if name is not None: + feature.name = name + if description is not None: + feature.description = description + if category is not None: + feature.category = category + if ui_location is not None: + feature.ui_location = ui_location + if ui_icon is not None: + feature.ui_icon = ui_icon + if ui_route is not None: + feature.ui_route = ui_route + if ui_badge_text is not None: + feature.ui_badge_text = ui_badge_text + if is_active is not None: + feature.is_active = is_active + if is_visible is not None: + feature.is_visible = is_visible + if display_order is not None: + feature.display_order = display_order + + # Update minimum tier if provided + if minimum_tier_code is not None: + if minimum_tier_code == "": + feature.minimum_tier_id = None + else: + tier = ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == minimum_tier_code) + .first() + ) + if not tier: + raise TierNotFoundError(minimum_tier_code) + feature.minimum_tier_id = tier.id + + logger.info(f"Updated feature {feature_code}") + return feature + def update_feature_minimum_tier( self, db: Session, feature_code: str, tier_code: str | None ) -> Feature: @@ -415,16 +549,20 @@ class FeatureService: db: Database session feature_code: Feature code tier_code: Tier code or None + + Raises: + FeatureNotFoundError: If feature not found + TierNotFoundError: If tier_code provided but not found """ feature = db.query(Feature).filter(Feature.code == feature_code).first() if not feature: - raise ValueError(f"Feature '{feature_code}' not found") + raise FeatureNotFoundError(feature_code) if tier_code: tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == tier_code).first() if not tier: - raise ValueError(f"Tier '{tier_code}' not found") + raise TierNotFoundError(tier_code) feature.minimum_tier_id = tier.id else: feature.minimum_tier_id = None diff --git a/app/services/usage_service.py b/app/services/usage_service.py new file mode 100644 index 00000000..d650109b --- /dev/null +++ b/app/services/usage_service.py @@ -0,0 +1,445 @@ +# app/services/usage_service.py +""" +Usage and limits service. + +Provides methods for: +- Getting current usage vs limits +- Calculating upgrade recommendations +- Checking limits before actions +""" + +import logging +from dataclasses import dataclass + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from models.database.product import Product +from models.database.subscription import SubscriptionTier, VendorSubscription +from models.database.vendor import VendorUser + +logger = logging.getLogger(__name__) + + +@dataclass +class UsageMetricData: + """Usage metric data.""" + + name: str + current: int + limit: int | None + percentage: float + is_unlimited: bool + is_at_limit: bool + is_approaching_limit: bool + + +@dataclass +class TierInfoData: + """Tier information.""" + + code: str + name: str + price_monthly_cents: int + is_highest_tier: bool + + +@dataclass +class UpgradeTierData: + """Upgrade tier information.""" + + code: str + name: str + price_monthly_cents: int + price_increase_cents: int + benefits: list[str] + + +@dataclass +class UsageData: + """Full usage data.""" + + tier: TierInfoData + usage: list[UsageMetricData] + has_limits_approaching: bool + has_limits_reached: bool + upgrade_available: bool + upgrade_tier: UpgradeTierData | None + upgrade_reasons: list[str] + + +@dataclass +class LimitCheckData: + """Limit check result.""" + + limit_type: str + can_proceed: bool + current: int + limit: int | None + percentage: float + message: str | None + upgrade_tier_code: str | None + upgrade_tier_name: str | None + + +class UsageService: + """Service for usage and limits management.""" + + def get_vendor_usage(self, db: Session, vendor_id: int) -> UsageData: + """ + Get comprehensive usage data for a vendor. + + Returns current usage, limits, and upgrade recommendations. + """ + from app.services.subscription_service import subscription_service + + # Get subscription + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + + # Get current tier + tier = self._get_tier(db, subscription) + + # Calculate usage metrics + usage_metrics = self._calculate_usage_metrics(db, vendor_id, subscription) + + # Check for approaching/reached limits + has_limits_approaching = any(m.is_approaching_limit for m in usage_metrics) + has_limits_reached = any(m.is_at_limit for m in usage_metrics) + + # Get upgrade info + next_tier = self._get_next_tier(db, tier) + is_highest_tier = next_tier is None + + # Build upgrade info + upgrade_tier_info = None + upgrade_reasons = [] + + if next_tier: + upgrade_tier_info = self._build_upgrade_tier_info(tier, next_tier) + upgrade_reasons = self._build_upgrade_reasons( + usage_metrics, has_limits_reached, has_limits_approaching + ) + + return UsageData( + tier=TierInfoData( + code=tier.code if tier else subscription.tier, + name=tier.name if tier else subscription.tier.title(), + price_monthly_cents=tier.price_monthly_cents if tier else 0, + is_highest_tier=is_highest_tier, + ), + usage=usage_metrics, + has_limits_approaching=has_limits_approaching, + has_limits_reached=has_limits_reached, + upgrade_available=not is_highest_tier, + upgrade_tier=upgrade_tier_info, + upgrade_reasons=upgrade_reasons, + ) + + def check_limit( + self, db: Session, vendor_id: int, limit_type: str + ) -> LimitCheckData: + """ + Check a specific limit before performing an action. + + Args: + db: Database session + vendor_id: Vendor ID + limit_type: One of "orders", "products", "team_members" + + Returns: + LimitCheckData with proceed status and upgrade info + """ + from app.services.subscription_service import subscription_service + + if limit_type == "orders": + can_proceed, message = subscription_service.can_create_order(db, vendor_id) + subscription = subscription_service.get_subscription(db, vendor_id) + current = subscription.orders_this_period if subscription else 0 + limit = subscription.orders_limit if subscription else 0 + + elif limit_type == "products": + can_proceed, message = subscription_service.can_add_product(db, vendor_id) + subscription = subscription_service.get_subscription(db, vendor_id) + current = self._get_product_count(db, vendor_id) + limit = subscription.products_limit if subscription else 0 + + elif limit_type == "team_members": + can_proceed, message = subscription_service.can_add_team_member(db, vendor_id) + subscription = subscription_service.get_subscription(db, vendor_id) + current = self._get_team_member_count(db, vendor_id) + limit = subscription.team_members_limit if subscription else 0 + + else: + return LimitCheckData( + limit_type=limit_type, + can_proceed=True, + current=0, + limit=None, + percentage=0, + message=f"Unknown limit type: {limit_type}", + upgrade_tier_code=None, + upgrade_tier_name=None, + ) + + # Calculate percentage + is_unlimited = limit is None or limit < 0 + percentage = 0 if is_unlimited else (current / limit * 100 if limit > 0 else 100) + + # Get upgrade info if at limit + upgrade_tier_code = None + upgrade_tier_name = None + + if not can_proceed: + subscription = subscription_service.get_subscription(db, vendor_id) + current_tier = subscription.tier_obj if subscription else None + + if current_tier: + next_tier = self._get_next_tier(db, current_tier) + if next_tier: + upgrade_tier_code = next_tier.code + upgrade_tier_name = next_tier.name + + return LimitCheckData( + limit_type=limit_type, + can_proceed=can_proceed, + current=current, + limit=None if is_unlimited else limit, + percentage=percentage, + message=message, + upgrade_tier_code=upgrade_tier_code, + upgrade_tier_name=upgrade_tier_name, + ) + + # ========================================================================= + # Private Helper Methods + # ========================================================================= + + def _get_tier( + self, db: Session, subscription: VendorSubscription + ) -> SubscriptionTier | None: + """Get tier from subscription or query by code.""" + tier = subscription.tier_obj + if not tier: + tier = ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == subscription.tier) + .first() + ) + return tier + + def _get_product_count(self, db: Session, vendor_id: int) -> int: + """Get product count for vendor.""" + return ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + def _get_team_member_count(self, db: Session, vendor_id: int) -> int: + """Get active team member count for vendor.""" + return ( + db.query(func.count(VendorUser.id)) + .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) # noqa: E712 + .scalar() + or 0 + ) + + def _calculate_usage_metrics( + self, db: Session, vendor_id: int, subscription: VendorSubscription + ) -> list[UsageMetricData]: + """Calculate all usage metrics for a vendor.""" + metrics = [] + + # Orders this period + orders_current = subscription.orders_this_period or 0 + orders_limit = subscription.orders_limit + orders_unlimited = orders_limit is None or orders_limit < 0 + orders_percentage = ( + 0 + if orders_unlimited + else (orders_current / orders_limit * 100 if orders_limit > 0 else 100) + ) + + metrics.append( + UsageMetricData( + name="orders", + current=orders_current, + limit=None if orders_unlimited else orders_limit, + percentage=orders_percentage, + is_unlimited=orders_unlimited, + is_at_limit=not orders_unlimited and orders_current >= orders_limit, + is_approaching_limit=not orders_unlimited and orders_percentage >= 80, + ) + ) + + # Products + products_count = self._get_product_count(db, vendor_id) + products_limit = subscription.products_limit + products_unlimited = products_limit is None or products_limit < 0 + products_percentage = ( + 0 + if products_unlimited + else (products_count / products_limit * 100 if products_limit > 0 else 100) + ) + + metrics.append( + UsageMetricData( + name="products", + current=products_count, + limit=None if products_unlimited else products_limit, + percentage=products_percentage, + is_unlimited=products_unlimited, + is_at_limit=not products_unlimited and products_count >= products_limit, + is_approaching_limit=not products_unlimited and products_percentage >= 80, + ) + ) + + # Team members + team_count = self._get_team_member_count(db, vendor_id) + team_limit = subscription.team_members_limit + team_unlimited = team_limit is None or team_limit < 0 + team_percentage = ( + 0 + if team_unlimited + else (team_count / team_limit * 100 if team_limit > 0 else 100) + ) + + metrics.append( + UsageMetricData( + name="team_members", + current=team_count, + limit=None if team_unlimited else team_limit, + percentage=team_percentage, + is_unlimited=team_unlimited, + is_at_limit=not team_unlimited and team_count >= team_limit, + is_approaching_limit=not team_unlimited and team_percentage >= 80, + ) + ) + + return metrics + + def _get_next_tier( + self, db: Session, current_tier: SubscriptionTier | None + ) -> SubscriptionTier | None: + """Get next tier for upgrade.""" + current_tier_order = current_tier.display_order if current_tier else 0 + + return ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.is_active == True, # noqa: E712 + SubscriptionTier.display_order > current_tier_order, + ) + .order_by(SubscriptionTier.display_order) + .first() + ) + + def _build_upgrade_tier_info( + self, current_tier: SubscriptionTier | None, next_tier: SubscriptionTier + ) -> UpgradeTierData: + """Build upgrade tier information with benefits.""" + benefits = [] + + # Numeric limit benefits + if next_tier.orders_per_month and ( + not current_tier + or ( + current_tier.orders_per_month + and next_tier.orders_per_month > current_tier.orders_per_month + ) + ): + if next_tier.orders_per_month < 0: + benefits.append("Unlimited orders per month") + else: + benefits.append(f"{next_tier.orders_per_month:,} orders/month") + + if next_tier.products_limit and ( + not current_tier + or ( + current_tier.products_limit + and next_tier.products_limit > current_tier.products_limit + ) + ): + if next_tier.products_limit < 0: + benefits.append("Unlimited products") + else: + benefits.append(f"{next_tier.products_limit:,} products") + + if next_tier.team_members and ( + not current_tier + or ( + current_tier.team_members + and next_tier.team_members > current_tier.team_members + ) + ): + if next_tier.team_members < 0: + benefits.append("Unlimited team members") + else: + benefits.append(f"{next_tier.team_members} team members") + + # Feature benefits + current_features = ( + set(current_tier.features) if current_tier and current_tier.features else set() + ) + next_features = set(next_tier.features) if next_tier.features else set() + new_features = next_features - current_features + + feature_names = { + "analytics_dashboard": "Advanced Analytics", + "api_access": "API Access", + "automation_rules": "Automation Rules", + "team_roles": "Team Roles & Permissions", + "custom_domain": "Custom Domain", + "webhooks": "Webhooks", + "accounting_export": "Accounting Export", + } + for feature in list(new_features)[:3]: + if feature in feature_names: + benefits.append(feature_names[feature]) + + current_price = current_tier.price_monthly_cents if current_tier else 0 + + return UpgradeTierData( + code=next_tier.code, + name=next_tier.name, + price_monthly_cents=next_tier.price_monthly_cents, + price_increase_cents=next_tier.price_monthly_cents - current_price, + benefits=benefits, + ) + + def _build_upgrade_reasons( + self, + usage_metrics: list[UsageMetricData], + has_limits_reached: bool, + has_limits_approaching: bool, + ) -> list[str]: + """Build upgrade reasons based on usage.""" + reasons = [] + + if has_limits_reached: + for m in usage_metrics: + if m.is_at_limit: + reasons.append(f"You've reached your {m.name.replace('_', ' ')} limit") + elif has_limits_approaching: + for m in usage_metrics: + if m.is_approaching_limit: + reasons.append( + f"You're approaching your {m.name.replace('_', ' ')} limit ({int(m.percentage)}%)" + ) + + return reasons + + +# Singleton instance +usage_service = UsageService() + +__all__ = [ + "usage_service", + "UsageService", + "UsageData", + "UsageMetricData", + "TierInfoData", + "UpgradeTierData", + "LimitCheckData", +] diff --git a/tests/unit/services/test_feature_service.py b/tests/unit/services/test_feature_service.py new file mode 100644 index 00000000..6f9a1116 --- /dev/null +++ b/tests/unit/services/test_feature_service.py @@ -0,0 +1,367 @@ +# tests/unit/services/test_feature_service.py +"""Unit tests for FeatureService.""" + +import pytest + +from app.exceptions import FeatureNotFoundError, InvalidFeatureCodesError, TierNotFoundError +from app.services.feature_service import FeatureService, feature_service +from models.database.feature import Feature +from models.database.subscription import SubscriptionTier, VendorSubscription + + +@pytest.mark.unit +@pytest.mark.features +class TestFeatureServiceAvailability: + """Test suite for feature availability checking.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = FeatureService() + + def test_has_feature_true(self, db, test_vendor_with_subscription): + """Test has_feature returns True for available feature.""" + vendor_id = test_vendor_with_subscription.id + result = self.service.has_feature(db, vendor_id, "basic_reports") + assert result is True + + def test_has_feature_false(self, db, test_vendor_with_subscription): + """Test has_feature returns False for unavailable feature.""" + vendor_id = test_vendor_with_subscription.id + result = self.service.has_feature(db, vendor_id, "api_access") + assert result is False + + def test_has_feature_no_subscription(self, db, test_vendor): + """Test has_feature returns False for vendor without subscription.""" + result = self.service.has_feature(db, test_vendor.id, "basic_reports") + assert result is False + + def test_get_vendor_feature_codes(self, db, test_vendor_with_subscription): + """Test getting all feature codes for vendor.""" + vendor_id = test_vendor_with_subscription.id + features = self.service.get_vendor_feature_codes(db, vendor_id) + + assert isinstance(features, set) + assert "basic_reports" in features + assert "api_access" not in features + + +@pytest.mark.unit +@pytest.mark.features +class TestFeatureServiceListing: + """Test suite for feature listing operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = FeatureService() + + def test_get_vendor_features(self, db, test_vendor_with_subscription, test_features): + """Test getting all features with availability.""" + vendor_id = test_vendor_with_subscription.id + features = self.service.get_vendor_features(db, vendor_id) + + assert len(features) > 0 + basic_reports = next((f for f in features if f.code == "basic_reports"), None) + assert basic_reports is not None + assert basic_reports.is_available is True + + api_access = next((f for f in features if f.code == "api_access"), None) + assert api_access is not None + assert api_access.is_available is False + + def test_get_vendor_features_by_category( + self, db, test_vendor_with_subscription, test_features + ): + """Test filtering features by category.""" + vendor_id = test_vendor_with_subscription.id + features = self.service.get_vendor_features(db, vendor_id, category="analytics") + + assert all(f.category == "analytics" for f in features) + + def test_get_vendor_features_available_only( + self, db, test_vendor_with_subscription, test_features + ): + """Test getting only available features.""" + vendor_id = test_vendor_with_subscription.id + features = self.service.get_vendor_features( + db, vendor_id, include_unavailable=False + ) + + assert all(f.is_available for f in features) + + def test_get_available_feature_codes(self, db, test_vendor_with_subscription): + """Test getting simple list of available codes.""" + vendor_id = test_vendor_with_subscription.id + codes = self.service.get_available_feature_codes(db, vendor_id) + + assert isinstance(codes, list) + assert "basic_reports" in codes + + +@pytest.mark.unit +@pytest.mark.features +class TestFeatureServiceMetadata: + """Test suite for feature metadata operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = FeatureService() + + def test_get_feature_by_code(self, db, test_features): + """Test getting feature by code.""" + feature = self.service.get_feature_by_code(db, "basic_reports") + + assert feature is not None + assert feature.code == "basic_reports" + assert feature.name == "Basic Reports" + + def test_get_feature_by_code_not_found(self, db, test_features): + """Test getting non-existent feature returns None.""" + feature = self.service.get_feature_by_code(db, "nonexistent") + assert feature is None + + def test_get_feature_upgrade_info(self, db, test_features, test_subscription_tiers): + """Test getting upgrade info for locked feature.""" + info = self.service.get_feature_upgrade_info(db, "api_access") + + assert info is not None + assert info.feature_code == "api_access" + assert info.required_tier_code == "professional" + + def test_get_feature_upgrade_info_no_minimum_tier(self, db, test_features): + """Test upgrade info for feature without minimum tier.""" + # basic_reports has no minimum tier in fixtures + info = self.service.get_feature_upgrade_info(db, "basic_reports") + assert info is None + + def test_get_all_features(self, db, test_features): + """Test getting all features for admin.""" + features = self.service.get_all_features(db) + assert len(features) >= 3 + + def test_get_all_features_by_category(self, db, test_features): + """Test filtering features by category.""" + features = self.service.get_all_features(db, category="analytics") + assert all(f.category == "analytics" for f in features) + + def test_get_categories(self, db, test_features): + """Test getting unique categories.""" + categories = self.service.get_categories(db) + assert "analytics" in categories + assert "integrations" in categories + + +@pytest.mark.unit +@pytest.mark.features +class TestFeatureServiceCache: + """Test suite for cache operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = FeatureService() + + def test_cache_invalidation(self, db, test_vendor_with_subscription): + """Test cache invalidation for vendor.""" + vendor_id = test_vendor_with_subscription.id + + # Prime the cache + self.service.get_vendor_feature_codes(db, vendor_id) + assert self.service._cache.get(vendor_id) is not None + + # Invalidate + self.service.invalidate_vendor_cache(vendor_id) + assert self.service._cache.get(vendor_id) is None + + def test_cache_invalidate_all(self, db, test_vendor_with_subscription): + """Test invalidating entire cache.""" + vendor_id = test_vendor_with_subscription.id + + # Prime the cache + self.service.get_vendor_feature_codes(db, vendor_id) + + # Invalidate all + self.service.invalidate_all_cache() + assert self.service._cache.get(vendor_id) is None + + +@pytest.mark.unit +@pytest.mark.features +class TestFeatureServiceAdmin: + """Test suite for admin operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = FeatureService() + + def test_get_all_tiers_with_features(self, db, test_subscription_tiers): + """Test getting all tiers.""" + tiers = self.service.get_all_tiers_with_features(db) + + assert len(tiers) == 2 + assert tiers[0].code == "essential" + assert tiers[1].code == "professional" + + def test_update_tier_features(self, db, test_subscription_tiers, test_features): + """Test updating tier features.""" + tier = self.service.update_tier_features( + db, "essential", ["basic_reports", "api_access"] + ) + db.commit() + + assert "api_access" in tier.features + + def test_update_tier_features_invalid_codes( + self, db, test_subscription_tiers, test_features + ): + """Test updating tier with invalid feature codes.""" + with pytest.raises(InvalidFeatureCodesError) as exc_info: + self.service.update_tier_features( + db, "essential", ["basic_reports", "nonexistent_feature"] + ) + + assert "nonexistent_feature" in exc_info.value.invalid_codes + + def test_update_tier_features_tier_not_found(self, db, test_features): + """Test updating non-existent tier.""" + with pytest.raises(TierNotFoundError) as exc_info: + self.service.update_tier_features(db, "nonexistent", ["basic_reports"]) + + assert exc_info.value.tier_code == "nonexistent" + + def test_update_feature(self, db, test_features): + """Test updating feature metadata.""" + feature = self.service.update_feature( + db, "basic_reports", name="Updated Reports", description="New description" + ) + db.commit() + + assert feature.name == "Updated Reports" + assert feature.description == "New description" + + def test_update_feature_not_found(self, db, test_features): + """Test updating non-existent feature.""" + with pytest.raises(FeatureNotFoundError) as exc_info: + self.service.update_feature(db, "nonexistent", name="Test") + + assert exc_info.value.feature_code == "nonexistent" + + def test_update_feature_minimum_tier( + self, db, test_features, test_subscription_tiers + ): + """Test updating feature minimum tier.""" + feature = self.service.update_feature( + db, "basic_reports", minimum_tier_code="professional" + ) + db.commit() + db.refresh(feature) + + assert feature.minimum_tier is not None + assert feature.minimum_tier.code == "professional" + + +# ==================== Fixtures ==================== + + +@pytest.fixture +def test_subscription_tiers(db): + """Create multiple subscription tiers.""" + tiers = [ + SubscriptionTier( + code="essential", + name="Essential", + description="Essential plan", + price_monthly_cents=4900, + price_annual_cents=49000, + orders_per_month=100, + products_limit=500, + team_members=2, + features=["basic_reports"], + is_active=True, + display_order=1, + ), + SubscriptionTier( + code="professional", + name="Professional", + description="Professional plan", + price_monthly_cents=9900, + price_annual_cents=99000, + orders_per_month=500, + products_limit=2000, + team_members=5, + features=["basic_reports", "api_access", "analytics_dashboard"], + is_active=True, + display_order=2, + ), + ] + for tier in tiers: + db.add(tier) + db.commit() + for tier in tiers: + db.refresh(tier) + return tiers + + +@pytest.fixture +def test_vendor_with_subscription(db, test_vendor, test_subscription_tiers): + """Create a vendor with an active subscription.""" + from datetime import datetime, timezone + + essential_tier = test_subscription_tiers[0] # Use the essential tier from tiers list + now = datetime.now(timezone.utc) + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + tier_id=essential_tier.id, + status="active", + period_start=now, + period_end=now.replace(month=now.month + 1 if now.month < 12 else 1), + orders_this_period=10, + ) + db.add(subscription) + db.commit() + db.refresh(test_vendor) + return test_vendor + + +@pytest.fixture +def test_features(db, test_subscription_tiers): + """Create test features.""" + features = [ + Feature( + code="basic_reports", + name="Basic Reports", + description="View basic analytics reports", + category="analytics", + ui_location="sidebar", + ui_icon="chart-bar", + is_active=True, + display_order=1, + ), + Feature( + code="api_access", + name="API Access", + description="Access the REST API", + category="integrations", + ui_location="settings", + ui_icon="code", + minimum_tier_id=test_subscription_tiers[1].id, # Professional + is_active=True, + display_order=2, + ), + Feature( + code="analytics_dashboard", + name="Analytics Dashboard", + description="Advanced analytics dashboard", + category="analytics", + ui_location="sidebar", + ui_icon="presentation-chart-line", + minimum_tier_id=test_subscription_tiers[1].id, + is_active=True, + display_order=3, + ), + ] + for feature in features: + db.add(feature) + db.commit() + for feature in features: + db.refresh(feature) + return features diff --git a/tests/unit/services/test_usage_service.py b/tests/unit/services/test_usage_service.py new file mode 100644 index 00000000..754765d9 --- /dev/null +++ b/tests/unit/services/test_usage_service.py @@ -0,0 +1,308 @@ +# tests/unit/services/test_usage_service.py +"""Unit tests for UsageService.""" + +import pytest + +from app.services.usage_service import UsageService, usage_service +from models.database.product import Product +from models.database.subscription import SubscriptionTier, VendorSubscription +from models.database.vendor import VendorUser + + +@pytest.mark.unit +@pytest.mark.usage +class TestUsageServiceGetUsage: + """Test suite for get_vendor_usage operation.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = UsageService() + + def test_get_vendor_usage_basic(self, db, test_vendor_with_subscription): + """Test getting basic usage data.""" + vendor_id = test_vendor_with_subscription.id + usage = self.service.get_vendor_usage(db, vendor_id) + + assert usage.tier.code == "essential" + assert usage.tier.name == "Essential" + assert len(usage.usage) == 3 + + def test_get_vendor_usage_metrics(self, db, test_vendor_with_subscription): + """Test usage metrics are calculated correctly.""" + vendor_id = test_vendor_with_subscription.id + usage = self.service.get_vendor_usage(db, vendor_id) + + orders_metric = next((m for m in usage.usage if m.name == "orders"), None) + assert orders_metric is not None + assert orders_metric.current == 10 + assert orders_metric.limit == 100 + assert orders_metric.percentage == 10.0 + assert orders_metric.is_unlimited is False + + def test_get_vendor_usage_at_limit(self, db, test_vendor_at_limit): + """Test usage shows at limit correctly.""" + vendor_id = test_vendor_at_limit.id + usage = self.service.get_vendor_usage(db, vendor_id) + + orders_metric = next((m for m in usage.usage if m.name == "orders"), None) + assert orders_metric.is_at_limit is True + assert usage.has_limits_reached is True + + def test_get_vendor_usage_approaching_limit(self, db, test_vendor_approaching_limit): + """Test usage shows approaching limit correctly.""" + vendor_id = test_vendor_approaching_limit.id + usage = self.service.get_vendor_usage(db, vendor_id) + + orders_metric = next((m for m in usage.usage if m.name == "orders"), None) + assert orders_metric.is_approaching_limit is True + assert usage.has_limits_approaching is True + + def test_get_vendor_usage_upgrade_available( + self, db, test_vendor_with_subscription, test_professional_tier + ): + """Test upgrade info when not on highest tier.""" + vendor_id = test_vendor_with_subscription.id + usage = self.service.get_vendor_usage(db, vendor_id) + + assert usage.upgrade_available is True + assert usage.upgrade_tier is not None + assert usage.upgrade_tier.code == "professional" + + def test_get_vendor_usage_highest_tier(self, db, test_vendor_on_professional): + """Test no upgrade when on highest tier.""" + vendor_id = test_vendor_on_professional.id + usage = self.service.get_vendor_usage(db, vendor_id) + + assert usage.tier.is_highest_tier is True + assert usage.upgrade_available is False + assert usage.upgrade_tier is None + + +@pytest.mark.unit +@pytest.mark.usage +class TestUsageServiceCheckLimit: + """Test suite for check_limit operation.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = UsageService() + + def test_check_orders_limit_can_proceed(self, db, test_vendor_with_subscription): + """Test checking orders limit when under limit.""" + vendor_id = test_vendor_with_subscription.id + result = self.service.check_limit(db, vendor_id, "orders") + + assert result.can_proceed is True + assert result.current == 10 + assert result.limit == 100 + + def test_check_products_limit(self, db, test_vendor_with_products): + """Test checking products limit.""" + vendor_id = test_vendor_with_products.id + result = self.service.check_limit(db, vendor_id, "products") + + assert result.can_proceed is True + assert result.current == 5 + assert result.limit == 500 + + def test_check_team_members_limit(self, db, test_vendor_with_team): + """Test checking team members limit when at limit.""" + vendor_id = test_vendor_with_team.id + result = self.service.check_limit(db, vendor_id, "team_members") + + # At limit (2/2) - can_proceed should be False + assert result.can_proceed is False + assert result.current == 2 + assert result.limit == 2 + assert result.percentage == 100.0 + + def test_check_unknown_limit_type(self, db, test_vendor_with_subscription): + """Test checking unknown limit type.""" + vendor_id = test_vendor_with_subscription.id + result = self.service.check_limit(db, vendor_id, "unknown") + + assert result.can_proceed is True + assert "Unknown limit type" in result.message + + def test_check_limit_upgrade_info_when_blocked(self, db, test_vendor_at_limit): + """Test upgrade info is provided when at limit.""" + vendor_id = test_vendor_at_limit.id + result = self.service.check_limit(db, vendor_id, "orders") + + assert result.can_proceed is False + assert result.upgrade_tier_code == "professional" + assert result.upgrade_tier_name == "Professional" + + +# ==================== Fixtures ==================== + + +@pytest.fixture +def test_essential_tier(db): + """Create essential subscription tier.""" + tier = SubscriptionTier( + code="essential", + name="Essential", + description="Essential plan", + price_monthly_cents=4900, + price_annual_cents=49000, + orders_per_month=100, + products_limit=500, + team_members=2, + features=["basic_reports"], + is_active=True, + display_order=1, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def test_professional_tier(db, test_essential_tier): + """Create professional subscription tier.""" + tier = SubscriptionTier( + code="professional", + name="Professional", + description="Professional plan", + price_monthly_cents=9900, + price_annual_cents=99000, + orders_per_month=500, + products_limit=2000, + team_members=10, + features=["basic_reports", "api_access", "analytics_dashboard"], + is_active=True, + display_order=2, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def test_vendor_with_subscription(db, test_vendor, test_essential_tier): + """Create vendor with active subscription.""" + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + tier_id=test_essential_tier.id, + status="active", + period_start=now, + period_end=now.replace(month=now.month + 1 if now.month < 12 else 1), + orders_this_period=10, + ) + db.add(subscription) + db.commit() + db.refresh(test_vendor) + return test_vendor + + +@pytest.fixture +def test_vendor_at_limit(db, test_vendor, test_essential_tier, test_professional_tier): + """Create vendor at order limit.""" + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + tier_id=test_essential_tier.id, + status="active", + period_start=now, + period_end=now.replace(month=now.month + 1 if now.month < 12 else 1), + orders_this_period=100, # At limit + ) + db.add(subscription) + db.commit() + db.refresh(test_vendor) + return test_vendor + + +@pytest.fixture +def test_vendor_approaching_limit(db, test_vendor, test_essential_tier): + """Create vendor approaching order limit (>=80%).""" + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + tier_id=test_essential_tier.id, + status="active", + period_start=now, + period_end=now.replace(month=now.month + 1 if now.month < 12 else 1), + orders_this_period=85, # 85% of 100 + ) + db.add(subscription) + db.commit() + db.refresh(test_vendor) + return test_vendor + + +@pytest.fixture +def test_vendor_on_professional(db, test_vendor, test_professional_tier): + """Create vendor on highest tier.""" + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="professional", + tier_id=test_professional_tier.id, + status="active", + period_start=now, + period_end=now.replace(month=now.month + 1 if now.month < 12 else 1), + orders_this_period=50, + ) + db.add(subscription) + db.commit() + db.refresh(test_vendor) + return test_vendor + + +@pytest.fixture +def test_vendor_with_products(db, test_vendor_with_subscription, marketplace_product_factory): + """Create vendor with products.""" + for i in range(5): + # Create marketplace product first + mp = marketplace_product_factory(db, title=f"Test Product {i}") + product = Product( + vendor_id=test_vendor_with_subscription.id, + marketplace_product_id=mp.id, + price_cents=1000, + is_active=True, + ) + db.add(product) + db.commit() + return test_vendor_with_subscription + + +@pytest.fixture +def test_vendor_with_team(db, test_vendor_with_subscription, test_user, other_user): + """Create vendor with team members (owner + team member = 2).""" + from models.database.vendor import VendorUserType + + # Add owner + owner = VendorUser( + vendor_id=test_vendor_with_subscription.id, + user_id=test_user.id, + user_type=VendorUserType.OWNER.value, + is_active=True, + ) + db.add(owner) + + # Add team member + team_member = VendorUser( + vendor_id=test_vendor_with_subscription.id, + user_id=other_user.id, + user_type=VendorUserType.TEAM_MEMBER.value, + is_active=True, + ) + db.add(team_member) + db.commit() + return test_vendor_with_subscription