diff --git a/app/api/v1/admin/subscriptions.py b/app/api/v1/admin/subscriptions.py index 2ba9dd37..b181c700 100644 --- a/app/api/v1/admin/subscriptions.py +++ b/app/api/v1/admin/subscriptions.py @@ -12,16 +12,13 @@ Provides endpoints for platform administrators to manage: import logging from fastapi import APIRouter, Depends, Path, Query -from sqlalchemy import func from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api from app.core.database import get_db from app.services.admin_subscription_service import admin_subscription_service -from models.database.product import Product -from models.database.user import User -from models.database.vendor import VendorUser from app.services.subscription_service import subscription_service +from models.database.user import User from models.schema.billing import ( BillingHistoryListResponse, BillingHistoryWithVendor, @@ -246,13 +243,8 @@ def create_vendor_subscription( Creates a new subscription with the specified tier and status. Defaults to Essential tier with trial status. """ - from models.database.vendor import Vendor - # Verify vendor exists - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - if not vendor: - from app.exceptions import ResourceNotFoundException - raise ResourceNotFoundException("Vendor", str(vendor_id)) + vendor = admin_subscription_service.get_vendor(db, vendor_id) # Create subscription using the subscription service sub = subscription_service.get_or_create_subscription( @@ -272,22 +264,7 @@ def create_vendor_subscription( db.refresh(sub) # Get usage counts - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - team_count = ( - db.query(func.count(VendorUser.id)) - .filter( - VendorUser.vendor_id == vendor_id, - VendorUser.is_active == True, # noqa: E712 - ) - .scalar() - or 0 - ) + usage = admin_subscription_service.get_vendor_usage_counts(db, vendor_id) logger.info(f"Admin created subscription for vendor {vendor_id}: tier={create_data.tier}") @@ -295,8 +272,8 @@ def create_vendor_subscription( **VendorSubscriptionResponse.model_validate(sub).model_dump(), vendor_name=vendor.name, vendor_code=vendor.subdomain, - products_count=products_count, - team_count=team_count, + products_count=usage["products_count"], + team_count=usage["team_count"], ) @@ -310,29 +287,14 @@ def get_vendor_subscription( sub, vendor = admin_subscription_service.get_subscription(db, vendor_id) # Get usage counts - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - team_count = ( - db.query(func.count(VendorUser.id)) - .filter( - VendorUser.vendor_id == vendor_id, - VendorUser.is_active == True, # noqa: E712 - ) - .scalar() - or 0 - ) + usage = admin_subscription_service.get_vendor_usage_counts(db, vendor_id) return VendorSubscriptionWithVendor( **VendorSubscriptionResponse.model_validate(sub).model_dump(), vendor_name=vendor.name, vendor_code=vendor.subdomain, - products_count=products_count, - team_count=team_count, + products_count=usage["products_count"], + team_count=usage["team_count"], ) @@ -358,27 +320,12 @@ def update_vendor_subscription( db.refresh(sub) # Get usage counts - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - team_count = ( - db.query(func.count(VendorUser.id)) - .filter( - VendorUser.vendor_id == vendor_id, - VendorUser.is_active == True, # noqa: E712 - ) - .scalar() - or 0 - ) + usage = admin_subscription_service.get_vendor_usage_counts(db, vendor_id) return VendorSubscriptionWithVendor( **VendorSubscriptionResponse.model_validate(sub).model_dump(), vendor_name=vendor.name, vendor_code=vendor.subdomain, - products_count=products_count, - team_count=team_count, + products_count=usage["products_count"], + team_count=usage["team_count"], ) diff --git a/app/services/admin_subscription_service.py b/app/services/admin_subscription_service.py index 5f77fd62..9e5d37fc 100644 --- a/app/services/admin_subscription_service.py +++ b/app/services/admin_subscription_service.py @@ -21,13 +21,14 @@ from app.exceptions import ( ResourceNotFoundException, TierNotFoundException, ) +from models.database.product import Product from models.database.subscription import ( BillingHistory, SubscriptionStatus, SubscriptionTier, VendorSubscription, ) -from models.database.vendor import Vendor +from models.database.vendor import Vendor, VendorUser logger = logging.getLogger(__name__) @@ -197,6 +198,39 @@ class AdminSubscriptionService: return sub, vendor + def get_vendor(self, db: Session, vendor_id: int) -> Vendor: + """Get a vendor by ID.""" + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + + if not vendor: + raise ResourceNotFoundException("Vendor", str(vendor_id)) + + return vendor + + def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict: + """Get usage counts (products and team members) for a vendor.""" + products_count = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + team_count = ( + db.query(func.count(VendorUser.id)) + .filter( + VendorUser.vendor_id == vendor_id, + VendorUser.is_active == True, # noqa: E712 + ) + .scalar() + or 0 + ) + + return { + "products_count": products_count, + "team_count": team_count, + } + # ========================================================================= # Billing History # =========================================================================