feat: complete billing module migration (Phase 5)

Migrates billing module to self-contained structure:
- Create app/modules/billing/services/ with subscription, stripe, admin services
- Create app/modules/billing/models/ re-exporting from central location
- Create app/modules/billing/schemas/ re-exporting from central location
- Create app/modules/billing/tasks/ with 4 scheduled Celery tasks
- Create app/modules/billing/exceptions.py with module-specific exceptions
- Update definition.py with is_self_contained=True and scheduled_tasks

Celery task migration:
- reset_period_counters -> billing module
- check_trial_expirations -> billing module
- sync_stripe_status -> billing module
- cleanup_stale_subscriptions -> billing module
- capture_capacity_snapshot remains in legacy (will go to monitoring)

Backward compatibility:
- Create re-exports in app/services/ for subscription, stripe, admin services
- Old import paths continue to work
- Update celery_config.py to use module-defined schedules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-27 23:06:23 +01:00
parent f1f91abe51
commit 4f379b472b
17 changed files with 2198 additions and 1931 deletions

View File

@@ -2,654 +2,29 @@
"""
Subscription service for tier-based access control.
Handles:
- Subscription creation and management
- Tier limit enforcement
- Usage tracking
- Feature gating
DEPRECATED: This file is maintained for backward compatibility.
Import from app.modules.billing.services instead:
Usage:
from app.services.subscription_service import subscription_service
from app.modules.billing.services import subscription_service
# Check if vendor can create an order
can_create, message = subscription_service.can_create_order(db, vendor_id)
# Increment order counter after successful order
subscription_service.increment_order_count(db, vendor_id)
This file re-exports the service from its new location in the billing module.
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from models.database.product import Product
from models.database.subscription import (
SubscriptionStatus,
SubscriptionTier,
TIER_LIMITS,
TierCode,
VendorSubscription,
# Re-export from new location for backward compatibility
from app.modules.billing.services.subscription_service import (
SubscriptionService,
subscription_service,
)
from models.database.vendor import Vendor, VendorUser
from models.schema.subscription import (
SubscriptionCreate,
SubscriptionUpdate,
SubscriptionUsage,
TierInfo,
TierLimits,
UsageSummary,
from app.modules.billing.exceptions import (
SubscriptionNotFoundException,
TierLimitExceededException,
FeatureNotAvailableException,
)
logger = logging.getLogger(__name__)
class SubscriptionNotFoundException(Exception):
"""Raised when subscription not found."""
pass
class TierLimitExceededException(Exception):
"""Raised when a tier limit is exceeded."""
def __init__(self, message: str, limit_type: str, current: int, limit: int):
super().__init__(message)
self.limit_type = limit_type
self.current = current
self.limit = limit
class FeatureNotAvailableException(Exception):
"""Raised when a feature is not available in current tier."""
def __init__(self, feature: str, current_tier: str, required_tier: str):
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
super().__init__(message)
self.feature = feature
self.current_tier = current_tier
self.required_tier = required_tier
class SubscriptionService:
"""Service for subscription and tier limit operations."""
# =========================================================================
# Tier Information
# =========================================================================
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
"""
Get full tier information.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
# Try database first if session provided
if db is not None:
db_tier = self.get_tier_by_code(db, tier_code)
if db_tier:
return TierInfo(
code=db_tier.code,
name=db_tier.name,
price_monthly_cents=db_tier.price_monthly_cents,
price_annual_cents=db_tier.price_annual_cents,
limits=TierLimits(
orders_per_month=db_tier.orders_per_month,
products_limit=db_tier.products_limit,
team_members=db_tier.team_members,
order_history_months=db_tier.order_history_months,
),
features=db_tier.features or [],
)
# Fallback to hardcoded TIER_LIMITS
return self._get_tier_from_legacy(tier_code)
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
try:
tier = TierCode(tier_code)
except ValueError:
tier = TierCode.ESSENTIAL
limits = TIER_LIMITS[tier]
return TierInfo(
code=tier.value,
name=limits["name"],
price_monthly_cents=limits["price_monthly_cents"],
price_annual_cents=limits.get("price_annual_cents"),
limits=TierLimits(
orders_per_month=limits.get("orders_per_month"),
products_limit=limits.get("products_limit"),
team_members=limits.get("team_members"),
order_history_months=limits.get("order_history_months"),
),
features=limits.get("features", []),
)
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
"""
Get information for all tiers.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
if db is not None:
db_tiers = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.is_public == True, # noqa: E712
)
.order_by(SubscriptionTier.display_order)
.all()
)
if db_tiers:
return [
TierInfo(
code=t.code,
name=t.name,
price_monthly_cents=t.price_monthly_cents,
price_annual_cents=t.price_annual_cents,
limits=TierLimits(
orders_per_month=t.orders_per_month,
products_limit=t.products_limit,
team_members=t.team_members,
order_history_months=t.order_history_months,
),
features=t.features or [],
)
for t in db_tiers
]
# Fallback to hardcoded
return [
self._get_tier_from_legacy(tier.value)
for tier in TierCode
]
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
"""Get subscription tier by code."""
return (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
"""Get tier ID from tier code. Returns None if tier not found."""
tier = self.get_tier_by_code(db, tier_code)
return tier.id if tier else None
# =========================================================================
# Subscription CRUD
# =========================================================================
def get_subscription(
self, db: Session, vendor_id: int
) -> VendorSubscription | None:
"""Get vendor subscription."""
return (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
def get_subscription_or_raise(
self, db: Session, vendor_id: int
) -> VendorSubscription:
"""Get vendor subscription or raise exception."""
subscription = self.get_subscription(db, vendor_id)
if not subscription:
raise SubscriptionNotFoundException(
f"No subscription found for vendor {vendor_id}"
)
return subscription
def get_current_tier(
self, db: Session, vendor_id: int
) -> TierCode | None:
"""Get vendor's current subscription tier code."""
subscription = self.get_subscription(db, vendor_id)
if subscription:
try:
return TierCode(subscription.tier)
except ValueError:
return None
return None
def get_or_create_subscription(
self,
db: Session,
vendor_id: int,
tier: str = TierCode.ESSENTIAL.value,
trial_days: int = 14,
) -> VendorSubscription:
"""
Get existing subscription or create a new trial subscription.
Used when a vendor first accesses the system.
"""
subscription = self.get_subscription(db, vendor_id)
if subscription:
return subscription
# Create new trial subscription
now = datetime.now(UTC)
trial_end = now + timedelta(days=trial_days)
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=tier,
tier_id=tier_id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=trial_end,
trial_ends_at=trial_end,
is_annual=False,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(
f"Created trial subscription for vendor {vendor_id} "
f"(tier={tier}, trial_ends={trial_end})"
)
return subscription
def create_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionCreate,
) -> VendorSubscription:
"""Create a subscription for a vendor."""
# Check if subscription exists
existing = self.get_subscription(db, vendor_id)
if existing:
raise ValueError("Vendor already has a subscription")
now = datetime.now(UTC)
# Calculate period end based on billing cycle
if data.is_annual:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
# Handle trial
trial_ends_at = None
status = SubscriptionStatus.ACTIVE.value
if data.trial_days > 0:
trial_ends_at = now + timedelta(days=data.trial_days)
status = SubscriptionStatus.TRIAL.value
period_end = trial_ends_at
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, data.tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=data.tier,
tier_id=tier_id,
status=status,
period_start=now,
period_end=period_end,
trial_ends_at=trial_ends_at,
is_annual=data.is_annual,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
return subscription
def update_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionUpdate,
) -> VendorSubscription:
"""Update a vendor subscription."""
subscription = self.get_subscription_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
# If tier is being updated, also update tier_id
if "tier" in update_data:
tier_id = self.get_tier_id(db, update_data["tier"])
update_data["tier_id"] = tier_id
for key, value in update_data.items():
setattr(subscription, key, value)
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Updated subscription for vendor {vendor_id}")
return subscription
def upgrade_tier(
self,
db: Session,
vendor_id: int,
new_tier: str,
) -> VendorSubscription:
"""Upgrade vendor to a new tier."""
subscription = self.get_subscription_or_raise(db, vendor_id)
old_tier = subscription.tier
subscription.tier = new_tier
subscription.tier_id = self.get_tier_id(db, new_tier)
subscription.updated_at = datetime.now(UTC)
# If upgrading from trial, mark as active
if subscription.status == SubscriptionStatus.TRIAL.value:
subscription.status = SubscriptionStatus.ACTIVE.value
db.flush()
db.refresh(subscription)
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
return subscription
def cancel_subscription(
self,
db: Session,
vendor_id: int,
reason: str | None = None,
) -> VendorSubscription:
"""Cancel a vendor subscription (access until period end)."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.status = SubscriptionStatus.CANCELLED.value
subscription.cancelled_at = datetime.now(UTC)
subscription.cancellation_reason = reason
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Cancelled subscription for vendor {vendor_id}")
return subscription
# =========================================================================
# Usage Tracking
# =========================================================================
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
"""Get current subscription usage statistics."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Calculate usage stats
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
def calc_percent(current: int, limit: int | None) -> float | None:
if limit is None or limit == 0:
return None
return min(100.0, (current / limit) * 100)
return SubscriptionUsage(
orders_used=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
products_used=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
products_percent_used=calc_percent(products_count, products_limit),
team_members_used=team_count,
team_members_limit=team_limit,
team_members_remaining=calc_remaining(team_count, team_limit),
team_members_percent_used=calc_percent(team_count, team_limit),
)
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
"""Get usage summary for billing page display."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Get limits
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
return UsageSummary(
orders_this_period=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
products_count=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
team_count=team_count,
team_limit=team_limit,
team_remaining=calc_remaining(team_count, team_limit),
)
def increment_order_count(self, db: Session, vendor_id: int) -> None:
"""
Increment the order counter for the current period.
Call this after successfully creating/importing an order.
"""
subscription = self.get_or_create_subscription(db, vendor_id)
subscription.increment_order_count()
db.flush()
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
"""Reset counters for a new billing period."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.reset_period_counters()
db.flush()
logger.info(f"Reset period counters for vendor {vendor_id}")
# =========================================================================
# Limit Checks
# =========================================================================
def can_create_order(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can create/import another order.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.can_create_order()
def check_order_limit(self, db: Session, vendor_id: int) -> None:
"""
Check order limit and raise exception if exceeded.
Use this in order creation flows.
"""
can_create, message = self.can_create_order(db, vendor_id)
if not can_create:
subscription = self.get_subscription(db, vendor_id)
raise TierLimitExceededException(
message=message or "Order limit exceeded",
limit_type="orders",
current=subscription.orders_this_period if subscription else 0,
limit=subscription.orders_limit if subscription else 0,
)
def can_add_product(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another product.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
return subscription.can_add_product(products_count)
def check_product_limit(self, db: Session, vendor_id: int) -> None:
"""
Check product limit and raise exception if exceeded.
Use this in product creation flows.
"""
can_add, message = self.can_add_product(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Product limit exceeded",
limit_type="products",
current=products_count,
limit=subscription.products_limit if subscription else 0,
)
def can_add_team_member(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another team member.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
return subscription.can_add_team_member(team_count)
def check_team_limit(self, db: Session, vendor_id: int) -> None:
"""
Check team member limit and raise exception if exceeded.
Use this in team member invitation flows.
"""
can_add, message = self.can_add_team_member(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Team member limit exceeded",
limit_type="team_members",
current=team_count,
limit=subscription.team_members_limit if subscription else 0,
)
# =========================================================================
# Feature Gating
# =========================================================================
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
"""Check if vendor has access to a feature."""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.has_feature(feature)
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
"""
Check feature access and raise exception if not available.
Use this to gate premium features.
"""
if not self.has_feature(db, vendor_id, feature):
subscription = self.get_or_create_subscription(db, vendor_id)
# Find which tier has this feature
required_tier = None
for tier_code, limits in TIER_LIMITS.items():
if feature in limits.get("features", []):
required_tier = limits["name"]
break
raise FeatureNotAvailableException(
feature=feature,
current_tier=subscription.tier,
required_tier=required_tier or "higher",
)
def get_feature_tier(self, feature: str) -> str | None:
"""Get the minimum tier required for a feature."""
for tier_code in [
TierCode.ESSENTIAL,
TierCode.PROFESSIONAL,
TierCode.BUSINESS,
TierCode.ENTERPRISE,
]:
if feature in TIER_LIMITS[tier_code].get("features", []):
return tier_code.value
return None
# Singleton instance
subscription_service = SubscriptionService()
__all__ = [
"SubscriptionService",
"subscription_service",
"SubscriptionNotFoundException",
"TierLimitExceededException",
"FeatureNotAvailableException",
]