feat: complete billing module migration (Phase 5)
Migrates billing module to self-contained structure: - Create app/modules/billing/services/ with subscription, stripe, admin services - Create app/modules/billing/models/ re-exporting from central location - Create app/modules/billing/schemas/ re-exporting from central location - Create app/modules/billing/tasks/ with 4 scheduled Celery tasks - Create app/modules/billing/exceptions.py with module-specific exceptions - Update definition.py with is_self_contained=True and scheduled_tasks Celery task migration: - reset_period_counters -> billing module - check_trial_expirations -> billing module - sync_stripe_status -> billing module - cleanup_stale_subscriptions -> billing module - capture_capacity_snapshot remains in legacy (will go to monitoring) Backward compatibility: - Create re-exports in app/services/ for subscription, stripe, admin services - Old import paths continue to work - Update celery_config.py to use module-defined schedules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -49,10 +49,12 @@ if SENTRY_DSN:
|
||||
# TASK DISCOVERY
|
||||
# =============================================================================
|
||||
# Legacy tasks (will be migrated to modules over time)
|
||||
# NOTE: Most subscription tasks have been migrated to app.modules.billing.tasks
|
||||
# The subscription module is kept for capture_capacity_snapshot (will move to monitoring)
|
||||
LEGACY_TASK_MODULES = [
|
||||
"app.tasks.celery_tasks.marketplace",
|
||||
"app.tasks.celery_tasks.letzshop",
|
||||
"app.tasks.celery_tasks.subscription",
|
||||
"app.tasks.celery_tasks.subscription", # Kept for capture_capacity_snapshot only
|
||||
"app.tasks.celery_tasks.export",
|
||||
"app.tasks.celery_tasks.code_quality",
|
||||
"app.tasks.celery_tasks.test_runner",
|
||||
@@ -141,32 +143,9 @@ celery_app.conf.task_routes = {
|
||||
# =============================================================================
|
||||
|
||||
# Legacy scheduled tasks (will be migrated to module definitions)
|
||||
# NOTE: Subscription tasks have been migrated to billing module (see definition.py)
|
||||
LEGACY_BEAT_SCHEDULE = {
|
||||
# Reset usage counters at start of each period
|
||||
"reset-period-counters-daily": {
|
||||
"task": "app.tasks.celery_tasks.subscription.reset_period_counters",
|
||||
"schedule": crontab(hour=0, minute=5), # 00:05 daily
|
||||
"options": {"queue": "scheduled"},
|
||||
},
|
||||
# Check for expiring trials and send notifications
|
||||
"check-trial-expirations-daily": {
|
||||
"task": "app.tasks.celery_tasks.subscription.check_trial_expirations",
|
||||
"schedule": crontab(hour=1, minute=0), # 01:00 daily
|
||||
"options": {"queue": "scheduled"},
|
||||
},
|
||||
# Sync subscription status with Stripe
|
||||
"sync-stripe-status-hourly": {
|
||||
"task": "app.tasks.celery_tasks.subscription.sync_stripe_status",
|
||||
"schedule": crontab(minute=30), # Every hour at :30
|
||||
"options": {"queue": "scheduled"},
|
||||
},
|
||||
# Clean up stale/orphaned subscriptions
|
||||
"cleanup-stale-subscriptions-weekly": {
|
||||
"task": "app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions",
|
||||
"schedule": crontab(hour=3, minute=0, day_of_week=0), # Sunday 03:00
|
||||
"options": {"queue": "scheduled"},
|
||||
},
|
||||
# Capture daily capacity snapshot for analytics
|
||||
# Capacity snapshot - will be migrated to monitoring module
|
||||
"capture-capacity-snapshot-daily": {
|
||||
"task": "app.tasks.celery_tasks.subscription.capture_capacity_snapshot",
|
||||
"schedule": crontab(hour=0, minute=0), # Midnight daily
|
||||
|
||||
@@ -7,6 +7,7 @@ This module provides:
|
||||
- Vendor subscription CRUD
|
||||
- Billing history and invoices
|
||||
- Stripe integration
|
||||
- Scheduled tasks for subscription lifecycle
|
||||
|
||||
Routes:
|
||||
- Admin: /api/v1/admin/subscriptions/*
|
||||
@@ -15,8 +16,17 @@ Routes:
|
||||
Menu Items:
|
||||
- Admin: subscription-tiers, subscriptions, billing-history
|
||||
- Vendor: billing, invoices
|
||||
|
||||
Usage:
|
||||
from app.modules.billing import billing_module
|
||||
from app.modules.billing.services import subscription_service, stripe_service
|
||||
from app.modules.billing.models import VendorSubscription, SubscriptionTier
|
||||
from app.modules.billing.exceptions import TierLimitExceededException
|
||||
"""
|
||||
|
||||
from app.modules.billing.definition import billing_module
|
||||
from app.modules.billing.definition import billing_module, get_billing_module_with_routers
|
||||
|
||||
__all__ = ["billing_module"]
|
||||
__all__ = [
|
||||
"billing_module",
|
||||
"get_billing_module_with_routers",
|
||||
]
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
Billing module definition.
|
||||
|
||||
Defines the billing module including its features, menu items,
|
||||
and route configurations.
|
||||
route configurations, and scheduled tasks.
|
||||
"""
|
||||
|
||||
from app.modules.base import ModuleDefinition
|
||||
from app.modules.base import ModuleDefinition, ScheduledTask
|
||||
from models.database.admin_menu_config import FrontendType
|
||||
|
||||
|
||||
@@ -54,6 +54,44 @@ billing_module = ModuleDefinition(
|
||||
],
|
||||
},
|
||||
is_core=False, # Billing can be disabled (e.g., internal platforms)
|
||||
# =========================================================================
|
||||
# Self-Contained Module Configuration
|
||||
# =========================================================================
|
||||
is_self_contained=True,
|
||||
services_path="app.modules.billing.services",
|
||||
models_path="app.modules.billing.models",
|
||||
schemas_path="app.modules.billing.schemas",
|
||||
exceptions_path="app.modules.billing.exceptions",
|
||||
tasks_path="app.modules.billing.tasks",
|
||||
# =========================================================================
|
||||
# Scheduled Tasks
|
||||
# =========================================================================
|
||||
scheduled_tasks=[
|
||||
ScheduledTask(
|
||||
name="billing.reset_period_counters",
|
||||
task="app.modules.billing.tasks.subscription.reset_period_counters",
|
||||
schedule="5 0 * * *", # Daily at 00:05
|
||||
options={"queue": "scheduled"},
|
||||
),
|
||||
ScheduledTask(
|
||||
name="billing.check_trial_expirations",
|
||||
task="app.modules.billing.tasks.subscription.check_trial_expirations",
|
||||
schedule="0 1 * * *", # Daily at 01:00
|
||||
options={"queue": "scheduled"},
|
||||
),
|
||||
ScheduledTask(
|
||||
name="billing.sync_stripe_status",
|
||||
task="app.modules.billing.tasks.subscription.sync_stripe_status",
|
||||
schedule="30 * * * *", # Hourly at :30
|
||||
options={"queue": "scheduled"},
|
||||
),
|
||||
ScheduledTask(
|
||||
name="billing.cleanup_stale_subscriptions",
|
||||
task="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions",
|
||||
schedule="0 3 * * 0", # Weekly on Sunday at 03:00
|
||||
options={"queue": "scheduled"},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
83
app/modules/billing/exceptions.py
Normal file
83
app/modules/billing/exceptions.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# app/modules/billing/exceptions.py
|
||||
"""
|
||||
Billing module exceptions.
|
||||
|
||||
Custom exceptions for subscription, billing, and payment operations.
|
||||
"""
|
||||
|
||||
from app.exceptions import BusinessLogicException, ResourceNotFoundException
|
||||
|
||||
|
||||
class BillingException(BusinessLogicException):
|
||||
"""Base exception for billing module errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubscriptionNotFoundException(ResourceNotFoundException):
|
||||
"""Raised when a subscription is not found."""
|
||||
|
||||
def __init__(self, vendor_id: int):
|
||||
super().__init__("Subscription", str(vendor_id))
|
||||
|
||||
|
||||
class TierNotFoundException(ResourceNotFoundException):
|
||||
"""Raised when a subscription tier is not found."""
|
||||
|
||||
def __init__(self, tier_code: str):
|
||||
super().__init__("SubscriptionTier", tier_code)
|
||||
|
||||
|
||||
class TierLimitExceededException(BillingException):
|
||||
"""Raised when a tier limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str, limit_type: str, current: int, limit: int):
|
||||
super().__init__(message)
|
||||
self.limit_type = limit_type
|
||||
self.current = current
|
||||
self.limit = limit
|
||||
|
||||
|
||||
class FeatureNotAvailableException(BillingException):
|
||||
"""Raised when a feature is not available in current tier."""
|
||||
|
||||
def __init__(self, feature: str, current_tier: str, required_tier: str):
|
||||
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
|
||||
super().__init__(message)
|
||||
self.feature = feature
|
||||
self.current_tier = current_tier
|
||||
self.required_tier = required_tier
|
||||
|
||||
|
||||
class StripeNotConfiguredException(BillingException):
|
||||
"""Raised when Stripe is not configured."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Stripe is not configured")
|
||||
|
||||
|
||||
class PaymentFailedException(BillingException):
|
||||
"""Raised when a payment fails."""
|
||||
|
||||
def __init__(self, message: str, stripe_error: str | None = None):
|
||||
super().__init__(message)
|
||||
self.stripe_error = stripe_error
|
||||
|
||||
|
||||
class WebhookVerificationException(BillingException):
|
||||
"""Raised when webhook signature verification fails."""
|
||||
|
||||
def __init__(self, message: str = "Invalid webhook signature"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BillingException",
|
||||
"SubscriptionNotFoundException",
|
||||
"TierNotFoundException",
|
||||
"TierLimitExceededException",
|
||||
"FeatureNotAvailableException",
|
||||
"StripeNotConfiguredException",
|
||||
"PaymentFailedException",
|
||||
"WebhookVerificationException",
|
||||
]
|
||||
52
app/modules/billing/models/__init__.py
Normal file
52
app/modules/billing/models/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# app/modules/billing/models/__init__.py
|
||||
"""
|
||||
Billing module models.
|
||||
|
||||
Re-exports subscription models from the central models location.
|
||||
Models remain in models/database/ for now to avoid breaking existing imports
|
||||
across the codebase. This provides a module-local import path.
|
||||
|
||||
Usage:
|
||||
from app.modules.billing.models import (
|
||||
VendorSubscription,
|
||||
SubscriptionTier,
|
||||
SubscriptionStatus,
|
||||
TierCode,
|
||||
)
|
||||
"""
|
||||
|
||||
from models.database.subscription import (
|
||||
# Enums
|
||||
TierCode,
|
||||
SubscriptionStatus,
|
||||
AddOnCategory,
|
||||
BillingPeriod,
|
||||
# Models
|
||||
SubscriptionTier,
|
||||
AddOnProduct,
|
||||
VendorAddOn,
|
||||
StripeWebhookEvent,
|
||||
BillingHistory,
|
||||
VendorSubscription,
|
||||
CapacitySnapshot,
|
||||
# Legacy constants
|
||||
TIER_LIMITS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"TierCode",
|
||||
"SubscriptionStatus",
|
||||
"AddOnCategory",
|
||||
"BillingPeriod",
|
||||
# Models
|
||||
"SubscriptionTier",
|
||||
"AddOnProduct",
|
||||
"VendorAddOn",
|
||||
"StripeWebhookEvent",
|
||||
"BillingHistory",
|
||||
"VendorSubscription",
|
||||
"CapacitySnapshot",
|
||||
# Legacy constants
|
||||
"TIER_LIMITS",
|
||||
]
|
||||
56
app/modules/billing/schemas/__init__.py
Normal file
56
app/modules/billing/schemas/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# app/modules/billing/schemas/__init__.py
|
||||
"""
|
||||
Billing module Pydantic schemas.
|
||||
|
||||
Re-exports subscription schemas from the central schemas location.
|
||||
Provides a module-local import path while maintaining backwards compatibility.
|
||||
|
||||
Usage:
|
||||
from app.modules.billing.schemas import (
|
||||
SubscriptionCreate,
|
||||
SubscriptionResponse,
|
||||
TierInfo,
|
||||
)
|
||||
"""
|
||||
|
||||
from models.schema.subscription import (
|
||||
# Tier schemas
|
||||
TierFeatures,
|
||||
TierLimits,
|
||||
TierInfo,
|
||||
# Subscription CRUD schemas
|
||||
SubscriptionCreate,
|
||||
SubscriptionUpdate,
|
||||
SubscriptionResponse,
|
||||
# Usage schemas
|
||||
SubscriptionUsage,
|
||||
UsageSummary,
|
||||
SubscriptionStatusResponse,
|
||||
# Limit check schemas
|
||||
LimitCheckResult,
|
||||
CanCreateOrderResponse,
|
||||
CanAddProductResponse,
|
||||
CanAddTeamMemberResponse,
|
||||
FeatureCheckResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Tier schemas
|
||||
"TierFeatures",
|
||||
"TierLimits",
|
||||
"TierInfo",
|
||||
# Subscription CRUD schemas
|
||||
"SubscriptionCreate",
|
||||
"SubscriptionUpdate",
|
||||
"SubscriptionResponse",
|
||||
# Usage schemas
|
||||
"SubscriptionUsage",
|
||||
"UsageSummary",
|
||||
"SubscriptionStatusResponse",
|
||||
# Limit check schemas
|
||||
"LimitCheckResult",
|
||||
"CanCreateOrderResponse",
|
||||
"CanAddProductResponse",
|
||||
"CanAddTeamMemberResponse",
|
||||
"FeatureCheckResponse",
|
||||
]
|
||||
28
app/modules/billing/services/__init__.py
Normal file
28
app/modules/billing/services/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# app/modules/billing/services/__init__.py
|
||||
"""
|
||||
Billing module services.
|
||||
|
||||
Provides subscription management, Stripe integration, and admin operations.
|
||||
"""
|
||||
|
||||
from app.modules.billing.services.subscription_service import (
|
||||
SubscriptionService,
|
||||
subscription_service,
|
||||
)
|
||||
from app.modules.billing.services.stripe_service import (
|
||||
StripeService,
|
||||
stripe_service,
|
||||
)
|
||||
from app.modules.billing.services.admin_subscription_service import (
|
||||
AdminSubscriptionService,
|
||||
admin_subscription_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SubscriptionService",
|
||||
"subscription_service",
|
||||
"StripeService",
|
||||
"stripe_service",
|
||||
"AdminSubscriptionService",
|
||||
"admin_subscription_service",
|
||||
]
|
||||
352
app/modules/billing/services/admin_subscription_service.py
Normal file
352
app/modules/billing/services/admin_subscription_service.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# app/modules/billing/services/admin_subscription_service.py
|
||||
"""
|
||||
Admin Subscription Service.
|
||||
|
||||
Handles subscription management operations for platform administrators:
|
||||
- Subscription tier CRUD
|
||||
- Vendor subscription management
|
||||
- Billing history queries
|
||||
- Subscription analytics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from math import ceil
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
BusinessLogicException,
|
||||
ConflictException,
|
||||
ResourceNotFoundException,
|
||||
)
|
||||
from app.modules.billing.exceptions import TierNotFoundException
|
||||
from app.modules.billing.models import (
|
||||
BillingHistory,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminSubscriptionService:
|
||||
"""Service for admin subscription management operations."""
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Tiers
|
||||
# =========================================================================
|
||||
|
||||
def get_tiers(
|
||||
self, db: Session, include_inactive: bool = False
|
||||
) -> list[SubscriptionTier]:
|
||||
"""Get all subscription tiers."""
|
||||
query = db.query(SubscriptionTier)
|
||||
|
||||
if not include_inactive:
|
||||
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
|
||||
|
||||
return query.order_by(SubscriptionTier.display_order).all()
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||
"""Get a subscription tier by code."""
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_code)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tier:
|
||||
raise TierNotFoundException(tier_code)
|
||||
|
||||
return tier
|
||||
|
||||
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
|
||||
"""Create a new subscription tier."""
|
||||
# Check for duplicate code
|
||||
existing = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_data["code"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ConflictException(
|
||||
f"Tier with code '{tier_data['code']}' already exists"
|
||||
)
|
||||
|
||||
tier = SubscriptionTier(**tier_data)
|
||||
db.add(tier)
|
||||
|
||||
logger.info(f"Created subscription tier: {tier.code}")
|
||||
return tier
|
||||
|
||||
def update_tier(
|
||||
self, db: Session, tier_code: str, update_data: dict
|
||||
) -> SubscriptionTier:
|
||||
"""Update a subscription tier."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(tier, field, value)
|
||||
|
||||
logger.info(f"Updated subscription tier: {tier.code}")
|
||||
return tier
|
||||
|
||||
def deactivate_tier(self, db: Session, tier_code: str) -> None:
|
||||
"""Soft-delete a subscription tier."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
# Check if any active subscriptions use this tier
|
||||
active_subs = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.tier == tier_code,
|
||||
VendorSubscription.status.in_([
|
||||
SubscriptionStatus.ACTIVE.value,
|
||||
SubscriptionStatus.TRIAL.value,
|
||||
]),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if active_subs > 0:
|
||||
raise BusinessLogicException(
|
||||
f"Cannot delete tier: {active_subs} active subscriptions are using it"
|
||||
)
|
||||
|
||||
tier.is_active = False
|
||||
|
||||
logger.info(f"Soft-deleted subscription tier: {tier.code}")
|
||||
|
||||
# =========================================================================
|
||||
# Vendor Subscriptions
|
||||
# =========================================================================
|
||||
|
||||
def list_subscriptions(
|
||||
self,
|
||||
db: Session,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
status: str | None = None,
|
||||
tier: str | None = None,
|
||||
search: str | None = None,
|
||||
) -> dict:
|
||||
"""List vendor subscriptions with filtering and pagination."""
|
||||
query = (
|
||||
db.query(VendorSubscription, Vendor)
|
||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if status:
|
||||
query = query.filter(VendorSubscription.status == status)
|
||||
if tier:
|
||||
query = query.filter(VendorSubscription.tier == tier)
|
||||
if search:
|
||||
query = query.filter(Vendor.name.ilike(f"%{search}%"))
|
||||
|
||||
# Count total
|
||||
total = query.count()
|
||||
|
||||
# Paginate
|
||||
offset = (page - 1) * per_page
|
||||
results = (
|
||||
query.order_by(VendorSubscription.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||
}
|
||||
|
||||
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
|
||||
"""Get subscription for a specific vendor."""
|
||||
result = (
|
||||
db.query(VendorSubscription, Vendor)
|
||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise ResourceNotFoundException("Subscription", str(vendor_id))
|
||||
|
||||
return result
|
||||
|
||||
def update_subscription(
|
||||
self, db: Session, vendor_id: int, update_data: dict
|
||||
) -> tuple:
|
||||
"""Update a vendor's subscription."""
|
||||
result = self.get_subscription(db, vendor_id)
|
||||
sub, vendor = result
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(sub, field, value)
|
||||
|
||||
logger.info(
|
||||
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
|
||||
)
|
||||
|
||||
return sub, vendor
|
||||
|
||||
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""Get a vendor by ID."""
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
|
||||
if not vendor:
|
||||
raise ResourceNotFoundException("Vendor", str(vendor_id))
|
||||
|
||||
return vendor
|
||||
|
||||
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
|
||||
"""Get usage counts (products and team members) for a vendor."""
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor_id,
|
||||
VendorUser.is_active == True, # noqa: E712
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"products_count": products_count,
|
||||
"team_count": team_count,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Billing History
|
||||
# =========================================================================
|
||||
|
||||
def list_billing_history(
|
||||
self,
|
||||
db: Session,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
vendor_id: int | None = None,
|
||||
status: str | None = None,
|
||||
) -> dict:
|
||||
"""List billing history across all vendors."""
|
||||
query = (
|
||||
db.query(BillingHistory, Vendor)
|
||||
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
|
||||
)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(BillingHistory.vendor_id == vendor_id)
|
||||
if status:
|
||||
query = query.filter(BillingHistory.status == status)
|
||||
|
||||
total = query.count()
|
||||
|
||||
offset = (page - 1) * per_page
|
||||
results = (
|
||||
query.order_by(BillingHistory.invoice_date.desc())
|
||||
.offset(offset)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_stats(self, db: Session) -> dict:
|
||||
"""Get subscription statistics for admin dashboard."""
|
||||
# Count by status
|
||||
status_counts = (
|
||||
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
|
||||
.group_by(VendorSubscription.status)
|
||||
.all()
|
||||
)
|
||||
|
||||
stats = {
|
||||
"total_subscriptions": 0,
|
||||
"active_count": 0,
|
||||
"trial_count": 0,
|
||||
"past_due_count": 0,
|
||||
"cancelled_count": 0,
|
||||
"expired_count": 0,
|
||||
}
|
||||
|
||||
for status, count in status_counts:
|
||||
stats["total_subscriptions"] += count
|
||||
if status == SubscriptionStatus.ACTIVE.value:
|
||||
stats["active_count"] = count
|
||||
elif status == SubscriptionStatus.TRIAL.value:
|
||||
stats["trial_count"] = count
|
||||
elif status == SubscriptionStatus.PAST_DUE.value:
|
||||
stats["past_due_count"] = count
|
||||
elif status == SubscriptionStatus.CANCELLED.value:
|
||||
stats["cancelled_count"] = count
|
||||
elif status == SubscriptionStatus.EXPIRED.value:
|
||||
stats["expired_count"] = count
|
||||
|
||||
# Count by tier
|
||||
tier_counts = (
|
||||
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
|
||||
.filter(
|
||||
VendorSubscription.status.in_([
|
||||
SubscriptionStatus.ACTIVE.value,
|
||||
SubscriptionStatus.TRIAL.value,
|
||||
])
|
||||
)
|
||||
.group_by(VendorSubscription.tier)
|
||||
.all()
|
||||
)
|
||||
|
||||
tier_distribution = {tier: count for tier, count in tier_counts}
|
||||
|
||||
# Calculate MRR (Monthly Recurring Revenue)
|
||||
mrr_cents = 0
|
||||
arr_cents = 0
|
||||
|
||||
active_subs = (
|
||||
db.query(VendorSubscription, SubscriptionTier)
|
||||
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
|
||||
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
|
||||
.all()
|
||||
)
|
||||
|
||||
for sub, tier in active_subs:
|
||||
if sub.is_annual and tier.price_annual_cents:
|
||||
mrr_cents += tier.price_annual_cents // 12
|
||||
arr_cents += tier.price_annual_cents
|
||||
else:
|
||||
mrr_cents += tier.price_monthly_cents
|
||||
arr_cents += tier.price_monthly_cents * 12
|
||||
|
||||
stats["tier_distribution"] = tier_distribution
|
||||
stats["mrr_cents"] = mrr_cents
|
||||
stats["arr_cents"] = arr_cents
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Singleton instance
|
||||
admin_subscription_service = AdminSubscriptionService()
|
||||
582
app/modules/billing/services/stripe_service.py
Normal file
582
app/modules/billing/services/stripe_service.py
Normal file
@@ -0,0 +1,582 @@
|
||||
# app/modules/billing/services/stripe_service.py
|
||||
"""
|
||||
Stripe payment integration service.
|
||||
|
||||
Provides:
|
||||
- Customer management
|
||||
- Subscription management
|
||||
- Checkout session creation
|
||||
- Customer portal access
|
||||
- Webhook event construction
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import stripe
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.modules.billing.exceptions import (
|
||||
StripeNotConfiguredException,
|
||||
WebhookVerificationException,
|
||||
)
|
||||
from app.modules.billing.models import (
|
||||
BillingHistory,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Service for Stripe payment operations."""
|
||||
|
||||
def __init__(self):
|
||||
self._configured = False
|
||||
self._configure()
|
||||
|
||||
def _configure(self):
|
||||
"""Configure Stripe with API key."""
|
||||
if settings.stripe_secret_key:
|
||||
stripe.api_key = settings.stripe_secret_key
|
||||
self._configured = True
|
||||
else:
|
||||
logger.warning("Stripe API key not configured")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Stripe is properly configured."""
|
||||
return self._configured and bool(settings.stripe_secret_key)
|
||||
|
||||
def _check_configured(self) -> None:
|
||||
"""Raise exception if Stripe is not configured."""
|
||||
if not self.is_configured:
|
||||
raise StripeNotConfiguredException()
|
||||
|
||||
# =========================================================================
|
||||
# Customer Management
|
||||
# =========================================================================
|
||||
|
||||
def create_customer(
|
||||
self,
|
||||
vendor: Vendor,
|
||||
email: str,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Stripe customer for a vendor.
|
||||
|
||||
Returns the Stripe customer ID.
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
customer_metadata = {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
email=email,
|
||||
name=name or vendor.name,
|
||||
metadata=customer_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
|
||||
)
|
||||
return customer.id
|
||||
|
||||
def get_customer(self, customer_id: str) -> stripe.Customer:
|
||||
"""Get a Stripe customer by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Customer.retrieve(customer_id)
|
||||
|
||||
def update_customer(
|
||||
self,
|
||||
customer_id: str,
|
||||
email: str | None = None,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Customer:
|
||||
"""Update a Stripe customer."""
|
||||
self._check_configured()
|
||||
|
||||
update_data = {}
|
||||
if email:
|
||||
update_data["email"] = email
|
||||
if name:
|
||||
update_data["name"] = name
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
return stripe.Customer.modify(customer_id, **update_data)
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Management
|
||||
# =========================================================================
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
trial_days: int | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Create a new Stripe subscription.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
price_id: Stripe price ID for the subscription
|
||||
trial_days: Optional trial period in days
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
subscription_data = {
|
||||
"customer": customer_id,
|
||||
"items": [{"price": price_id}],
|
||||
"metadata": metadata or {},
|
||||
"payment_behavior": "default_incomplete",
|
||||
"expand": ["latest_invoice.payment_intent"],
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
subscription_data["trial_period_days"] = trial_days
|
||||
|
||||
subscription = stripe.Subscription.create(**subscription_data)
|
||||
logger.info(
|
||||
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""Get a Stripe subscription by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Subscription.retrieve(subscription_id)
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
new_price_id: str | None = None,
|
||||
proration_behavior: str = "create_prorations",
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Update a Stripe subscription (e.g., change tier).
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
new_price_id: New price ID for tier change
|
||||
proration_behavior: How to handle prorations
|
||||
metadata: Optional metadata to update
|
||||
|
||||
Returns:
|
||||
Updated Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
update_data = {"proration_behavior": proration_behavior}
|
||||
|
||||
if new_price_id:
|
||||
# Get the subscription to find the item ID
|
||||
subscription = stripe.Subscription.retrieve(subscription_id)
|
||||
item_id = subscription["items"]["data"][0]["id"]
|
||||
update_data["items"] = [{"id": item_id, "price": new_price_id}]
|
||||
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
updated = stripe.Subscription.modify(subscription_id, **update_data)
|
||||
logger.info(f"Updated Stripe subscription {subscription_id}")
|
||||
return updated
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
immediately: bool = False,
|
||||
cancellation_reason: str | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Cancel a Stripe subscription.
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
immediately: If True, cancel now. If False, cancel at period end.
|
||||
cancellation_reason: Optional reason for cancellation
|
||||
|
||||
Returns:
|
||||
Cancelled Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
if immediately:
|
||||
subscription = stripe.Subscription.cancel(subscription_id)
|
||||
else:
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=True,
|
||||
metadata={"cancellation_reason": cancellation_reason or "user_request"},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Cancelled Stripe subscription {subscription_id} "
|
||||
f"(immediately={immediately})"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""
|
||||
Reactivate a cancelled subscription (if not yet ended).
|
||||
|
||||
Returns:
|
||||
Reactivated Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=False,
|
||||
)
|
||||
logger.info(f"Reactivated Stripe subscription {subscription_id}")
|
||||
return subscription
|
||||
|
||||
def cancel_subscription_item(self, subscription_item_id: str) -> None:
|
||||
"""
|
||||
Cancel a subscription item (used for add-ons).
|
||||
|
||||
Args:
|
||||
subscription_item_id: Stripe subscription item ID
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
stripe.SubscriptionItem.delete(subscription_item_id)
|
||||
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
|
||||
|
||||
# =========================================================================
|
||||
# Checkout & Portal
|
||||
# =========================================================================
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor,
|
||||
price_id: str,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
trial_days: int | None = None,
|
||||
quantity: int = 1,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.checkout.Session:
|
||||
"""
|
||||
Create a Stripe Checkout session for subscription signup.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor: Vendor to create checkout for
|
||||
price_id: Stripe price ID
|
||||
success_url: URL to redirect on success
|
||||
cancel_url: URL to redirect on cancel
|
||||
trial_days: Optional trial period
|
||||
quantity: Number of items (default 1)
|
||||
metadata: Additional metadata to store
|
||||
|
||||
Returns:
|
||||
Stripe Checkout Session object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
# Get or create Stripe customer
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription and subscription.stripe_customer_id:
|
||||
customer_id = subscription.stripe_customer_id
|
||||
else:
|
||||
# Get vendor owner email
|
||||
from models.database.vendor import VendorUser
|
||||
|
||||
owner = (
|
||||
db.query(VendorUser)
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor.id,
|
||||
VendorUser.is_owner == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
email = owner.user.email if owner and owner.user else None
|
||||
|
||||
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
|
||||
|
||||
# Store the customer ID
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer_id
|
||||
db.flush()
|
||||
|
||||
# Build metadata
|
||||
session_metadata = {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
}
|
||||
if metadata:
|
||||
session_metadata.update(metadata)
|
||||
|
||||
session_data = {
|
||||
"customer": customer_id,
|
||||
"line_items": [{"price": price_id, "quantity": quantity}],
|
||||
"mode": "subscription",
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"metadata": session_metadata,
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
session_data["subscription_data"] = {"trial_period_days": trial_days}
|
||||
|
||||
session = stripe.checkout.Session.create(**session_data)
|
||||
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
|
||||
return session
|
||||
|
||||
def create_portal_session(
|
||||
self,
|
||||
customer_id: str,
|
||||
return_url: str,
|
||||
) -> stripe.billing_portal.Session:
|
||||
"""
|
||||
Create a Stripe Customer Portal session.
|
||||
|
||||
Allows customers to manage their subscription, payment methods, and invoices.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
return_url: URL to return to after portal
|
||||
|
||||
Returns:
|
||||
Stripe Portal Session object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
logger.info(f"Created portal session for customer {customer_id}")
|
||||
return session
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Management
|
||||
# =========================================================================
|
||||
|
||||
def get_invoices(
|
||||
self,
|
||||
customer_id: str,
|
||||
limit: int = 10,
|
||||
) -> list[stripe.Invoice]:
|
||||
"""Get invoices for a customer."""
|
||||
self._check_configured()
|
||||
|
||||
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
|
||||
return list(invoices.data)
|
||||
|
||||
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
|
||||
"""Get the upcoming invoice for a customer."""
|
||||
self._check_configured()
|
||||
|
||||
try:
|
||||
return stripe.Invoice.upcoming(customer=customer_id)
|
||||
except stripe.error.InvalidRequestError:
|
||||
# No upcoming invoice
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Webhook Handling
|
||||
# =========================================================================
|
||||
|
||||
def construct_event(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
) -> stripe.Event:
|
||||
"""
|
||||
Construct and verify a Stripe webhook event.
|
||||
|
||||
Args:
|
||||
payload: Raw request body
|
||||
sig_header: Stripe-Signature header
|
||||
|
||||
Returns:
|
||||
Verified Stripe Event object
|
||||
|
||||
Raises:
|
||||
WebhookVerificationException: If signature verification fails
|
||||
"""
|
||||
if not settings.stripe_webhook_secret:
|
||||
raise WebhookVerificationException("Stripe webhook secret not configured")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload,
|
||||
sig_header,
|
||||
settings.stripe_webhook_secret,
|
||||
)
|
||||
return event
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
logger.error(f"Webhook signature verification failed: {e}")
|
||||
raise WebhookVerificationException("Invalid webhook signature")
|
||||
|
||||
# =========================================================================
|
||||
# SetupIntent & Payment Method Management
|
||||
# =========================================================================
|
||||
|
||||
def create_setup_intent(
|
||||
self,
|
||||
customer_id: str,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.SetupIntent:
|
||||
"""
|
||||
Create a SetupIntent to collect card without charging.
|
||||
|
||||
Used for trial signups where we collect card upfront
|
||||
but don't charge until trial ends.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe SetupIntent object with client_secret for frontend
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
setup_intent = stripe.SetupIntent.create(
|
||||
customer=customer_id,
|
||||
payment_method_types=["card"],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
|
||||
return setup_intent
|
||||
|
||||
def attach_payment_method_to_customer(
|
||||
self,
|
||||
customer_id: str,
|
||||
payment_method_id: str,
|
||||
set_as_default: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Attach a payment method to customer and optionally set as default.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
payment_method_id: Payment method ID from confirmed SetupIntent
|
||||
set_as_default: Whether to set as default payment method
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
# Attach the payment method to the customer
|
||||
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
|
||||
|
||||
if set_as_default:
|
||||
stripe.Customer.modify(
|
||||
customer_id,
|
||||
invoice_settings={"default_payment_method": payment_method_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Attached payment method {payment_method_id} to customer {customer_id} "
|
||||
f"(default={set_as_default})"
|
||||
)
|
||||
|
||||
def create_subscription_with_trial(
|
||||
self,
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
trial_days: int = 30,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Create subscription with trial period.
|
||||
|
||||
Customer must have a default payment method attached.
|
||||
Card will be charged automatically after trial ends.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID (must have default payment method)
|
||||
price_id: Stripe price ID for the subscription tier
|
||||
trial_days: Number of trial days (default 30)
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe Subscription object
|
||||
"""
|
||||
self._check_configured()
|
||||
|
||||
subscription = stripe.Subscription.create(
|
||||
customer=customer_id,
|
||||
items=[{"price": price_id}],
|
||||
trial_period_days=trial_days,
|
||||
metadata=metadata or {},
|
||||
# Use default payment method for future charges
|
||||
default_payment_method=None, # Uses customer's default
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created subscription {subscription.id} with {trial_days}-day trial "
|
||||
f"for customer {customer_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
|
||||
"""Get a SetupIntent by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.SetupIntent.retrieve(setup_intent_id)
|
||||
|
||||
# =========================================================================
|
||||
# Price/Product Management
|
||||
# =========================================================================
|
||||
|
||||
def get_price(self, price_id: str) -> stripe.Price:
|
||||
"""Get a Stripe price by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Price.retrieve(price_id)
|
||||
|
||||
def get_product(self, product_id: str) -> stripe.Product:
|
||||
"""Get a Stripe product by ID."""
|
||||
self._check_configured()
|
||||
|
||||
return stripe.Product.retrieve(product_id)
|
||||
|
||||
def list_prices(
|
||||
self,
|
||||
product_id: str | None = None,
|
||||
active: bool = True,
|
||||
) -> list[stripe.Price]:
|
||||
"""List Stripe prices, optionally filtered by product."""
|
||||
self._check_configured()
|
||||
|
||||
params = {"active": active}
|
||||
if product_id:
|
||||
params["product"] = product_id
|
||||
|
||||
prices = stripe.Price.list(**params)
|
||||
return list(prices.data)
|
||||
|
||||
|
||||
# Create service instance
|
||||
stripe_service = StripeService()
|
||||
631
app/modules/billing/services/subscription_service.py
Normal file
631
app/modules/billing/services/subscription_service.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# app/modules/billing/services/subscription_service.py
|
||||
"""
|
||||
Subscription service for tier-based access control.
|
||||
|
||||
Handles:
|
||||
- Subscription creation and management
|
||||
- Tier limit enforcement
|
||||
- Usage tracking
|
||||
- Feature gating
|
||||
|
||||
Usage:
|
||||
from app.modules.billing.services import subscription_service
|
||||
|
||||
# Check if vendor can create an order
|
||||
can_create, message = subscription_service.can_create_order(db, vendor_id)
|
||||
|
||||
# Increment order counter after successful order
|
||||
subscription_service.increment_order_count(db, vendor_id)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.billing.exceptions import (
|
||||
FeatureNotAvailableException,
|
||||
SubscriptionNotFoundException,
|
||||
TierLimitExceededException,
|
||||
)
|
||||
from app.modules.billing.models import (
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
TIER_LIMITS,
|
||||
TierCode,
|
||||
VendorSubscription,
|
||||
)
|
||||
from app.modules.billing.schemas import (
|
||||
SubscriptionCreate,
|
||||
SubscriptionUpdate,
|
||||
SubscriptionUsage,
|
||||
TierInfo,
|
||||
TierLimits,
|
||||
UsageSummary,
|
||||
)
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubscriptionService:
|
||||
"""Service for subscription and tier limit operations."""
|
||||
|
||||
# =========================================================================
|
||||
# Tier Information
|
||||
# =========================================================================
|
||||
|
||||
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
|
||||
"""
|
||||
Get full tier information.
|
||||
|
||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||
"""
|
||||
# Try database first if session provided
|
||||
if db is not None:
|
||||
db_tier = self.get_tier_by_code(db, tier_code)
|
||||
if db_tier:
|
||||
return TierInfo(
|
||||
code=db_tier.code,
|
||||
name=db_tier.name,
|
||||
price_monthly_cents=db_tier.price_monthly_cents,
|
||||
price_annual_cents=db_tier.price_annual_cents,
|
||||
limits=TierLimits(
|
||||
orders_per_month=db_tier.orders_per_month,
|
||||
products_limit=db_tier.products_limit,
|
||||
team_members=db_tier.team_members,
|
||||
order_history_months=db_tier.order_history_months,
|
||||
),
|
||||
features=db_tier.features or [],
|
||||
)
|
||||
|
||||
# Fallback to hardcoded TIER_LIMITS
|
||||
return self._get_tier_from_legacy(tier_code)
|
||||
|
||||
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
|
||||
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
|
||||
try:
|
||||
tier = TierCode(tier_code)
|
||||
except ValueError:
|
||||
tier = TierCode.ESSENTIAL
|
||||
|
||||
limits = TIER_LIMITS[tier]
|
||||
return TierInfo(
|
||||
code=tier.value,
|
||||
name=limits["name"],
|
||||
price_monthly_cents=limits["price_monthly_cents"],
|
||||
price_annual_cents=limits.get("price_annual_cents"),
|
||||
limits=TierLimits(
|
||||
orders_per_month=limits.get("orders_per_month"),
|
||||
products_limit=limits.get("products_limit"),
|
||||
team_members=limits.get("team_members"),
|
||||
order_history_months=limits.get("order_history_months"),
|
||||
),
|
||||
features=limits.get("features", []),
|
||||
)
|
||||
|
||||
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
|
||||
"""
|
||||
Get information for all tiers.
|
||||
|
||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||
"""
|
||||
if db is not None:
|
||||
db_tiers = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
SubscriptionTier.is_public == True, # noqa: E712
|
||||
)
|
||||
.order_by(SubscriptionTier.display_order)
|
||||
.all()
|
||||
)
|
||||
if db_tiers:
|
||||
return [
|
||||
TierInfo(
|
||||
code=t.code,
|
||||
name=t.name,
|
||||
price_monthly_cents=t.price_monthly_cents,
|
||||
price_annual_cents=t.price_annual_cents,
|
||||
limits=TierLimits(
|
||||
orders_per_month=t.orders_per_month,
|
||||
products_limit=t.products_limit,
|
||||
team_members=t.team_members,
|
||||
order_history_months=t.order_history_months,
|
||||
),
|
||||
features=t.features or [],
|
||||
)
|
||||
for t in db_tiers
|
||||
]
|
||||
|
||||
# Fallback to hardcoded
|
||||
return [
|
||||
self._get_tier_from_legacy(tier.value)
|
||||
for tier in TierCode
|
||||
]
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
|
||||
"""Get subscription tier by code."""
|
||||
return (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_code)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
|
||||
"""Get tier ID from tier code. Returns None if tier not found."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
return tier.id if tier else None
|
||||
|
||||
# =========================================================================
|
||||
# Subscription CRUD
|
||||
# =========================================================================
|
||||
|
||||
def get_subscription(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription | None:
|
||||
"""Get vendor subscription."""
|
||||
return (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_subscription_or_raise(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription:
|
||||
"""Get vendor subscription or raise exception."""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if not subscription:
|
||||
raise SubscriptionNotFoundException(vendor_id)
|
||||
return subscription
|
||||
|
||||
def get_current_tier(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> TierCode | None:
|
||||
"""Get vendor's current subscription tier code."""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if subscription:
|
||||
try:
|
||||
return TierCode(subscription.tier)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_or_create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
tier: str = TierCode.ESSENTIAL.value,
|
||||
trial_days: int = 14,
|
||||
) -> VendorSubscription:
|
||||
"""
|
||||
Get existing subscription or create a new trial subscription.
|
||||
|
||||
Used when a vendor first accesses the system.
|
||||
"""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if subscription:
|
||||
return subscription
|
||||
|
||||
# Create new trial subscription
|
||||
now = datetime.now(UTC)
|
||||
trial_end = now + timedelta(days=trial_days)
|
||||
|
||||
# Lookup tier_id from tier code
|
||||
tier_id = self.get_tier_id(db, tier)
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=tier,
|
||||
tier_id=tier_id,
|
||||
status=SubscriptionStatus.TRIAL.value,
|
||||
period_start=now,
|
||||
period_end=trial_end,
|
||||
trial_ends_at=trial_end,
|
||||
is_annual=False,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(
|
||||
f"Created trial subscription for vendor {vendor_id} "
|
||||
f"(tier={tier}, trial_ends={trial_end})"
|
||||
)
|
||||
|
||||
return subscription
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionCreate,
|
||||
) -> VendorSubscription:
|
||||
"""Create a subscription for a vendor."""
|
||||
# Check if subscription exists
|
||||
existing = self.get_subscription(db, vendor_id)
|
||||
if existing:
|
||||
raise ValueError("Vendor already has a subscription")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Calculate period end based on billing cycle
|
||||
if data.is_annual:
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
# Handle trial
|
||||
trial_ends_at = None
|
||||
status = SubscriptionStatus.ACTIVE.value
|
||||
if data.trial_days > 0:
|
||||
trial_ends_at = now + timedelta(days=data.trial_days)
|
||||
status = SubscriptionStatus.TRIAL.value
|
||||
period_end = trial_ends_at
|
||||
|
||||
# Lookup tier_id from tier code
|
||||
tier_id = self.get_tier_id(db, data.tier)
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=data.tier,
|
||||
tier_id=tier_id,
|
||||
status=status,
|
||||
period_start=now,
|
||||
period_end=period_end,
|
||||
trial_ends_at=trial_ends_at,
|
||||
is_annual=data.is_annual,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
|
||||
return subscription
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionUpdate,
|
||||
) -> VendorSubscription:
|
||||
"""Update a vendor subscription."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# If tier is being updated, also update tier_id
|
||||
if "tier" in update_data:
|
||||
tier_id = self.get_tier_id(db, update_data["tier"])
|
||||
update_data["tier_id"] = tier_id
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(subscription, key, value)
|
||||
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Updated subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
def upgrade_tier(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
new_tier: str,
|
||||
) -> VendorSubscription:
|
||||
"""Upgrade vendor to a new tier."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
old_tier = subscription.tier
|
||||
subscription.tier = new_tier
|
||||
subscription.tier_id = self.get_tier_id(db, new_tier)
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
# If upgrading from trial, mark as active
|
||||
if subscription.status == SubscriptionStatus.TRIAL.value:
|
||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
|
||||
return subscription
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
reason: str | None = None,
|
||||
) -> VendorSubscription:
|
||||
"""Cancel a vendor subscription (access until period end)."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
subscription.status = SubscriptionStatus.CANCELLED.value
|
||||
subscription.cancelled_at = datetime.now(UTC)
|
||||
subscription.cancellation_reason = reason
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Cancelled subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
# =========================================================================
|
||||
# Usage Tracking
|
||||
# =========================================================================
|
||||
|
||||
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
|
||||
"""Get current subscription usage statistics."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Get actual counts
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Calculate usage stats
|
||||
orders_limit = subscription.orders_limit
|
||||
products_limit = subscription.products_limit
|
||||
team_limit = subscription.team_members_limit
|
||||
|
||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||
if limit is None:
|
||||
return None
|
||||
return max(0, limit - current)
|
||||
|
||||
def calc_percent(current: int, limit: int | None) -> float | None:
|
||||
if limit is None or limit == 0:
|
||||
return None
|
||||
return min(100.0, (current / limit) * 100)
|
||||
|
||||
return SubscriptionUsage(
|
||||
orders_used=subscription.orders_this_period,
|
||||
orders_limit=orders_limit,
|
||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
|
||||
products_used=products_count,
|
||||
products_limit=products_limit,
|
||||
products_remaining=calc_remaining(products_count, products_limit),
|
||||
products_percent_used=calc_percent(products_count, products_limit),
|
||||
team_members_used=team_count,
|
||||
team_members_limit=team_limit,
|
||||
team_members_remaining=calc_remaining(team_count, team_limit),
|
||||
team_members_percent_used=calc_percent(team_count, team_limit),
|
||||
)
|
||||
|
||||
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
|
||||
"""Get usage summary for billing page display."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Get actual counts
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Get limits
|
||||
orders_limit = subscription.orders_limit
|
||||
products_limit = subscription.products_limit
|
||||
team_limit = subscription.team_members_limit
|
||||
|
||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||
if limit is None:
|
||||
return None
|
||||
return max(0, limit - current)
|
||||
|
||||
return UsageSummary(
|
||||
orders_this_period=subscription.orders_this_period,
|
||||
orders_limit=orders_limit,
|
||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||
products_count=products_count,
|
||||
products_limit=products_limit,
|
||||
products_remaining=calc_remaining(products_count, products_limit),
|
||||
team_count=team_count,
|
||||
team_limit=team_limit,
|
||||
team_remaining=calc_remaining(team_count, team_limit),
|
||||
)
|
||||
|
||||
def increment_order_count(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Increment the order counter for the current period.
|
||||
|
||||
Call this after successfully creating/importing an order.
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
subscription.increment_order_count()
|
||||
db.flush()
|
||||
|
||||
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
|
||||
"""Reset counters for a new billing period."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
subscription.reset_period_counters()
|
||||
db.flush()
|
||||
logger.info(f"Reset period counters for vendor {vendor_id}")
|
||||
|
||||
# =========================================================================
|
||||
# Limit Checks
|
||||
# =========================================================================
|
||||
|
||||
def can_create_order(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can create/import another order.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.can_create_order()
|
||||
|
||||
def check_order_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check order limit and raise exception if exceeded.
|
||||
|
||||
Use this in order creation flows.
|
||||
"""
|
||||
can_create, message = self.can_create_order(db, vendor_id)
|
||||
if not can_create:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Order limit exceeded",
|
||||
limit_type="orders",
|
||||
current=subscription.orders_this_period if subscription else 0,
|
||||
limit=subscription.orders_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_product(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another product.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_product(products_count)
|
||||
|
||||
def check_product_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check product limit and raise exception if exceeded.
|
||||
|
||||
Use this in product creation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_product(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Product limit exceeded",
|
||||
limit_type="products",
|
||||
current=products_count,
|
||||
limit=subscription.products_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_team_member(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another team member.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_team_member(team_count)
|
||||
|
||||
def check_team_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check team member limit and raise exception if exceeded.
|
||||
|
||||
Use this in team member invitation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_team_member(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Team member limit exceeded",
|
||||
limit_type="team_members",
|
||||
current=team_count,
|
||||
limit=subscription.team_members_limit if subscription else 0,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Feature Gating
|
||||
# =========================================================================
|
||||
|
||||
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
|
||||
"""Check if vendor has access to a feature."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.has_feature(feature)
|
||||
|
||||
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
|
||||
"""
|
||||
Check feature access and raise exception if not available.
|
||||
|
||||
Use this to gate premium features.
|
||||
"""
|
||||
if not self.has_feature(db, vendor_id, feature):
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Find which tier has this feature
|
||||
required_tier = None
|
||||
for tier_code, limits in TIER_LIMITS.items():
|
||||
if feature in limits.get("features", []):
|
||||
required_tier = limits["name"]
|
||||
break
|
||||
|
||||
raise FeatureNotAvailableException(
|
||||
feature=feature,
|
||||
current_tier=subscription.tier,
|
||||
required_tier=required_tier or "higher",
|
||||
)
|
||||
|
||||
def get_feature_tier(self, feature: str) -> str | None:
|
||||
"""Get the minimum tier required for a feature."""
|
||||
for tier_code in [
|
||||
TierCode.ESSENTIAL,
|
||||
TierCode.PROFESSIONAL,
|
||||
TierCode.BUSINESS,
|
||||
TierCode.ENTERPRISE,
|
||||
]:
|
||||
if feature in TIER_LIMITS[tier_code].get("features", []):
|
||||
return tier_code.value
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
subscription_service = SubscriptionService()
|
||||
26
app/modules/billing/tasks/__init__.py
Normal file
26
app/modules/billing/tasks/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# app/modules/billing/tasks/__init__.py
|
||||
"""
|
||||
Billing module Celery tasks.
|
||||
|
||||
Scheduled tasks for:
|
||||
- Resetting period counters
|
||||
- Checking trial expirations
|
||||
- Syncing with Stripe
|
||||
- Cleaning up stale subscriptions
|
||||
|
||||
Note: capture_capacity_snapshot moved to monitoring module.
|
||||
"""
|
||||
|
||||
from app.modules.billing.tasks.subscription import (
|
||||
reset_period_counters,
|
||||
check_trial_expirations,
|
||||
sync_stripe_status,
|
||||
cleanup_stale_subscriptions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"reset_period_counters",
|
||||
"check_trial_expirations",
|
||||
"sync_stripe_status",
|
||||
"cleanup_stale_subscriptions",
|
||||
]
|
||||
255
app/modules/billing/tasks/subscription.py
Normal file
255
app/modules/billing/tasks/subscription.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# app/modules/billing/tasks/subscription.py
|
||||
"""
|
||||
Celery tasks for subscription management.
|
||||
|
||||
Scheduled tasks for:
|
||||
- Resetting period counters
|
||||
- Checking trial expirations
|
||||
- Syncing with Stripe
|
||||
- Cleaning up stale subscriptions
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.core.celery_config import celery_app
|
||||
from app.modules.billing.models import SubscriptionStatus, VendorSubscription
|
||||
from app.modules.billing.services import stripe_service
|
||||
from app.modules.task_base import ModuleTask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=ModuleTask,
|
||||
name="app.modules.billing.tasks.subscription.reset_period_counters",
|
||||
)
|
||||
def reset_period_counters(self):
|
||||
"""
|
||||
Reset order counters for subscriptions whose billing period has ended.
|
||||
|
||||
Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
reset_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find subscriptions where period has ended
|
||||
expired_periods = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.period_end <= now,
|
||||
VendorSubscription.status.in_(["active", "trial"]),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in expired_periods:
|
||||
old_period_end = subscription.period_end
|
||||
|
||||
# Reset counters
|
||||
subscription.orders_this_period = 0
|
||||
subscription.orders_limit_reached_at = None
|
||||
|
||||
# Set new period dates
|
||||
if subscription.is_annual:
|
||||
subscription.period_start = now
|
||||
subscription.period_end = now + timedelta(days=365)
|
||||
else:
|
||||
subscription.period_start = now
|
||||
subscription.period_end = now + timedelta(days=30)
|
||||
|
||||
subscription.updated_at = now
|
||||
reset_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Reset period counters for vendor {subscription.vendor_id}: "
|
||||
f"old_period_end={old_period_end}, new_period_end={subscription.period_end}"
|
||||
)
|
||||
|
||||
logger.info(f"Reset period counters for {reset_count} subscriptions")
|
||||
|
||||
return {"reset_count": reset_count}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=ModuleTask,
|
||||
name="app.modules.billing.tasks.subscription.check_trial_expirations",
|
||||
)
|
||||
def check_trial_expirations(self):
|
||||
"""
|
||||
Check for expired trials and update their status.
|
||||
|
||||
Runs daily at 01:00.
|
||||
- Trials without payment method -> expired
|
||||
- Trials with payment method -> active
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expired_count = 0
|
||||
activated_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find expired trials
|
||||
expired_trials = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.status == SubscriptionStatus.TRIAL.value,
|
||||
VendorSubscription.trial_ends_at <= now,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in expired_trials:
|
||||
if subscription.stripe_payment_method_id:
|
||||
# Has payment method - activate
|
||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
activated_count += 1
|
||||
logger.info(
|
||||
f"Activated subscription for vendor {subscription.vendor_id} "
|
||||
f"(trial ended with payment method)"
|
||||
)
|
||||
else:
|
||||
# No payment method - expire
|
||||
subscription.status = SubscriptionStatus.EXPIRED.value
|
||||
expired_count += 1
|
||||
logger.info(
|
||||
f"Expired trial for vendor {subscription.vendor_id} "
|
||||
f"(no payment method)"
|
||||
)
|
||||
|
||||
subscription.updated_at = now
|
||||
|
||||
logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated")
|
||||
|
||||
return {"expired_count": expired_count, "activated_count": activated_count}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=ModuleTask,
|
||||
name="app.modules.billing.tasks.subscription.sync_stripe_status",
|
||||
max_retries=3,
|
||||
default_retry_delay=300,
|
||||
)
|
||||
def sync_stripe_status(self):
|
||||
"""
|
||||
Sync subscription status with Stripe.
|
||||
|
||||
Runs hourly at :30. Fetches current status from Stripe and updates local records.
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
logger.warning("Stripe not configured, skipping sync")
|
||||
return {"synced": 0, "skipped": True}
|
||||
|
||||
synced_count = 0
|
||||
error_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find subscriptions with Stripe IDs
|
||||
subscriptions = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_subscription_id.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in subscriptions:
|
||||
try:
|
||||
# Fetch from Stripe
|
||||
stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id)
|
||||
|
||||
if not stripe_sub:
|
||||
logger.warning(
|
||||
f"Stripe subscription {subscription.stripe_subscription_id} "
|
||||
f"not found for vendor {subscription.vendor_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Map Stripe status to local status
|
||||
status_map = {
|
||||
"active": SubscriptionStatus.ACTIVE.value,
|
||||
"trialing": SubscriptionStatus.TRIAL.value,
|
||||
"past_due": SubscriptionStatus.PAST_DUE.value,
|
||||
"canceled": SubscriptionStatus.CANCELLED.value,
|
||||
"unpaid": SubscriptionStatus.PAST_DUE.value,
|
||||
"incomplete": SubscriptionStatus.TRIAL.value,
|
||||
"incomplete_expired": SubscriptionStatus.EXPIRED.value,
|
||||
}
|
||||
|
||||
new_status = status_map.get(stripe_sub.status)
|
||||
if new_status and new_status != subscription.status:
|
||||
old_status = subscription.status
|
||||
subscription.status = new_status
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
logger.info(
|
||||
f"Updated vendor {subscription.vendor_id} status: "
|
||||
f"{old_status} -> {new_status} (from Stripe)"
|
||||
)
|
||||
|
||||
# Update period dates from Stripe
|
||||
if stripe_sub.current_period_start:
|
||||
subscription.period_start = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_start, tz=UTC
|
||||
)
|
||||
if stripe_sub.current_period_end:
|
||||
subscription.period_end = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_end, tz=UTC
|
||||
)
|
||||
|
||||
# Update payment method
|
||||
if stripe_sub.default_payment_method:
|
||||
subscription.stripe_payment_method_id = (
|
||||
stripe_sub.default_payment_method
|
||||
if isinstance(stripe_sub.default_payment_method, str)
|
||||
else stripe_sub.default_payment_method.id
|
||||
)
|
||||
|
||||
synced_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}")
|
||||
error_count += 1
|
||||
|
||||
logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors")
|
||||
|
||||
return {"synced_count": synced_count, "error_count": error_count}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=ModuleTask,
|
||||
name="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions",
|
||||
)
|
||||
def cleanup_stale_subscriptions(self):
|
||||
"""
|
||||
Clean up subscriptions in inconsistent states.
|
||||
|
||||
Runs weekly on Sunday at 03:00.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
cleaned_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find cancelled subscriptions past their period end
|
||||
stale_cancelled = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.status == SubscriptionStatus.CANCELLED.value,
|
||||
VendorSubscription.period_end < now - timedelta(days=30),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in stale_cancelled:
|
||||
# Mark as expired (fully terminated)
|
||||
subscription.status = SubscriptionStatus.EXPIRED.value
|
||||
subscription.updated_at = now
|
||||
cleaned_count += 1
|
||||
logger.info(
|
||||
f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Cleaned up {cleaned_count} stale subscriptions")
|
||||
|
||||
return {"cleaned_count": cleaned_count}
|
||||
@@ -2,351 +2,21 @@
|
||||
"""
|
||||
Admin Subscription Service.
|
||||
|
||||
Handles subscription management operations for platform administrators:
|
||||
- Subscription tier CRUD
|
||||
- Vendor subscription management
|
||||
- Billing history queries
|
||||
- Subscription analytics
|
||||
DEPRECATED: This file is maintained for backward compatibility.
|
||||
Import from app.modules.billing.services instead:
|
||||
|
||||
from app.modules.billing.services import admin_subscription_service
|
||||
|
||||
This file re-exports the service from its new location in the billing module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from math import ceil
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
BusinessLogicException,
|
||||
ConflictException,
|
||||
ResourceNotFoundException,
|
||||
TierNotFoundException,
|
||||
# Re-export from new location for backward compatibility
|
||||
from app.modules.billing.services.admin_subscription_service import (
|
||||
AdminSubscriptionService,
|
||||
admin_subscription_service,
|
||||
)
|
||||
from models.database.product import Product
|
||||
from models.database.subscription import (
|
||||
BillingHistory,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminSubscriptionService:
|
||||
"""Service for admin subscription management operations."""
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Tiers
|
||||
# =========================================================================
|
||||
|
||||
def get_tiers(
|
||||
self, db: Session, include_inactive: bool = False
|
||||
) -> list[SubscriptionTier]:
|
||||
"""Get all subscription tiers."""
|
||||
query = db.query(SubscriptionTier)
|
||||
|
||||
if not include_inactive:
|
||||
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
|
||||
|
||||
return query.order_by(SubscriptionTier.display_order).all()
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||
"""Get a subscription tier by code."""
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_code)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tier:
|
||||
raise TierNotFoundException(tier_code)
|
||||
|
||||
return tier
|
||||
|
||||
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
|
||||
"""Create a new subscription tier."""
|
||||
# Check for duplicate code
|
||||
existing = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_data["code"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ConflictException(
|
||||
f"Tier with code '{tier_data['code']}' already exists"
|
||||
)
|
||||
|
||||
tier = SubscriptionTier(**tier_data)
|
||||
db.add(tier)
|
||||
|
||||
logger.info(f"Created subscription tier: {tier.code}")
|
||||
return tier
|
||||
|
||||
def update_tier(
|
||||
self, db: Session, tier_code: str, update_data: dict
|
||||
) -> SubscriptionTier:
|
||||
"""Update a subscription tier."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(tier, field, value)
|
||||
|
||||
logger.info(f"Updated subscription tier: {tier.code}")
|
||||
return tier
|
||||
|
||||
def deactivate_tier(self, db: Session, tier_code: str) -> None:
|
||||
"""Soft-delete a subscription tier."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
# Check if any active subscriptions use this tier
|
||||
active_subs = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.tier == tier_code,
|
||||
VendorSubscription.status.in_([
|
||||
SubscriptionStatus.ACTIVE.value,
|
||||
SubscriptionStatus.TRIAL.value,
|
||||
]),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if active_subs > 0:
|
||||
raise BusinessLogicException(
|
||||
f"Cannot delete tier: {active_subs} active subscriptions are using it"
|
||||
)
|
||||
|
||||
tier.is_active = False
|
||||
|
||||
logger.info(f"Soft-deleted subscription tier: {tier.code}")
|
||||
|
||||
# =========================================================================
|
||||
# Vendor Subscriptions
|
||||
# =========================================================================
|
||||
|
||||
def list_subscriptions(
|
||||
self,
|
||||
db: Session,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
status: str | None = None,
|
||||
tier: str | None = None,
|
||||
search: str | None = None,
|
||||
) -> dict:
|
||||
"""List vendor subscriptions with filtering and pagination."""
|
||||
query = (
|
||||
db.query(VendorSubscription, Vendor)
|
||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if status:
|
||||
query = query.filter(VendorSubscription.status == status)
|
||||
if tier:
|
||||
query = query.filter(VendorSubscription.tier == tier)
|
||||
if search:
|
||||
query = query.filter(Vendor.name.ilike(f"%{search}%"))
|
||||
|
||||
# Count total
|
||||
total = query.count()
|
||||
|
||||
# Paginate
|
||||
offset = (page - 1) * per_page
|
||||
results = (
|
||||
query.order_by(VendorSubscription.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||
}
|
||||
|
||||
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
|
||||
"""Get subscription for a specific vendor."""
|
||||
result = (
|
||||
db.query(VendorSubscription, Vendor)
|
||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise ResourceNotFoundException("Subscription", str(vendor_id))
|
||||
|
||||
return result
|
||||
|
||||
def update_subscription(
|
||||
self, db: Session, vendor_id: int, update_data: dict
|
||||
) -> tuple:
|
||||
"""Update a vendor's subscription."""
|
||||
result = self.get_subscription(db, vendor_id)
|
||||
sub, vendor = result
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(sub, field, value)
|
||||
|
||||
logger.info(
|
||||
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
|
||||
)
|
||||
|
||||
return sub, vendor
|
||||
|
||||
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""Get a vendor by ID."""
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
|
||||
if not vendor:
|
||||
raise ResourceNotFoundException("Vendor", str(vendor_id))
|
||||
|
||||
return vendor
|
||||
|
||||
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
|
||||
"""Get usage counts (products and team members) for a vendor."""
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor_id,
|
||||
VendorUser.is_active == True, # noqa: E712
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"products_count": products_count,
|
||||
"team_count": team_count,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Billing History
|
||||
# =========================================================================
|
||||
|
||||
def list_billing_history(
|
||||
self,
|
||||
db: Session,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
vendor_id: int | None = None,
|
||||
status: str | None = None,
|
||||
) -> dict:
|
||||
"""List billing history across all vendors."""
|
||||
query = (
|
||||
db.query(BillingHistory, Vendor)
|
||||
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
|
||||
)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(BillingHistory.vendor_id == vendor_id)
|
||||
if status:
|
||||
query = query.filter(BillingHistory.status == status)
|
||||
|
||||
total = query.count()
|
||||
|
||||
offset = (page - 1) * per_page
|
||||
results = (
|
||||
query.order_by(BillingHistory.invoice_date.desc())
|
||||
.offset(offset)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_stats(self, db: Session) -> dict:
|
||||
"""Get subscription statistics for admin dashboard."""
|
||||
# Count by status
|
||||
status_counts = (
|
||||
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
|
||||
.group_by(VendorSubscription.status)
|
||||
.all()
|
||||
)
|
||||
|
||||
stats = {
|
||||
"total_subscriptions": 0,
|
||||
"active_count": 0,
|
||||
"trial_count": 0,
|
||||
"past_due_count": 0,
|
||||
"cancelled_count": 0,
|
||||
"expired_count": 0,
|
||||
}
|
||||
|
||||
for status, count in status_counts:
|
||||
stats["total_subscriptions"] += count
|
||||
if status == SubscriptionStatus.ACTIVE.value:
|
||||
stats["active_count"] = count
|
||||
elif status == SubscriptionStatus.TRIAL.value:
|
||||
stats["trial_count"] = count
|
||||
elif status == SubscriptionStatus.PAST_DUE.value:
|
||||
stats["past_due_count"] = count
|
||||
elif status == SubscriptionStatus.CANCELLED.value:
|
||||
stats["cancelled_count"] = count
|
||||
elif status == SubscriptionStatus.EXPIRED.value:
|
||||
stats["expired_count"] = count
|
||||
|
||||
# Count by tier
|
||||
tier_counts = (
|
||||
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
|
||||
.filter(
|
||||
VendorSubscription.status.in_([
|
||||
SubscriptionStatus.ACTIVE.value,
|
||||
SubscriptionStatus.TRIAL.value,
|
||||
])
|
||||
)
|
||||
.group_by(VendorSubscription.tier)
|
||||
.all()
|
||||
)
|
||||
|
||||
tier_distribution = {tier: count for tier, count in tier_counts}
|
||||
|
||||
# Calculate MRR (Monthly Recurring Revenue)
|
||||
mrr_cents = 0
|
||||
arr_cents = 0
|
||||
|
||||
active_subs = (
|
||||
db.query(VendorSubscription, SubscriptionTier)
|
||||
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
|
||||
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
|
||||
.all()
|
||||
)
|
||||
|
||||
for sub, tier in active_subs:
|
||||
if sub.is_annual and tier.price_annual_cents:
|
||||
mrr_cents += tier.price_annual_cents // 12
|
||||
arr_cents += tier.price_annual_cents
|
||||
else:
|
||||
mrr_cents += tier.price_monthly_cents
|
||||
arr_cents += tier.price_monthly_cents * 12
|
||||
|
||||
stats["tier_distribution"] = tier_distribution
|
||||
stats["mrr_cents"] = mrr_cents
|
||||
stats["arr_cents"] = arr_cents
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Singleton instance
|
||||
admin_subscription_service = AdminSubscriptionService()
|
||||
__all__ = [
|
||||
"AdminSubscriptionService",
|
||||
"admin_subscription_service",
|
||||
]
|
||||
|
||||
@@ -2,592 +2,21 @@
|
||||
"""
|
||||
Stripe payment integration service.
|
||||
|
||||
Provides:
|
||||
- Customer management
|
||||
- Subscription management
|
||||
- Checkout session creation
|
||||
- Customer portal access
|
||||
- Webhook event construction
|
||||
DEPRECATED: This file is maintained for backward compatibility.
|
||||
Import from app.modules.billing.services instead:
|
||||
|
||||
from app.modules.billing.services import stripe_service
|
||||
|
||||
This file re-exports the service from its new location in the billing module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import stripe
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from models.database.subscription import (
|
||||
BillingHistory,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
# Re-export from new location for backward compatibility
|
||||
from app.modules.billing.services.stripe_service import (
|
||||
StripeService,
|
||||
stripe_service,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Service for Stripe payment operations."""
|
||||
|
||||
def __init__(self):
|
||||
self._configured = False
|
||||
self._configure()
|
||||
|
||||
def _configure(self):
|
||||
"""Configure Stripe with API key."""
|
||||
if settings.stripe_secret_key:
|
||||
stripe.api_key = settings.stripe_secret_key
|
||||
self._configured = True
|
||||
else:
|
||||
logger.warning("Stripe API key not configured")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Stripe is properly configured."""
|
||||
return self._configured and bool(settings.stripe_secret_key)
|
||||
|
||||
# =========================================================================
|
||||
# Customer Management
|
||||
# =========================================================================
|
||||
|
||||
def create_customer(
|
||||
self,
|
||||
vendor: Vendor,
|
||||
email: str,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Stripe customer for a vendor.
|
||||
|
||||
Returns the Stripe customer ID.
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
customer_metadata = {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
email=email,
|
||||
name=name or vendor.name,
|
||||
metadata=customer_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
|
||||
)
|
||||
return customer.id
|
||||
|
||||
def get_customer(self, customer_id: str) -> stripe.Customer:
|
||||
"""Get a Stripe customer by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Customer.retrieve(customer_id)
|
||||
|
||||
def update_customer(
|
||||
self,
|
||||
customer_id: str,
|
||||
email: str | None = None,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Customer:
|
||||
"""Update a Stripe customer."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
update_data = {}
|
||||
if email:
|
||||
update_data["email"] = email
|
||||
if name:
|
||||
update_data["name"] = name
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
return stripe.Customer.modify(customer_id, **update_data)
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Management
|
||||
# =========================================================================
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
trial_days: int | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Create a new Stripe subscription.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
price_id: Stripe price ID for the subscription
|
||||
trial_days: Optional trial period in days
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
subscription_data = {
|
||||
"customer": customer_id,
|
||||
"items": [{"price": price_id}],
|
||||
"metadata": metadata or {},
|
||||
"payment_behavior": "default_incomplete",
|
||||
"expand": ["latest_invoice.payment_intent"],
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
subscription_data["trial_period_days"] = trial_days
|
||||
|
||||
subscription = stripe.Subscription.create(**subscription_data)
|
||||
logger.info(
|
||||
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""Get a Stripe subscription by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Subscription.retrieve(subscription_id)
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
new_price_id: str | None = None,
|
||||
proration_behavior: str = "create_prorations",
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Update a Stripe subscription (e.g., change tier).
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
new_price_id: New price ID for tier change
|
||||
proration_behavior: How to handle prorations
|
||||
metadata: Optional metadata to update
|
||||
|
||||
Returns:
|
||||
Updated Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
update_data = {"proration_behavior": proration_behavior}
|
||||
|
||||
if new_price_id:
|
||||
# Get the subscription to find the item ID
|
||||
subscription = stripe.Subscription.retrieve(subscription_id)
|
||||
item_id = subscription["items"]["data"][0]["id"]
|
||||
update_data["items"] = [{"id": item_id, "price": new_price_id}]
|
||||
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
updated = stripe.Subscription.modify(subscription_id, **update_data)
|
||||
logger.info(f"Updated Stripe subscription {subscription_id}")
|
||||
return updated
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
immediately: bool = False,
|
||||
cancellation_reason: str | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Cancel a Stripe subscription.
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
immediately: If True, cancel now. If False, cancel at period end.
|
||||
cancellation_reason: Optional reason for cancellation
|
||||
|
||||
Returns:
|
||||
Cancelled Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
if immediately:
|
||||
subscription = stripe.Subscription.cancel(subscription_id)
|
||||
else:
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=True,
|
||||
metadata={"cancellation_reason": cancellation_reason or "user_request"},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Cancelled Stripe subscription {subscription_id} "
|
||||
f"(immediately={immediately})"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""
|
||||
Reactivate a cancelled subscription (if not yet ended).
|
||||
|
||||
Returns:
|
||||
Reactivated Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=False,
|
||||
)
|
||||
logger.info(f"Reactivated Stripe subscription {subscription_id}")
|
||||
return subscription
|
||||
|
||||
def cancel_subscription_item(self, subscription_item_id: str) -> None:
|
||||
"""
|
||||
Cancel a subscription item (used for add-ons).
|
||||
|
||||
Args:
|
||||
subscription_item_id: Stripe subscription item ID
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
stripe.SubscriptionItem.delete(subscription_item_id)
|
||||
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
|
||||
|
||||
# =========================================================================
|
||||
# Checkout & Portal
|
||||
# =========================================================================
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor,
|
||||
price_id: str,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
trial_days: int | None = None,
|
||||
quantity: int = 1,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.checkout.Session:
|
||||
"""
|
||||
Create a Stripe Checkout session for subscription signup.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor: Vendor to create checkout for
|
||||
price_id: Stripe price ID
|
||||
success_url: URL to redirect on success
|
||||
cancel_url: URL to redirect on cancel
|
||||
trial_days: Optional trial period
|
||||
quantity: Number of items (default 1)
|
||||
metadata: Additional metadata to store
|
||||
|
||||
Returns:
|
||||
Stripe Checkout Session object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
# Get or create Stripe customer
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription and subscription.stripe_customer_id:
|
||||
customer_id = subscription.stripe_customer_id
|
||||
else:
|
||||
# Get vendor owner email
|
||||
from models.database.vendor import VendorUser
|
||||
|
||||
owner = (
|
||||
db.query(VendorUser)
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor.id,
|
||||
VendorUser.is_owner == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
email = owner.user.email if owner and owner.user else None
|
||||
|
||||
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
|
||||
|
||||
# Store the customer ID
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer_id
|
||||
db.flush()
|
||||
|
||||
# Build metadata
|
||||
session_metadata = {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
}
|
||||
if metadata:
|
||||
session_metadata.update(metadata)
|
||||
|
||||
session_data = {
|
||||
"customer": customer_id,
|
||||
"line_items": [{"price": price_id, "quantity": quantity}],
|
||||
"mode": "subscription",
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"metadata": session_metadata,
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
session_data["subscription_data"] = {"trial_period_days": trial_days}
|
||||
|
||||
session = stripe.checkout.Session.create(**session_data)
|
||||
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
|
||||
return session
|
||||
|
||||
def create_portal_session(
|
||||
self,
|
||||
customer_id: str,
|
||||
return_url: str,
|
||||
) -> stripe.billing_portal.Session:
|
||||
"""
|
||||
Create a Stripe Customer Portal session.
|
||||
|
||||
Allows customers to manage their subscription, payment methods, and invoices.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
return_url: URL to return to after portal
|
||||
|
||||
Returns:
|
||||
Stripe Portal Session object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
logger.info(f"Created portal session for customer {customer_id}")
|
||||
return session
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Management
|
||||
# =========================================================================
|
||||
|
||||
def get_invoices(
|
||||
self,
|
||||
customer_id: str,
|
||||
limit: int = 10,
|
||||
) -> list[stripe.Invoice]:
|
||||
"""Get invoices for a customer."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
|
||||
return list(invoices.data)
|
||||
|
||||
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
|
||||
"""Get the upcoming invoice for a customer."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
try:
|
||||
return stripe.Invoice.upcoming(customer=customer_id)
|
||||
except stripe.error.InvalidRequestError:
|
||||
# No upcoming invoice
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Webhook Handling
|
||||
# =========================================================================
|
||||
|
||||
def construct_event(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
) -> stripe.Event:
|
||||
"""
|
||||
Construct and verify a Stripe webhook event.
|
||||
|
||||
Args:
|
||||
payload: Raw request body
|
||||
sig_header: Stripe-Signature header
|
||||
|
||||
Returns:
|
||||
Verified Stripe Event object
|
||||
|
||||
Raises:
|
||||
ValueError: If signature verification fails
|
||||
"""
|
||||
if not settings.stripe_webhook_secret:
|
||||
raise ValueError("Stripe webhook secret not configured")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload,
|
||||
sig_header,
|
||||
settings.stripe_webhook_secret,
|
||||
)
|
||||
return event
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
logger.error(f"Webhook signature verification failed: {e}")
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# =========================================================================
|
||||
# SetupIntent & Payment Method Management
|
||||
# =========================================================================
|
||||
|
||||
def create_setup_intent(
|
||||
self,
|
||||
customer_id: str,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.SetupIntent:
|
||||
"""
|
||||
Create a SetupIntent to collect card without charging.
|
||||
|
||||
Used for trial signups where we collect card upfront
|
||||
but don't charge until trial ends.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe SetupIntent object with client_secret for frontend
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
setup_intent = stripe.SetupIntent.create(
|
||||
customer=customer_id,
|
||||
payment_method_types=["card"],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
|
||||
return setup_intent
|
||||
|
||||
def attach_payment_method_to_customer(
|
||||
self,
|
||||
customer_id: str,
|
||||
payment_method_id: str,
|
||||
set_as_default: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Attach a payment method to customer and optionally set as default.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
payment_method_id: Payment method ID from confirmed SetupIntent
|
||||
set_as_default: Whether to set as default payment method
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
# Attach the payment method to the customer
|
||||
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
|
||||
|
||||
if set_as_default:
|
||||
stripe.Customer.modify(
|
||||
customer_id,
|
||||
invoice_settings={"default_payment_method": payment_method_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Attached payment method {payment_method_id} to customer {customer_id} "
|
||||
f"(default={set_as_default})"
|
||||
)
|
||||
|
||||
def create_subscription_with_trial(
|
||||
self,
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
trial_days: int = 30,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Create subscription with trial period.
|
||||
|
||||
Customer must have a default payment method attached.
|
||||
Card will be charged automatically after trial ends.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID (must have default payment method)
|
||||
price_id: Stripe price ID for the subscription tier
|
||||
trial_days: Number of trial days (default 30)
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
subscription = stripe.Subscription.create(
|
||||
customer=customer_id,
|
||||
items=[{"price": price_id}],
|
||||
trial_period_days=trial_days,
|
||||
metadata=metadata or {},
|
||||
# Use default payment method for future charges
|
||||
default_payment_method=None, # Uses customer's default
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created subscription {subscription.id} with {trial_days}-day trial "
|
||||
f"for customer {customer_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
|
||||
"""Get a SetupIntent by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.SetupIntent.retrieve(setup_intent_id)
|
||||
|
||||
# =========================================================================
|
||||
# Price/Product Management
|
||||
# =========================================================================
|
||||
|
||||
def get_price(self, price_id: str) -> stripe.Price:
|
||||
"""Get a Stripe price by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Price.retrieve(price_id)
|
||||
|
||||
def get_product(self, product_id: str) -> stripe.Product:
|
||||
"""Get a Stripe product by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Product.retrieve(product_id)
|
||||
|
||||
def list_prices(
|
||||
self,
|
||||
product_id: str | None = None,
|
||||
active: bool = True,
|
||||
) -> list[stripe.Price]:
|
||||
"""List Stripe prices, optionally filtered by product."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
params = {"active": active}
|
||||
if product_id:
|
||||
params["product"] = product_id
|
||||
|
||||
prices = stripe.Price.list(**params)
|
||||
return list(prices.data)
|
||||
|
||||
|
||||
# Create service instance
|
||||
stripe_service = StripeService()
|
||||
__all__ = [
|
||||
"StripeService",
|
||||
"stripe_service",
|
||||
]
|
||||
|
||||
@@ -2,654 +2,29 @@
|
||||
"""
|
||||
Subscription service for tier-based access control.
|
||||
|
||||
Handles:
|
||||
- Subscription creation and management
|
||||
- Tier limit enforcement
|
||||
- Usage tracking
|
||||
- Feature gating
|
||||
DEPRECATED: This file is maintained for backward compatibility.
|
||||
Import from app.modules.billing.services instead:
|
||||
|
||||
Usage:
|
||||
from app.services.subscription_service import subscription_service
|
||||
from app.modules.billing.services import subscription_service
|
||||
|
||||
# Check if vendor can create an order
|
||||
can_create, message = subscription_service.can_create_order(db, vendor_id)
|
||||
|
||||
# Increment order counter after successful order
|
||||
subscription_service.increment_order_count(db, vendor_id)
|
||||
This file re-exports the service from its new location in the billing module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.product import Product
|
||||
from models.database.subscription import (
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
TIER_LIMITS,
|
||||
TierCode,
|
||||
VendorSubscription,
|
||||
# Re-export from new location for backward compatibility
|
||||
from app.modules.billing.services.subscription_service import (
|
||||
SubscriptionService,
|
||||
subscription_service,
|
||||
)
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
from models.schema.subscription import (
|
||||
SubscriptionCreate,
|
||||
SubscriptionUpdate,
|
||||
SubscriptionUsage,
|
||||
TierInfo,
|
||||
TierLimits,
|
||||
UsageSummary,
|
||||
from app.modules.billing.exceptions import (
|
||||
SubscriptionNotFoundException,
|
||||
TierLimitExceededException,
|
||||
FeatureNotAvailableException,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubscriptionNotFoundException(Exception):
|
||||
"""Raised when subscription not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TierLimitExceededException(Exception):
|
||||
"""Raised when a tier limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str, limit_type: str, current: int, limit: int):
|
||||
super().__init__(message)
|
||||
self.limit_type = limit_type
|
||||
self.current = current
|
||||
self.limit = limit
|
||||
|
||||
|
||||
class FeatureNotAvailableException(Exception):
|
||||
"""Raised when a feature is not available in current tier."""
|
||||
|
||||
def __init__(self, feature: str, current_tier: str, required_tier: str):
|
||||
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
|
||||
super().__init__(message)
|
||||
self.feature = feature
|
||||
self.current_tier = current_tier
|
||||
self.required_tier = required_tier
|
||||
|
||||
|
||||
class SubscriptionService:
|
||||
"""Service for subscription and tier limit operations."""
|
||||
|
||||
# =========================================================================
|
||||
# Tier Information
|
||||
# =========================================================================
|
||||
|
||||
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
|
||||
"""
|
||||
Get full tier information.
|
||||
|
||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||
"""
|
||||
# Try database first if session provided
|
||||
if db is not None:
|
||||
db_tier = self.get_tier_by_code(db, tier_code)
|
||||
if db_tier:
|
||||
return TierInfo(
|
||||
code=db_tier.code,
|
||||
name=db_tier.name,
|
||||
price_monthly_cents=db_tier.price_monthly_cents,
|
||||
price_annual_cents=db_tier.price_annual_cents,
|
||||
limits=TierLimits(
|
||||
orders_per_month=db_tier.orders_per_month,
|
||||
products_limit=db_tier.products_limit,
|
||||
team_members=db_tier.team_members,
|
||||
order_history_months=db_tier.order_history_months,
|
||||
),
|
||||
features=db_tier.features or [],
|
||||
)
|
||||
|
||||
# Fallback to hardcoded TIER_LIMITS
|
||||
return self._get_tier_from_legacy(tier_code)
|
||||
|
||||
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
|
||||
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
|
||||
try:
|
||||
tier = TierCode(tier_code)
|
||||
except ValueError:
|
||||
tier = TierCode.ESSENTIAL
|
||||
|
||||
limits = TIER_LIMITS[tier]
|
||||
return TierInfo(
|
||||
code=tier.value,
|
||||
name=limits["name"],
|
||||
price_monthly_cents=limits["price_monthly_cents"],
|
||||
price_annual_cents=limits.get("price_annual_cents"),
|
||||
limits=TierLimits(
|
||||
orders_per_month=limits.get("orders_per_month"),
|
||||
products_limit=limits.get("products_limit"),
|
||||
team_members=limits.get("team_members"),
|
||||
order_history_months=limits.get("order_history_months"),
|
||||
),
|
||||
features=limits.get("features", []),
|
||||
)
|
||||
|
||||
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
|
||||
"""
|
||||
Get information for all tiers.
|
||||
|
||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||
"""
|
||||
if db is not None:
|
||||
db_tiers = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
SubscriptionTier.is_public == True, # noqa: E712
|
||||
)
|
||||
.order_by(SubscriptionTier.display_order)
|
||||
.all()
|
||||
)
|
||||
if db_tiers:
|
||||
return [
|
||||
TierInfo(
|
||||
code=t.code,
|
||||
name=t.name,
|
||||
price_monthly_cents=t.price_monthly_cents,
|
||||
price_annual_cents=t.price_annual_cents,
|
||||
limits=TierLimits(
|
||||
orders_per_month=t.orders_per_month,
|
||||
products_limit=t.products_limit,
|
||||
team_members=t.team_members,
|
||||
order_history_months=t.order_history_months,
|
||||
),
|
||||
features=t.features or [],
|
||||
)
|
||||
for t in db_tiers
|
||||
]
|
||||
|
||||
# Fallback to hardcoded
|
||||
return [
|
||||
self._get_tier_from_legacy(tier.value)
|
||||
for tier in TierCode
|
||||
]
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
|
||||
"""Get subscription tier by code."""
|
||||
return (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == tier_code)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
|
||||
"""Get tier ID from tier code. Returns None if tier not found."""
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
return tier.id if tier else None
|
||||
|
||||
# =========================================================================
|
||||
# Subscription CRUD
|
||||
# =========================================================================
|
||||
|
||||
def get_subscription(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription | None:
|
||||
"""Get vendor subscription."""
|
||||
return (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_subscription_or_raise(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription:
|
||||
"""Get vendor subscription or raise exception."""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if not subscription:
|
||||
raise SubscriptionNotFoundException(
|
||||
f"No subscription found for vendor {vendor_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_current_tier(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> TierCode | None:
|
||||
"""Get vendor's current subscription tier code."""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if subscription:
|
||||
try:
|
||||
return TierCode(subscription.tier)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_or_create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
tier: str = TierCode.ESSENTIAL.value,
|
||||
trial_days: int = 14,
|
||||
) -> VendorSubscription:
|
||||
"""
|
||||
Get existing subscription or create a new trial subscription.
|
||||
|
||||
Used when a vendor first accesses the system.
|
||||
"""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if subscription:
|
||||
return subscription
|
||||
|
||||
# Create new trial subscription
|
||||
now = datetime.now(UTC)
|
||||
trial_end = now + timedelta(days=trial_days)
|
||||
|
||||
# Lookup tier_id from tier code
|
||||
tier_id = self.get_tier_id(db, tier)
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=tier,
|
||||
tier_id=tier_id,
|
||||
status=SubscriptionStatus.TRIAL.value,
|
||||
period_start=now,
|
||||
period_end=trial_end,
|
||||
trial_ends_at=trial_end,
|
||||
is_annual=False,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(
|
||||
f"Created trial subscription for vendor {vendor_id} "
|
||||
f"(tier={tier}, trial_ends={trial_end})"
|
||||
)
|
||||
|
||||
return subscription
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionCreate,
|
||||
) -> VendorSubscription:
|
||||
"""Create a subscription for a vendor."""
|
||||
# Check if subscription exists
|
||||
existing = self.get_subscription(db, vendor_id)
|
||||
if existing:
|
||||
raise ValueError("Vendor already has a subscription")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Calculate period end based on billing cycle
|
||||
if data.is_annual:
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
# Handle trial
|
||||
trial_ends_at = None
|
||||
status = SubscriptionStatus.ACTIVE.value
|
||||
if data.trial_days > 0:
|
||||
trial_ends_at = now + timedelta(days=data.trial_days)
|
||||
status = SubscriptionStatus.TRIAL.value
|
||||
period_end = trial_ends_at
|
||||
|
||||
# Lookup tier_id from tier code
|
||||
tier_id = self.get_tier_id(db, data.tier)
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=data.tier,
|
||||
tier_id=tier_id,
|
||||
status=status,
|
||||
period_start=now,
|
||||
period_end=period_end,
|
||||
trial_ends_at=trial_ends_at,
|
||||
is_annual=data.is_annual,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
|
||||
return subscription
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionUpdate,
|
||||
) -> VendorSubscription:
|
||||
"""Update a vendor subscription."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# If tier is being updated, also update tier_id
|
||||
if "tier" in update_data:
|
||||
tier_id = self.get_tier_id(db, update_data["tier"])
|
||||
update_data["tier_id"] = tier_id
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(subscription, key, value)
|
||||
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Updated subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
def upgrade_tier(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
new_tier: str,
|
||||
) -> VendorSubscription:
|
||||
"""Upgrade vendor to a new tier."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
old_tier = subscription.tier
|
||||
subscription.tier = new_tier
|
||||
subscription.tier_id = self.get_tier_id(db, new_tier)
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
# If upgrading from trial, mark as active
|
||||
if subscription.status == SubscriptionStatus.TRIAL.value:
|
||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
|
||||
return subscription
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
reason: str | None = None,
|
||||
) -> VendorSubscription:
|
||||
"""Cancel a vendor subscription (access until period end)."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
subscription.status = SubscriptionStatus.CANCELLED.value
|
||||
subscription.cancelled_at = datetime.now(UTC)
|
||||
subscription.cancellation_reason = reason
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
db.flush()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Cancelled subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
# =========================================================================
|
||||
# Usage Tracking
|
||||
# =========================================================================
|
||||
|
||||
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
|
||||
"""Get current subscription usage statistics."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Get actual counts
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Calculate usage stats
|
||||
orders_limit = subscription.orders_limit
|
||||
products_limit = subscription.products_limit
|
||||
team_limit = subscription.team_members_limit
|
||||
|
||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||
if limit is None:
|
||||
return None
|
||||
return max(0, limit - current)
|
||||
|
||||
def calc_percent(current: int, limit: int | None) -> float | None:
|
||||
if limit is None or limit == 0:
|
||||
return None
|
||||
return min(100.0, (current / limit) * 100)
|
||||
|
||||
return SubscriptionUsage(
|
||||
orders_used=subscription.orders_this_period,
|
||||
orders_limit=orders_limit,
|
||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
|
||||
products_used=products_count,
|
||||
products_limit=products_limit,
|
||||
products_remaining=calc_remaining(products_count, products_limit),
|
||||
products_percent_used=calc_percent(products_count, products_limit),
|
||||
team_members_used=team_count,
|
||||
team_members_limit=team_limit,
|
||||
team_members_remaining=calc_remaining(team_count, team_limit),
|
||||
team_members_percent_used=calc_percent(team_count, team_limit),
|
||||
)
|
||||
|
||||
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
|
||||
"""Get usage summary for billing page display."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Get actual counts
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Get limits
|
||||
orders_limit = subscription.orders_limit
|
||||
products_limit = subscription.products_limit
|
||||
team_limit = subscription.team_members_limit
|
||||
|
||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||
if limit is None:
|
||||
return None
|
||||
return max(0, limit - current)
|
||||
|
||||
return UsageSummary(
|
||||
orders_this_period=subscription.orders_this_period,
|
||||
orders_limit=orders_limit,
|
||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||
products_count=products_count,
|
||||
products_limit=products_limit,
|
||||
products_remaining=calc_remaining(products_count, products_limit),
|
||||
team_count=team_count,
|
||||
team_limit=team_limit,
|
||||
team_remaining=calc_remaining(team_count, team_limit),
|
||||
)
|
||||
|
||||
def increment_order_count(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Increment the order counter for the current period.
|
||||
|
||||
Call this after successfully creating/importing an order.
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
subscription.increment_order_count()
|
||||
db.flush()
|
||||
|
||||
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
|
||||
"""Reset counters for a new billing period."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
subscription.reset_period_counters()
|
||||
db.flush()
|
||||
logger.info(f"Reset period counters for vendor {vendor_id}")
|
||||
|
||||
# =========================================================================
|
||||
# Limit Checks
|
||||
# =========================================================================
|
||||
|
||||
def can_create_order(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can create/import another order.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.can_create_order()
|
||||
|
||||
def check_order_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check order limit and raise exception if exceeded.
|
||||
|
||||
Use this in order creation flows.
|
||||
"""
|
||||
can_create, message = self.can_create_order(db, vendor_id)
|
||||
if not can_create:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Order limit exceeded",
|
||||
limit_type="orders",
|
||||
current=subscription.orders_this_period if subscription else 0,
|
||||
limit=subscription.orders_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_product(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another product.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_product(products_count)
|
||||
|
||||
def check_product_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check product limit and raise exception if exceeded.
|
||||
|
||||
Use this in product creation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_product(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Product limit exceeded",
|
||||
limit_type="products",
|
||||
current=products_count,
|
||||
limit=subscription.products_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_team_member(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another team member.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_team_member(team_count)
|
||||
|
||||
def check_team_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check team member limit and raise exception if exceeded.
|
||||
|
||||
Use this in team member invitation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_team_member(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Team member limit exceeded",
|
||||
limit_type="team_members",
|
||||
current=team_count,
|
||||
limit=subscription.team_members_limit if subscription else 0,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Feature Gating
|
||||
# =========================================================================
|
||||
|
||||
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
|
||||
"""Check if vendor has access to a feature."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.has_feature(feature)
|
||||
|
||||
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
|
||||
"""
|
||||
Check feature access and raise exception if not available.
|
||||
|
||||
Use this to gate premium features.
|
||||
"""
|
||||
if not self.has_feature(db, vendor_id, feature):
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Find which tier has this feature
|
||||
required_tier = None
|
||||
for tier_code, limits in TIER_LIMITS.items():
|
||||
if feature in limits.get("features", []):
|
||||
required_tier = limits["name"]
|
||||
break
|
||||
|
||||
raise FeatureNotAvailableException(
|
||||
feature=feature,
|
||||
current_tier=subscription.tier,
|
||||
required_tier=required_tier or "higher",
|
||||
)
|
||||
|
||||
def get_feature_tier(self, feature: str) -> str | None:
|
||||
"""Get the minimum tier required for a feature."""
|
||||
for tier_code in [
|
||||
TierCode.ESSENTIAL,
|
||||
TierCode.PROFESSIONAL,
|
||||
TierCode.BUSINESS,
|
||||
TierCode.ENTERPRISE,
|
||||
]:
|
||||
if feature in TIER_LIMITS[tier_code].get("features", []):
|
||||
return tier_code.value
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
subscription_service = SubscriptionService()
|
||||
__all__ = [
|
||||
"SubscriptionService",
|
||||
"subscription_service",
|
||||
"SubscriptionNotFoundException",
|
||||
"TierLimitExceededException",
|
||||
"FeatureNotAvailableException",
|
||||
]
|
||||
|
||||
@@ -1,265 +1,27 @@
|
||||
# app/tasks/celery_tasks/subscription.py
|
||||
"""
|
||||
Celery tasks for subscription management.
|
||||
Legacy subscription tasks.
|
||||
|
||||
Scheduled tasks for:
|
||||
- Resetting period counters
|
||||
- Checking trial expirations
|
||||
- Syncing with Stripe
|
||||
- Cleaning up stale subscriptions
|
||||
- Capturing capacity snapshots
|
||||
MOSTLY MIGRATED: Most tasks have been migrated to app.modules.billing.tasks.
|
||||
|
||||
The following tasks now live in the billing module:
|
||||
- reset_period_counters -> app.modules.billing.tasks.subscription
|
||||
- check_trial_expirations -> app.modules.billing.tasks.subscription
|
||||
- sync_stripe_status -> app.modules.billing.tasks.subscription
|
||||
- cleanup_stale_subscriptions -> app.modules.billing.tasks.subscription
|
||||
|
||||
Remaining task (to be migrated to monitoring module):
|
||||
- capture_capacity_snapshot
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from app.core.celery_config import celery_app
|
||||
from app.services.stripe_service import stripe_service
|
||||
from app.tasks.celery_tasks.base import DatabaseTask
|
||||
from models.database.subscription import SubscriptionStatus, VendorSubscription
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=DatabaseTask,
|
||||
name="app.tasks.celery_tasks.subscription.reset_period_counters",
|
||||
)
|
||||
def reset_period_counters(self):
|
||||
"""
|
||||
Reset order counters for subscriptions whose billing period has ended.
|
||||
|
||||
Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
reset_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find subscriptions where period has ended
|
||||
expired_periods = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.period_end <= now,
|
||||
VendorSubscription.status.in_(["active", "trial"]),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in expired_periods:
|
||||
old_period_end = subscription.period_end
|
||||
|
||||
# Reset counters
|
||||
subscription.orders_this_period = 0
|
||||
subscription.orders_limit_reached_at = None
|
||||
|
||||
# Set new period dates
|
||||
if subscription.is_annual:
|
||||
subscription.period_start = now
|
||||
subscription.period_end = now + timedelta(days=365)
|
||||
else:
|
||||
subscription.period_start = now
|
||||
subscription.period_end = now + timedelta(days=30)
|
||||
|
||||
subscription.updated_at = now
|
||||
reset_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Reset period counters for vendor {subscription.vendor_id}: "
|
||||
f"old_period_end={old_period_end}, new_period_end={subscription.period_end}"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Reset period counters for {reset_count} subscriptions")
|
||||
|
||||
return {"reset_count": reset_count}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=DatabaseTask,
|
||||
name="app.tasks.celery_tasks.subscription.check_trial_expirations",
|
||||
)
|
||||
def check_trial_expirations(self):
|
||||
"""
|
||||
Check for expired trials and update their status.
|
||||
|
||||
Runs daily at 01:00.
|
||||
- Trials without payment method -> expired
|
||||
- Trials with payment method -> active
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
expired_count = 0
|
||||
activated_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find expired trials
|
||||
expired_trials = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.status == SubscriptionStatus.TRIAL.value,
|
||||
VendorSubscription.trial_ends_at <= now,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in expired_trials:
|
||||
if subscription.stripe_payment_method_id:
|
||||
# Has payment method - activate
|
||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
activated_count += 1
|
||||
logger.info(
|
||||
f"Activated subscription for vendor {subscription.vendor_id} "
|
||||
f"(trial ended with payment method)"
|
||||
)
|
||||
else:
|
||||
# No payment method - expire
|
||||
subscription.status = SubscriptionStatus.EXPIRED.value
|
||||
expired_count += 1
|
||||
logger.info(
|
||||
f"Expired trial for vendor {subscription.vendor_id} "
|
||||
f"(no payment method)"
|
||||
)
|
||||
|
||||
subscription.updated_at = now
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated")
|
||||
|
||||
return {"expired_count": expired_count, "activated_count": activated_count}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=DatabaseTask,
|
||||
name="app.tasks.celery_tasks.subscription.sync_stripe_status",
|
||||
max_retries=3,
|
||||
default_retry_delay=300,
|
||||
)
|
||||
def sync_stripe_status(self):
|
||||
"""
|
||||
Sync subscription status with Stripe.
|
||||
|
||||
Runs hourly at :30. Fetches current status from Stripe and updates local records.
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
logger.warning("Stripe not configured, skipping sync")
|
||||
return {"synced": 0, "skipped": True}
|
||||
|
||||
synced_count = 0
|
||||
error_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find subscriptions with Stripe IDs
|
||||
subscriptions = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_subscription_id.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in subscriptions:
|
||||
try:
|
||||
# Fetch from Stripe
|
||||
stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id)
|
||||
|
||||
if not stripe_sub:
|
||||
logger.warning(
|
||||
f"Stripe subscription {subscription.stripe_subscription_id} "
|
||||
f"not found for vendor {subscription.vendor_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Map Stripe status to local status
|
||||
status_map = {
|
||||
"active": SubscriptionStatus.ACTIVE.value,
|
||||
"trialing": SubscriptionStatus.TRIAL.value,
|
||||
"past_due": SubscriptionStatus.PAST_DUE.value,
|
||||
"canceled": SubscriptionStatus.CANCELLED.value,
|
||||
"unpaid": SubscriptionStatus.PAST_DUE.value,
|
||||
"incomplete": SubscriptionStatus.TRIAL.value,
|
||||
"incomplete_expired": SubscriptionStatus.EXPIRED.value,
|
||||
}
|
||||
|
||||
new_status = status_map.get(stripe_sub.status)
|
||||
if new_status and new_status != subscription.status:
|
||||
old_status = subscription.status
|
||||
subscription.status = new_status
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
logger.info(
|
||||
f"Updated vendor {subscription.vendor_id} status: "
|
||||
f"{old_status} -> {new_status} (from Stripe)"
|
||||
)
|
||||
|
||||
# Update period dates from Stripe
|
||||
if stripe_sub.current_period_start:
|
||||
subscription.period_start = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_start, tz=UTC
|
||||
)
|
||||
if stripe_sub.current_period_end:
|
||||
subscription.period_end = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_end, tz=UTC
|
||||
)
|
||||
|
||||
# Update payment method
|
||||
if stripe_sub.default_payment_method:
|
||||
subscription.stripe_payment_method_id = (
|
||||
stripe_sub.default_payment_method
|
||||
if isinstance(stripe_sub.default_payment_method, str)
|
||||
else stripe_sub.default_payment_method.id
|
||||
)
|
||||
|
||||
synced_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}")
|
||||
error_count += 1
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors")
|
||||
|
||||
return {"synced_count": synced_count, "error_count": error_count}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=DatabaseTask,
|
||||
name="app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions",
|
||||
)
|
||||
def cleanup_stale_subscriptions(self):
|
||||
"""
|
||||
Clean up subscriptions in inconsistent states.
|
||||
|
||||
Runs weekly on Sunday at 03:00.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
cleaned_count = 0
|
||||
|
||||
with self.get_db() as db:
|
||||
# Find cancelled subscriptions past their period end
|
||||
stale_cancelled = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(
|
||||
VendorSubscription.status == SubscriptionStatus.CANCELLED.value,
|
||||
VendorSubscription.period_end < now - timedelta(days=30),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for subscription in stale_cancelled:
|
||||
# Mark as expired (fully terminated)
|
||||
subscription.status = SubscriptionStatus.EXPIRED.value
|
||||
subscription.updated_at = now
|
||||
cleaned_count += 1
|
||||
logger.info(
|
||||
f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
logger.info(f"Cleaned up {cleaned_count} stale subscriptions")
|
||||
|
||||
return {"cleaned_count": cleaned_count}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
base=DatabaseTask,
|
||||
@@ -270,6 +32,8 @@ def capture_capacity_snapshot(self):
|
||||
Capture a daily snapshot of platform capacity metrics.
|
||||
|
||||
Runs daily at midnight.
|
||||
|
||||
TODO: Migrate to app.modules.monitoring.tasks
|
||||
"""
|
||||
from app.services.capacity_forecast_service import capacity_forecast_service
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ Transform the platform from a monolithic structure to self-contained modules whe
|
||||
|--------|---------------|---------------|----------|--------|-------|--------|
|
||||
| `cms` | Core | ✅ **Complete** | ✅ | ✅ | - | Done |
|
||||
| `payments` | Optional | 🟡 Partial | ✅ | ✅ | - | Done |
|
||||
| `billing` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full |
|
||||
| `billing` | Optional | ✅ **Complete** | ✅ | ✅ | ✅ | Done |
|
||||
| `marketplace` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full |
|
||||
| `orders` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
|
||||
| `inventory` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
|
||||
@@ -110,85 +110,22 @@ app/tasks/celery_tasks/ # → Move to respective modules
|
||||
- Observability docs
|
||||
- Creating modules developer guide
|
||||
|
||||
---
|
||||
### ✅ Phase 4: Celery Task Infrastructure
|
||||
- Added `ScheduledTask` dataclass to `ModuleDefinition`
|
||||
- Added `tasks_path` and `scheduled_tasks` fields
|
||||
- Created `app/modules/task_base.py` with `ModuleTask` base class
|
||||
- Created `app/modules/tasks.py` with discovery utilities
|
||||
- Updated `celery_config.py` to use module discovery
|
||||
|
||||
## Infrastructure Work (Remaining)
|
||||
|
||||
### Phase 4: Celery Task Infrastructure
|
||||
|
||||
Add task support to module system before migrating modules.
|
||||
|
||||
#### 4.1 Update ModuleDefinition
|
||||
|
||||
```python
|
||||
# app/modules/base.py - Add fields
|
||||
|
||||
@dataclass
|
||||
class ScheduledTask:
|
||||
"""Celery Beat scheduled task definition."""
|
||||
name: str # e.g., "billing.reset_counters"
|
||||
task: str # Full task path
|
||||
schedule: str | dict # Cron string or crontab dict
|
||||
args: tuple = ()
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class ModuleDefinition:
|
||||
# ... existing fields ...
|
||||
|
||||
# Task configuration (NEW)
|
||||
tasks_path: str | None = None
|
||||
scheduled_tasks: list[ScheduledTask] = field(default_factory=list)
|
||||
```
|
||||
|
||||
#### 4.2 Create Task Discovery
|
||||
|
||||
```python
|
||||
# app/modules/tasks.py (NEW)
|
||||
|
||||
def discover_module_tasks() -> list[str]:
|
||||
"""Discover task modules from all registered modules."""
|
||||
...
|
||||
|
||||
def build_beat_schedule() -> dict:
|
||||
"""Build Celery Beat schedule from module definitions."""
|
||||
...
|
||||
```
|
||||
|
||||
#### 4.3 Create Module Task Base
|
||||
|
||||
```python
|
||||
# app/modules/task_base.py (NEW)
|
||||
|
||||
class ModuleTask(Task):
|
||||
"""Base Celery task with DB session management."""
|
||||
|
||||
@contextmanager
|
||||
def get_db(self):
|
||||
...
|
||||
```
|
||||
|
||||
#### 4.4 Update Celery Config
|
||||
|
||||
```python
|
||||
# app/core/celery_config.py
|
||||
|
||||
from app.modules.tasks import discover_module_tasks, build_beat_schedule
|
||||
|
||||
# Auto-discover from modules
|
||||
celery_app.autodiscover_tasks(discover_module_tasks())
|
||||
|
||||
# Build schedule from modules
|
||||
celery_app.conf.beat_schedule = build_beat_schedule()
|
||||
```
|
||||
|
||||
**Files to create:**
|
||||
- `app/modules/tasks.py`
|
||||
- `app/modules/task_base.py`
|
||||
|
||||
**Files to modify:**
|
||||
- `app/modules/base.py`
|
||||
- `app/core/celery_config.py`
|
||||
### ✅ Phase 5: Billing Module Migration
|
||||
- Created `app/modules/billing/services/` with subscription, stripe, admin services
|
||||
- Created `app/modules/billing/models/` re-exporting from central location
|
||||
- Created `app/modules/billing/schemas/` re-exporting from central location
|
||||
- Created `app/modules/billing/tasks/` with 4 scheduled tasks
|
||||
- Created `app/modules/billing/exceptions.py`
|
||||
- Updated `definition.py` with self-contained configuration
|
||||
- Created backward-compatible re-exports in `app/services/`
|
||||
- Updated legacy celery_config.py to not duplicate scheduled tasks
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user