feat: complete billing module migration (Phase 5)

Migrates billing module to self-contained structure:
- Create app/modules/billing/services/ with subscription, stripe, admin services
- Create app/modules/billing/models/ re-exporting from central location
- Create app/modules/billing/schemas/ re-exporting from central location
- Create app/modules/billing/tasks/ with 4 scheduled Celery tasks
- Create app/modules/billing/exceptions.py with module-specific exceptions
- Update definition.py with is_self_contained=True and scheduled_tasks

Celery task migration:
- reset_period_counters -> billing module
- check_trial_expirations -> billing module
- sync_stripe_status -> billing module
- cleanup_stale_subscriptions -> billing module
- capture_capacity_snapshot remains in legacy (will go to monitoring)

Backward compatibility:
- Create re-exports in app/services/ for subscription, stripe, admin services
- Old import paths continue to work
- Update celery_config.py to use module-defined schedules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-27 23:06:23 +01:00
parent f1f91abe51
commit 4f379b472b
17 changed files with 2198 additions and 1931 deletions

View File

@@ -0,0 +1,28 @@
# app/modules/billing/services/__init__.py
"""
Billing module services.
Provides subscription management, Stripe integration, and admin operations.
"""
from app.modules.billing.services.subscription_service import (
SubscriptionService,
subscription_service,
)
from app.modules.billing.services.stripe_service import (
StripeService,
stripe_service,
)
from app.modules.billing.services.admin_subscription_service import (
AdminSubscriptionService,
admin_subscription_service,
)
__all__ = [
"SubscriptionService",
"subscription_service",
"StripeService",
"stripe_service",
"AdminSubscriptionService",
"admin_subscription_service",
]

View File

