feat: complete billing module migration (Phase 5)
Migrates billing module to self-contained structure: - Create app/modules/billing/services/ with subscription, stripe, admin services - Create app/modules/billing/models/ re-exporting from central location - Create app/modules/billing/schemas/ re-exporting from central location - Create app/modules/billing/tasks/ with 4 scheduled Celery tasks - Create app/modules/billing/exceptions.py with module-specific exceptions - Update definition.py with is_self_contained=True and scheduled_tasks Celery task migration: - reset_period_counters -> billing module - check_trial_expirations -> billing module - sync_stripe_status -> billing module - cleanup_stale_subscriptions -> billing module - capture_capacity_snapshot remains in legacy (will go to monitoring) Backward compatibility: - Create re-exports in app/services/ for subscription, stripe, admin services - Old import paths continue to work - Update celery_config.py to use module-defined schedules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
28
app/modules/billing/services/__init__.py
Normal file
28
app/modules/billing/services/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# app/modules/billing/services/__init__.py
|
||||
"""
|
||||
Billing module services.
|
||||
|
||||
Provides subscription management, Stripe integration, and admin operations.
|
||||
"""
|
||||
|
||||
from app.modules.billing.services.subscription_service import (
|
||||
SubscriptionService,
|
||||
subscription_service,
|
||||
)
|
||||
from app.modules.billing.services.stripe_service import (
|
||||
StripeService,
|
||||
stripe_service,
|
||||
)
|
||||
from app.modules.billing.services.admin_subscription_service import (
|
||||
AdminSubscriptionService,
|
||||
admin_subscription_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SubscriptionService",
|
||||
"subscription_service",
|
||||
"StripeService",
|
||||
"stripe_service",
|
||||
"AdminSubscriptionService",
|
||||
"admin_subscription_service",
|
||||
]
|
||||
352
app/modules/billing/services/admin_subscription_service.py
Normal file
352
app/modules/billing/services/admin_subscription_service.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# app/modules/billing/services/admin_subscription_service.py
|
||||
"""
|
||||
Admin Subscription Service.
|
||||
|
||||
Handles subscription management operations for platform administrators:
|
||||
- Subscription tier CRUD
|
||||
- Vendor subscription management
|
||||
- Billing history queries
|
||||
- Subscription analytics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from math import ceil
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
BusinessLogicException,
|
||||
ConflictException,
|
||||
ResourceNotFoundException,
|
||||
)
|
||||
from app.modules.billing.exceptions import TierNotFoundException
|
||||
from app.modules.billing.models import (
|
||||
BillingHistory,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminSubscriptionService:
|
||||
"""Service for admin subscription management operations."""
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Tiers
|
||||
# =========================================================================
|
||||
|
||||
def get_tiers(
|
||||
self, db: Session, include_inactive: bool = False
|
||||
) -> list[SubscriptionTier]:
|
||||
"""Get all subscription tiers."""
|
||||
query = db.query(SubscriptionTier)
|
||||
|
||||
if not include_inactive:
|
||||
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
|
||||
|
||||
return query.order_by(SubscriptionTier.display_order).all()
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||
"""Get a subscription tier by code."""
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_code)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tier:
|
||||
raise TierNotFoundException(tier_code)
|
||||
|
||||
return tier
|
||||
|
||||
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
|
||||
"""Create a new subscription tier."""
|
||||
# Check for duplicate code
|
||||
existing = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_data["code"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ConflictException(
|
||||
f"Tier with code '{tier_data['code']}' already exists"
|
||||
)
|
||||
|
||||
tier = SubscriptionTier(**tier_data)
|
||||
db.add(tier)
|
||||
|
||||
logger.info(f"Created subscription tier: {tier.code}")
|
||||
return tier
|
||||
|
||||
def update_tier(
|
||||
self, db: Session, tier_code: str, update_data: dict
|
||||
) -> SubscriptionTier:
|
||||
"""Update a subscription tier."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(tier, field, value)
|
||||
|
||||
logger.info(f"Updated subscription tier: {tier.code}")
|
||||
return tier
|
||||
|
||||
def deactivate_tier(self, db: Session, tier_code: str) -> None:
|
||||
"""Soft-delete a subscription tier."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
# Check if any active subscriptions use this tier
|
||||
active_subs = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.tier == tier_code,
|
||||
VendorSubscription.status.in_([
|
||||
SubscriptionStatus.ACTIVE.value,
|
||||
SubscriptionStatus.TRIAL.value,
|
||||
]),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if active_subs > 0:
|
||||
raise BusinessLogicException(
|
||||
f"Cannot delete tier: {active_subs} active subscriptions are using it"
|
||||
)
|
||||
|
||||
tier.is_active = False
|
||||
|
||||
logger.info(f"Soft-deleted subscription tier: {tier.code}")
|
||||
|
||||
# =========================================================================
|
||||
# Vendor Subscriptions
|
||||
# =========================================================================
|
||||
|
||||
def list_subscriptions(
|
||||
self,
|
||||
db: Session,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
status: str | None = None,
|
||||
tier: str | None = None,
|
||||
search: str | None = None,
|
||||
) -> dict:
|
||||
"""List vendor subscriptions with filtering and pagination."""
|
||||
query = (
|
||||
db.query(VendorSubscription, Vendor)
|
||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if status:
|
||||
query = query.filter(VendorSubscription.status == status)
|
||||
if tier:
|
||||
query = query.filter(VendorSubscription.tier == tier)
|
||||
if search:
|
||||
query = query.filter(Vendor.name.ilike(f"%{search}%"))
|
||||
|
||||
# Count total
|
||||
total = query.count()
|
||||
|
||||
# Paginate
|
||||
offset = (page - 1) * per_page
|
||||
results = (
|
||||
query.order_by(VendorSubscription.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||
}
|
||||
|
||||
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
|
||||
"""Get subscription for a specific vendor."""
|
||||
result = (
|
||||
db.query(VendorSubscription, Vendor)
|
||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise ResourceNotFoundException("Subscription", str(vendor_id))
|
||||
|
||||
return result
|
||||
|
||||
def update_subscription(
|
||||
self, db: Session, vendor_id: int, update_data: dict
|
||||
) -> tuple:
|
||||
"""Update a vendor's subscription."""
|
||||
result = self.get_subscription(db, vendor_id)
|
||||
sub, vendor = result
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(sub, field, value)
|
||||
|
||||
logger.info(
|
||||
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
|
||||
)
|
||||
|
||||
return sub, vendor
|
||||
|
||||
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""Get a vendor by ID."""
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
|
||||
if not vendor:
|
||||
raise ResourceNotFoundException("Vendor", str(vendor_id))
|
||||
|
||||
return vendor
|
||||
|
||||
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
|
||||
"""Get usage counts (products and team members) for a vendor."""
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor_id,
|
||||
VendorUser.is_active == True, # noqa: E712
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"products_count": products_count,
|
||||
"team_count": team_count,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Billing History
|
||||
# =========================================================================
|
||||
|
||||
def list_billing_history(
|
||||
self,
|
||||
db: Session,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
vendor_id: int | None = None,
|
||||
status: str | None = None,
|
||||
) -> dict:
|
||||
"""List billing history across all vendors."""
|
||||
query = (
|
||||
db.query(BillingHistory, Vendor)
|
||||
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
|
||||
)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(BillingHistory.vendor_id == vendor_id)
|
||||
if status:
|
||||
query = query.filter(BillingHistory.status == status)
|
||||
|
||||
total = query.count()
|
||||
|
||||
offset = (page - 1) * per_page
|
||||
results = (
|
||||
query.order_by(BillingHistory.invoice_date.desc())
|
||||
.offset(offset)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_stats(self, db: Session) -> dict:
|
||||
"""Get subscription statistics for admin dashboard."""
|
||||
# Count by status
|
||||
status_counts = (
|
||||
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
|
||||
.group_by(VendorSubscription.status)
|
||||
.all()
|
||||
)
|
||||
|
||||
stats = {
|
||||
"total_subscriptions": 0,
|
||||
"active_count": 0,
|
||||
"trial_count": 0,
|
||||
"past_due_count": 0,
|
||||
"cancelled_count": 0,
|
||||
"expired_count": 0,
|
||||
}
|
||||
|
||||
for status, count in status_counts:
|
||||
stats["total_subscriptions"] += count
|
||||
if status == SubscriptionStatus.ACTIVE.value:
|
||||
stats["active_count"] = count
|
||||
elif status == SubscriptionStatus.TRIAL.value:
|
||||
stats["trial_count"] = count
|
||||
elif status == SubscriptionStatus.PAST_DUE.value:
|
||||
stats["past_due_count"] = count
|
||||
elif status == SubscriptionStatus.CANCELLED.value:
|
||||
stats["cancelled_count"] = count
|
||||
elif status == SubscriptionStatus.EXPIRED.value:
|
||||
stats["expired_count"] = count
|
||||
|
||||
# Count by tier
|
||||
tier_counts = (
|
||||
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
|
||||
.filter(
|
||||
VendorSubscription.status.in_([
|
||||
SubscriptionStatus.ACTIVE.value,
|
||||
SubscriptionStatus.TRIAL.value,
|
||||
])
|
||||
)
|
||||
.group_by(VendorSubscription.tier)
|
||||
.all()
|
||||
)
|
||||
|
||||
tier_distribution = {tier: count for tier, count in tier_counts}
|
||||
|
||||
# Calculate MRR (Monthly Recurring Revenue)
|
||||
mrr_cents = 0
|
||||
arr_cents = 0
|
||||
|
||||
active_subs = (
|
||||
db.query(VendorSubscription, SubscriptionTier)
|
||||
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
|
||||
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
|
||||
.all()
|
||||
)
|
||||
|
||||
for sub, tier in active_subs:
|
||||
if sub.is_annual and tier.price_annual_cents:
|
||||
mrr_cents += tier.price_annual_cents // 12
|
||||
arr_cents += tier.price_annual_cents
|
||||
else:
|
||||
mrr_cents += tier.price_monthly_cents
|
||||
arr_cents += tier.price_monthly_cents * 12
|
||||
|
||||
stats["tier_distribution"] = tier_distribution
|
||||
stats["mrr_cents"] = mrr_cents
|
||||
stats["arr_cents"] = arr_cents
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Singleton instance
|
||||
admin_subscription_service = AdminSubscriptionService()
|
||||
582
app/modules/billing/services/stripe_service.py
Normal file
582
app/modules/billing/services/stripe_service.py
Normal file
@@ -0,0 +1,582 @@
|
||||
# app/modules/billing/services/stripe_service.py
|
||||
"""
|
||||
Stripe payment integration service.
|
||||
|
||||
Provides:
|
||||
- Customer management
|
||||
- Subscription management
|
||||
- Checkout session creation
|
||||
- Customer portal access
|
||||
- Webhook event construction
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import stripe
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.modules.billing.exceptions import (
|
||||
StripeNotConfiguredException,
|
||||
WebhookVerificationException,
|
||||
)
|
||||
from app.modules.billing.models import (
|
||||
BillingHistory,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Service for Stripe payment operations."""
|
||||
|
||||
def __init__(self):
|
||||
self._configured = False
|
||||
self._configure()
|
||||
|
||||
def _configure(self):
|
||||
"""Configure Stripe with API key."""
|
||||
if settings.stripe_secret_key:
|
||||
stripe.api_key = settings.stripe_secret_key
|
||||
self._configured = True
|
||||
else:
|
||||
logger.warning("Stripe API key not configured")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Stripe is properly configured."""
|
||||
return self._configured and bool(settings.stripe_secret_key)
|
||||
|
||||
def _check_configured(self) -> None:
|
||||
"""Raise exception if Stripe is not configured."""
|
||||
if not self.is_configured:
|
||||
raise StripeNotConfiguredException()
|
||||
|
||||
# =========================================================================
|
||||
# Customer Management
|
||||
# =========================================================================
|
||||
|
||||
def create_customer(
|
||||
self,
|
||||
vendor: Vendor,
|
||||
email: str,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Stripe customer for a vendor.
|
||||
|
||||
Returns the Stripe customer ID.
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
customer_metadata = {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
email=email,
|
||||
name=name or vendor.name,
|
||||
metadata=customer_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
|
||||
)
|
||||
return customer.id
|
||||
|
||||
def get_customer(self, customer_id: str) -> stripe.Customer:
|
||||
"""Get a Stripe customer by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Customer.retrieve(customer_id)
|
||||
|
||||
def update_customer(
|
||||
self,
|
||||
customer_id: str,
|
||||
email: str | None = None,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Customer:
|
||||
"""Update a Stripe customer."""
|
||||
self._check_configured()
|
||||
|
||||
update_data = {}
|
||||
if email:
|
||||
update_data["email"] = email
|
||||
if name:
|
||||
update_data["name"] = name
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
return stripe.Customer.modify(customer_id, **update_data)
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Management
|
||||
# =========================================================================
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
trial_days: int | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Create a new Stripe subscription.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
price_id: Stripe price ID for the subscription
|
||||
trial_days: Optional trial period in days
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
subscription_data = {
|
||||
"customer": customer_id,
|
||||
"items": [{"price": price_id}],
|
||||
"metadata": metadata or {},
|
||||
"payment_behavior": "default_incomplete",
|
||||
"expand": ["latest_invoice.payment_intent"],
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
subscription_data["trial_period_days"] = trial_days
|
||||
|
||||
subscription = stripe.Subscription.create(**subscription_data)
|
||||
logger.info(
|
||||
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""Get a Stripe subscription by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Subscription.retrieve(subscription_id)
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
new_price_id: str | None = None,
|
||||
proration_behavior: str = "create_prorations",
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Update a Stripe subscription (e.g., change tier).
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
new_price_id: New price ID for tier change
|
||||
proration_behavior: How to handle prorations
|
||||
metadata: Optional metadata to update
|
||||
|
||||
Returns:
|
||||
Updated Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
update_data = {"proration_behavior": proration_behavior}
|
||||
|
||||
if new_price_id:
|
||||
# Get the subscription to find the item ID
|
||||
subscription = stripe.Subscription.retrieve(subscription_id)
|
||||
item_id = subscription["items"]["data"][0]["id"]
|
||||
update_data["items"] = [{"id": item_id, "price": new_price_id}]
|
||||
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
updated = stripe.Subscription.modify(subscription_id, **update_data)
|
||||
logger.info(f"Updated Stripe subscription {subscription_id}")
|
||||
return updated
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
immediately: bool = False,
|
||||
cancellation_reason: str | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Cancel a Stripe subscription.
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
immediately: If True, cancel now. If False, cancel at period end.
|
||||
cancellation_reason: Optional reason for cancellation
|
||||
|
||||
Returns:
|
||||
Cancelled Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
if immediately:
|
||||
subscription = stripe.Subscription.cancel(subscription_id)
|
||||
else:
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=True,
|
||||
metadata={"cancellation_reason": cancellation_reason or "user_request"},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Cancelled Stripe subscription {subscription_id} "
|
||||
f"(immediately={immediately})"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""
|
||||
Reactivate a cancelled subscription (if not yet ended).
|
||||
|
||||
Returns:
|
||||
Reactivated Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=False,
|
||||
)
|
||||
logger.info(f"Reactivated Stripe subscription {subscription_id}")
|
||||
return subscription
|
||||
|
||||
def cancel_subscription_item(self, subscription_item_id: str) -> None:
|
||||
"""
|
||||
Cancel a subscription item (used for add-ons).
|
||||
|
||||
Args:
|
||||
subscription_item_id: Stripe subscription item ID
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
stripe.SubscriptionItem.delete(subscription_item_id)
|
||||
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
|
||||
|
||||
# =========================================================================
|
||||
# Checkout & Portal
|
||||
# =========================================================================
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor,
|
||||
price_id: str,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
trial_days: int | None = None,
|
||||
quantity: int = 1,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.checkout.Session:
|
||||
"""
|
||||
Create a Stripe Checkout session for subscription signup.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor: Vendor to create checkout for
|
||||
price_id: Stripe price ID
|
||||
success_url: URL to redirect on success
|
||||
cancel_url: URL to redirect on cancel
|
||||
trial_days: Optional trial period
|
||||
quantity: Number of items (default 1)
|
||||
metadata: Additional metadata to store
|
||||
|
||||
Returns:
|
||||
Stripe Checkout Session object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
# Get or create Stripe customer
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription and subscription.stripe_customer_id:
|
||||
customer_id = subscription.stripe_customer_id
|
||||
else:
|
||||
# Get vendor owner email
|
||||
from models.database.vendor import VendorUser
|
||||
|
||||
owner = (
|
||||
db.query(VendorUser)
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor.id,
|
||||
VendorUser.is_owner == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
email = owner.user.email if owner and owner.user else None
|
||||
|
||||
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
|
||||
|
||||
# Store the customer ID
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer_id
|
||||
db.flush()
|
||||
|
||||
# Build metadata
|
||||
session_metadata = {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
}
|
||||
if metadata:
|
||||
session_metadata.update(metadata)
|
||||
|
||||
session_data = {
|
||||
"customer": customer_id,
|
||||
"line_items": [{"price": price_id, "quantity": quantity}],
|
||||
"mode": "subscription",
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"metadata": session_metadata,
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
session_data["subscription_data"] = {"trial_period_days": trial_days}
|
||||
|
||||
session = stripe.checkout.Session.create(**session_data)
|
||||
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
|
||||
return session
|
||||
|
||||
def create_portal_session(
|
||||
self,
|
||||
customer_id: str,
|
||||
return_url: str,
|
||||
) -> stripe.billing_portal.Session:
|
||||
"""
|
||||
Create a Stripe Customer Portal session.
|
||||
|
||||
Allows customers to manage their subscription, payment methods, and invoices.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
return_url: URL to return to after portal
|
||||
|
||||
Returns:
|
||||
Stripe Portal Session object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
logger.info(f"Created portal session for customer {customer_id}")
|
||||
return session
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Management
|
||||
# =========================================================================
|
||||
|
||||
def get_invoices(
|
||||
self,
|
||||
customer_id: str,
|
||||
limit: int = 10,
|
||||
) -> list[stripe.Invoice]:
|
||||
"""Get invoices for a customer."""
|
||||
self._check_configured()
|
||||
|
||||
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
|
||||
return list(invoices.data)
|
||||
|
||||
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
|
||||
"""Get the upcoming invoice for a customer."""
|
||||
self._check_configured()
|
||||
|
||||
try:
|
||||
return stripe.Invoice.upcoming(customer=customer_id)
|
||||
except stripe.error.InvalidRequestError:
|
||||
# No upcoming invoice
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Webhook Handling
|
||||
# =========================================================================
|
||||
|
||||
def construct_event(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
) -> stripe.Event:
|
||||
"""
|
||||
Construct and verify a Stripe webhook event.
|
||||
|
||||
Args:
|
||||
payload: Raw request body
|
||||
sig_header: Stripe-Signature header
|
||||
|
||||
Returns:
|
||||
Verified Stripe Event object
|
||||
|
||||
Raises:
|
||||
WebhookVerificationException: If signature verification fails
|
||||
"""
|
||||
if not settings.stripe_webhook_secret:
|
||||
raise WebhookVerificationException("Stripe webhook secret not configured")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload,
|
||||
sig_header,
|
||||
settings.stripe_webhook_secret,
|
||||
)
|
||||
return event
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
logger.error(f"Webhook signature verification failed: {e}")
|
||||
raise WebhookVerificationException("Invalid webhook signature")
|
||||
|
||||
# =========================================================================
|
||||
# SetupIntent & Payment Method Management
|
||||
# =========================================================================
|
||||
|
||||
def create_setup_intent(
|
||||
self,
|
||||
customer_id: str,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.SetupIntent:
|
||||
"""
|
||||
Create a SetupIntent to collect card without charging.
|
||||
|
||||
Used for trial signups where we collect card upfront
|
||||
but don't charge until trial ends.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe SetupIntent object with client_secret for frontend
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
setup_intent = stripe.SetupIntent.create(
|
||||
customer=customer_id,
|
||||
payment_method_types=["card"],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
|
||||
return setup_intent
|
||||
|
||||
def attach_payment_method_to_customer(
|
||||
self,
|
||||
customer_id: str,
|
||||
payment_method_id: str,
|
||||
set_as_default: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Attach a payment method to customer and optionally set as default.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
payment_method_id: Payment method ID from confirmed SetupIntent
|
||||
set_as_default: Whether to set as default payment method
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
# Attach the payment method to the customer
|
||||
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
|
||||
|
||||
if set_as_default:
|
||||
stripe.Customer.modify(
|
||||
customer_id,
|
||||
invoice_settings={"default_payment_method": payment_method_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Attached payment method {payment_method_id} to customer {customer_id} "
|
||||
f"(default={set_as_default})"
|
||||
)
|
||||
|
||||
def create_subscription_with_trial(
|
||||
self,
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
trial_days: int = 30,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Create subscription with trial period.
|
||||
|
||||
Customer must have a default payment method attached.
|
||||
Card will be charged automatically after trial ends.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID (must have default payment method)
|
||||
price_id: Stripe price ID for the subscription tier
|
||||
trial_days: Number of trial days (default 30)
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
subscription = stripe.Subscription.create(
|
||||
customer=customer_id,
|
||||
items=[{"price": price_id}],
|
||||
trial_period_days=trial_days,
|
||||
metadata=metadata or {},
|
||||
# Use default payment method for future charges
|
||||
default_payment_method=None, # Uses customer's default
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created subscription {subscription.id} with {trial_days}-day trial "
|
||||
f"for customer {customer_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
|
||||
"""Get a SetupIntent by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.SetupIntent.retrieve(setup_intent_id)
|
||||
|
||||
# =========================================================================
|
||||
# Price/Product Management
|
||||
# =========================================================================
|
||||
|
||||
def get_price(self, price_id: str) -> stripe.Price:
|
||||
"""Get a Stripe price by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Price.retrieve(price_id)
|
||||
|
||||
def get_product(self, product_id: str) -> stripe.Product:
|
||||
"""Get a Stripe product by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Product.retrieve(product_id)
|
||||
|
||||
def list_prices(
|
||||
self,
|
||||
product_id: str | None = None,
|
||||
active: bool = True,
|
||||
) -> list[stripe.Price]:
|
||||
"""List Stripe prices, optionally filtered by product."""
|
||||
self._check_configured()
|
||||
|
||||
params = {"active": active}
|
||||
if product_id:
|
||||
params["product"] = product_id
|
||||
|
||||
prices = stripe.Price.list(**params)
|
||||
return list(prices.data)
|
||||
|
||||
|
||||
# Create service instance
|
||||
stripe_service = StripeService()
|
||||
631
app/modules/billing/services/subscription_service.py
Normal file
631
app/modules/billing/services/subscription_service.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# app/modules/billing/services/subscription_service.py
|
||||
"""
|
||||
Subscription service for tier-based access control.
|
||||
|
||||
Handles:
|
||||
- Subscription creation and management
|
||||
- Tier limit enforcement
|
||||
- Usage tracking
|
||||
- Feature gating
|
||||
|
||||
Usage:
|
||||
from app.modules.billing.services import subscription_service
|
||||
|
||||
# Check if vendor can create an order
|
||||
can_create, message = subscription_service.can_create_order(db, vendor_id)
|
||||
|
||||
# Increment order counter after successful order
|
||||
subscription_service.increment_order_count(db, vendor_id)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.billing.exceptions import (
|
||||
FeatureNotAvailableException,
|
||||
SubscriptionNotFoundException,
|
||||
TierLimitExceededException,
|
||||
)
|
||||
from app.modules.billing.models import (
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
TIER_LIMITS,
|
||||
TierCode,
|
||||
VendorSubscription,
|
||||
)
|
||||
from app.modules.billing.schemas import (
|
||||
SubscriptionCreate,
|
||||
SubscriptionUpdate,
|
||||
SubscriptionUsage,
|
||||
TierInfo,
|
||||
TierLimits,
|
||||
UsageSummary,
|
||||
)
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubscriptionService:
|
||||
"""Service for subscription and tier limit operations."""
|
||||
|
||||
# =========================================================================
|
||||
# Tier Information
|
||||
# =========================================================================
|
||||
|
||||
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
|
||||
"""
|
||||
Get full tier information.
|
||||
|
||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||
"""
|
||||
# Try database first if session provided
|
||||
if db is not None:
|
||||
db_tier = self.get_tier_by_code(db, tier_code)
|
||||
if db_tier:
|
||||
return TierInfo(
|
||||
code=db_tier.code,
|
||||
name=db_tier.name,
|
||||
price_monthly_cents=db_tier.price_monthly_cents,
|
||||
price_annual_cents=db_tier.price_annual_cents,
|
||||
limits=TierLimits(
|
||||
orders_per_month=db_tier.orders_per_month,
|
||||
products_limit=db_tier.products_limit,
|
||||
team_members=db_tier.team_members,
|
||||
order_history_months=db_tier.order_history_months,
|
||||
),
|
||||
features=db_tier.features or [],
|
||||
)
|
||||
|
||||
# Fallback to hardcoded TIER_LIMITS
|
||||
return self._get_tier_from_legacy(tier_code)
|
||||
|
||||
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
|
||||
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
|
||||
try:
|
||||
tier = TierCode(tier_code)
|
||||
except ValueError:
|
||||
tier = TierCode.ESSENTIAL
|
||||
|
||||
limits = TIER_LIMITS[tier]
|
||||
return TierInfo(
|
||||
code=tier.value,
|
||||
name=limits["name"],
|
||||
price_monthly_cents=limits["price_monthly_cents"],
|
||||
price_annual_cents=limits.get("price_annual_cents"),
|
||||
limits=TierLimits(
|
||||
orders_per_month=limits.get("orders_per_month"),
|
||||
products_limit=limits.get("products_limit"),
|
||||
team_members=limits.get("team_members"),
|
||||
order_history_months=limits.get("order_history_months"),
|
||||
),
|
||||
features=limits.get("features", []),
|
||||
)
|
||||
|
||||
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
|
||||
"""
|
||||
Get information for all tiers.
|
||||
|
||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||
"""
|
||||
if db is not None:
|
||||
db_tiers = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
SubscriptionTier.is_public == True, # noqa: E712
|
||||
)
|
||||
.order_by(SubscriptionTier.display_order)
|
||||
.all()
|
||||
)
|
||||
if db_tiers:
|
||||
return [
|
||||
TierInfo(
|
||||
code=t.code,
|
||||
name=t.name,
|
||||
price_monthly_cents=t.price_monthly_cents,
|
||||
price_annual_cents=t.price_annual_cents,
|
||||
limits=TierLimits(
|
||||
orders_per_month=t.orders_per_month,
|
||||
products_limit=t.products_limit,
|
||||
team_members=t.team_members,
|
||||
order_history_months=t.order_history_months,
|
||||
),
|
||||
features=t.features or [],
|
||||
)
|
||||
for t in db_tiers
|
||||
]
|
||||
|
||||
# Fallback to hardcoded
|
||||
return [
|
||||
self._get_tier_from_legacy(tier.value)
|
||||
for tier in TierCode
|
||||
]
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
|
||||
"""Get subscription tier by code."""
|
||||
return (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_code)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
|
||||
"""Get tier ID from tier code. Returns None if tier not found."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
return tier.id if tier else None
|
||||
|
||||
# =========================================================================
|
||||
# Subscription CRUD
|
||||
# =========================================================================
|
||||
|
||||
def get_subscription(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription | None:
|
||||
"""Get vendor subscription."""
|
||||
return (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_subscription_or_raise(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription:
|
||||
"""Get vendor subscription or raise exception."""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if not subscription:
|
||||
raise SubscriptionNotFoundException(vendor_id)
|
||||
return subscription
|
||||
|
||||
def get_current_tier(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> TierCode | None:
|
||||
"""Get vendor's current subscription tier code."""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if subscription:
|
||||
try:
|
||||
return TierCode(subscription.tier)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_or_create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
tier: str = TierCode.ESSENTIAL.value,
|
||||
trial_days: int = 14,
|
||||
) -> VendorSubscription:
|
||||
"""
|
||||
Get existing subscription or create a new trial subscription.
|
||||
|
||||
Used when a vendor first accesses the system.
|
||||
"""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if subscription:
|
||||
return subscription
|
||||
|
||||
# Create new trial subscription
|
||||
now = datetime.now(UTC)
|
||||
trial_end = now + timedelta(days=trial_days)
|
||||
|
||||
# Lookup tier_id from tier code
|
||||
tier_id = self.get_tier_id(db, tier)
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=tier,
|
||||
tier_id=tier_id,
|
||||
status=SubscriptionStatus.TRIAL.value,
|
||||
period_start=now,
|
||||
period_end=trial_end,
|
||||
trial_ends_at=trial_end,
|
||||
is_annual=False,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(
|
||||
f"Created trial subscription for vendor {vendor_id} "
|
||||
f"(tier={tier}, trial_ends={trial_end})"
|
||||
)
|
||||
|
||||
return subscription
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionCreate,
|
||||
) -> VendorSubscription:
|
||||
"""Create a subscription for a vendor."""
|
||||
# Check if subscription exists
|
||||
existing = self.get_subscription(db, vendor_id)
|
||||
if existing:
|
||||
raise ValueError("Vendor already has a subscription")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Calculate period end based on billing cycle
|
||||
if data.is_annual:
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
# Handle trial
|
||||
trial_ends_at = None
|
||||
status = SubscriptionStatus.ACTIVE.value
|
||||
if data.trial_days > 0:
|
||||
trial_ends_at = now + timedelta(days=data.trial_days)
|
||||
status = SubscriptionStatus.TRIAL.value
|
||||
period_end = trial_ends_at
|
||||
|
||||
# Lookup tier_id from tier code
|
||||
tier_id = self.get_tier_id(db, data.tier)
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=data.tier,
|
||||
tier_id=tier_id,
|
||||
status=status,
|
||||
period_start=now,
|
||||
period_end=period_end,
|
||||
trial_ends_at=trial_ends_at,
|
||||
is_annual=data.is_annual,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
|
||||
return subscription
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionUpdate,
|
||||
) -> VendorSubscription:
|
||||
"""Update a vendor subscription."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# If tier is being updated, also update tier_id
|
||||
if "tier" in update_data:
|
||||
tier_id = self.get_tier_id(db, update_data["tier"])
|
||||
update_data["tier_id"] = tier_id
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(subscription, key, value)
|
||||
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Updated subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
def upgrade_tier(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
new_tier: str,
|
||||
) -> VendorSubscription:
|
||||
"""Upgrade vendor to a new tier."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
old_tier = subscription.tier
|
||||
subscription.tier = new_tier
|
||||
subscription.tier_id = self.get_tier_id(db, new_tier)
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
# If upgrading from trial, mark as active
|
||||
if subscription.status == SubscriptionStatus.TRIAL.value:
|
||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
|
||||
return subscription
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
reason: str | None = None,
|
||||
) -> VendorSubscription:
|
||||
"""Cancel a vendor subscription (access until period end)."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
subscription.status = SubscriptionStatus.CANCELLED.value
|
||||
subscription.cancelled_at = datetime.now(UTC)
|
||||
subscription.cancellation_reason = reason
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Cancelled subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
# =========================================================================
|
||||
# Usage Tracking
|
||||
# =========================================================================
|
||||
|
||||
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
|
||||
"""Get current subscription usage statistics."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Get actual counts
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Calculate usage stats
|
||||
orders_limit = subscription.orders_limit
|
||||
products_limit = subscription.products_limit
|
||||
team_limit = subscription.team_members_limit
|
||||
|
||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||
if limit is None:
|
||||
return None
|
||||
return max(0, limit - current)
|
||||
|
||||
def calc_percent(current: int, limit: int | None) -> float | None:
|
||||
if limit is None or limit == 0:
|
||||
return None
|
||||
return min(100.0, (current / limit) * 100)
|
||||
|
||||
return SubscriptionUsage(
|
||||
orders_used=subscription.orders_this_period,
|
||||
orders_limit=orders_limit,
|
||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
|
||||
products_used=products_count,
|
||||
products_limit=products_limit,
|
||||
products_remaining=calc_remaining(products_count, products_limit),
|
||||
products_percent_used=calc_percent(products_count, products_limit),
|
||||
team_members_used=team_count,
|
||||
team_members_limit=team_limit,
|
||||
team_members_remaining=calc_remaining(team_count, team_limit),
|
||||
team_members_percent_used=calc_percent(team_count, team_limit),
|
||||
)
|
||||
|
||||
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
|
||||
"""Get usage summary for billing page display."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Get actual counts
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Get limits
|
||||
orders_limit = subscription.orders_limit
|
||||
products_limit = subscription.products_limit
|
||||
team_limit = subscription.team_members_limit
|
||||
|
||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||
if limit is None:
|
||||
return None
|
||||
return max(0, limit - current)
|
||||
|
||||
return UsageSummary(
|
||||
orders_this_period=subscription.orders_this_period,
|
||||
orders_limit=orders_limit,
|
||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||
products_count=products_count,
|
||||
products_limit=products_limit,
|
||||
products_remaining=calc_remaining(products_count, products_limit),
|
||||
team_count=team_count,
|
||||
team_limit=team_limit,
|
||||
team_remaining=calc_remaining(team_count, team_limit),
|
||||
)
|
||||
|
||||
def increment_order_count(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Increment the order counter for the current period.
|
||||
|
||||
Call this after successfully creating/importing an order.
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
subscription.increment_order_count()
|
||||
db.flush()
|
||||
|
||||
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
|
||||
"""Reset counters for a new billing period."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
subscription.reset_period_counters()
|
||||
db.flush()
|
||||
logger.info(f"Reset period counters for vendor {vendor_id}")
|
||||
|
||||
# =========================================================================
|
||||
# Limit Checks
|
||||
# =========================================================================
|
||||
|
||||
def can_create_order(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can create/import another order.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.can_create_order()
|
||||
|
||||
def check_order_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check order limit and raise exception if exceeded.
|
||||
|
||||
Use this in order creation flows.
|
||||
"""
|
||||
can_create, message = self.can_create_order(db, vendor_id)
|
||||
if not can_create:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Order limit exceeded",
|
||||
limit_type="orders",
|
||||
current=subscription.orders_this_period if subscription else 0,
|
||||
limit=subscription.orders_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_product(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another product.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_product(products_count)
|
||||
|
||||
def check_product_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check product limit and raise exception if exceeded.
|
||||
|
||||
Use this in product creation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_product(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Product limit exceeded",
|
||||
limit_type="products",
|
||||
current=products_count,
|
||||
limit=subscription.products_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_team_member(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another team member.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_team_member(team_count)
|
||||
|
||||
def check_team_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check team member limit and raise exception if exceeded.
|
||||
|
||||
Use this in team member invitation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_team_member(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Team member limit exceeded",
|
||||
limit_type="team_members",
|
||||
current=team_count,
|
||||
limit=subscription.team_members_limit if subscription else 0,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Feature Gating
|
||||
# =========================================================================
|
||||
|
||||
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
|
||||
"""Check if vendor has access to a feature."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.has_feature(feature)
|
||||
|
||||
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
|
||||
"""
|
||||
Check feature access and raise exception if not available.
|
||||
|
||||
Use this to gate premium features.
|
||||
"""
|
||||
if not self.has_feature(db, vendor_id, feature):
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Find which tier has this feature
|
||||
required_tier = None
|
||||
for tier_code, limits in TIER_LIMITS.items():
|
||||
if feature in limits.get("features", []):
|
||||
required_tier = limits["name"]
|
||||
break
|
||||
|
||||
raise FeatureNotAvailableException(
|
||||
feature=feature,
|
||||
current_tier=subscription.tier,
|
||||
required_tier=required_tier or "higher",
|
||||
)
|
||||
|
||||
def get_feature_tier(self, feature: str) -> str | None:
|
||||
"""Get the minimum tier required for a feature."""
|
||||
for tier_code in [
|
||||
TierCode.ESSENTIAL,
|
||||
TierCode.PROFESSIONAL,
|
||||
TierCode.BUSINESS,
|
||||
TierCode.ENTERPRISE,
|
||||
]:
|
||||
if feature in TIER_LIMITS[tier_code].get("features", []):
|
||||
return tier_code.value
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
subscription_service = SubscriptionService()
|
||||
Reference in New Issue
Block a user