@@ -0,0 +1,352 @@
# app/modules/billing/services/admin_subscription_service.py
"""
Admin Subscription Service.
Handles subscription management operations for platform administrators:
- Subscription tier CRUD
- Vendor subscription management
- Billing history queries
- Subscription analytics
"""
import logging
from math import ceil
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
BusinessLogicException,
ConflictException,
ResourceNotFoundException,
)
from app.modules.billing.exceptions import TierNotFoundException
from app.modules.billing.models import (
BillingHistory,
SubscriptionStatus,
SubscriptionTier,
VendorSubscription,
)
from models.database.product import Product
from models.database.vendor import Vendor, VendorUser
logger = logging.getLogger(__name__)
class AdminSubscriptionService:
"""Service for admin subscription management operations."""
# =========================================================================
# Subscription Tiers
# =========================================================================
def get_tiers(
self, db: Session, include_inactive: bool = False
) -> list[SubscriptionTier]:
"""Get all subscription tiers."""
query = db.query(SubscriptionTier)
if not include_inactive:
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
return query.order_by(SubscriptionTier.display_order).all()
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
"""Get a subscription tier by code."""
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
if not tier:
raise TierNotFoundException(tier_code)
return tier
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
"""Create a new subscription tier."""
# Check for duplicate code
existing = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_data["code"])
.first()
)
if existing:
raise ConflictException(
f"Tier with code '{tier_data['code']}' already exists"
)
tier = SubscriptionTier(**tier_data)
db.add(tier)
logger.info(f"Created subscription tier: {tier.code}")
return tier
def update_tier(
self, db: Session, tier_code: str, update_data: dict
) -> SubscriptionTier:
"""Update a subscription tier."""
tier = self.get_tier_by_code(db, tier_code)
for field, value in update_data.items():
setattr(tier, field, value)
logger.info(f"Updated subscription tier: {tier.code}")
return tier
def deactivate_tier(self, db: Session, tier_code: str) -> None:
"""Soft-delete a subscription tier."""
tier = self.get_tier_by_code(db, tier_code)
# Check if any active subscriptions use this tier
active_subs = (
db.query(VendorSubscription)
.filter(
VendorSubscription.tier == tier_code,
VendorSubscription.status.in_([
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
]),
)
.count()
)
if active_subs > 0:
raise BusinessLogicException(
f"Cannot delete tier: {active_subs} active subscriptions are using it"
)
tier.is_active = False
logger.info(f"Soft-deleted subscription tier: {tier.code}")
# =========================================================================
# Vendor Subscriptions
# =========================================================================
def list_subscriptions(
self,
db: Session,
page: int = 1,
per_page: int = 20,
status: str | None = None,
tier: str | None = None,
search: str | None = None,
) -> dict:
"""List vendor subscriptions with filtering and pagination."""
query = (
db.query(VendorSubscription, Vendor)
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
)
# Apply filters
if status:
query = query.filter(VendorSubscription.status == status)
if tier:
query = query.filter(VendorSubscription.tier == tier)
if search:
query = query.filter(Vendor.name.ilike(f"%{search}%"))
# Count total
total = query.count()
# Paginate
offset = (page - 1) * per_page
results = (
query.order_by(VendorSubscription.created_at.desc())
.offset(offset)
.limit(per_page)
.all()
)
return {
"results": results,
"total": total,
"page": page,
"per_page": per_page,
"pages": ceil(total / per_page) if total > 0 else 0,
}
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
"""Get subscription for a specific vendor."""
result = (
db.query(VendorSubscription, Vendor)
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
if not result:
raise ResourceNotFoundException("Subscription", str(vendor_id))
return result
def update_subscription(
self, db: Session, vendor_id: int, update_data: dict
) -> tuple:
"""Update a vendor's subscription."""
result = self.get_subscription(db, vendor_id)
sub, vendor = result
for field, value in update_data.items():
setattr(sub, field, value)
logger.info(
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
)
return sub, vendor
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
"""Get a vendor by ID."""
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise ResourceNotFoundException("Vendor", str(vendor_id))
return vendor
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
"""Get usage counts (products and team members) for a vendor."""
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(
VendorUser.vendor_id == vendor_id,
VendorUser.is_active == True, # noqa: E712
)
.scalar()
or 0
)
return {
"products_count": products_count,
"team_count": team_count,
}
# =========================================================================
# Billing History
# =========================================================================
def list_billing_history(
self,
db: Session,
page: int = 1,
per_page: int = 20,
vendor_id: int | None = None,
status: str | None = None,
) -> dict:
"""List billing history across all vendors."""
query = (
db.query(BillingHistory, Vendor)
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
)
if vendor_id:
query = query.filter(BillingHistory.vendor_id == vendor_id)
if status:
query = query.filter(BillingHistory.status == status)
total = query.count()
offset = (page - 1) * per_page
results = (
query.order_by(BillingHistory.invoice_date.desc())
.offset(offset)
.limit(per_page)
.all()
)
return {
"results": results,
"total": total,
"page": page,
"per_page": per_page,
"pages": ceil(total / per_page) if total > 0 else 0,
}
# =========================================================================
# Statistics
# =========================================================================
def get_stats(self, db: Session) -> dict:
"""Get subscription statistics for admin dashboard."""
# Count by status
status_counts = (
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
.group_by(VendorSubscription.status)
.all()
)
stats = {
"total_subscriptions": 0,
"active_count": 0,
"trial_count": 0,
"past_due_count": 0,
"cancelled_count": 0,
"expired_count": 0,
}
for status, count in status_counts:
stats["total_subscriptions"] += count
if status == SubscriptionStatus.ACTIVE.value:
stats["active_count"] = count
elif status == SubscriptionStatus.TRIAL.value:
stats["trial_count"] = count
elif status == SubscriptionStatus.PAST_DUE.value:
stats["past_due_count"] = count
elif status == SubscriptionStatus.CANCELLED.value:
stats["cancelled_count"] = count
elif status == SubscriptionStatus.EXPIRED.value:
stats["expired_count"] = count
# Count by tier
tier_counts = (
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
.filter(
VendorSubscription.status.in_([
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
])
)
.group_by(VendorSubscription.tier)
.all()
)
tier_distribution = {tier: count for tier, count in tier_counts}
# Calculate MRR (Monthly Recurring Revenue)
mrr_cents = 0
arr_cents = 0
active_subs = (
db.query(VendorSubscription, SubscriptionTier)
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
.all()
)
for sub, tier in active_subs:
if sub.is_annual and tier.price_annual_cents:
mrr_cents += tier.price_annual_cents // 12
arr_cents += tier.price_annual_cents
else:
mrr_cents += tier.price_monthly_cents
arr_cents += tier.price_monthly_cents * 12
stats["tier_distribution"] = tier_distribution
stats["mrr_cents"] = mrr_cents
stats["arr_cents"] = arr_cents
return stats
# Singleton instance
admin_subscription_service = AdminSubscriptionService()

View File

@@ -0,0 +1,582 @@
# app/modules/billing/services/stripe_service.py
"""
Stripe payment integration service.
Provides:
- Customer management
- Subscription management
- Checkout session creation
- Customer portal access
- Webhook event construction
"""
import logging
from datetime import datetime
import stripe
from sqlalchemy.orm import Session
from app.core.config import settings
from app.modules.billing.exceptions import (
StripeNotConfiguredException,
WebhookVerificationException,
)
from app.modules.billing.models import (
BillingHistory,
SubscriptionStatus,
SubscriptionTier,
VendorSubscription,
)
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class StripeService:
"""Service for Stripe payment operations."""
def __init__(self):
self._configured = False
self._configure()
def _configure(self):
"""Configure Stripe with API key."""
if settings.stripe_secret_key:
stripe.api_key = settings.stripe_secret_key
self._configured = True
else:
logger.warning("Stripe API key not configured")
@property
def is_configured(self) -> bool:
"""Check if Stripe is properly configured."""
return self._configured and bool(settings.stripe_secret_key)
def _check_configured(self) -> None:
"""Raise exception if Stripe is not configured."""
if not self.is_configured:
raise StripeNotConfiguredException()
# =========================================================================
# Customer Management
# =========================================================================
def create_customer(
self,
vendor: Vendor,
email: str,
name: str | None = None,
metadata: dict | None = None,
) -> str:
"""
Create a Stripe customer for a vendor.
Returns the Stripe customer ID.
"""
self._check_configured()
customer_metadata = {
"vendor_id": str(vendor.id),
"vendor_code": vendor.vendor_code,
**(metadata or {}),
}
customer = stripe.Customer.create(
email=email,
name=name or vendor.name,
metadata=customer_metadata,
)
logger.info(
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
)
return customer.id
def get_customer(self, customer_id: str) -> stripe.Customer:
"""Get a Stripe customer by ID."""
self._check_configured()
return stripe.Customer.retrieve(customer_id)
def update_customer(
self,
customer_id: str,
email: str | None = None,
name: str | None = None,
metadata: dict | None = None,
) -> stripe.Customer:
"""Update a Stripe customer."""
self._check_configured()
update_data = {}
if email:
update_data["email"] = email
if name:
update_data["name"] = name
if metadata:
update_data["metadata"] = metadata
return stripe.Customer.modify(customer_id, **update_data)
# =========================================================================
# Subscription Management
# =========================================================================
def create_subscription(
self,
customer_id: str,
price_id: str,
trial_days: int | None = None,
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Create a new Stripe subscription.
Args:
customer_id: Stripe customer ID
price_id: Stripe price ID for the subscription
trial_days: Optional trial period in days
metadata: Optional metadata to attach
Returns:
Stripe Subscription object
"""
self._check_configured()
subscription_data = {
"customer": customer_id,
"items": [{"price": price_id}],
"metadata": metadata or {},
"payment_behavior": "default_incomplete",
"expand": ["latest_invoice.payment_intent"],
}
if trial_days:
subscription_data["trial_period_days"] = trial_days
subscription = stripe.Subscription.create(**subscription_data)
logger.info(
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
)
return subscription
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
"""Get a Stripe subscription by ID."""
self._check_configured()
return stripe.Subscription.retrieve(subscription_id)
def update_subscription(
self,
subscription_id: str,
new_price_id: str | None = None,
proration_behavior: str = "create_prorations",
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Update a Stripe subscription (e.g., change tier).
Args:
subscription_id: Stripe subscription ID
new_price_id: New price ID for tier change
proration_behavior: How to handle prorations
metadata: Optional metadata to update
Returns:
Updated Stripe Subscription object
"""
self._check_configured()
update_data = {"proration_behavior": proration_behavior}
if new_price_id:
# Get the subscription to find the item ID
subscription = stripe.Subscription.retrieve(subscription_id)
item_id = subscription["items"]["data"][0]["id"]
update_data["items"] = [{"id": item_id, "price": new_price_id}]
if metadata:
update_data["metadata"] = metadata
updated = stripe.Subscription.modify(subscription_id, **update_data)
logger.info(f"Updated Stripe subscription {subscription_id}")
return updated
def cancel_subscription(
self,
subscription_id: str,
immediately: bool = False,
cancellation_reason: str | None = None,
) -> stripe.Subscription:
"""
Cancel a Stripe subscription.
Args:
subscription_id: Stripe subscription ID
immediately: If True, cancel now. If False, cancel at period end.
cancellation_reason: Optional reason for cancellation
Returns:
Cancelled Stripe Subscription object
"""
self._check_configured()
if immediately:
subscription = stripe.Subscription.cancel(subscription_id)
else:
subscription = stripe.Subscription.modify(
subscription_id,
cancel_at_period_end=True,
metadata={"cancellation_reason": cancellation_reason or "user_request"},
)
logger.info(
f"Cancelled Stripe subscription {subscription_id} "
f"(immediately={immediately})"
)
return subscription
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
"""
Reactivate a cancelled subscription (if not yet ended).
Returns:
Reactivated Stripe Subscription object
"""
self._check_configured()
subscription = stripe.Subscription.modify(
subscription_id,
cancel_at_period_end=False,
)
logger.info(f"Reactivated Stripe subscription {subscription_id}")
return subscription
def cancel_subscription_item(self, subscription_item_id: str) -> None:
"""
Cancel a subscription item (used for add-ons).
Args:
subscription_item_id: Stripe subscription item ID
"""
self._check_configured()
stripe.SubscriptionItem.delete(subscription_item_id)
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
# =========================================================================
# Checkout & Portal
# =========================================================================
def create_checkout_session(
self,
db: Session,
vendor: Vendor,
price_id: str,
success_url: str,
cancel_url: str,
trial_days: int | None = None,
quantity: int = 1,
metadata: dict | None = None,
) -> stripe.checkout.Session:
"""
Create a Stripe Checkout session for subscription signup.
Args:
db: Database session
vendor: Vendor to create checkout for
price_id: Stripe price ID
success_url: URL to redirect on success
cancel_url: URL to redirect on cancel
trial_days: Optional trial period
quantity: Number of items (default 1)
metadata: Additional metadata to store
Returns:
Stripe Checkout Session object
"""
self._check_configured()
# Get or create Stripe customer
subscription = (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor.id)
.first()
)
if subscription and subscription.stripe_customer_id:
customer_id = subscription.stripe_customer_id
else:
# Get vendor owner email
from models.database.vendor import VendorUser
owner = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.is_owner == True,
)
.first()
)
email = owner.user.email if owner and owner.user else None
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
# Store the customer ID
if subscription:
subscription.stripe_customer_id = customer_id
db.flush()
# Build metadata
session_metadata = {
"vendor_id": str(vendor.id),
"vendor_code": vendor.vendor_code,
}
if metadata:
session_metadata.update(metadata)
session_data = {
"customer": customer_id,
"line_items": [{"price": price_id, "quantity": quantity}],
"mode": "subscription",
"success_url": success_url,
"cancel_url": cancel_url,
"metadata": session_metadata,
}
if trial_days:
session_data["subscription_data"] = {"trial_period_days": trial_days}
session = stripe.checkout.Session.create(**session_data)
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
return session
def create_portal_session(
self,
customer_id: str,
return_url: str,
) -> stripe.billing_portal.Session:
"""
Create a Stripe Customer Portal session.
Allows customers to manage their subscription, payment methods, and invoices.
Args:
customer_id: Stripe customer ID
return_url: URL to return to after portal
Returns:
Stripe Portal Session object
"""
self._check_configured()
session = stripe.billing_portal.Session.create(
customer=customer_id,
return_url=return_url,
)
logger.info(f"Created portal session for customer {customer_id}")
return session
# =========================================================================
# Invoice Management
# =========================================================================
def get_invoices(
self,
customer_id: str,
limit: int = 10,
) -> list[stripe.Invoice]:
"""Get invoices for a customer."""
self._check_configured()
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
return list(invoices.data)
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
"""Get the upcoming invoice for a customer."""
self._check_configured()
try:
return stripe.Invoice.upcoming(customer=customer_id)
except stripe.error.InvalidRequestError:
# No upcoming invoice
return None
# =========================================================================
# Webhook Handling
# =========================================================================
def construct_event(
self,
payload: bytes,
sig_header: str,
) -> stripe.Event:
"""
Construct and verify a Stripe webhook event.
Args:
payload: Raw request body
sig_header: Stripe-Signature header
Returns:
Verified Stripe Event object
Raises:
WebhookVerificationException: If signature verification fails
"""
if not settings.stripe_webhook_secret:
raise WebhookVerificationException("Stripe webhook secret not configured")
try:
event = stripe.Webhook.construct_event(
payload,
sig_header,
settings.stripe_webhook_secret,
)
return event
except stripe.error.SignatureVerificationError as e:
logger.error(f"Webhook signature verification failed: {e}")
raise WebhookVerificationException("Invalid webhook signature")
# =========================================================================
# SetupIntent & Payment Method Management
# =========================================================================
def create_setup_intent(
self,
customer_id: str,
metadata: dict | None = None,
) -> stripe.SetupIntent:
"""
Create a SetupIntent to collect card without charging.
Used for trial signups where we collect card upfront
but don't charge until trial ends.
Args:
customer_id: Stripe customer ID
metadata: Optional metadata to attach
Returns:
Stripe SetupIntent object with client_secret for frontend
"""
self._check_configured()
setup_intent = stripe.SetupIntent.create(
customer=customer_id,
payment_method_types=["card"],
metadata=metadata or {},
)
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
return setup_intent
def attach_payment_method_to_customer(
self,
customer_id: str,
payment_method_id: str,
set_as_default: bool = True,
) -> None:
"""
Attach a payment method to customer and optionally set as default.
Args:
customer_id: Stripe customer ID
payment_method_id: Payment method ID from confirmed SetupIntent
set_as_default: Whether to set as default payment method
"""
self._check_configured()
# Attach the payment method to the customer
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
if set_as_default:
stripe.Customer.modify(
customer_id,
invoice_settings={"default_payment_method": payment_method_id},
)
logger.info(
f"Attached payment method {payment_method_id} to customer {customer_id} "
f"(default={set_as_default})"
)
def create_subscription_with_trial(
self,
customer_id: str,
price_id: str,
trial_days: int = 30,
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Create subscription with trial period.
Customer must have a default payment method attached.
Card will be charged automatically after trial ends.
Args:
customer_id: Stripe customer ID (must have default payment method)
price_id: Stripe price ID for the subscription tier
trial_days: Number of trial days (default 30)
metadata: Optional metadata to attach
Returns:
Stripe Subscription object
"""
self._check_configured()
subscription = stripe.Subscription.create(
customer=customer_id,
items=[{"price": price_id}],
trial_period_days=trial_days,
metadata=metadata or {},
# Use default payment method for future charges
default_payment_method=None, # Uses customer's default
)
logger.info(
f"Created subscription {subscription.id} with {trial_days}-day trial "
f"for customer {customer_id}"
)
return subscription
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
"""Get a SetupIntent by ID."""
self._check_configured()
return stripe.SetupIntent.retrieve(setup_intent_id)
# =========================================================================
# Price/Product Management
# =========================================================================
def get_price(self, price_id: str) -> stripe.Price:
"""Get a Stripe price by ID."""
self._check_configured()
return stripe.Price.retrieve(price_id)
def get_product(self, product_id: str) -> stripe.Product:
"""Get a Stripe product by ID."""
self._check_configured()
return stripe.Product.retrieve(product_id)
def list_prices(
self,
product_id: str | None = None,
active: bool = True,
) -> list[stripe.Price]:
"""List Stripe prices, optionally filtered by product."""
self._check_configured()
params = {"active": active}
if product_id:
params["product"] = product_id
prices = stripe.Price.list(**params)
return list(prices.data)
# Create service instance
stripe_service = StripeService()

View File

@@ -0,0 +1,631 @@
# app/modules/billing/services/subscription_service.py
"""
Subscription service for tier-based access control.
Handles:
- Subscription creation and management
- Tier limit enforcement
- Usage tracking
- Feature gating
Usage:
from app.modules.billing.services import subscription_service
# Check if vendor can create an order
can_create, message = subscription_service.can_create_order(db, vendor_id)
# Increment order counter after successful order
subscription_service.increment_order_count(db, vendor_id)
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.modules.billing.exceptions import (
FeatureNotAvailableException,
SubscriptionNotFoundException,
TierLimitExceededException,
)
from app.modules.billing.models import (
SubscriptionStatus,
SubscriptionTier,
TIER_LIMITS,
TierCode,
VendorSubscription,
)
from app.modules.billing.schemas import (
SubscriptionCreate,
SubscriptionUpdate,
SubscriptionUsage,
TierInfo,
TierLimits,
UsageSummary,
)
from models.database.product import Product
from models.database.vendor import Vendor, VendorUser
logger = logging.getLogger(__name__)
class SubscriptionService:
"""Service for subscription and tier limit operations."""
# =========================================================================
# Tier Information
# =========================================================================
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
"""
Get full tier information.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
# Try database first if session provided
if db is not None:
db_tier = self.get_tier_by_code(db, tier_code)
if db_tier:
return TierInfo(
code=db_tier.code,
name=db_tier.name,
price_monthly_cents=db_tier.price_monthly_cents,
price_annual_cents=db_tier.price_annual_cents,
limits=TierLimits(
orders_per_month=db_tier.orders_per_month,
products_limit=db_tier.products_limit,
team_members=db_tier.team_members,
order_history_months=db_tier.order_history_months,
),
features=db_tier.features or [],
)
# Fallback to hardcoded TIER_LIMITS
return self._get_tier_from_legacy(tier_code)
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
try:
tier = TierCode(tier_code)
except ValueError:
tier = TierCode.ESSENTIAL
limits = TIER_LIMITS[tier]
return TierInfo(
code=tier.value,
name=limits["name"],
price_monthly_cents=limits["price_monthly_cents"],
price_annual_cents=limits.get("price_annual_cents"),
limits=TierLimits(
orders_per_month=limits.get("orders_per_month"),
products_limit=limits.get("products_limit"),
team_members=limits.get("team_members"),
order_history_months=limits.get("order_history_months"),
),
features=limits.get("features", []),
)
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
"""
Get information for all tiers.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
if db is not None:
db_tiers = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.is_public == True, # noqa: E712
)
.order_by(SubscriptionTier.display_order)
.all()
)
if db_tiers:
return [
TierInfo(
code=t.code,
name=t.name,
price_monthly_cents=t.price_monthly_cents,
price_annual_cents=t.price_annual_cents,
limits=TierLimits(
orders_per_month=t.orders_per_month,
products_limit=t.products_limit,
team_members=t.team_members,
order_history_months=t.order_history_months,
),
features=t.features or [],
)
for t in db_tiers
]
# Fallback to hardcoded
return [
self._get_tier_from_legacy(tier.value)
for tier in TierCode
]
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
"""Get subscription tier by code."""
return (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
"""Get tier ID from tier code. Returns None if tier not found."""
tier = self.get_tier_by_code(db, tier_code)
return tier.id if tier else None
# =========================================================================
# Subscription CRUD
# =========================================================================
def get_subscription(
self, db: Session, vendor_id: int
) -> VendorSubscription | None:
"""Get vendor subscription."""
return (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
def get_subscription_or_raise(
self, db: Session, vendor_id: int
) -> VendorSubscription:
"""Get vendor subscription or raise exception."""
subscription = self.get_subscription(db, vendor_id)
if not subscription:
raise SubscriptionNotFoundException(vendor_id)
return subscription
def get_current_tier(
self, db: Session, vendor_id: int
) -> TierCode | None:
"""Get vendor's current subscription tier code."""
subscription = self.get_subscription(db, vendor_id)
if subscription:
try:
return TierCode(subscription.tier)
except ValueError:
return None
return None
def get_or_create_subscription(
self,
db: Session,
vendor_id: int,
tier: str = TierCode.ESSENTIAL.value,
trial_days: int = 14,
) -> VendorSubscription:
"""
Get existing subscription or create a new trial subscription.
Used when a vendor first accesses the system.
"""
subscription = self.get_subscription(db, vendor_id)
if subscription:
return subscription
# Create new trial subscription
now = datetime.now(UTC)
trial_end = now + timedelta(days=trial_days)
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=tier,
tier_id=tier_id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=trial_end,
trial_ends_at=trial_end,
is_annual=False,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(
f"Created trial subscription for vendor {vendor_id} "
f"(tier={tier}, trial_ends={trial_end})"
)
return subscription
def create_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionCreate,
) -> VendorSubscription:
"""Create a subscription for a vendor."""
# Check if subscription exists
existing = self.get_subscription(db, vendor_id)
if existing:
raise ValueError("Vendor already has a subscription")
now = datetime.now(UTC)
# Calculate period end based on billing cycle
if data.is_annual:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
# Handle trial
trial_ends_at = None
status = SubscriptionStatus.ACTIVE.value
if data.trial_days > 0:
trial_ends_at = now + timedelta(days=data.trial_days)
status = SubscriptionStatus.TRIAL.value
period_end = trial_ends_at
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, data.tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=data.tier,
tier_id=tier_id,
status=status,
period_start=now,
period_end=period_end,
trial_ends_at=trial_ends_at,
is_annual=data.is_annual,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
return subscription
def update_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionUpdate,
) -> VendorSubscription:
"""Update a vendor subscription."""
subscription = self.get_subscription_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
# If tier is being updated, also update tier_id
if "tier" in update_data:
tier_id = self.get_tier_id(db, update_data["tier"])
update_data["tier_id"] = tier_id
for key, value in update_data.items():
setattr(subscription, key, value)
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Updated subscription for vendor {vendor_id}")
return subscription
def upgrade_tier(
self,
db: Session,
vendor_id: int,
new_tier: str,
) -> VendorSubscription:
"""Upgrade vendor to a new tier."""
subscription = self.get_subscription_or_raise(db, vendor_id)
old_tier = subscription.tier
subscription.tier = new_tier
subscription.tier_id = self.get_tier_id(db, new_tier)
subscription.updated_at = datetime.now(UTC)
# If upgrading from trial, mark as active
if subscription.status == SubscriptionStatus.TRIAL.value:
subscription.status = SubscriptionStatus.ACTIVE.value
db.flush()
db.refresh(subscription)
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
return subscription
def cancel_subscription(
self,
db: Session,
vendor_id: int,
reason: str | None = None,
) -> VendorSubscription:
"""Cancel a vendor subscription (access until period end)."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.status = SubscriptionStatus.CANCELLED.value
subscription.cancelled_at = datetime.now(UTC)
subscription.cancellation_reason = reason
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Cancelled subscription for vendor {vendor_id}")
return subscription
# =========================================================================
# Usage Tracking
# =========================================================================
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
"""Get current subscription usage statistics."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Calculate usage stats
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
def calc_percent(current: int, limit: int | None) -> float | None:
if limit is None or limit == 0:
return None
return min(100.0, (current / limit) * 100)
return SubscriptionUsage(
orders_used=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
products_used=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
products_percent_used=calc_percent(products_count, products_limit),
team_members_used=team_count,
team_members_limit=team_limit,
team_members_remaining=calc_remaining(team_count, team_limit),
team_members_percent_used=calc_percent(team_count, team_limit),
)
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
"""Get usage summary for billing page display."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Get limits
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
return UsageSummary(
orders_this_period=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
products_count=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
team_count=team_count,
team_limit=team_limit,
team_remaining=calc_remaining(team_count, team_limit),
)
def increment_order_count(self, db: Session, vendor_id: int) -> None:
"""
Increment the order counter for the current period.
Call this after successfully creating/importing an order.
"""
subscription = self.get_or_create_subscription(db, vendor_id)
subscription.increment_order_count()
db.flush()
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
"""Reset counters for a new billing period."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.reset_period_counters()
db.flush()
logger.info(f"Reset period counters for vendor {vendor_id}")
# =========================================================================
# Limit Checks
# =========================================================================
def can_create_order(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can create/import another order.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.can_create_order()
def check_order_limit(self, db: Session, vendor_id: int) -> None:
"""
Check order limit and raise exception if exceeded.
Use this in order creation flows.
"""
can_create, message = self.can_create_order(db, vendor_id)
if not can_create:
subscription = self.get_subscription(db, vendor_id)
raise TierLimitExceededException(
message=message or "Order limit exceeded",
limit_type="orders",
current=subscription.orders_this_period if subscription else 0,
limit=subscription.orders_limit if subscription else 0,
)
def can_add_product(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another product.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
return subscription.can_add_product(products_count)
def check_product_limit(self, db: Session, vendor_id: int) -> None:
"""
Check product limit and raise exception if exceeded.
Use this in product creation flows.
"""
can_add, message = self.can_add_product(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Product limit exceeded",
limit_type="products",
current=products_count,
limit=subscription.products_limit if subscription else 0,
)
def can_add_team_member(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another team member.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
return subscription.can_add_team_member(team_count)
def check_team_limit(self, db: Session, vendor_id: int) -> None:
"""
Check team member limit and raise exception if exceeded.
Use this in team member invitation flows.
"""
can_add, message = self.can_add_team_member(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Team member limit exceeded",
limit_type="team_members",
current=team_count,
limit=subscription.team_members_limit if subscription else 0,
)
# =========================================================================
# Feature Gating
# =========================================================================
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
"""Check if vendor has access to a feature."""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.has_feature(feature)
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
"""
Check feature access and raise exception if not available.
Use this to gate premium features.
"""
if not self.has_feature(db, vendor_id, feature):
subscription = self.get_or_create_subscription(db, vendor_id)
# Find which tier has this feature
required_tier = None
for tier_code, limits in TIER_LIMITS.items():
if feature in limits.get("features", []):
required_tier = limits["name"]
break
raise FeatureNotAvailableException(
feature=feature,
current_tier=subscription.tier,
required_tier=required_tier or "higher",
)
def get_feature_tier(self, feature: str) -> str | None:
"""Get the minimum tier required for a feature."""
for tier_code in [
TierCode.ESSENTIAL,
TierCode.PROFESSIONAL,
TierCode.BUSINESS,
TierCode.ENTERPRISE,
]:
if feature in TIER_LIMITS[tier_code].get("features", []):
return tier_code.value
return None
# Singleton instance
subscription_service = SubscriptionService()