feat: complete billing module migration (Phase 5)
Migrates billing module to self-contained structure: - Create app/modules/billing/services/ with subscription, stripe, admin services - Create app/modules/billing/models/ re-exporting from central location - Create app/modules/billing/schemas/ re-exporting from central location - Create app/modules/billing/tasks/ with 4 scheduled Celery tasks - Create app/modules/billing/exceptions.py with module-specific exceptions - Update definition.py with is_self_contained=True and scheduled_tasks Celery task migration: - reset_period_counters -> billing module - check_trial_expirations -> billing module - sync_stripe_status -> billing module - cleanup_stale_subscriptions -> billing module - capture_capacity_snapshot remains in legacy (will go to monitoring) Backward compatibility: - Create re-exports in app/services/ for subscription, stripe, admin services - Old import paths continue to work - Update celery_config.py to use module-defined schedules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -49,10 +49,12 @@ if SENTRY_DSN:
|
|||||||
# TASK DISCOVERY
|
# TASK DISCOVERY
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Legacy tasks (will be migrated to modules over time)
|
# Legacy tasks (will be migrated to modules over time)
|
||||||
|
# NOTE: Most subscription tasks have been migrated to app.modules.billing.tasks
|
||||||
|
# The subscription module is kept for capture_capacity_snapshot (will move to monitoring)
|
||||||
LEGACY_TASK_MODULES = [
|
LEGACY_TASK_MODULES = [
|
||||||
"app.tasks.celery_tasks.marketplace",
|
"app.tasks.celery_tasks.marketplace",
|
||||||
"app.tasks.celery_tasks.letzshop",
|
"app.tasks.celery_tasks.letzshop",
|
||||||
"app.tasks.celery_tasks.subscription",
|
"app.tasks.celery_tasks.subscription", # Kept for capture_capacity_snapshot only
|
||||||
"app.tasks.celery_tasks.export",
|
"app.tasks.celery_tasks.export",
|
||||||
"app.tasks.celery_tasks.code_quality",
|
"app.tasks.celery_tasks.code_quality",
|
||||||
"app.tasks.celery_tasks.test_runner",
|
"app.tasks.celery_tasks.test_runner",
|
||||||
@@ -141,32 +143,9 @@ celery_app.conf.task_routes = {
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# Legacy scheduled tasks (will be migrated to module definitions)
|
# Legacy scheduled tasks (will be migrated to module definitions)
|
||||||
|
# NOTE: Subscription tasks have been migrated to billing module (see definition.py)
|
||||||
LEGACY_BEAT_SCHEDULE = {
|
LEGACY_BEAT_SCHEDULE = {
|
||||||
# Reset usage counters at start of each period
|
# Capacity snapshot - will be migrated to monitoring module
|
||||||
"reset-period-counters-daily": {
|
|
||||||
"task": "app.tasks.celery_tasks.subscription.reset_period_counters",
|
|
||||||
"schedule": crontab(hour=0, minute=5), # 00:05 daily
|
|
||||||
"options": {"queue": "scheduled"},
|
|
||||||
},
|
|
||||||
# Check for expiring trials and send notifications
|
|
||||||
"check-trial-expirations-daily": {
|
|
||||||
"task": "app.tasks.celery_tasks.subscription.check_trial_expirations",
|
|
||||||
"schedule": crontab(hour=1, minute=0), # 01:00 daily
|
|
||||||
"options": {"queue": "scheduled"},
|
|
||||||
},
|
|
||||||
# Sync subscription status with Stripe
|
|
||||||
"sync-stripe-status-hourly": {
|
|
||||||
"task": "app.tasks.celery_tasks.subscription.sync_stripe_status",
|
|
||||||
"schedule": crontab(minute=30), # Every hour at :30
|
|
||||||
"options": {"queue": "scheduled"},
|
|
||||||
},
|
|
||||||
# Clean up stale/orphaned subscriptions
|
|
||||||
"cleanup-stale-subscriptions-weekly": {
|
|
||||||
"task": "app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions",
|
|
||||||
"schedule": crontab(hour=3, minute=0, day_of_week=0), # Sunday 03:00
|
|
||||||
"options": {"queue": "scheduled"},
|
|
||||||
},
|
|
||||||
# Capture daily capacity snapshot for analytics
|
|
||||||
"capture-capacity-snapshot-daily": {
|
"capture-capacity-snapshot-daily": {
|
||||||
"task": "app.tasks.celery_tasks.subscription.capture_capacity_snapshot",
|
"task": "app.tasks.celery_tasks.subscription.capture_capacity_snapshot",
|
||||||
"schedule": crontab(hour=0, minute=0), # Midnight daily
|
"schedule": crontab(hour=0, minute=0), # Midnight daily
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ This module provides:
|
|||||||
- Vendor subscription CRUD
|
- Vendor subscription CRUD
|
||||||
- Billing history and invoices
|
- Billing history and invoices
|
||||||
- Stripe integration
|
- Stripe integration
|
||||||
|
- Scheduled tasks for subscription lifecycle
|
||||||
|
|
||||||
Routes:
|
Routes:
|
||||||
- Admin: /api/v1/admin/subscriptions/*
|
- Admin: /api/v1/admin/subscriptions/*
|
||||||
@@ -15,8 +16,17 @@ Routes:
|
|||||||
Menu Items:
|
Menu Items:
|
||||||
- Admin: subscription-tiers, subscriptions, billing-history
|
- Admin: subscription-tiers, subscriptions, billing-history
|
||||||
- Vendor: billing, invoices
|
- Vendor: billing, invoices
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.modules.billing import billing_module
|
||||||
|
from app.modules.billing.services import subscription_service, stripe_service
|
||||||
|
from app.modules.billing.models import VendorSubscription, SubscriptionTier
|
||||||
|
from app.modules.billing.exceptions import TierLimitExceededException
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.modules.billing.definition import billing_module
|
from app.modules.billing.definition import billing_module, get_billing_module_with_routers
|
||||||
|
|
||||||
__all__ = ["billing_module"]
|
__all__ = [
|
||||||
|
"billing_module",
|
||||||
|
"get_billing_module_with_routers",
|
||||||
|
]
|
||||||
|
|||||||
@@ -3,10 +3,10 @@
|
|||||||
Billing module definition.
|
Billing module definition.
|
||||||
|
|
||||||
Defines the billing module including its features, menu items,
|
Defines the billing module including its features, menu items,
|
||||||
and route configurations.
|
route configurations, and scheduled tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from app.modules.base import ModuleDefinition
|
from app.modules.base import ModuleDefinition, ScheduledTask
|
||||||
from models.database.admin_menu_config import FrontendType
|
from models.database.admin_menu_config import FrontendType
|
||||||
|
|
||||||
|
|
||||||
@@ -54,6 +54,44 @@ billing_module = ModuleDefinition(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
is_core=False, # Billing can be disabled (e.g., internal platforms)
|
is_core=False, # Billing can be disabled (e.g., internal platforms)
|
||||||
|
# =========================================================================
|
||||||
|
# Self-Contained Module Configuration
|
||||||
|
# =========================================================================
|
||||||
|
is_self_contained=True,
|
||||||
|
services_path="app.modules.billing.services",
|
||||||
|
models_path="app.modules.billing.models",
|
||||||
|
schemas_path="app.modules.billing.schemas",
|
||||||
|
exceptions_path="app.modules.billing.exceptions",
|
||||||
|
tasks_path="app.modules.billing.tasks",
|
||||||
|
# =========================================================================
|
||||||
|
# Scheduled Tasks
|
||||||
|
# =========================================================================
|
||||||
|
scheduled_tasks=[
|
||||||
|
ScheduledTask(
|
||||||
|
name="billing.reset_period_counters",
|
||||||
|
task="app.modules.billing.tasks.subscription.reset_period_counters",
|
||||||
|
schedule="5 0 * * *", # Daily at 00:05
|
||||||
|
options={"queue": "scheduled"},
|
||||||
|
),
|
||||||
|
ScheduledTask(
|
||||||
|
name="billing.check_trial_expirations",
|
||||||
|
task="app.modules.billing.tasks.subscription.check_trial_expirations",
|
||||||
|
schedule="0 1 * * *", # Daily at 01:00
|
||||||
|
options={"queue": "scheduled"},
|
||||||
|
),
|
||||||
|
ScheduledTask(
|
||||||
|
name="billing.sync_stripe_status",
|
||||||
|
task="app.modules.billing.tasks.subscription.sync_stripe_status",
|
||||||
|
schedule="30 * * * *", # Hourly at :30
|
||||||
|
options={"queue": "scheduled"},
|
||||||
|
),
|
||||||
|
ScheduledTask(
|
||||||
|
name="billing.cleanup_stale_subscriptions",
|
||||||
|
task="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions",
|
||||||
|
schedule="0 3 * * 0", # Weekly on Sunday at 03:00
|
||||||
|
options={"queue": "scheduled"},
|
||||||
|
),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
83
app/modules/billing/exceptions.py
Normal file
83
app/modules/billing/exceptions.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# app/modules/billing/exceptions.py
|
||||||
|
"""
|
||||||
|
Billing module exceptions.
|
||||||
|
|
||||||
|
Custom exceptions for subscription, billing, and payment operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.exceptions import BusinessLogicException, ResourceNotFoundException
|
||||||
|
|
||||||
|
|
||||||
|
class BillingException(BusinessLogicException):
|
||||||
|
"""Base exception for billing module errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionNotFoundException(ResourceNotFoundException):
|
||||||
|
"""Raised when a subscription is not found."""
|
||||||
|
|
||||||
|
def __init__(self, vendor_id: int):
|
||||||
|
super().__init__("Subscription", str(vendor_id))
|
||||||
|
|
||||||
|
|
||||||
|
class TierNotFoundException(ResourceNotFoundException):
|
||||||
|
"""Raised when a subscription tier is not found."""
|
||||||
|
|
||||||
|
def __init__(self, tier_code: str):
|
||||||
|
super().__init__("SubscriptionTier", tier_code)
|
||||||
|
|
||||||
|
|
||||||
|
class TierLimitExceededException(BillingException):
|
||||||
|
"""Raised when a tier limit is exceeded."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, limit_type: str, current: int, limit: int):
|
||||||
|
super().__init__(message)
|
||||||
|
self.limit_type = limit_type
|
||||||
|
self.current = current
|
||||||
|
self.limit = limit
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureNotAvailableException(BillingException):
|
||||||
|
"""Raised when a feature is not available in current tier."""
|
||||||
|
|
||||||
|
def __init__(self, feature: str, current_tier: str, required_tier: str):
|
||||||
|
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
|
||||||
|
super().__init__(message)
|
||||||
|
self.feature = feature
|
||||||
|
self.current_tier = current_tier
|
||||||
|
self.required_tier = required_tier
|
||||||
|
|
||||||
|
|
||||||
|
class StripeNotConfiguredException(BillingException):
|
||||||
|
"""Raised when Stripe is not configured."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("Stripe is not configured")
|
||||||
|
|
||||||
|
|
||||||
|
class PaymentFailedException(BillingException):
|
||||||
|
"""Raised when a payment fails."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, stripe_error: str | None = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.stripe_error = stripe_error
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookVerificationException(BillingException):
|
||||||
|
"""Raised when webhook signature verification fails."""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Invalid webhook signature"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BillingException",
|
||||||
|
"SubscriptionNotFoundException",
|
||||||
|
"TierNotFoundException",
|
||||||
|
"TierLimitExceededException",
|
||||||
|
"FeatureNotAvailableException",
|
||||||
|
"StripeNotConfiguredException",
|
||||||
|
"PaymentFailedException",
|
||||||
|
"WebhookVerificationException",
|
||||||
|
]
|
||||||
52
app/modules/billing/models/__init__.py
Normal file
52
app/modules/billing/models/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# app/modules/billing/models/__init__.py
|
||||||
|
"""
|
||||||
|
Billing module models.
|
||||||
|
|
||||||
|
Re-exports subscription models from the central models location.
|
||||||
|
Models remain in models/database/ for now to avoid breaking existing imports
|
||||||
|
across the codebase. This provides a module-local import path.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.modules.billing.models import (
|
||||||
|
VendorSubscription,
|
||||||
|
SubscriptionTier,
|
||||||
|
SubscriptionStatus,
|
||||||
|
TierCode,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from models.database.subscription import (
|
||||||
|
# Enums
|
||||||
|
TierCode,
|
||||||
|
SubscriptionStatus,
|
||||||
|
AddOnCategory,
|
||||||
|
BillingPeriod,
|
||||||
|
# Models
|
||||||
|
SubscriptionTier,
|
||||||
|
AddOnProduct,
|
||||||
|
VendorAddOn,
|
||||||
|
StripeWebhookEvent,
|
||||||
|
BillingHistory,
|
||||||
|
VendorSubscription,
|
||||||
|
CapacitySnapshot,
|
||||||
|
# Legacy constants
|
||||||
|
TIER_LIMITS,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Enums
|
||||||
|
"TierCode",
|
||||||
|
"SubscriptionStatus",
|
||||||
|
"AddOnCategory",
|
||||||
|
"BillingPeriod",
|
||||||
|
# Models
|
||||||
|
"SubscriptionTier",
|
||||||
|
"AddOnProduct",
|
||||||
|
"VendorAddOn",
|
||||||
|
"StripeWebhookEvent",
|
||||||
|
"BillingHistory",
|
||||||
|
"VendorSubscription",
|
||||||
|
"CapacitySnapshot",
|
||||||
|
# Legacy constants
|
||||||
|
"TIER_LIMITS",
|
||||||
|
]
|
||||||
56
app/modules/billing/schemas/__init__.py
Normal file
56
app/modules/billing/schemas/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# app/modules/billing/schemas/__init__.py
|
||||||
|
"""
|
||||||
|
Billing module Pydantic schemas.
|
||||||
|
|
||||||
|
Re-exports subscription schemas from the central schemas location.
|
||||||
|
Provides a module-local import path while maintaining backwards compatibility.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.modules.billing.schemas import (
|
||||||
|
SubscriptionCreate,
|
||||||
|
SubscriptionResponse,
|
||||||
|
TierInfo,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from models.schema.subscription import (
|
||||||
|
# Tier schemas
|
||||||
|
TierFeatures,
|
||||||
|
TierLimits,
|
||||||
|
TierInfo,
|
||||||
|
# Subscription CRUD schemas
|
||||||
|
SubscriptionCreate,
|
||||||
|
SubscriptionUpdate,
|
||||||
|
SubscriptionResponse,
|
||||||
|
# Usage schemas
|
||||||
|
SubscriptionUsage,
|
||||||
|
UsageSummary,
|
||||||
|
SubscriptionStatusResponse,
|
||||||
|
# Limit check schemas
|
||||||
|
LimitCheckResult,
|
||||||
|
CanCreateOrderResponse,
|
||||||
|
CanAddProductResponse,
|
||||||
|
CanAddTeamMemberResponse,
|
||||||
|
FeatureCheckResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Tier schemas
|
||||||
|
"TierFeatures",
|
||||||
|
"TierLimits",
|
||||||
|
"TierInfo",
|
||||||
|
# Subscription CRUD schemas
|
||||||
|
"SubscriptionCreate",
|
||||||
|
"SubscriptionUpdate",
|
||||||
|
"SubscriptionResponse",
|
||||||
|
# Usage schemas
|
||||||
|
"SubscriptionUsage",
|
||||||
|
"UsageSummary",
|
||||||
|
"SubscriptionStatusResponse",
|
||||||
|
# Limit check schemas
|
||||||
|
"LimitCheckResult",
|
||||||
|
"CanCreateOrderResponse",
|
||||||
|
"CanAddProductResponse",
|
||||||
|
"CanAddTeamMemberResponse",
|
||||||
|
"FeatureCheckResponse",
|
||||||
|
]
|
||||||
28
app/modules/billing/services/__init__.py
Normal file
28
app/modules/billing/services/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# app/modules/billing/services/__init__.py
|
||||||
|
"""
|
||||||
|
Billing module services.
|
||||||
|
|
||||||
|
Provides subscription management, Stripe integration, and admin operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.modules.billing.services.subscription_service import (
|
||||||
|
SubscriptionService,
|
||||||
|
subscription_service,
|
||||||
|
)
|
||||||
|
from app.modules.billing.services.stripe_service import (
|
||||||
|
StripeService,
|
||||||
|
stripe_service,
|
||||||
|
)
|
||||||
|
from app.modules.billing.services.admin_subscription_service import (
|
||||||
|
AdminSubscriptionService,
|
||||||
|
admin_subscription_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SubscriptionService",
|
||||||
|
"subscription_service",
|
||||||
|
"StripeService",
|
||||||
|
"stripe_service",
|
||||||
|
"AdminSubscriptionService",
|
||||||
|
"admin_subscription_service",
|
||||||
|
]
|
||||||
352
app/modules/billing/services/admin_subscription_service.py
Normal file
352
app/modules/billing/services/admin_subscription_service.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
# app/modules/billing/services/admin_subscription_service.py
|
||||||
|
"""
|
||||||
|
Admin Subscription Service.
|
||||||
|
|
||||||
|
Handles subscription management operations for platform administrators:
|
||||||
|
- Subscription tier CRUD
|
||||||
|
- Vendor subscription management
|
||||||
|
- Billing history queries
|
||||||
|
- Subscription analytics
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.exceptions import (
|
||||||
|
BusinessLogicException,
|
||||||
|
ConflictException,
|
||||||
|
ResourceNotFoundException,
|
||||||
|
)
|
||||||
|
from app.modules.billing.exceptions import TierNotFoundException
|
||||||
|
from app.modules.billing.models import (
|
||||||
|
BillingHistory,
|
||||||
|
SubscriptionStatus,
|
||||||
|
SubscriptionTier,
|
||||||
|
VendorSubscription,
|
||||||
|
)
|
||||||
|
from models.database.product import Product
|
||||||
|
from models.database.vendor import Vendor, VendorUser
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AdminSubscriptionService:
|
||||||
|
"""Service for admin subscription management operations."""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Subscription Tiers
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_tiers(
|
||||||
|
self, db: Session, include_inactive: bool = False
|
||||||
|
) -> list[SubscriptionTier]:
|
||||||
|
"""Get all subscription tiers."""
|
||||||
|
query = db.query(SubscriptionTier)
|
||||||
|
|
||||||
|
if not include_inactive:
|
||||||
|
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
|
||||||
|
|
||||||
|
return query.order_by(SubscriptionTier.display_order).all()
|
||||||
|
|
||||||
|
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||||
|
"""Get a subscription tier by code."""
|
||||||
|
tier = (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(SubscriptionTier.code == tier_code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tier:
|
||||||
|
raise TierNotFoundException(tier_code)
|
||||||
|
|
||||||
|
return tier
|
||||||
|
|
||||||
|
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
|
||||||
|
"""Create a new subscription tier."""
|
||||||
|
# Check for duplicate code
|
||||||
|
existing = (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(SubscriptionTier.code == tier_data["code"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ConflictException(
|
||||||
|
f"Tier with code '{tier_data['code']}' already exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
tier = SubscriptionTier(**tier_data)
|
||||||
|
db.add(tier)
|
||||||
|
|
||||||
|
logger.info(f"Created subscription tier: {tier.code}")
|
||||||
|
return tier
|
||||||
|
|
||||||
|
def update_tier(
|
||||||
|
self, db: Session, tier_code: str, update_data: dict
|
||||||
|
) -> SubscriptionTier:
|
||||||
|
"""Update a subscription tier."""
|
||||||
|
tier = self.get_tier_by_code(db, tier_code)
|
||||||
|
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(tier, field, value)
|
||||||
|
|
||||||
|
logger.info(f"Updated subscription tier: {tier.code}")
|
||||||
|
return tier
|
||||||
|
|
||||||
|
def deactivate_tier(self, db: Session, tier_code: str) -> None:
|
||||||
|
"""Soft-delete a subscription tier."""
|
||||||
|
tier = self.get_tier_by_code(db, tier_code)
|
||||||
|
|
||||||
|
# Check if any active subscriptions use this tier
|
||||||
|
active_subs = (
|
||||||
|
db.query(VendorSubscription)
|
||||||
|
.filter(
|
||||||
|
VendorSubscription.tier == tier_code,
|
||||||
|
VendorSubscription.status.in_([
|
||||||
|
SubscriptionStatus.ACTIVE.value,
|
||||||
|
SubscriptionStatus.TRIAL.value,
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
if active_subs > 0:
|
||||||
|
raise BusinessLogicException(
|
||||||
|
f"Cannot delete tier: {active_subs} active subscriptions are using it"
|
||||||
|
)
|
||||||
|
|
||||||
|
tier.is_active = False
|
||||||
|
|
||||||
|
logger.info(f"Soft-deleted subscription tier: {tier.code}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Vendor Subscriptions
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def list_subscriptions(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
page: int = 1,
|
||||||
|
per_page: int = 20,
|
||||||
|
status: str | None = None,
|
||||||
|
tier: str | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""List vendor subscriptions with filtering and pagination."""
|
||||||
|
query = (
|
||||||
|
db.query(VendorSubscription, Vendor)
|
||||||
|
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if status:
|
||||||
|
query = query.filter(VendorSubscription.status == status)
|
||||||
|
if tier:
|
||||||
|
query = query.filter(VendorSubscription.tier == tier)
|
||||||
|
if search:
|
||||||
|
query = query.filter(Vendor.name.ilike(f"%{search}%"))
|
||||||
|
|
||||||
|
# Count total
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
# Paginate
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
results = (
|
||||||
|
query.order_by(VendorSubscription.created_at.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(per_page)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"results": results,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
|
||||||
|
"""Get subscription for a specific vendor."""
|
||||||
|
result = (
|
||||||
|
db.query(VendorSubscription, Vendor)
|
||||||
|
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
||||||
|
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
raise ResourceNotFoundException("Subscription", str(vendor_id))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def update_subscription(
|
||||||
|
self, db: Session, vendor_id: int, update_data: dict
|
||||||
|
) -> tuple:
|
||||||
|
"""Update a vendor's subscription."""
|
||||||
|
result = self.get_subscription(db, vendor_id)
|
||||||
|
sub, vendor = result
|
||||||
|
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(sub, field, value)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sub, vendor
|
||||||
|
|
||||||
|
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
||||||
|
"""Get a vendor by ID."""
|
||||||
|
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||||
|
|
||||||
|
if not vendor:
|
||||||
|
raise ResourceNotFoundException("Vendor", str(vendor_id))
|
||||||
|
|
||||||
|
return vendor
|
||||||
|
|
||||||
|
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
|
||||||
|
"""Get usage counts (products and team members) for a vendor."""
|
||||||
|
products_count = (
|
||||||
|
db.query(func.count(Product.id))
|
||||||
|
.filter(Product.vendor_id == vendor_id)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
team_count = (
|
||||||
|
db.query(func.count(VendorUser.id))
|
||||||
|
.filter(
|
||||||
|
VendorUser.vendor_id == vendor_id,
|
||||||
|
VendorUser.is_active == True, # noqa: E712
|
||||||
|
)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"products_count": products_count,
|
||||||
|
"team_count": team_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Billing History
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def list_billing_history(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
page: int = 1,
|
||||||
|
per_page: int = 20,
|
||||||
|
vendor_id: int | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""List billing history across all vendors."""
|
||||||
|
query = (
|
||||||
|
db.query(BillingHistory, Vendor)
|
||||||
|
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if vendor_id:
|
||||||
|
query = query.filter(BillingHistory.vendor_id == vendor_id)
|
||||||
|
if status:
|
||||||
|
query = query.filter(BillingHistory.status == status)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
results = (
|
||||||
|
query.order_by(BillingHistory.invoice_date.desc())
|
||||||
|
.offset(offset)
|
||||||
|
.limit(per_page)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"results": results,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
"pages": ceil(total / per_page) if total > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Statistics
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_stats(self, db: Session) -> dict:
|
||||||
|
"""Get subscription statistics for admin dashboard."""
|
||||||
|
# Count by status
|
||||||
|
status_counts = (
|
||||||
|
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
|
||||||
|
.group_by(VendorSubscription.status)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total_subscriptions": 0,
|
||||||
|
"active_count": 0,
|
||||||
|
"trial_count": 0,
|
||||||
|
"past_due_count": 0,
|
||||||
|
"cancelled_count": 0,
|
||||||
|
"expired_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
for status, count in status_counts:
|
||||||
|
stats["total_subscriptions"] += count
|
||||||
|
if status == SubscriptionStatus.ACTIVE.value:
|
||||||
|
stats["active_count"] = count
|
||||||
|
elif status == SubscriptionStatus.TRIAL.value:
|
||||||
|
stats["trial_count"] = count
|
||||||
|
elif status == SubscriptionStatus.PAST_DUE.value:
|
||||||
|
stats["past_due_count"] = count
|
||||||
|
elif status == SubscriptionStatus.CANCELLED.value:
|
||||||
|
stats["cancelled_count"] = count
|
||||||
|
elif status == SubscriptionStatus.EXPIRED.value:
|
||||||
|
stats["expired_count"] = count
|
||||||
|
|
||||||
|
# Count by tier
|
||||||
|
tier_counts = (
|
||||||
|
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
|
||||||
|
.filter(
|
||||||
|
VendorSubscription.status.in_([
|
||||||
|
SubscriptionStatus.ACTIVE.value,
|
||||||
|
SubscriptionStatus.TRIAL.value,
|
||||||
|
])
|
||||||
|
)
|
||||||
|
.group_by(VendorSubscription.tier)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
tier_distribution = {tier: count for tier, count in tier_counts}
|
||||||
|
|
||||||
|
# Calculate MRR (Monthly Recurring Revenue)
|
||||||
|
mrr_cents = 0
|
||||||
|
arr_cents = 0
|
||||||
|
|
||||||
|
active_subs = (
|
||||||
|
db.query(VendorSubscription, SubscriptionTier)
|
||||||
|
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
|
||||||
|
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for sub, tier in active_subs:
|
||||||
|
if sub.is_annual and tier.price_annual_cents:
|
||||||
|
mrr_cents += tier.price_annual_cents // 12
|
||||||
|
arr_cents += tier.price_annual_cents
|
||||||
|
else:
|
||||||
|
mrr_cents += tier.price_monthly_cents
|
||||||
|
arr_cents += tier.price_monthly_cents * 12
|
||||||
|
|
||||||
|
stats["tier_distribution"] = tier_distribution
|
||||||
|
stats["mrr_cents"] = mrr_cents
|
||||||
|
stats["arr_cents"] = arr_cents
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
admin_subscription_service = AdminSubscriptionService()
|
||||||
582
app/modules/billing/services/stripe_service.py
Normal file
582
app/modules/billing/services/stripe_service.py
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
# app/modules/billing/services/stripe_service.py
|
||||||
|
"""
|
||||||
|
Stripe payment integration service.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- Customer management
|
||||||
|
- Subscription management
|
||||||
|
- Checkout session creation
|
||||||
|
- Customer portal access
|
||||||
|
- Webhook event construction
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import stripe
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.modules.billing.exceptions import (
|
||||||
|
StripeNotConfiguredException,
|
||||||
|
WebhookVerificationException,
|
||||||
|
)
|
||||||
|
from app.modules.billing.models import (
|
||||||
|
BillingHistory,
|
||||||
|
SubscriptionStatus,
|
||||||
|
SubscriptionTier,
|
||||||
|
VendorSubscription,
|
||||||
|
)
|
||||||
|
from models.database.vendor import Vendor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StripeService:
|
||||||
|
"""Service for Stripe payment operations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._configured = False
|
||||||
|
self._configure()
|
||||||
|
|
||||||
|
def _configure(self):
|
||||||
|
"""Configure Stripe with API key."""
|
||||||
|
if settings.stripe_secret_key:
|
||||||
|
stripe.api_key = settings.stripe_secret_key
|
||||||
|
self._configured = True
|
||||||
|
else:
|
||||||
|
logger.warning("Stripe API key not configured")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""Check if Stripe is properly configured."""
|
||||||
|
return self._configured and bool(settings.stripe_secret_key)
|
||||||
|
|
||||||
|
def _check_configured(self) -> None:
|
||||||
|
"""Raise exception if Stripe is not configured."""
|
||||||
|
if not self.is_configured:
|
||||||
|
raise StripeNotConfiguredException()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Customer Management
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def create_customer(
|
||||||
|
self,
|
||||||
|
vendor: Vendor,
|
||||||
|
email: str,
|
||||||
|
name: str | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create a Stripe customer for a vendor.
|
||||||
|
|
||||||
|
Returns the Stripe customer ID.
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
customer_metadata = {
|
||||||
|
"vendor_id": str(vendor.id),
|
||||||
|
"vendor_code": vendor.vendor_code,
|
||||||
|
**(metadata or {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
customer = stripe.Customer.create(
|
||||||
|
email=email,
|
||||||
|
name=name or vendor.name,
|
||||||
|
metadata=customer_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
|
||||||
|
)
|
||||||
|
return customer.id
|
||||||
|
|
||||||
|
def get_customer(self, customer_id: str) -> stripe.Customer:
|
||||||
|
"""Get a Stripe customer by ID."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
return stripe.Customer.retrieve(customer_id)
|
||||||
|
|
||||||
|
def update_customer(
|
||||||
|
self,
|
||||||
|
customer_id: str,
|
||||||
|
email: str | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> stripe.Customer:
|
||||||
|
"""Update a Stripe customer."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
update_data = {}
|
||||||
|
if email:
|
||||||
|
update_data["email"] = email
|
||||||
|
if name:
|
||||||
|
update_data["name"] = name
|
||||||
|
if metadata:
|
||||||
|
update_data["metadata"] = metadata
|
||||||
|
|
||||||
|
return stripe.Customer.modify(customer_id, **update_data)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Subscription Management
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def create_subscription(
|
||||||
|
self,
|
||||||
|
customer_id: str,
|
||||||
|
price_id: str,
|
||||||
|
trial_days: int | None = None,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> stripe.Subscription:
|
||||||
|
"""
|
||||||
|
Create a new Stripe subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_id: Stripe customer ID
|
||||||
|
price_id: Stripe price ID for the subscription
|
||||||
|
trial_days: Optional trial period in days
|
||||||
|
metadata: Optional metadata to attach
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stripe Subscription object
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
subscription_data = {
|
||||||
|
"customer": customer_id,
|
||||||
|
"items": [{"price": price_id}],
|
||||||
|
"metadata": metadata or {},
|
||||||
|
"payment_behavior": "default_incomplete",
|
||||||
|
"expand": ["latest_invoice.payment_intent"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if trial_days:
|
||||||
|
subscription_data["trial_period_days"] = trial_days
|
||||||
|
|
||||||
|
subscription = stripe.Subscription.create(**subscription_data)
|
||||||
|
logger.info(
|
||||||
|
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
|
||||||
|
)
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||||
|
"""Get a Stripe subscription by ID."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
return stripe.Subscription.retrieve(subscription_id)
|
||||||
|
|
||||||
|
def update_subscription(
|
||||||
|
self,
|
||||||
|
subscription_id: str,
|
||||||
|
new_price_id: str | None = None,
|
||||||
|
proration_behavior: str = "create_prorations",
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> stripe.Subscription:
|
||||||
|
"""
|
||||||
|
Update a Stripe subscription (e.g., change tier).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Stripe subscription ID
|
||||||
|
new_price_id: New price ID for tier change
|
||||||
|
proration_behavior: How to handle prorations
|
||||||
|
metadata: Optional metadata to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated Stripe Subscription object
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
update_data = {"proration_behavior": proration_behavior}
|
||||||
|
|
||||||
|
if new_price_id:
|
||||||
|
# Get the subscription to find the item ID
|
||||||
|
subscription = stripe.Subscription.retrieve(subscription_id)
|
||||||
|
item_id = subscription["items"]["data"][0]["id"]
|
||||||
|
update_data["items"] = [{"id": item_id, "price": new_price_id}]
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
update_data["metadata"] = metadata
|
||||||
|
|
||||||
|
updated = stripe.Subscription.modify(subscription_id, **update_data)
|
||||||
|
logger.info(f"Updated Stripe subscription {subscription_id}")
|
||||||
|
return updated
|
||||||
|
|
||||||
|
def cancel_subscription(
|
||||||
|
self,
|
||||||
|
subscription_id: str,
|
||||||
|
immediately: bool = False,
|
||||||
|
cancellation_reason: str | None = None,
|
||||||
|
) -> stripe.Subscription:
|
||||||
|
"""
|
||||||
|
Cancel a Stripe subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_id: Stripe subscription ID
|
||||||
|
immediately: If True, cancel now. If False, cancel at period end.
|
||||||
|
cancellation_reason: Optional reason for cancellation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cancelled Stripe Subscription object
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
if immediately:
|
||||||
|
subscription = stripe.Subscription.cancel(subscription_id)
|
||||||
|
else:
|
||||||
|
subscription = stripe.Subscription.modify(
|
||||||
|
subscription_id,
|
||||||
|
cancel_at_period_end=True,
|
||||||
|
metadata={"cancellation_reason": cancellation_reason or "user_request"},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Cancelled Stripe subscription {subscription_id} "
|
||||||
|
f"(immediately={immediately})"
|
||||||
|
)
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||||
|
"""
|
||||||
|
Reactivate a cancelled subscription (if not yet ended).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reactivated Stripe Subscription object
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
subscription = stripe.Subscription.modify(
|
||||||
|
subscription_id,
|
||||||
|
cancel_at_period_end=False,
|
||||||
|
)
|
||||||
|
logger.info(f"Reactivated Stripe subscription {subscription_id}")
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def cancel_subscription_item(self, subscription_item_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Cancel a subscription item (used for add-ons).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subscription_item_id: Stripe subscription item ID
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
stripe.SubscriptionItem.delete(subscription_item_id)
|
||||||
|
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Checkout & Portal
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def create_checkout_session(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
vendor: Vendor,
|
||||||
|
price_id: str,
|
||||||
|
success_url: str,
|
||||||
|
cancel_url: str,
|
||||||
|
trial_days: int | None = None,
|
||||||
|
quantity: int = 1,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> stripe.checkout.Session:
|
||||||
|
"""
|
||||||
|
Create a Stripe Checkout session for subscription signup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
vendor: Vendor to create checkout for
|
||||||
|
price_id: Stripe price ID
|
||||||
|
success_url: URL to redirect on success
|
||||||
|
cancel_url: URL to redirect on cancel
|
||||||
|
trial_days: Optional trial period
|
||||||
|
quantity: Number of items (default 1)
|
||||||
|
metadata: Additional metadata to store
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stripe Checkout Session object
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
# Get or create Stripe customer
|
||||||
|
subscription = (
|
||||||
|
db.query(VendorSubscription)
|
||||||
|
.filter(VendorSubscription.vendor_id == vendor.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscription and subscription.stripe_customer_id:
|
||||||
|
customer_id = subscription.stripe_customer_id
|
||||||
|
else:
|
||||||
|
# Get vendor owner email
|
||||||
|
from models.database.vendor import VendorUser
|
||||||
|
|
||||||
|
owner = (
|
||||||
|
db.query(VendorUser)
|
||||||
|
.filter(
|
||||||
|
VendorUser.vendor_id == vendor.id,
|
||||||
|
VendorUser.is_owner == True,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
email = owner.user.email if owner and owner.user else None
|
||||||
|
|
||||||
|
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
|
||||||
|
|
||||||
|
# Store the customer ID
|
||||||
|
if subscription:
|
||||||
|
subscription.stripe_customer_id = customer_id
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
# Build metadata
|
||||||
|
session_metadata = {
|
||||||
|
"vendor_id": str(vendor.id),
|
||||||
|
"vendor_code": vendor.vendor_code,
|
||||||
|
}
|
||||||
|
if metadata:
|
||||||
|
session_metadata.update(metadata)
|
||||||
|
|
||||||
|
session_data = {
|
||||||
|
"customer": customer_id,
|
||||||
|
"line_items": [{"price": price_id, "quantity": quantity}],
|
||||||
|
"mode": "subscription",
|
||||||
|
"success_url": success_url,
|
||||||
|
"cancel_url": cancel_url,
|
||||||
|
"metadata": session_metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
if trial_days:
|
||||||
|
session_data["subscription_data"] = {"trial_period_days": trial_days}
|
||||||
|
|
||||||
|
session = stripe.checkout.Session.create(**session_data)
|
||||||
|
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
|
||||||
|
return session
|
||||||
|
|
||||||
|
def create_portal_session(
|
||||||
|
self,
|
||||||
|
customer_id: str,
|
||||||
|
return_url: str,
|
||||||
|
) -> stripe.billing_portal.Session:
|
||||||
|
"""
|
||||||
|
Create a Stripe Customer Portal session.
|
||||||
|
|
||||||
|
Allows customers to manage their subscription, payment methods, and invoices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_id: Stripe customer ID
|
||||||
|
return_url: URL to return to after portal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stripe Portal Session object
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
session = stripe.billing_portal.Session.create(
|
||||||
|
customer=customer_id,
|
||||||
|
return_url=return_url,
|
||||||
|
)
|
||||||
|
logger.info(f"Created portal session for customer {customer_id}")
|
||||||
|
return session
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Invoice Management
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_invoices(
|
||||||
|
self,
|
||||||
|
customer_id: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[stripe.Invoice]:
|
||||||
|
"""Get invoices for a customer."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
|
||||||
|
return list(invoices.data)
|
||||||
|
|
||||||
|
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
|
||||||
|
"""Get the upcoming invoice for a customer."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return stripe.Invoice.upcoming(customer=customer_id)
|
||||||
|
except stripe.error.InvalidRequestError:
|
||||||
|
# No upcoming invoice
|
||||||
|
return None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Webhook Handling
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def construct_event(
|
||||||
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
sig_header: str,
|
||||||
|
) -> stripe.Event:
|
||||||
|
"""
|
||||||
|
Construct and verify a Stripe webhook event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Raw request body
|
||||||
|
sig_header: Stripe-Signature header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Verified Stripe Event object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WebhookVerificationException: If signature verification fails
|
||||||
|
"""
|
||||||
|
if not settings.stripe_webhook_secret:
|
||||||
|
raise WebhookVerificationException("Stripe webhook secret not configured")
|
||||||
|
|
||||||
|
try:
|
||||||
|
event = stripe.Webhook.construct_event(
|
||||||
|
payload,
|
||||||
|
sig_header,
|
||||||
|
settings.stripe_webhook_secret,
|
||||||
|
)
|
||||||
|
return event
|
||||||
|
except stripe.error.SignatureVerificationError as e:
|
||||||
|
logger.error(f"Webhook signature verification failed: {e}")
|
||||||
|
raise WebhookVerificationException("Invalid webhook signature")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# SetupIntent & Payment Method Management
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def create_setup_intent(
|
||||||
|
self,
|
||||||
|
customer_id: str,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> stripe.SetupIntent:
|
||||||
|
"""
|
||||||
|
Create a SetupIntent to collect card without charging.
|
||||||
|
|
||||||
|
Used for trial signups where we collect card upfront
|
||||||
|
but don't charge until trial ends.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_id: Stripe customer ID
|
||||||
|
metadata: Optional metadata to attach
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stripe SetupIntent object with client_secret for frontend
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
setup_intent = stripe.SetupIntent.create(
|
||||||
|
customer=customer_id,
|
||||||
|
payment_method_types=["card"],
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
|
||||||
|
return setup_intent
|
||||||
|
|
||||||
|
def attach_payment_method_to_customer(
|
||||||
|
self,
|
||||||
|
customer_id: str,
|
||||||
|
payment_method_id: str,
|
||||||
|
set_as_default: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Attach a payment method to customer and optionally set as default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_id: Stripe customer ID
|
||||||
|
payment_method_id: Payment method ID from confirmed SetupIntent
|
||||||
|
set_as_default: Whether to set as default payment method
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
# Attach the payment method to the customer
|
||||||
|
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
|
||||||
|
|
||||||
|
if set_as_default:
|
||||||
|
stripe.Customer.modify(
|
||||||
|
customer_id,
|
||||||
|
invoice_settings={"default_payment_method": payment_method_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Attached payment method {payment_method_id} to customer {customer_id} "
|
||||||
|
f"(default={set_as_default})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_subscription_with_trial(
|
||||||
|
self,
|
||||||
|
customer_id: str,
|
||||||
|
price_id: str,
|
||||||
|
trial_days: int = 30,
|
||||||
|
metadata: dict | None = None,
|
||||||
|
) -> stripe.Subscription:
|
||||||
|
"""
|
||||||
|
Create subscription with trial period.
|
||||||
|
|
||||||
|
Customer must have a default payment method attached.
|
||||||
|
Card will be charged automatically after trial ends.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_id: Stripe customer ID (must have default payment method)
|
||||||
|
price_id: Stripe price ID for the subscription tier
|
||||||
|
trial_days: Number of trial days (default 30)
|
||||||
|
metadata: Optional metadata to attach
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stripe Subscription object
|
||||||
|
"""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
subscription = stripe.Subscription.create(
|
||||||
|
customer=customer_id,
|
||||||
|
items=[{"price": price_id}],
|
||||||
|
trial_period_days=trial_days,
|
||||||
|
metadata=metadata or {},
|
||||||
|
# Use default payment method for future charges
|
||||||
|
default_payment_method=None, # Uses customer's default
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created subscription {subscription.id} with {trial_days}-day trial "
|
||||||
|
f"for customer {customer_id}"
|
||||||
|
)
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
|
||||||
|
"""Get a SetupIntent by ID."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
return stripe.SetupIntent.retrieve(setup_intent_id)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Price/Product Management
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_price(self, price_id: str) -> stripe.Price:
|
||||||
|
"""Get a Stripe price by ID."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
return stripe.Price.retrieve(price_id)
|
||||||
|
|
||||||
|
def get_product(self, product_id: str) -> stripe.Product:
|
||||||
|
"""Get a Stripe product by ID."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
return stripe.Product.retrieve(product_id)
|
||||||
|
|
||||||
|
def list_prices(
|
||||||
|
self,
|
||||||
|
product_id: str | None = None,
|
||||||
|
active: bool = True,
|
||||||
|
) -> list[stripe.Price]:
|
||||||
|
"""List Stripe prices, optionally filtered by product."""
|
||||||
|
self._check_configured()
|
||||||
|
|
||||||
|
params = {"active": active}
|
||||||
|
if product_id:
|
||||||
|
params["product"] = product_id
|
||||||
|
|
||||||
|
prices = stripe.Price.list(**params)
|
||||||
|
return list(prices.data)
|
||||||
|
|
||||||
|
|
||||||
|
# Create service instance
|
||||||
|
stripe_service = StripeService()
|
||||||
631
app/modules/billing/services/subscription_service.py
Normal file
631
app/modules/billing/services/subscription_service.py
Normal file
@@ -0,0 +1,631 @@
|
|||||||
|
# app/modules/billing/services/subscription_service.py
|
||||||
|
"""
|
||||||
|
Subscription service for tier-based access control.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Subscription creation and management
|
||||||
|
- Tier limit enforcement
|
||||||
|
- Usage tracking
|
||||||
|
- Feature gating
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.modules.billing.services import subscription_service
|
||||||
|
|
||||||
|
# Check if vendor can create an order
|
||||||
|
can_create, message = subscription_service.can_create_order(db, vendor_id)
|
||||||
|
|
||||||
|
# Increment order counter after successful order
|
||||||
|
subscription_service.increment_order_count(db, vendor_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.modules.billing.exceptions import (
|
||||||
|
FeatureNotAvailableException,
|
||||||
|
SubscriptionNotFoundException,
|
||||||
|
TierLimitExceededException,
|
||||||
|
)
|
||||||
|
from app.modules.billing.models import (
|
||||||
|
SubscriptionStatus,
|
||||||
|
SubscriptionTier,
|
||||||
|
TIER_LIMITS,
|
||||||
|
TierCode,
|
||||||
|
VendorSubscription,
|
||||||
|
)
|
||||||
|
from app.modules.billing.schemas import (
|
||||||
|
SubscriptionCreate,
|
||||||
|
SubscriptionUpdate,
|
||||||
|
SubscriptionUsage,
|
||||||
|
TierInfo,
|
||||||
|
TierLimits,
|
||||||
|
UsageSummary,
|
||||||
|
)
|
||||||
|
from models.database.product import Product
|
||||||
|
from models.database.vendor import Vendor, VendorUser
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionService:
|
||||||
|
"""Service for subscription and tier limit operations."""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Tier Information
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
|
||||||
|
"""
|
||||||
|
Get full tier information.
|
||||||
|
|
||||||
|
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||||
|
"""
|
||||||
|
# Try database first if session provided
|
||||||
|
if db is not None:
|
||||||
|
db_tier = self.get_tier_by_code(db, tier_code)
|
||||||
|
if db_tier:
|
||||||
|
return TierInfo(
|
||||||
|
code=db_tier.code,
|
||||||
|
name=db_tier.name,
|
||||||
|
price_monthly_cents=db_tier.price_monthly_cents,
|
||||||
|
price_annual_cents=db_tier.price_annual_cents,
|
||||||
|
limits=TierLimits(
|
||||||
|
orders_per_month=db_tier.orders_per_month,
|
||||||
|
products_limit=db_tier.products_limit,
|
||||||
|
team_members=db_tier.team_members,
|
||||||
|
order_history_months=db_tier.order_history_months,
|
||||||
|
),
|
||||||
|
features=db_tier.features or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to hardcoded TIER_LIMITS
|
||||||
|
return self._get_tier_from_legacy(tier_code)
|
||||||
|
|
||||||
|
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
|
||||||
|
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
|
||||||
|
try:
|
||||||
|
tier = TierCode(tier_code)
|
||||||
|
except ValueError:
|
||||||
|
tier = TierCode.ESSENTIAL
|
||||||
|
|
||||||
|
limits = TIER_LIMITS[tier]
|
||||||
|
return TierInfo(
|
||||||
|
code=tier.value,
|
||||||
|
name=limits["name"],
|
||||||
|
price_monthly_cents=limits["price_monthly_cents"],
|
||||||
|
price_annual_cents=limits.get("price_annual_cents"),
|
||||||
|
limits=TierLimits(
|
||||||
|
orders_per_month=limits.get("orders_per_month"),
|
||||||
|
products_limit=limits.get("products_limit"),
|
||||||
|
team_members=limits.get("team_members"),
|
||||||
|
order_history_months=limits.get("order_history_months"),
|
||||||
|
),
|
||||||
|
features=limits.get("features", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
|
||||||
|
"""
|
||||||
|
Get information for all tiers.
|
||||||
|
|
||||||
|
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
||||||
|
"""
|
||||||
|
if db is not None:
|
||||||
|
db_tiers = (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(
|
||||||
|
SubscriptionTier.is_active == True, # noqa: E712
|
||||||
|
SubscriptionTier.is_public == True, # noqa: E712
|
||||||
|
)
|
||||||
|
.order_by(SubscriptionTier.display_order)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
if db_tiers:
|
||||||
|
return [
|
||||||
|
TierInfo(
|
||||||
|
code=t.code,
|
||||||
|
name=t.name,
|
||||||
|
price_monthly_cents=t.price_monthly_cents,
|
||||||
|
price_annual_cents=t.price_annual_cents,
|
||||||
|
limits=TierLimits(
|
||||||
|
orders_per_month=t.orders_per_month,
|
||||||
|
products_limit=t.products_limit,
|
||||||
|
team_members=t.team_members,
|
||||||
|
order_history_months=t.order_history_months,
|
||||||
|
),
|
||||||
|
features=t.features or [],
|
||||||
|
)
|
||||||
|
for t in db_tiers
|
||||||
|
]
|
||||||
|
|
||||||
|
# Fallback to hardcoded
|
||||||
|
return [
|
||||||
|
self._get_tier_from_legacy(tier.value)
|
||||||
|
for tier in TierCode
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
|
||||||
|
"""Get subscription tier by code."""
|
||||||
|
return (
|
||||||
|
db.query(SubscriptionTier)
|
||||||
|
.filter(SubscriptionTier.code == tier_code)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
|
||||||
|
"""Get tier ID from tier code. Returns None if tier not found."""
|
||||||
|
tier = self.get_tier_by_code(db, tier_code)
|
||||||
|
return tier.id if tier else None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Subscription CRUD
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_subscription(
|
||||||
|
self, db: Session, vendor_id: int
|
||||||
|
) -> VendorSubscription | None:
|
||||||
|
"""Get vendor subscription."""
|
||||||
|
return (
|
||||||
|
db.query(VendorSubscription)
|
||||||
|
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_subscription_or_raise(
|
||||||
|
self, db: Session, vendor_id: int
|
||||||
|
) -> VendorSubscription:
|
||||||
|
"""Get vendor subscription or raise exception."""
|
||||||
|
subscription = self.get_subscription(db, vendor_id)
|
||||||
|
if not subscription:
|
||||||
|
raise SubscriptionNotFoundException(vendor_id)
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def get_current_tier(
|
||||||
|
self, db: Session, vendor_id: int
|
||||||
|
) -> TierCode | None:
|
||||||
|
"""Get vendor's current subscription tier code."""
|
||||||
|
subscription = self.get_subscription(db, vendor_id)
|
||||||
|
if subscription:
|
||||||
|
try:
|
||||||
|
return TierCode(subscription.tier)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_or_create_subscription(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
vendor_id: int,
|
||||||
|
tier: str = TierCode.ESSENTIAL.value,
|
||||||
|
trial_days: int = 14,
|
||||||
|
) -> VendorSubscription:
|
||||||
|
"""
|
||||||
|
Get existing subscription or create a new trial subscription.
|
||||||
|
|
||||||
|
Used when a vendor first accesses the system.
|
||||||
|
"""
|
||||||
|
subscription = self.get_subscription(db, vendor_id)
|
||||||
|
if subscription:
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
# Create new trial subscription
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
trial_end = now + timedelta(days=trial_days)
|
||||||
|
|
||||||
|
# Lookup tier_id from tier code
|
||||||
|
tier_id = self.get_tier_id(db, tier)
|
||||||
|
|
||||||
|
subscription = VendorSubscription(
|
||||||
|
vendor_id=vendor_id,
|
||||||
|
tier=tier,
|
||||||
|
tier_id=tier_id,
|
||||||
|
status=SubscriptionStatus.TRIAL.value,
|
||||||
|
period_start=now,
|
||||||
|
period_end=trial_end,
|
||||||
|
trial_ends_at=trial_end,
|
||||||
|
is_annual=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(subscription)
|
||||||
|
db.flush()
|
||||||
|
db.refresh(subscription)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created trial subscription for vendor {vendor_id} "
|
||||||
|
f"(tier={tier}, trial_ends={trial_end})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def create_subscription(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
vendor_id: int,
|
||||||
|
data: SubscriptionCreate,
|
||||||
|
) -> VendorSubscription:
|
||||||
|
"""Create a subscription for a vendor."""
|
||||||
|
# Check if subscription exists
|
||||||
|
existing = self.get_subscription(db, vendor_id)
|
||||||
|
if existing:
|
||||||
|
raise ValueError("Vendor already has a subscription")
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Calculate period end based on billing cycle
|
||||||
|
if data.is_annual:
|
||||||
|
period_end = now + timedelta(days=365)
|
||||||
|
else:
|
||||||
|
period_end = now + timedelta(days=30)
|
||||||
|
|
||||||
|
# Handle trial
|
||||||
|
trial_ends_at = None
|
||||||
|
status = SubscriptionStatus.ACTIVE.value
|
||||||
|
if data.trial_days > 0:
|
||||||
|
trial_ends_at = now + timedelta(days=data.trial_days)
|
||||||
|
status = SubscriptionStatus.TRIAL.value
|
||||||
|
period_end = trial_ends_at
|
||||||
|
|
||||||
|
# Lookup tier_id from tier code
|
||||||
|
tier_id = self.get_tier_id(db, data.tier)
|
||||||
|
|
||||||
|
subscription = VendorSubscription(
|
||||||
|
vendor_id=vendor_id,
|
||||||
|
tier=data.tier,
|
||||||
|
tier_id=tier_id,
|
||||||
|
status=status,
|
||||||
|
period_start=now,
|
||||||
|
period_end=period_end,
|
||||||
|
trial_ends_at=trial_ends_at,
|
||||||
|
is_annual=data.is_annual,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(subscription)
|
||||||
|
db.flush()
|
||||||
|
db.refresh(subscription)
|
||||||
|
|
||||||
|
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def update_subscription(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
vendor_id: int,
|
||||||
|
data: SubscriptionUpdate,
|
||||||
|
) -> VendorSubscription:
|
||||||
|
"""Update a vendor subscription."""
|
||||||
|
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||||
|
|
||||||
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# If tier is being updated, also update tier_id
|
||||||
|
if "tier" in update_data:
|
||||||
|
tier_id = self.get_tier_id(db, update_data["tier"])
|
||||||
|
update_data["tier_id"] = tier_id
|
||||||
|
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(subscription, key, value)
|
||||||
|
|
||||||
|
subscription.updated_at = datetime.now(UTC)
|
||||||
|
db.flush()
|
||||||
|
db.refresh(subscription)
|
||||||
|
|
||||||
|
logger.info(f"Updated subscription for vendor {vendor_id}")
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def upgrade_tier(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
vendor_id: int,
|
||||||
|
new_tier: str,
|
||||||
|
) -> VendorSubscription:
|
||||||
|
"""Upgrade vendor to a new tier."""
|
||||||
|
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||||
|
|
||||||
|
old_tier = subscription.tier
|
||||||
|
subscription.tier = new_tier
|
||||||
|
subscription.tier_id = self.get_tier_id(db, new_tier)
|
||||||
|
subscription.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
# If upgrading from trial, mark as active
|
||||||
|
if subscription.status == SubscriptionStatus.TRIAL.value:
|
||||||
|
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
db.refresh(subscription)
|
||||||
|
|
||||||
|
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
def cancel_subscription(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
vendor_id: int,
|
||||||
|
reason: str | None = None,
|
||||||
|
) -> VendorSubscription:
|
||||||
|
"""Cancel a vendor subscription (access until period end)."""
|
||||||
|
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||||
|
|
||||||
|
subscription.status = SubscriptionStatus.CANCELLED.value
|
||||||
|
subscription.cancelled_at = datetime.now(UTC)
|
||||||
|
subscription.cancellation_reason = reason
|
||||||
|
subscription.updated_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
db.flush()
|
||||||
|
db.refresh(subscription)
|
||||||
|
|
||||||
|
logger.info(f"Cancelled subscription for vendor {vendor_id}")
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Usage Tracking
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
|
||||||
|
"""Get current subscription usage statistics."""
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
|
||||||
|
# Get actual counts
|
||||||
|
products_count = (
|
||||||
|
db.query(func.count(Product.id))
|
||||||
|
.filter(Product.vendor_id == vendor_id)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
team_count = (
|
||||||
|
db.query(func.count(VendorUser.id))
|
||||||
|
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate usage stats
|
||||||
|
orders_limit = subscription.orders_limit
|
||||||
|
products_limit = subscription.products_limit
|
||||||
|
team_limit = subscription.team_members_limit
|
||||||
|
|
||||||
|
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||||
|
if limit is None:
|
||||||
|
return None
|
||||||
|
return max(0, limit - current)
|
||||||
|
|
||||||
|
def calc_percent(current: int, limit: int | None) -> float | None:
|
||||||
|
if limit is None or limit == 0:
|
||||||
|
return None
|
||||||
|
return min(100.0, (current / limit) * 100)
|
||||||
|
|
||||||
|
return SubscriptionUsage(
|
||||||
|
orders_used=subscription.orders_this_period,
|
||||||
|
orders_limit=orders_limit,
|
||||||
|
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||||
|
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
|
||||||
|
products_used=products_count,
|
||||||
|
products_limit=products_limit,
|
||||||
|
products_remaining=calc_remaining(products_count, products_limit),
|
||||||
|
products_percent_used=calc_percent(products_count, products_limit),
|
||||||
|
team_members_used=team_count,
|
||||||
|
team_members_limit=team_limit,
|
||||||
|
team_members_remaining=calc_remaining(team_count, team_limit),
|
||||||
|
team_members_percent_used=calc_percent(team_count, team_limit),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
|
||||||
|
"""Get usage summary for billing page display."""
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
|
||||||
|
# Get actual counts
|
||||||
|
products_count = (
|
||||||
|
db.query(func.count(Product.id))
|
||||||
|
.filter(Product.vendor_id == vendor_id)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
team_count = (
|
||||||
|
db.query(func.count(VendorUser.id))
|
||||||
|
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get limits
|
||||||
|
orders_limit = subscription.orders_limit
|
||||||
|
products_limit = subscription.products_limit
|
||||||
|
team_limit = subscription.team_members_limit
|
||||||
|
|
||||||
|
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||||
|
if limit is None:
|
||||||
|
return None
|
||||||
|
return max(0, limit - current)
|
||||||
|
|
||||||
|
return UsageSummary(
|
||||||
|
orders_this_period=subscription.orders_this_period,
|
||||||
|
orders_limit=orders_limit,
|
||||||
|
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||||
|
products_count=products_count,
|
||||||
|
products_limit=products_limit,
|
||||||
|
products_remaining=calc_remaining(products_count, products_limit),
|
||||||
|
team_count=team_count,
|
||||||
|
team_limit=team_limit,
|
||||||
|
team_remaining=calc_remaining(team_count, team_limit),
|
||||||
|
)
|
||||||
|
|
||||||
|
def increment_order_count(self, db: Session, vendor_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Increment the order counter for the current period.
|
||||||
|
|
||||||
|
Call this after successfully creating/importing an order.
|
||||||
|
"""
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
subscription.increment_order_count()
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
|
||||||
|
"""Reset counters for a new billing period."""
|
||||||
|
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||||
|
subscription.reset_period_counters()
|
||||||
|
db.flush()
|
||||||
|
logger.info(f"Reset period counters for vendor {vendor_id}")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Limit Checks
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def can_create_order(
|
||||||
|
self, db: Session, vendor_id: int
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
"""
|
||||||
|
Check if vendor can create/import another order.
|
||||||
|
|
||||||
|
Returns: (allowed, error_message)
|
||||||
|
"""
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
return subscription.can_create_order()
|
||||||
|
|
||||||
|
def check_order_limit(self, db: Session, vendor_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Check order limit and raise exception if exceeded.
|
||||||
|
|
||||||
|
Use this in order creation flows.
|
||||||
|
"""
|
||||||
|
can_create, message = self.can_create_order(db, vendor_id)
|
||||||
|
if not can_create:
|
||||||
|
subscription = self.get_subscription(db, vendor_id)
|
||||||
|
raise TierLimitExceededException(
|
||||||
|
message=message or "Order limit exceeded",
|
||||||
|
limit_type="orders",
|
||||||
|
current=subscription.orders_this_period if subscription else 0,
|
||||||
|
limit=subscription.orders_limit if subscription else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_add_product(
|
||||||
|
self, db: Session, vendor_id: int
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
"""
|
||||||
|
Check if vendor can add another product.
|
||||||
|
|
||||||
|
Returns: (allowed, error_message)
|
||||||
|
"""
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
|
||||||
|
products_count = (
|
||||||
|
db.query(func.count(Product.id))
|
||||||
|
.filter(Product.vendor_id == vendor_id)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return subscription.can_add_product(products_count)
|
||||||
|
|
||||||
|
def check_product_limit(self, db: Session, vendor_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Check product limit and raise exception if exceeded.
|
||||||
|
|
||||||
|
Use this in product creation flows.
|
||||||
|
"""
|
||||||
|
can_add, message = self.can_add_product(db, vendor_id)
|
||||||
|
if not can_add:
|
||||||
|
subscription = self.get_subscription(db, vendor_id)
|
||||||
|
products_count = (
|
||||||
|
db.query(func.count(Product.id))
|
||||||
|
.filter(Product.vendor_id == vendor_id)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
raise TierLimitExceededException(
|
||||||
|
message=message or "Product limit exceeded",
|
||||||
|
limit_type="products",
|
||||||
|
current=products_count,
|
||||||
|
limit=subscription.products_limit if subscription else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_add_team_member(
|
||||||
|
self, db: Session, vendor_id: int
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
"""
|
||||||
|
Check if vendor can add another team member.
|
||||||
|
|
||||||
|
Returns: (allowed, error_message)
|
||||||
|
"""
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
|
||||||
|
team_count = (
|
||||||
|
db.query(func.count(VendorUser.id))
|
||||||
|
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return subscription.can_add_team_member(team_count)
|
||||||
|
|
||||||
|
def check_team_limit(self, db: Session, vendor_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Check team member limit and raise exception if exceeded.
|
||||||
|
|
||||||
|
Use this in team member invitation flows.
|
||||||
|
"""
|
||||||
|
can_add, message = self.can_add_team_member(db, vendor_id)
|
||||||
|
if not can_add:
|
||||||
|
subscription = self.get_subscription(db, vendor_id)
|
||||||
|
team_count = (
|
||||||
|
db.query(func.count(VendorUser.id))
|
||||||
|
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||||
|
.scalar()
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
raise TierLimitExceededException(
|
||||||
|
message=message or "Team member limit exceeded",
|
||||||
|
limit_type="team_members",
|
||||||
|
current=team_count,
|
||||||
|
limit=subscription.team_members_limit if subscription else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Feature Gating
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
|
||||||
|
"""Check if vendor has access to a feature."""
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
return subscription.has_feature(feature)
|
||||||
|
|
||||||
|
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
|
||||||
|
"""
|
||||||
|
Check feature access and raise exception if not available.
|
||||||
|
|
||||||
|
Use this to gate premium features.
|
||||||
|
"""
|
||||||
|
if not self.has_feature(db, vendor_id, feature):
|
||||||
|
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||||
|
|
||||||
|
# Find which tier has this feature
|
||||||
|
required_tier = None
|
||||||
|
for tier_code, limits in TIER_LIMITS.items():
|
||||||
|
if feature in limits.get("features", []):
|
||||||
|
required_tier = limits["name"]
|
||||||
|
break
|
||||||
|
|
||||||
|
raise FeatureNotAvailableException(
|
||||||
|
feature=feature,
|
||||||
|
current_tier=subscription.tier,
|
||||||
|
required_tier=required_tier or "higher",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_feature_tier(self, feature: str) -> str | None:
|
||||||
|
"""Get the minimum tier required for a feature."""
|
||||||
|
for tier_code in [
|
||||||
|
TierCode.ESSENTIAL,
|
||||||
|
TierCode.PROFESSIONAL,
|
||||||
|
TierCode.BUSINESS,
|
||||||
|
TierCode.ENTERPRISE,
|
||||||
|
]:
|
||||||
|
if feature in TIER_LIMITS[tier_code].get("features", []):
|
||||||
|
return tier_code.value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
subscription_service = SubscriptionService()
|
||||||
26
app/modules/billing/tasks/__init__.py
Normal file
26
app/modules/billing/tasks/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# app/modules/billing/tasks/__init__.py
|
||||||
|
"""
|
||||||
|
Billing module Celery tasks.
|
||||||
|
|
||||||
|
Scheduled tasks for:
|
||||||
|
- Resetting period counters
|
||||||
|
- Checking trial expirations
|
||||||
|
- Syncing with Stripe
|
||||||
|
- Cleaning up stale subscriptions
|
||||||
|
|
||||||
|
Note: capture_capacity_snapshot moved to monitoring module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.modules.billing.tasks.subscription import (
|
||||||
|
reset_period_counters,
|
||||||
|
check_trial_expirations,
|
||||||
|
sync_stripe_status,
|
||||||
|
cleanup_stale_subscriptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"reset_period_counters",
|
||||||
|
"check_trial_expirations",
|
||||||
|
"sync_stripe_status",
|
||||||
|
"cleanup_stale_subscriptions",
|
||||||
|
]
|
||||||
255
app/modules/billing/tasks/subscription.py
Normal file
255
app/modules/billing/tasks/subscription.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
# app/modules/billing/tasks/subscription.py
|
||||||
|
"""
|
||||||
|
Celery tasks for subscription management.
|
||||||
|
|
||||||
|
Scheduled tasks for:
|
||||||
|
- Resetting period counters
|
||||||
|
- Checking trial expirations
|
||||||
|
- Syncing with Stripe
|
||||||
|
- Cleaning up stale subscriptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from app.core.celery_config import celery_app
|
||||||
|
from app.modules.billing.models import SubscriptionStatus, VendorSubscription
|
||||||
|
from app.modules.billing.services import stripe_service
|
||||||
|
from app.modules.task_base import ModuleTask
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
bind=True,
|
||||||
|
base=ModuleTask,
|
||||||
|
name="app.modules.billing.tasks.subscription.reset_period_counters",
|
||||||
|
)
|
||||||
|
def reset_period_counters(self):
|
||||||
|
"""
|
||||||
|
Reset order counters for subscriptions whose billing period has ended.
|
||||||
|
|
||||||
|
Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
reset_count = 0
|
||||||
|
|
||||||
|
with self.get_db() as db:
|
||||||
|
# Find subscriptions where period has ended
|
||||||
|
expired_periods = (
|
||||||
|
db.query(VendorSubscription)
|
||||||
|
.filter(
|
||||||
|
VendorSubscription.period_end <= now,
|
||||||
|
VendorSubscription.status.in_(["active", "trial"]),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for subscription in expired_periods:
|
||||||
|
old_period_end = subscription.period_end
|
||||||
|
|
||||||
|
# Reset counters
|
||||||
|
subscription.orders_this_period = 0
|
||||||
|
subscription.orders_limit_reached_at = None
|
||||||
|
|
||||||
|
# Set new period dates
|
||||||
|
if subscription.is_annual:
|
||||||
|
subscription.period_start = now
|
||||||
|
subscription.period_end = now + timedelta(days=365)
|
||||||
|
else:
|
||||||
|
subscription.period_start = now
|
||||||
|
subscription.period_end = now + timedelta(days=30)
|
||||||
|
|
||||||
|
subscription.updated_at = now
|
||||||
|
reset_count += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Reset period counters for vendor {subscription.vendor_id}: "
|
||||||
|
f"old_period_end={old_period_end}, new_period_end={subscription.period_end}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Reset period counters for {reset_count} subscriptions")
|
||||||
|
|
||||||
|
return {"reset_count": reset_count}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
bind=True,
|
||||||
|
base=ModuleTask,
|
||||||
|
name="app.modules.billing.tasks.subscription.check_trial_expirations",
|
||||||
|
)
|
||||||
|
def check_trial_expirations(self):
|
||||||
|
"""
|
||||||
|
Check for expired trials and update their status.
|
||||||
|
|
||||||
|
Runs daily at 01:00.
|
||||||
|
- Trials without payment method -> expired
|
||||||
|
- Trials with payment method -> active
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expired_count = 0
|
||||||
|
activated_count = 0
|
||||||
|
|
||||||
|
with self.get_db() as db:
|
||||||
|
# Find expired trials
|
||||||
|
expired_trials = (
|
||||||
|
db.query(VendorSubscription)
|
||||||
|
.filter(
|
||||||
|
VendorSubscription.status == SubscriptionStatus.TRIAL.value,
|
||||||
|
VendorSubscription.trial_ends_at <= now,
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for subscription in expired_trials:
|
||||||
|
if subscription.stripe_payment_method_id:
|
||||||
|
# Has payment method - activate
|
||||||
|
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||||
|
activated_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Activated subscription for vendor {subscription.vendor_id} "
|
||||||
|
f"(trial ended with payment method)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No payment method - expire
|
||||||
|
subscription.status = SubscriptionStatus.EXPIRED.value
|
||||||
|
expired_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Expired trial for vendor {subscription.vendor_id} "
|
||||||
|
f"(no payment method)"
|
||||||
|
)
|
||||||
|
|
||||||
|
subscription.updated_at = now
|
||||||
|
|
||||||
|
logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated")
|
||||||
|
|
||||||
|
return {"expired_count": expired_count, "activated_count": activated_count}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
bind=True,
|
||||||
|
base=ModuleTask,
|
||||||
|
name="app.modules.billing.tasks.subscription.sync_stripe_status",
|
||||||
|
max_retries=3,
|
||||||
|
default_retry_delay=300,
|
||||||
|
)
|
||||||
|
def sync_stripe_status(self):
|
||||||
|
"""
|
||||||
|
Sync subscription status with Stripe.
|
||||||
|
|
||||||
|
Runs hourly at :30. Fetches current status from Stripe and updates local records.
|
||||||
|
"""
|
||||||
|
if not stripe_service.is_configured:
|
||||||
|
logger.warning("Stripe not configured, skipping sync")
|
||||||
|
return {"synced": 0, "skipped": True}
|
||||||
|
|
||||||
|
synced_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
with self.get_db() as db:
|
||||||
|
# Find subscriptions with Stripe IDs
|
||||||
|
subscriptions = (
|
||||||
|
db.query(VendorSubscription)
|
||||||
|
.filter(VendorSubscription.stripe_subscription_id.isnot(None))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for subscription in subscriptions:
|
||||||
|
try:
|
||||||
|
# Fetch from Stripe
|
||||||
|
stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id)
|
||||||
|
|
||||||
|
if not stripe_sub:
|
||||||
|
logger.warning(
|
||||||
|
f"Stripe subscription {subscription.stripe_subscription_id} "
|
||||||
|
f"not found for vendor {subscription.vendor_id}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Map Stripe status to local status
|
||||||
|
status_map = {
|
||||||
|
"active": SubscriptionStatus.ACTIVE.value,
|
||||||
|
"trialing": SubscriptionStatus.TRIAL.value,
|
||||||
|
"past_due": SubscriptionStatus.PAST_DUE.value,
|
||||||
|
"canceled": SubscriptionStatus.CANCELLED.value,
|
||||||
|
"unpaid": SubscriptionStatus.PAST_DUE.value,
|
||||||
|
"incomplete": SubscriptionStatus.TRIAL.value,
|
||||||
|
"incomplete_expired": SubscriptionStatus.EXPIRED.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
new_status = status_map.get(stripe_sub.status)
|
||||||
|
if new_status and new_status != subscription.status:
|
||||||
|
old_status = subscription.status
|
||||||
|
subscription.status = new_status
|
||||||
|
subscription.updated_at = datetime.now(UTC)
|
||||||
|
logger.info(
|
||||||
|
f"Updated vendor {subscription.vendor_id} status: "
|
||||||
|
f"{old_status} -> {new_status} (from Stripe)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update period dates from Stripe
|
||||||
|
if stripe_sub.current_period_start:
|
||||||
|
subscription.period_start = datetime.fromtimestamp(
|
||||||
|
stripe_sub.current_period_start, tz=UTC
|
||||||
|
)
|
||||||
|
if stripe_sub.current_period_end:
|
||||||
|
subscription.period_end = datetime.fromtimestamp(
|
||||||
|
stripe_sub.current_period_end, tz=UTC
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update payment method
|
||||||
|
if stripe_sub.default_payment_method:
|
||||||
|
subscription.stripe_payment_method_id = (
|
||||||
|
stripe_sub.default_payment_method
|
||||||
|
if isinstance(stripe_sub.default_payment_method, str)
|
||||||
|
else stripe_sub.default_payment_method.id
|
||||||
|
)
|
||||||
|
|
||||||
|
synced_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}")
|
||||||
|
error_count += 1
|
||||||
|
|
||||||
|
logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors")
|
||||||
|
|
||||||
|
return {"synced_count": synced_count, "error_count": error_count}
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
bind=True,
|
||||||
|
base=ModuleTask,
|
||||||
|
name="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions",
|
||||||
|
)
|
||||||
|
def cleanup_stale_subscriptions(self):
|
||||||
|
"""
|
||||||
|
Clean up subscriptions in inconsistent states.
|
||||||
|
|
||||||
|
Runs weekly on Sunday at 03:00.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
cleaned_count = 0
|
||||||
|
|
||||||
|
with self.get_db() as db:
|
||||||
|
# Find cancelled subscriptions past their period end
|
||||||
|
stale_cancelled = (
|
||||||
|
db.query(VendorSubscription)
|
||||||
|
.filter(
|
||||||
|
VendorSubscription.status == SubscriptionStatus.CANCELLED.value,
|
||||||
|
VendorSubscription.period_end < now - timedelta(days=30),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
for subscription in stale_cancelled:
|
||||||
|
# Mark as expired (fully terminated)
|
||||||
|
subscription.status = SubscriptionStatus.EXPIRED.value
|
||||||
|
subscription.updated_at = now
|
||||||
|
cleaned_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Cleaned up {cleaned_count} stale subscriptions")
|
||||||
|
|
||||||
|
return {"cleaned_count": cleaned_count}
|
||||||
@@ -2,351 +2,21 @@
|
|||||||
"""
|
"""
|
||||||
Admin Subscription Service.
|
Admin Subscription Service.
|
||||||
|
|
||||||
Handles subscription management operations for platform administrators:
|
DEPRECATED: This file is maintained for backward compatibility.
|
||||||
- Subscription tier CRUD
|
Import from app.modules.billing.services instead:
|
||||||
- Vendor subscription management
|
|
||||||
- Billing history queries
|
from app.modules.billing.services import admin_subscription_service
|
||||||
- Subscription analytics
|
|
||||||
|
This file re-exports the service from its new location in the billing module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
# Re-export from new location for backward compatibility
|
||||||
from math import ceil
|
from app.modules.billing.services.admin_subscription_service import (
|
||||||
|
AdminSubscriptionService,
|
||||||
from sqlalchemy import func
|
admin_subscription_service,
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.exceptions import (
|
|
||||||
BusinessLogicException,
|
|
||||||
ConflictException,
|
|
||||||
ResourceNotFoundException,
|
|
||||||
TierNotFoundException,
|
|
||||||
)
|
)
|
||||||
from models.database.product import Product
|
|
||||||
from models.database.subscription import (
|
|
||||||
BillingHistory,
|
|
||||||
SubscriptionStatus,
|
|
||||||
SubscriptionTier,
|
|
||||||
VendorSubscription,
|
|
||||||
)
|
|
||||||
from models.database.vendor import Vendor, VendorUser
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
__all__ = [
|
||||||
|
"AdminSubscriptionService",
|
||||||
|
"admin_subscription_service",
|
||||||
class AdminSubscriptionService:
|
]
|
||||||
"""Service for admin subscription management operations."""
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Subscription Tiers
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_tiers(
|
|
||||||
self, db: Session, include_inactive: bool = False
|
|
||||||
) -> list[SubscriptionTier]:
|
|
||||||
"""Get all subscription tiers."""
|
|
||||||
query = db.query(SubscriptionTier)
|
|
||||||
|
|
||||||
if not include_inactive:
|
|
||||||
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
|
|
||||||
|
|
||||||
return query.order_by(SubscriptionTier.display_order).all()
|
|
||||||
|
|
||||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
|
||||||
"""Get a subscription tier by code."""
|
|
||||||
tier = (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(SubscriptionTier.code == tier_code)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tier:
|
|
||||||
raise TierNotFoundException(tier_code)
|
|
||||||
|
|
||||||
return tier
|
|
||||||
|
|
||||||
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
|
|
||||||
"""Create a new subscription tier."""
|
|
||||||
# Check for duplicate code
|
|
||||||
existing = (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(SubscriptionTier.code == tier_data["code"])
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
raise ConflictException(
|
|
||||||
f"Tier with code '{tier_data['code']}' already exists"
|
|
||||||
)
|
|
||||||
|
|
||||||
tier = SubscriptionTier(**tier_data)
|
|
||||||
db.add(tier)
|
|
||||||
|
|
||||||
logger.info(f"Created subscription tier: {tier.code}")
|
|
||||||
return tier
|
|
||||||
|
|
||||||
def update_tier(
|
|
||||||
self, db: Session, tier_code: str, update_data: dict
|
|
||||||
) -> SubscriptionTier:
|
|
||||||
"""Update a subscription tier."""
|
|
||||||
tier = self.get_tier_by_code(db, tier_code)
|
|
||||||
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(tier, field, value)
|
|
||||||
|
|
||||||
logger.info(f"Updated subscription tier: {tier.code}")
|
|
||||||
return tier
|
|
||||||
|
|
||||||
def deactivate_tier(self, db: Session, tier_code: str) -> None:
|
|
||||||
"""Soft-delete a subscription tier."""
|
|
||||||
tier = self.get_tier_by_code(db, tier_code)
|
|
||||||
|
|
||||||
# Check if any active subscriptions use this tier
|
|
||||||
active_subs = (
|
|
||||||
db.query(VendorSubscription)
|
|
||||||
.filter(
|
|
||||||
VendorSubscription.tier == tier_code,
|
|
||||||
VendorSubscription.status.in_([
|
|
||||||
SubscriptionStatus.ACTIVE.value,
|
|
||||||
SubscriptionStatus.TRIAL.value,
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
if active_subs > 0:
|
|
||||||
raise BusinessLogicException(
|
|
||||||
f"Cannot delete tier: {active_subs} active subscriptions are using it"
|
|
||||||
)
|
|
||||||
|
|
||||||
tier.is_active = False
|
|
||||||
|
|
||||||
logger.info(f"Soft-deleted subscription tier: {tier.code}")
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Vendor Subscriptions
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def list_subscriptions(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
page: int = 1,
|
|
||||||
per_page: int = 20,
|
|
||||||
status: str | None = None,
|
|
||||||
tier: str | None = None,
|
|
||||||
search: str | None = None,
|
|
||||||
) -> dict:
|
|
||||||
"""List vendor subscriptions with filtering and pagination."""
|
|
||||||
query = (
|
|
||||||
db.query(VendorSubscription, Vendor)
|
|
||||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if status:
|
|
||||||
query = query.filter(VendorSubscription.status == status)
|
|
||||||
if tier:
|
|
||||||
query = query.filter(VendorSubscription.tier == tier)
|
|
||||||
if search:
|
|
||||||
query = query.filter(Vendor.name.ilike(f"%{search}%"))
|
|
||||||
|
|
||||||
# Count total
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
# Paginate
|
|
||||||
offset = (page - 1) * per_page
|
|
||||||
results = (
|
|
||||||
query.order_by(VendorSubscription.created_at.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(per_page)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"results": results,
|
|
||||||
"total": total,
|
|
||||||
"page": page,
|
|
||||||
"per_page": per_page,
|
|
||||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
|
|
||||||
"""Get subscription for a specific vendor."""
|
|
||||||
result = (
|
|
||||||
db.query(VendorSubscription, Vendor)
|
|
||||||
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
|
|
||||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
raise ResourceNotFoundException("Subscription", str(vendor_id))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def update_subscription(
|
|
||||||
self, db: Session, vendor_id: int, update_data: dict
|
|
||||||
) -> tuple:
|
|
||||||
"""Update a vendor's subscription."""
|
|
||||||
result = self.get_subscription(db, vendor_id)
|
|
||||||
sub, vendor = result
|
|
||||||
|
|
||||||
for field, value in update_data.items():
|
|
||||||
setattr(sub, field, value)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return sub, vendor
|
|
||||||
|
|
||||||
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
|
||||||
"""Get a vendor by ID."""
|
|
||||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
|
||||||
|
|
||||||
if not vendor:
|
|
||||||
raise ResourceNotFoundException("Vendor", str(vendor_id))
|
|
||||||
|
|
||||||
return vendor
|
|
||||||
|
|
||||||
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
|
|
||||||
"""Get usage counts (products and team members) for a vendor."""
|
|
||||||
products_count = (
|
|
||||||
db.query(func.count(Product.id))
|
|
||||||
.filter(Product.vendor_id == vendor_id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
team_count = (
|
|
||||||
db.query(func.count(VendorUser.id))
|
|
||||||
.filter(
|
|
||||||
VendorUser.vendor_id == vendor_id,
|
|
||||||
VendorUser.is_active == True, # noqa: E712
|
|
||||||
)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"products_count": products_count,
|
|
||||||
"team_count": team_count,
|
|
||||||
}
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Billing History
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def list_billing_history(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
page: int = 1,
|
|
||||||
per_page: int = 20,
|
|
||||||
vendor_id: int | None = None,
|
|
||||||
status: str | None = None,
|
|
||||||
) -> dict:
|
|
||||||
"""List billing history across all vendors."""
|
|
||||||
query = (
|
|
||||||
db.query(BillingHistory, Vendor)
|
|
||||||
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
if vendor_id:
|
|
||||||
query = query.filter(BillingHistory.vendor_id == vendor_id)
|
|
||||||
if status:
|
|
||||||
query = query.filter(BillingHistory.status == status)
|
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
offset = (page - 1) * per_page
|
|
||||||
results = (
|
|
||||||
query.order_by(BillingHistory.invoice_date.desc())
|
|
||||||
.offset(offset)
|
|
||||||
.limit(per_page)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"results": results,
|
|
||||||
"total": total,
|
|
||||||
"page": page,
|
|
||||||
"per_page": per_page,
|
|
||||||
"pages": ceil(total / per_page) if total > 0 else 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Statistics
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_stats(self, db: Session) -> dict:
|
|
||||||
"""Get subscription statistics for admin dashboard."""
|
|
||||||
# Count by status
|
|
||||||
status_counts = (
|
|
||||||
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
|
|
||||||
.group_by(VendorSubscription.status)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
stats = {
|
|
||||||
"total_subscriptions": 0,
|
|
||||||
"active_count": 0,
|
|
||||||
"trial_count": 0,
|
|
||||||
"past_due_count": 0,
|
|
||||||
"cancelled_count": 0,
|
|
||||||
"expired_count": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
for status, count in status_counts:
|
|
||||||
stats["total_subscriptions"] += count
|
|
||||||
if status == SubscriptionStatus.ACTIVE.value:
|
|
||||||
stats["active_count"] = count
|
|
||||||
elif status == SubscriptionStatus.TRIAL.value:
|
|
||||||
stats["trial_count"] = count
|
|
||||||
elif status == SubscriptionStatus.PAST_DUE.value:
|
|
||||||
stats["past_due_count"] = count
|
|
||||||
elif status == SubscriptionStatus.CANCELLED.value:
|
|
||||||
stats["cancelled_count"] = count
|
|
||||||
elif status == SubscriptionStatus.EXPIRED.value:
|
|
||||||
stats["expired_count"] = count
|
|
||||||
|
|
||||||
# Count by tier
|
|
||||||
tier_counts = (
|
|
||||||
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
|
|
||||||
.filter(
|
|
||||||
VendorSubscription.status.in_([
|
|
||||||
SubscriptionStatus.ACTIVE.value,
|
|
||||||
SubscriptionStatus.TRIAL.value,
|
|
||||||
])
|
|
||||||
)
|
|
||||||
.group_by(VendorSubscription.tier)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
tier_distribution = {tier: count for tier, count in tier_counts}
|
|
||||||
|
|
||||||
# Calculate MRR (Monthly Recurring Revenue)
|
|
||||||
mrr_cents = 0
|
|
||||||
arr_cents = 0
|
|
||||||
|
|
||||||
active_subs = (
|
|
||||||
db.query(VendorSubscription, SubscriptionTier)
|
|
||||||
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
|
|
||||||
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
for sub, tier in active_subs:
|
|
||||||
if sub.is_annual and tier.price_annual_cents:
|
|
||||||
mrr_cents += tier.price_annual_cents // 12
|
|
||||||
arr_cents += tier.price_annual_cents
|
|
||||||
else:
|
|
||||||
mrr_cents += tier.price_monthly_cents
|
|
||||||
arr_cents += tier.price_monthly_cents * 12
|
|
||||||
|
|
||||||
stats["tier_distribution"] = tier_distribution
|
|
||||||
stats["mrr_cents"] = mrr_cents
|
|
||||||
stats["arr_cents"] = arr_cents
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
# Singleton instance
|
|
||||||
admin_subscription_service = AdminSubscriptionService()
|
|
||||||
|
|||||||
@@ -2,592 +2,21 @@
|
|||||||
"""
|
"""
|
||||||
Stripe payment integration service.
|
Stripe payment integration service.
|
||||||
|
|
||||||
Provides:
|
DEPRECATED: This file is maintained for backward compatibility.
|
||||||
- Customer management
|
Import from app.modules.billing.services instead:
|
||||||
- Subscription management
|
|
||||||
- Checkout session creation
|
from app.modules.billing.services import stripe_service
|
||||||
- Customer portal access
|
|
||||||
- Webhook event construction
|
This file re-exports the service from its new location in the billing module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
# Re-export from new location for backward compatibility
|
||||||
from datetime import datetime
|
from app.modules.billing.services.stripe_service import (
|
||||||
|
StripeService,
|
||||||
import stripe
|
stripe_service,
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from models.database.subscription import (
|
|
||||||
BillingHistory,
|
|
||||||
SubscriptionStatus,
|
|
||||||
SubscriptionTier,
|
|
||||||
VendorSubscription,
|
|
||||||
)
|
)
|
||||||
from models.database.vendor import Vendor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
__all__ = [
|
||||||
|
"StripeService",
|
||||||
|
"stripe_service",
|
||||||
class StripeService:
|
]
|
||||||
"""Service for Stripe payment operations."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._configured = False
|
|
||||||
self._configure()
|
|
||||||
|
|
||||||
def _configure(self):
|
|
||||||
"""Configure Stripe with API key."""
|
|
||||||
if settings.stripe_secret_key:
|
|
||||||
stripe.api_key = settings.stripe_secret_key
|
|
||||||
self._configured = True
|
|
||||||
else:
|
|
||||||
logger.warning("Stripe API key not configured")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_configured(self) -> bool:
|
|
||||||
"""Check if Stripe is properly configured."""
|
|
||||||
return self._configured and bool(settings.stripe_secret_key)
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Customer Management
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def create_customer(
|
|
||||||
self,
|
|
||||||
vendor: Vendor,
|
|
||||||
email: str,
|
|
||||||
name: str | None = None,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Create a Stripe customer for a vendor.
|
|
||||||
|
|
||||||
Returns the Stripe customer ID.
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
customer_metadata = {
|
|
||||||
"vendor_id": str(vendor.id),
|
|
||||||
"vendor_code": vendor.vendor_code,
|
|
||||||
**(metadata or {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
customer = stripe.Customer.create(
|
|
||||||
email=email,
|
|
||||||
name=name or vendor.name,
|
|
||||||
metadata=customer_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
|
|
||||||
)
|
|
||||||
return customer.id
|
|
||||||
|
|
||||||
def get_customer(self, customer_id: str) -> stripe.Customer:
|
|
||||||
"""Get a Stripe customer by ID."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
return stripe.Customer.retrieve(customer_id)
|
|
||||||
|
|
||||||
def update_customer(
|
|
||||||
self,
|
|
||||||
customer_id: str,
|
|
||||||
email: str | None = None,
|
|
||||||
name: str | None = None,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
) -> stripe.Customer:
|
|
||||||
"""Update a Stripe customer."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
update_data = {}
|
|
||||||
if email:
|
|
||||||
update_data["email"] = email
|
|
||||||
if name:
|
|
||||||
update_data["name"] = name
|
|
||||||
if metadata:
|
|
||||||
update_data["metadata"] = metadata
|
|
||||||
|
|
||||||
return stripe.Customer.modify(customer_id, **update_data)
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Subscription Management
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def create_subscription(
|
|
||||||
self,
|
|
||||||
customer_id: str,
|
|
||||||
price_id: str,
|
|
||||||
trial_days: int | None = None,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
) -> stripe.Subscription:
|
|
||||||
"""
|
|
||||||
Create a new Stripe subscription.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
customer_id: Stripe customer ID
|
|
||||||
price_id: Stripe price ID for the subscription
|
|
||||||
trial_days: Optional trial period in days
|
|
||||||
metadata: Optional metadata to attach
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Stripe Subscription object
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
subscription_data = {
|
|
||||||
"customer": customer_id,
|
|
||||||
"items": [{"price": price_id}],
|
|
||||||
"metadata": metadata or {},
|
|
||||||
"payment_behavior": "default_incomplete",
|
|
||||||
"expand": ["latest_invoice.payment_intent"],
|
|
||||||
}
|
|
||||||
|
|
||||||
if trial_days:
|
|
||||||
subscription_data["trial_period_days"] = trial_days
|
|
||||||
|
|
||||||
subscription = stripe.Subscription.create(**subscription_data)
|
|
||||||
logger.info(
|
|
||||||
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
|
|
||||||
)
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
|
|
||||||
"""Get a Stripe subscription by ID."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
return stripe.Subscription.retrieve(subscription_id)
|
|
||||||
|
|
||||||
def update_subscription(
|
|
||||||
self,
|
|
||||||
subscription_id: str,
|
|
||||||
new_price_id: str | None = None,
|
|
||||||
proration_behavior: str = "create_prorations",
|
|
||||||
metadata: dict | None = None,
|
|
||||||
) -> stripe.Subscription:
|
|
||||||
"""
|
|
||||||
Update a Stripe subscription (e.g., change tier).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subscription_id: Stripe subscription ID
|
|
||||||
new_price_id: New price ID for tier change
|
|
||||||
proration_behavior: How to handle prorations
|
|
||||||
metadata: Optional metadata to update
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated Stripe Subscription object
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
update_data = {"proration_behavior": proration_behavior}
|
|
||||||
|
|
||||||
if new_price_id:
|
|
||||||
# Get the subscription to find the item ID
|
|
||||||
subscription = stripe.Subscription.retrieve(subscription_id)
|
|
||||||
item_id = subscription["items"]["data"][0]["id"]
|
|
||||||
update_data["items"] = [{"id": item_id, "price": new_price_id}]
|
|
||||||
|
|
||||||
if metadata:
|
|
||||||
update_data["metadata"] = metadata
|
|
||||||
|
|
||||||
updated = stripe.Subscription.modify(subscription_id, **update_data)
|
|
||||||
logger.info(f"Updated Stripe subscription {subscription_id}")
|
|
||||||
return updated
|
|
||||||
|
|
||||||
def cancel_subscription(
|
|
||||||
self,
|
|
||||||
subscription_id: str,
|
|
||||||
immediately: bool = False,
|
|
||||||
cancellation_reason: str | None = None,
|
|
||||||
) -> stripe.Subscription:
|
|
||||||
"""
|
|
||||||
Cancel a Stripe subscription.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subscription_id: Stripe subscription ID
|
|
||||||
immediately: If True, cancel now. If False, cancel at period end.
|
|
||||||
cancellation_reason: Optional reason for cancellation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cancelled Stripe Subscription object
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
if immediately:
|
|
||||||
subscription = stripe.Subscription.cancel(subscription_id)
|
|
||||||
else:
|
|
||||||
subscription = stripe.Subscription.modify(
|
|
||||||
subscription_id,
|
|
||||||
cancel_at_period_end=True,
|
|
||||||
metadata={"cancellation_reason": cancellation_reason or "user_request"},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Cancelled Stripe subscription {subscription_id} "
|
|
||||||
f"(immediately={immediately})"
|
|
||||||
)
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
|
|
||||||
"""
|
|
||||||
Reactivate a cancelled subscription (if not yet ended).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reactivated Stripe Subscription object
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
subscription = stripe.Subscription.modify(
|
|
||||||
subscription_id,
|
|
||||||
cancel_at_period_end=False,
|
|
||||||
)
|
|
||||||
logger.info(f"Reactivated Stripe subscription {subscription_id}")
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def cancel_subscription_item(self, subscription_item_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Cancel a subscription item (used for add-ons).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subscription_item_id: Stripe subscription item ID
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
stripe.SubscriptionItem.delete(subscription_item_id)
|
|
||||||
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Checkout & Portal
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def create_checkout_session(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
vendor: Vendor,
|
|
||||||
price_id: str,
|
|
||||||
success_url: str,
|
|
||||||
cancel_url: str,
|
|
||||||
trial_days: int | None = None,
|
|
||||||
quantity: int = 1,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
) -> stripe.checkout.Session:
|
|
||||||
"""
|
|
||||||
Create a Stripe Checkout session for subscription signup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
vendor: Vendor to create checkout for
|
|
||||||
price_id: Stripe price ID
|
|
||||||
success_url: URL to redirect on success
|
|
||||||
cancel_url: URL to redirect on cancel
|
|
||||||
trial_days: Optional trial period
|
|
||||||
quantity: Number of items (default 1)
|
|
||||||
metadata: Additional metadata to store
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Stripe Checkout Session object
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
# Get or create Stripe customer
|
|
||||||
subscription = (
|
|
||||||
db.query(VendorSubscription)
|
|
||||||
.filter(VendorSubscription.vendor_id == vendor.id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if subscription and subscription.stripe_customer_id:
|
|
||||||
customer_id = subscription.stripe_customer_id
|
|
||||||
else:
|
|
||||||
# Get vendor owner email
|
|
||||||
from models.database.vendor import VendorUser
|
|
||||||
|
|
||||||
owner = (
|
|
||||||
db.query(VendorUser)
|
|
||||||
.filter(
|
|
||||||
VendorUser.vendor_id == vendor.id,
|
|
||||||
VendorUser.is_owner == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
email = owner.user.email if owner and owner.user else None
|
|
||||||
|
|
||||||
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
|
|
||||||
|
|
||||||
# Store the customer ID
|
|
||||||
if subscription:
|
|
||||||
subscription.stripe_customer_id = customer_id
|
|
||||||
db.flush()
|
|
||||||
|
|
||||||
# Build metadata
|
|
||||||
session_metadata = {
|
|
||||||
"vendor_id": str(vendor.id),
|
|
||||||
"vendor_code": vendor.vendor_code,
|
|
||||||
}
|
|
||||||
if metadata:
|
|
||||||
session_metadata.update(metadata)
|
|
||||||
|
|
||||||
session_data = {
|
|
||||||
"customer": customer_id,
|
|
||||||
"line_items": [{"price": price_id, "quantity": quantity}],
|
|
||||||
"mode": "subscription",
|
|
||||||
"success_url": success_url,
|
|
||||||
"cancel_url": cancel_url,
|
|
||||||
"metadata": session_metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
if trial_days:
|
|
||||||
session_data["subscription_data"] = {"trial_period_days": trial_days}
|
|
||||||
|
|
||||||
session = stripe.checkout.Session.create(**session_data)
|
|
||||||
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
|
|
||||||
return session
|
|
||||||
|
|
||||||
def create_portal_session(
|
|
||||||
self,
|
|
||||||
customer_id: str,
|
|
||||||
return_url: str,
|
|
||||||
) -> stripe.billing_portal.Session:
|
|
||||||
"""
|
|
||||||
Create a Stripe Customer Portal session.
|
|
||||||
|
|
||||||
Allows customers to manage their subscription, payment methods, and invoices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
customer_id: Stripe customer ID
|
|
||||||
return_url: URL to return to after portal
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Stripe Portal Session object
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
session = stripe.billing_portal.Session.create(
|
|
||||||
customer=customer_id,
|
|
||||||
return_url=return_url,
|
|
||||||
)
|
|
||||||
logger.info(f"Created portal session for customer {customer_id}")
|
|
||||||
return session
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Invoice Management
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_invoices(
|
|
||||||
self,
|
|
||||||
customer_id: str,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> list[stripe.Invoice]:
|
|
||||||
"""Get invoices for a customer."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
|
|
||||||
return list(invoices.data)
|
|
||||||
|
|
||||||
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
|
|
||||||
"""Get the upcoming invoice for a customer."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
try:
|
|
||||||
return stripe.Invoice.upcoming(customer=customer_id)
|
|
||||||
except stripe.error.InvalidRequestError:
|
|
||||||
# No upcoming invoice
|
|
||||||
return None
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Webhook Handling
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def construct_event(
|
|
||||||
self,
|
|
||||||
payload: bytes,
|
|
||||||
sig_header: str,
|
|
||||||
) -> stripe.Event:
|
|
||||||
"""
|
|
||||||
Construct and verify a Stripe webhook event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
payload: Raw request body
|
|
||||||
sig_header: Stripe-Signature header
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Verified Stripe Event object
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If signature verification fails
|
|
||||||
"""
|
|
||||||
if not settings.stripe_webhook_secret:
|
|
||||||
raise ValueError("Stripe webhook secret not configured")
|
|
||||||
|
|
||||||
try:
|
|
||||||
event = stripe.Webhook.construct_event(
|
|
||||||
payload,
|
|
||||||
sig_header,
|
|
||||||
settings.stripe_webhook_secret,
|
|
||||||
)
|
|
||||||
return event
|
|
||||||
except stripe.error.SignatureVerificationError as e:
|
|
||||||
logger.error(f"Webhook signature verification failed: {e}")
|
|
||||||
raise ValueError("Invalid webhook signature")
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# SetupIntent & Payment Method Management
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def create_setup_intent(
|
|
||||||
self,
|
|
||||||
customer_id: str,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
) -> stripe.SetupIntent:
|
|
||||||
"""
|
|
||||||
Create a SetupIntent to collect card without charging.
|
|
||||||
|
|
||||||
Used for trial signups where we collect card upfront
|
|
||||||
but don't charge until trial ends.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
customer_id: Stripe customer ID
|
|
||||||
metadata: Optional metadata to attach
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Stripe SetupIntent object with client_secret for frontend
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
setup_intent = stripe.SetupIntent.create(
|
|
||||||
customer=customer_id,
|
|
||||||
payment_method_types=["card"],
|
|
||||||
metadata=metadata or {},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
|
|
||||||
return setup_intent
|
|
||||||
|
|
||||||
def attach_payment_method_to_customer(
|
|
||||||
self,
|
|
||||||
customer_id: str,
|
|
||||||
payment_method_id: str,
|
|
||||||
set_as_default: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Attach a payment method to customer and optionally set as default.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
customer_id: Stripe customer ID
|
|
||||||
payment_method_id: Payment method ID from confirmed SetupIntent
|
|
||||||
set_as_default: Whether to set as default payment method
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
# Attach the payment method to the customer
|
|
||||||
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
|
|
||||||
|
|
||||||
if set_as_default:
|
|
||||||
stripe.Customer.modify(
|
|
||||||
customer_id,
|
|
||||||
invoice_settings={"default_payment_method": payment_method_id},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Attached payment method {payment_method_id} to customer {customer_id} "
|
|
||||||
f"(default={set_as_default})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_subscription_with_trial(
|
|
||||||
self,
|
|
||||||
customer_id: str,
|
|
||||||
price_id: str,
|
|
||||||
trial_days: int = 30,
|
|
||||||
metadata: dict | None = None,
|
|
||||||
) -> stripe.Subscription:
|
|
||||||
"""
|
|
||||||
Create subscription with trial period.
|
|
||||||
|
|
||||||
Customer must have a default payment method attached.
|
|
||||||
Card will be charged automatically after trial ends.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
customer_id: Stripe customer ID (must have default payment method)
|
|
||||||
price_id: Stripe price ID for the subscription tier
|
|
||||||
trial_days: Number of trial days (default 30)
|
|
||||||
metadata: Optional metadata to attach
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Stripe Subscription object
|
|
||||||
"""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
subscription = stripe.Subscription.create(
|
|
||||||
customer=customer_id,
|
|
||||||
items=[{"price": price_id}],
|
|
||||||
trial_period_days=trial_days,
|
|
||||||
metadata=metadata or {},
|
|
||||||
# Use default payment method for future charges
|
|
||||||
default_payment_method=None, # Uses customer's default
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created subscription {subscription.id} with {trial_days}-day trial "
|
|
||||||
f"for customer {customer_id}"
|
|
||||||
)
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
|
|
||||||
"""Get a SetupIntent by ID."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
return stripe.SetupIntent.retrieve(setup_intent_id)
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Price/Product Management
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_price(self, price_id: str) -> stripe.Price:
|
|
||||||
"""Get a Stripe price by ID."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
return stripe.Price.retrieve(price_id)
|
|
||||||
|
|
||||||
def get_product(self, product_id: str) -> stripe.Product:
|
|
||||||
"""Get a Stripe product by ID."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
return stripe.Product.retrieve(product_id)
|
|
||||||
|
|
||||||
def list_prices(
|
|
||||||
self,
|
|
||||||
product_id: str | None = None,
|
|
||||||
active: bool = True,
|
|
||||||
) -> list[stripe.Price]:
|
|
||||||
"""List Stripe prices, optionally filtered by product."""
|
|
||||||
if not self.is_configured:
|
|
||||||
raise ValueError("Stripe is not configured")
|
|
||||||
|
|
||||||
params = {"active": active}
|
|
||||||
if product_id:
|
|
||||||
params["product"] = product_id
|
|
||||||
|
|
||||||
prices = stripe.Price.list(**params)
|
|
||||||
return list(prices.data)
|
|
||||||
|
|
||||||
|
|
||||||
# Create service instance
|
|
||||||
stripe_service = StripeService()
|
|
||||||
|
|||||||
@@ -2,654 +2,29 @@
|
|||||||
"""
|
"""
|
||||||
Subscription service for tier-based access control.
|
Subscription service for tier-based access control.
|
||||||
|
|
||||||
Handles:
|
DEPRECATED: This file is maintained for backward compatibility.
|
||||||
- Subscription creation and management
|
Import from app.modules.billing.services instead:
|
||||||
- Tier limit enforcement
|
|
||||||
- Usage tracking
|
|
||||||
- Feature gating
|
|
||||||
|
|
||||||
Usage:
|
from app.modules.billing.services import subscription_service
|
||||||
from app.services.subscription_service import subscription_service
|
|
||||||
|
|
||||||
# Check if vendor can create an order
|
This file re-exports the service from its new location in the billing module.
|
||||||
can_create, message = subscription_service.can_create_order(db, vendor_id)
|
|
||||||
|
|
||||||
# Increment order counter after successful order
|
|
||||||
subscription_service.increment_order_count(db, vendor_id)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
# Re-export from new location for backward compatibility
|
||||||
from datetime import UTC, datetime, timedelta
|
from app.modules.billing.services.subscription_service import (
|
||||||
from typing import Any
|
SubscriptionService,
|
||||||
|
subscription_service,
|
||||||
from sqlalchemy import func
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from models.database.product import Product
|
|
||||||
from models.database.subscription import (
|
|
||||||
SubscriptionStatus,
|
|
||||||
SubscriptionTier,
|
|
||||||
TIER_LIMITS,
|
|
||||||
TierCode,
|
|
||||||
VendorSubscription,
|
|
||||||
)
|
)
|
||||||
from models.database.vendor import Vendor, VendorUser
|
from app.modules.billing.exceptions import (
|
||||||
from models.schema.subscription import (
|
SubscriptionNotFoundException,
|
||||||
SubscriptionCreate,
|
TierLimitExceededException,
|
||||||
SubscriptionUpdate,
|
FeatureNotAvailableException,
|
||||||
SubscriptionUsage,
|
|
||||||
TierInfo,
|
|
||||||
TierLimits,
|
|
||||||
UsageSummary,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
__all__ = [
|
||||||
|
"SubscriptionService",
|
||||||
|
"subscription_service",
|
||||||
class SubscriptionNotFoundException(Exception):
|
"SubscriptionNotFoundException",
|
||||||
"""Raised when subscription not found."""
|
"TierLimitExceededException",
|
||||||
|
"FeatureNotAvailableException",
|
||||||
pass
|
]
|
||||||
|
|
||||||
|
|
||||||
class TierLimitExceededException(Exception):
|
|
||||||
"""Raised when a tier limit is exceeded."""
|
|
||||||
|
|
||||||
def __init__(self, message: str, limit_type: str, current: int, limit: int):
|
|
||||||
super().__init__(message)
|
|
||||||
self.limit_type = limit_type
|
|
||||||
self.current = current
|
|
||||||
self.limit = limit
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureNotAvailableException(Exception):
|
|
||||||
"""Raised when a feature is not available in current tier."""
|
|
||||||
|
|
||||||
def __init__(self, feature: str, current_tier: str, required_tier: str):
|
|
||||||
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
|
|
||||||
super().__init__(message)
|
|
||||||
self.feature = feature
|
|
||||||
self.current_tier = current_tier
|
|
||||||
self.required_tier = required_tier
|
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionService:
|
|
||||||
"""Service for subscription and tier limit operations."""
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Tier Information
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
|
|
||||||
"""
|
|
||||||
Get full tier information.
|
|
||||||
|
|
||||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
|
||||||
"""
|
|
||||||
# Try database first if session provided
|
|
||||||
if db is not None:
|
|
||||||
db_tier = self.get_tier_by_code(db, tier_code)
|
|
||||||
if db_tier:
|
|
||||||
return TierInfo(
|
|
||||||
code=db_tier.code,
|
|
||||||
name=db_tier.name,
|
|
||||||
price_monthly_cents=db_tier.price_monthly_cents,
|
|
||||||
price_annual_cents=db_tier.price_annual_cents,
|
|
||||||
limits=TierLimits(
|
|
||||||
orders_per_month=db_tier.orders_per_month,
|
|
||||||
products_limit=db_tier.products_limit,
|
|
||||||
team_members=db_tier.team_members,
|
|
||||||
order_history_months=db_tier.order_history_months,
|
|
||||||
),
|
|
||||||
features=db_tier.features or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback to hardcoded TIER_LIMITS
|
|
||||||
return self._get_tier_from_legacy(tier_code)
|
|
||||||
|
|
||||||
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
|
|
||||||
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
|
|
||||||
try:
|
|
||||||
tier = TierCode(tier_code)
|
|
||||||
except ValueError:
|
|
||||||
tier = TierCode.ESSENTIAL
|
|
||||||
|
|
||||||
limits = TIER_LIMITS[tier]
|
|
||||||
return TierInfo(
|
|
||||||
code=tier.value,
|
|
||||||
name=limits["name"],
|
|
||||||
price_monthly_cents=limits["price_monthly_cents"],
|
|
||||||
price_annual_cents=limits.get("price_annual_cents"),
|
|
||||||
limits=TierLimits(
|
|
||||||
orders_per_month=limits.get("orders_per_month"),
|
|
||||||
products_limit=limits.get("products_limit"),
|
|
||||||
team_members=limits.get("team_members"),
|
|
||||||
order_history_months=limits.get("order_history_months"),
|
|
||||||
),
|
|
||||||
features=limits.get("features", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
|
|
||||||
"""
|
|
||||||
Get information for all tiers.
|
|
||||||
|
|
||||||
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
|
|
||||||
"""
|
|
||||||
if db is not None:
|
|
||||||
db_tiers = (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(
|
|
||||||
SubscriptionTier.is_active == True, # noqa: E712
|
|
||||||
SubscriptionTier.is_public == True, # noqa: E712
|
|
||||||
)
|
|
||||||
.order_by(SubscriptionTier.display_order)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
if db_tiers:
|
|
||||||
return [
|
|
||||||
TierInfo(
|
|
||||||
code=t.code,
|
|
||||||
name=t.name,
|
|
||||||
price_monthly_cents=t.price_monthly_cents,
|
|
||||||
price_annual_cents=t.price_annual_cents,
|
|
||||||
limits=TierLimits(
|
|
||||||
orders_per_month=t.orders_per_month,
|
|
||||||
products_limit=t.products_limit,
|
|
||||||
team_members=t.team_members,
|
|
||||||
order_history_months=t.order_history_months,
|
|
||||||
),
|
|
||||||
features=t.features or [],
|
|
||||||
)
|
|
||||||
for t in db_tiers
|
|
||||||
]
|
|
||||||
|
|
||||||
# Fallback to hardcoded
|
|
||||||
return [
|
|
||||||
self._get_tier_from_legacy(tier.value)
|
|
||||||
for tier in TierCode
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
|
|
||||||
"""Get subscription tier by code."""
|
|
||||||
return (
|
|
||||||
db.query(SubscriptionTier)
|
|
||||||
.filter(SubscriptionTier.code == tier_code)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
|
|
||||||
"""Get tier ID from tier code. Returns None if tier not found."""
|
|
||||||
tier = self.get_tier_by_code(db, tier_code)
|
|
||||||
return tier.id if tier else None
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Subscription CRUD
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_subscription(
|
|
||||||
self, db: Session, vendor_id: int
|
|
||||||
) -> VendorSubscription | None:
|
|
||||||
"""Get vendor subscription."""
|
|
||||||
return (
|
|
||||||
db.query(VendorSubscription)
|
|
||||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_subscription_or_raise(
|
|
||||||
self, db: Session, vendor_id: int
|
|
||||||
) -> VendorSubscription:
|
|
||||||
"""Get vendor subscription or raise exception."""
|
|
||||||
subscription = self.get_subscription(db, vendor_id)
|
|
||||||
if not subscription:
|
|
||||||
raise SubscriptionNotFoundException(
|
|
||||||
f"No subscription found for vendor {vendor_id}"
|
|
||||||
)
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def get_current_tier(
|
|
||||||
self, db: Session, vendor_id: int
|
|
||||||
) -> TierCode | None:
|
|
||||||
"""Get vendor's current subscription tier code."""
|
|
||||||
subscription = self.get_subscription(db, vendor_id)
|
|
||||||
if subscription:
|
|
||||||
try:
|
|
||||||
return TierCode(subscription.tier)
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_or_create_subscription(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
vendor_id: int,
|
|
||||||
tier: str = TierCode.ESSENTIAL.value,
|
|
||||||
trial_days: int = 14,
|
|
||||||
) -> VendorSubscription:
|
|
||||||
"""
|
|
||||||
Get existing subscription or create a new trial subscription.
|
|
||||||
|
|
||||||
Used when a vendor first accesses the system.
|
|
||||||
"""
|
|
||||||
subscription = self.get_subscription(db, vendor_id)
|
|
||||||
if subscription:
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
# Create new trial subscription
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
trial_end = now + timedelta(days=trial_days)
|
|
||||||
|
|
||||||
# Lookup tier_id from tier code
|
|
||||||
tier_id = self.get_tier_id(db, tier)
|
|
||||||
|
|
||||||
subscription = VendorSubscription(
|
|
||||||
vendor_id=vendor_id,
|
|
||||||
tier=tier,
|
|
||||||
tier_id=tier_id,
|
|
||||||
status=SubscriptionStatus.TRIAL.value,
|
|
||||||
period_start=now,
|
|
||||||
period_end=trial_end,
|
|
||||||
trial_ends_at=trial_end,
|
|
||||||
is_annual=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(subscription)
|
|
||||||
db.flush()
|
|
||||||
db.refresh(subscription)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created trial subscription for vendor {vendor_id} "
|
|
||||||
f"(tier={tier}, trial_ends={trial_end})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def create_subscription(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
vendor_id: int,
|
|
||||||
data: SubscriptionCreate,
|
|
||||||
) -> VendorSubscription:
|
|
||||||
"""Create a subscription for a vendor."""
|
|
||||||
# Check if subscription exists
|
|
||||||
existing = self.get_subscription(db, vendor_id)
|
|
||||||
if existing:
|
|
||||||
raise ValueError("Vendor already has a subscription")
|
|
||||||
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
|
|
||||||
# Calculate period end based on billing cycle
|
|
||||||
if data.is_annual:
|
|
||||||
period_end = now + timedelta(days=365)
|
|
||||||
else:
|
|
||||||
period_end = now + timedelta(days=30)
|
|
||||||
|
|
||||||
# Handle trial
|
|
||||||
trial_ends_at = None
|
|
||||||
status = SubscriptionStatus.ACTIVE.value
|
|
||||||
if data.trial_days > 0:
|
|
||||||
trial_ends_at = now + timedelta(days=data.trial_days)
|
|
||||||
status = SubscriptionStatus.TRIAL.value
|
|
||||||
period_end = trial_ends_at
|
|
||||||
|
|
||||||
# Lookup tier_id from tier code
|
|
||||||
tier_id = self.get_tier_id(db, data.tier)
|
|
||||||
|
|
||||||
subscription = VendorSubscription(
|
|
||||||
vendor_id=vendor_id,
|
|
||||||
tier=data.tier,
|
|
||||||
tier_id=tier_id,
|
|
||||||
status=status,
|
|
||||||
period_start=now,
|
|
||||||
period_end=period_end,
|
|
||||||
trial_ends_at=trial_ends_at,
|
|
||||||
is_annual=data.is_annual,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(subscription)
|
|
||||||
db.flush()
|
|
||||||
db.refresh(subscription)
|
|
||||||
|
|
||||||
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def update_subscription(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
vendor_id: int,
|
|
||||||
data: SubscriptionUpdate,
|
|
||||||
) -> VendorSubscription:
|
|
||||||
"""Update a vendor subscription."""
|
|
||||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
|
||||||
|
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
|
||||||
|
|
||||||
# If tier is being updated, also update tier_id
|
|
||||||
if "tier" in update_data:
|
|
||||||
tier_id = self.get_tier_id(db, update_data["tier"])
|
|
||||||
update_data["tier_id"] = tier_id
|
|
||||||
|
|
||||||
for key, value in update_data.items():
|
|
||||||
setattr(subscription, key, value)
|
|
||||||
|
|
||||||
subscription.updated_at = datetime.now(UTC)
|
|
||||||
db.flush()
|
|
||||||
db.refresh(subscription)
|
|
||||||
|
|
||||||
logger.info(f"Updated subscription for vendor {vendor_id}")
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def upgrade_tier(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
vendor_id: int,
|
|
||||||
new_tier: str,
|
|
||||||
) -> VendorSubscription:
|
|
||||||
"""Upgrade vendor to a new tier."""
|
|
||||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
|
||||||
|
|
||||||
old_tier = subscription.tier
|
|
||||||
subscription.tier = new_tier
|
|
||||||
subscription.tier_id = self.get_tier_id(db, new_tier)
|
|
||||||
subscription.updated_at = datetime.now(UTC)
|
|
||||||
|
|
||||||
# If upgrading from trial, mark as active
|
|
||||||
if subscription.status == SubscriptionStatus.TRIAL.value:
|
|
||||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
|
||||||
|
|
||||||
db.flush()
|
|
||||||
db.refresh(subscription)
|
|
||||||
|
|
||||||
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
def cancel_subscription(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
vendor_id: int,
|
|
||||||
reason: str | None = None,
|
|
||||||
) -> VendorSubscription:
|
|
||||||
"""Cancel a vendor subscription (access until period end)."""
|
|
||||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
|
||||||
|
|
||||||
subscription.status = SubscriptionStatus.CANCELLED.value
|
|
||||||
subscription.cancelled_at = datetime.now(UTC)
|
|
||||||
subscription.cancellation_reason = reason
|
|
||||||
subscription.updated_at = datetime.now(UTC)
|
|
||||||
|
|
||||||
db.flush()
|
|
||||||
db.refresh(subscription)
|
|
||||||
|
|
||||||
logger.info(f"Cancelled subscription for vendor {vendor_id}")
|
|
||||||
return subscription
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Usage Tracking
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
|
|
||||||
"""Get current subscription usage statistics."""
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
|
|
||||||
# Get actual counts
|
|
||||||
products_count = (
|
|
||||||
db.query(func.count(Product.id))
|
|
||||||
.filter(Product.vendor_id == vendor_id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
team_count = (
|
|
||||||
db.query(func.count(VendorUser.id))
|
|
||||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate usage stats
|
|
||||||
orders_limit = subscription.orders_limit
|
|
||||||
products_limit = subscription.products_limit
|
|
||||||
team_limit = subscription.team_members_limit
|
|
||||||
|
|
||||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
|
||||||
if limit is None:
|
|
||||||
return None
|
|
||||||
return max(0, limit - current)
|
|
||||||
|
|
||||||
def calc_percent(current: int, limit: int | None) -> float | None:
|
|
||||||
if limit is None or limit == 0:
|
|
||||||
return None
|
|
||||||
return min(100.0, (current / limit) * 100)
|
|
||||||
|
|
||||||
return SubscriptionUsage(
|
|
||||||
orders_used=subscription.orders_this_period,
|
|
||||||
orders_limit=orders_limit,
|
|
||||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
|
||||||
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
|
|
||||||
products_used=products_count,
|
|
||||||
products_limit=products_limit,
|
|
||||||
products_remaining=calc_remaining(products_count, products_limit),
|
|
||||||
products_percent_used=calc_percent(products_count, products_limit),
|
|
||||||
team_members_used=team_count,
|
|
||||||
team_members_limit=team_limit,
|
|
||||||
team_members_remaining=calc_remaining(team_count, team_limit),
|
|
||||||
team_members_percent_used=calc_percent(team_count, team_limit),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
|
|
||||||
"""Get usage summary for billing page display."""
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
|
|
||||||
# Get actual counts
|
|
||||||
products_count = (
|
|
||||||
db.query(func.count(Product.id))
|
|
||||||
.filter(Product.vendor_id == vendor_id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
team_count = (
|
|
||||||
db.query(func.count(VendorUser.id))
|
|
||||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get limits
|
|
||||||
orders_limit = subscription.orders_limit
|
|
||||||
products_limit = subscription.products_limit
|
|
||||||
team_limit = subscription.team_members_limit
|
|
||||||
|
|
||||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
|
||||||
if limit is None:
|
|
||||||
return None
|
|
||||||
return max(0, limit - current)
|
|
||||||
|
|
||||||
return UsageSummary(
|
|
||||||
orders_this_period=subscription.orders_this_period,
|
|
||||||
orders_limit=orders_limit,
|
|
||||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
|
||||||
products_count=products_count,
|
|
||||||
products_limit=products_limit,
|
|
||||||
products_remaining=calc_remaining(products_count, products_limit),
|
|
||||||
team_count=team_count,
|
|
||||||
team_limit=team_limit,
|
|
||||||
team_remaining=calc_remaining(team_count, team_limit),
|
|
||||||
)
|
|
||||||
|
|
||||||
def increment_order_count(self, db: Session, vendor_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Increment the order counter for the current period.
|
|
||||||
|
|
||||||
Call this after successfully creating/importing an order.
|
|
||||||
"""
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
subscription.increment_order_count()
|
|
||||||
db.flush()
|
|
||||||
|
|
||||||
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
|
|
||||||
"""Reset counters for a new billing period."""
|
|
||||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
|
||||||
subscription.reset_period_counters()
|
|
||||||
db.flush()
|
|
||||||
logger.info(f"Reset period counters for vendor {vendor_id}")
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Limit Checks
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def can_create_order(
|
|
||||||
self, db: Session, vendor_id: int
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""
|
|
||||||
Check if vendor can create/import another order.
|
|
||||||
|
|
||||||
Returns: (allowed, error_message)
|
|
||||||
"""
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
return subscription.can_create_order()
|
|
||||||
|
|
||||||
def check_order_limit(self, db: Session, vendor_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Check order limit and raise exception if exceeded.
|
|
||||||
|
|
||||||
Use this in order creation flows.
|
|
||||||
"""
|
|
||||||
can_create, message = self.can_create_order(db, vendor_id)
|
|
||||||
if not can_create:
|
|
||||||
subscription = self.get_subscription(db, vendor_id)
|
|
||||||
raise TierLimitExceededException(
|
|
||||||
message=message or "Order limit exceeded",
|
|
||||||
limit_type="orders",
|
|
||||||
current=subscription.orders_this_period if subscription else 0,
|
|
||||||
limit=subscription.orders_limit if subscription else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def can_add_product(
|
|
||||||
self, db: Session, vendor_id: int
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""
|
|
||||||
Check if vendor can add another product.
|
|
||||||
|
|
||||||
Returns: (allowed, error_message)
|
|
||||||
"""
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
|
|
||||||
products_count = (
|
|
||||||
db.query(func.count(Product.id))
|
|
||||||
.filter(Product.vendor_id == vendor_id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
return subscription.can_add_product(products_count)
|
|
||||||
|
|
||||||
def check_product_limit(self, db: Session, vendor_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Check product limit and raise exception if exceeded.
|
|
||||||
|
|
||||||
Use this in product creation flows.
|
|
||||||
"""
|
|
||||||
can_add, message = self.can_add_product(db, vendor_id)
|
|
||||||
if not can_add:
|
|
||||||
subscription = self.get_subscription(db, vendor_id)
|
|
||||||
products_count = (
|
|
||||||
db.query(func.count(Product.id))
|
|
||||||
.filter(Product.vendor_id == vendor_id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
raise TierLimitExceededException(
|
|
||||||
message=message or "Product limit exceeded",
|
|
||||||
limit_type="products",
|
|
||||||
current=products_count,
|
|
||||||
limit=subscription.products_limit if subscription else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def can_add_team_member(
|
|
||||||
self, db: Session, vendor_id: int
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""
|
|
||||||
Check if vendor can add another team member.
|
|
||||||
|
|
||||||
Returns: (allowed, error_message)
|
|
||||||
"""
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
|
|
||||||
team_count = (
|
|
||||||
db.query(func.count(VendorUser.id))
|
|
||||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
return subscription.can_add_team_member(team_count)
|
|
||||||
|
|
||||||
def check_team_limit(self, db: Session, vendor_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Check team member limit and raise exception if exceeded.
|
|
||||||
|
|
||||||
Use this in team member invitation flows.
|
|
||||||
"""
|
|
||||||
can_add, message = self.can_add_team_member(db, vendor_id)
|
|
||||||
if not can_add:
|
|
||||||
subscription = self.get_subscription(db, vendor_id)
|
|
||||||
team_count = (
|
|
||||||
db.query(func.count(VendorUser.id))
|
|
||||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
raise TierLimitExceededException(
|
|
||||||
message=message or "Team member limit exceeded",
|
|
||||||
limit_type="team_members",
|
|
||||||
current=team_count,
|
|
||||||
limit=subscription.team_members_limit if subscription else 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Feature Gating
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
|
|
||||||
"""Check if vendor has access to a feature."""
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
return subscription.has_feature(feature)
|
|
||||||
|
|
||||||
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
|
|
||||||
"""
|
|
||||||
Check feature access and raise exception if not available.
|
|
||||||
|
|
||||||
Use this to gate premium features.
|
|
||||||
"""
|
|
||||||
if not self.has_feature(db, vendor_id, feature):
|
|
||||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
|
||||||
|
|
||||||
# Find which tier has this feature
|
|
||||||
required_tier = None
|
|
||||||
for tier_code, limits in TIER_LIMITS.items():
|
|
||||||
if feature in limits.get("features", []):
|
|
||||||
required_tier = limits["name"]
|
|
||||||
break
|
|
||||||
|
|
||||||
raise FeatureNotAvailableException(
|
|
||||||
feature=feature,
|
|
||||||
current_tier=subscription.tier,
|
|
||||||
required_tier=required_tier or "higher",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_feature_tier(self, feature: str) -> str | None:
|
|
||||||
"""Get the minimum tier required for a feature."""
|
|
||||||
for tier_code in [
|
|
||||||
TierCode.ESSENTIAL,
|
|
||||||
TierCode.PROFESSIONAL,
|
|
||||||
TierCode.BUSINESS,
|
|
||||||
TierCode.ENTERPRISE,
|
|
||||||
]:
|
|
||||||
if feature in TIER_LIMITS[tier_code].get("features", []):
|
|
||||||
return tier_code.value
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# Singleton instance
|
|
||||||
subscription_service = SubscriptionService()
|
|
||||||
|
|||||||
@@ -1,265 +1,27 @@
|
|||||||
# app/tasks/celery_tasks/subscription.py
|
# app/tasks/celery_tasks/subscription.py
|
||||||
"""
|
"""
|
||||||
Celery tasks for subscription management.
|
Legacy subscription tasks.
|
||||||
|
|
||||||
Scheduled tasks for:
|
MOSTLY MIGRATED: Most tasks have been migrated to app.modules.billing.tasks.
|
||||||
- Resetting period counters
|
|
||||||
- Checking trial expirations
|
The following tasks now live in the billing module:
|
||||||
- Syncing with Stripe
|
- reset_period_counters -> app.modules.billing.tasks.subscription
|
||||||
- Cleaning up stale subscriptions
|
- check_trial_expirations -> app.modules.billing.tasks.subscription
|
||||||
- Capturing capacity snapshots
|
- sync_stripe_status -> app.modules.billing.tasks.subscription
|
||||||
|
- cleanup_stale_subscriptions -> app.modules.billing.tasks.subscription
|
||||||
|
|
||||||
|
Remaining task (to be migrated to monitoring module):
|
||||||
|
- capture_capacity_snapshot
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime, timedelta
|
|
||||||
|
|
||||||
from app.core.celery_config import celery_app
|
from app.core.celery_config import celery_app
|
||||||
from app.services.stripe_service import stripe_service
|
|
||||||
from app.tasks.celery_tasks.base import DatabaseTask
|
from app.tasks.celery_tasks.base import DatabaseTask
|
||||||
from models.database.subscription import SubscriptionStatus, VendorSubscription
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
|
||||||
bind=True,
|
|
||||||
base=DatabaseTask,
|
|
||||||
name="app.tasks.celery_tasks.subscription.reset_period_counters",
|
|
||||||
)
|
|
||||||
def reset_period_counters(self):
|
|
||||||
"""
|
|
||||||
Reset order counters for subscriptions whose billing period has ended.
|
|
||||||
|
|
||||||
Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates.
|
|
||||||
"""
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
reset_count = 0
|
|
||||||
|
|
||||||
with self.get_db() as db:
|
|
||||||
# Find subscriptions where period has ended
|
|
||||||
expired_periods = (
|
|
||||||
db.query(VendorSubscription)
|
|
||||||
.filter(
|
|
||||||
VendorSubscription.period_end <= now,
|
|
||||||
VendorSubscription.status.in_(["active", "trial"]),
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
for subscription in expired_periods:
|
|
||||||
old_period_end = subscription.period_end
|
|
||||||
|
|
||||||
# Reset counters
|
|
||||||
subscription.orders_this_period = 0
|
|
||||||
subscription.orders_limit_reached_at = None
|
|
||||||
|
|
||||||
# Set new period dates
|
|
||||||
if subscription.is_annual:
|
|
||||||
subscription.period_start = now
|
|
||||||
subscription.period_end = now + timedelta(days=365)
|
|
||||||
else:
|
|
||||||
subscription.period_start = now
|
|
||||||
subscription.period_end = now + timedelta(days=30)
|
|
||||||
|
|
||||||
subscription.updated_at = now
|
|
||||||
reset_count += 1
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Reset period counters for vendor {subscription.vendor_id}: "
|
|
||||||
f"old_period_end={old_period_end}, new_period_end={subscription.period_end}"
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Reset period counters for {reset_count} subscriptions")
|
|
||||||
|
|
||||||
return {"reset_count": reset_count}
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
|
||||||
bind=True,
|
|
||||||
base=DatabaseTask,
|
|
||||||
name="app.tasks.celery_tasks.subscription.check_trial_expirations",
|
|
||||||
)
|
|
||||||
def check_trial_expirations(self):
|
|
||||||
"""
|
|
||||||
Check for expired trials and update their status.
|
|
||||||
|
|
||||||
Runs daily at 01:00.
|
|
||||||
- Trials without payment method -> expired
|
|
||||||
- Trials with payment method -> active
|
|
||||||
"""
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
expired_count = 0
|
|
||||||
activated_count = 0
|
|
||||||
|
|
||||||
with self.get_db() as db:
|
|
||||||
# Find expired trials
|
|
||||||
expired_trials = (
|
|
||||||
db.query(VendorSubscription)
|
|
||||||
.filter(
|
|
||||||
VendorSubscription.status == SubscriptionStatus.TRIAL.value,
|
|
||||||
VendorSubscription.trial_ends_at <= now,
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
for subscription in expired_trials:
|
|
||||||
if subscription.stripe_payment_method_id:
|
|
||||||
# Has payment method - activate
|
|
||||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
|
||||||
activated_count += 1
|
|
||||||
logger.info(
|
|
||||||
f"Activated subscription for vendor {subscription.vendor_id} "
|
|
||||||
f"(trial ended with payment method)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No payment method - expire
|
|
||||||
subscription.status = SubscriptionStatus.EXPIRED.value
|
|
||||||
expired_count += 1
|
|
||||||
logger.info(
|
|
||||||
f"Expired trial for vendor {subscription.vendor_id} "
|
|
||||||
f"(no payment method)"
|
|
||||||
)
|
|
||||||
|
|
||||||
subscription.updated_at = now
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated")
|
|
||||||
|
|
||||||
return {"expired_count": expired_count, "activated_count": activated_count}
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
|
||||||
bind=True,
|
|
||||||
base=DatabaseTask,
|
|
||||||
name="app.tasks.celery_tasks.subscription.sync_stripe_status",
|
|
||||||
max_retries=3,
|
|
||||||
default_retry_delay=300,
|
|
||||||
)
|
|
||||||
def sync_stripe_status(self):
|
|
||||||
"""
|
|
||||||
Sync subscription status with Stripe.
|
|
||||||
|
|
||||||
Runs hourly at :30. Fetches current status from Stripe and updates local records.
|
|
||||||
"""
|
|
||||||
if not stripe_service.is_configured:
|
|
||||||
logger.warning("Stripe not configured, skipping sync")
|
|
||||||
return {"synced": 0, "skipped": True}
|
|
||||||
|
|
||||||
synced_count = 0
|
|
||||||
error_count = 0
|
|
||||||
|
|
||||||
with self.get_db() as db:
|
|
||||||
# Find subscriptions with Stripe IDs
|
|
||||||
subscriptions = (
|
|
||||||
db.query(VendorSubscription)
|
|
||||||
.filter(VendorSubscription.stripe_subscription_id.isnot(None))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
for subscription in subscriptions:
|
|
||||||
try:
|
|
||||||
# Fetch from Stripe
|
|
||||||
stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id)
|
|
||||||
|
|
||||||
if not stripe_sub:
|
|
||||||
logger.warning(
|
|
||||||
f"Stripe subscription {subscription.stripe_subscription_id} "
|
|
||||||
f"not found for vendor {subscription.vendor_id}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Map Stripe status to local status
|
|
||||||
status_map = {
|
|
||||||
"active": SubscriptionStatus.ACTIVE.value,
|
|
||||||
"trialing": SubscriptionStatus.TRIAL.value,
|
|
||||||
"past_due": SubscriptionStatus.PAST_DUE.value,
|
|
||||||
"canceled": SubscriptionStatus.CANCELLED.value,
|
|
||||||
"unpaid": SubscriptionStatus.PAST_DUE.value,
|
|
||||||
"incomplete": SubscriptionStatus.TRIAL.value,
|
|
||||||
"incomplete_expired": SubscriptionStatus.EXPIRED.value,
|
|
||||||
}
|
|
||||||
|
|
||||||
new_status = status_map.get(stripe_sub.status)
|
|
||||||
if new_status and new_status != subscription.status:
|
|
||||||
old_status = subscription.status
|
|
||||||
subscription.status = new_status
|
|
||||||
subscription.updated_at = datetime.now(UTC)
|
|
||||||
logger.info(
|
|
||||||
f"Updated vendor {subscription.vendor_id} status: "
|
|
||||||
f"{old_status} -> {new_status} (from Stripe)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update period dates from Stripe
|
|
||||||
if stripe_sub.current_period_start:
|
|
||||||
subscription.period_start = datetime.fromtimestamp(
|
|
||||||
stripe_sub.current_period_start, tz=UTC
|
|
||||||
)
|
|
||||||
if stripe_sub.current_period_end:
|
|
||||||
subscription.period_end = datetime.fromtimestamp(
|
|
||||||
stripe_sub.current_period_end, tz=UTC
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update payment method
|
|
||||||
if stripe_sub.default_payment_method:
|
|
||||||
subscription.stripe_payment_method_id = (
|
|
||||||
stripe_sub.default_payment_method
|
|
||||||
if isinstance(stripe_sub.default_payment_method, str)
|
|
||||||
else stripe_sub.default_payment_method.id
|
|
||||||
)
|
|
||||||
|
|
||||||
synced_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}")
|
|
||||||
error_count += 1
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors")
|
|
||||||
|
|
||||||
return {"synced_count": synced_count, "error_count": error_count}
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
|
||||||
bind=True,
|
|
||||||
base=DatabaseTask,
|
|
||||||
name="app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions",
|
|
||||||
)
|
|
||||||
def cleanup_stale_subscriptions(self):
|
|
||||||
"""
|
|
||||||
Clean up subscriptions in inconsistent states.
|
|
||||||
|
|
||||||
Runs weekly on Sunday at 03:00.
|
|
||||||
"""
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
cleaned_count = 0
|
|
||||||
|
|
||||||
with self.get_db() as db:
|
|
||||||
# Find cancelled subscriptions past their period end
|
|
||||||
stale_cancelled = (
|
|
||||||
db.query(VendorSubscription)
|
|
||||||
.filter(
|
|
||||||
VendorSubscription.status == SubscriptionStatus.CANCELLED.value,
|
|
||||||
VendorSubscription.period_end < now - timedelta(days=30),
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
for subscription in stale_cancelled:
|
|
||||||
# Mark as expired (fully terminated)
|
|
||||||
subscription.status = SubscriptionStatus.EXPIRED.value
|
|
||||||
subscription.updated_at = now
|
|
||||||
cleaned_count += 1
|
|
||||||
logger.info(
|
|
||||||
f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
logger.info(f"Cleaned up {cleaned_count} stale subscriptions")
|
|
||||||
|
|
||||||
return {"cleaned_count": cleaned_count}
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
@celery_app.task(
|
||||||
bind=True,
|
bind=True,
|
||||||
base=DatabaseTask,
|
base=DatabaseTask,
|
||||||
@@ -270,6 +32,8 @@ def capture_capacity_snapshot(self):
|
|||||||
Capture a daily snapshot of platform capacity metrics.
|
Capture a daily snapshot of platform capacity metrics.
|
||||||
|
|
||||||
Runs daily at midnight.
|
Runs daily at midnight.
|
||||||
|
|
||||||
|
TODO: Migrate to app.modules.monitoring.tasks
|
||||||
"""
|
"""
|
||||||
from app.services.capacity_forecast_service import capacity_forecast_service
|
from app.services.capacity_forecast_service import capacity_forecast_service
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ Transform the platform from a monolithic structure to self-contained modules whe
|
|||||||
|--------|---------------|---------------|----------|--------|-------|--------|
|
|--------|---------------|---------------|----------|--------|-------|--------|
|
||||||
| `cms` | Core | ✅ **Complete** | ✅ | ✅ | - | Done |
|
| `cms` | Core | ✅ **Complete** | ✅ | ✅ | - | Done |
|
||||||
| `payments` | Optional | 🟡 Partial | ✅ | ✅ | - | Done |
|
| `payments` | Optional | 🟡 Partial | ✅ | ✅ | - | Done |
|
||||||
| `billing` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full |
|
| `billing` | Optional | ✅ **Complete** | ✅ | ✅ | ✅ | Done |
|
||||||
| `marketplace` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full |
|
| `marketplace` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full |
|
||||||
| `orders` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
|
| `orders` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
|
||||||
| `inventory` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
|
| `inventory` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
|
||||||
@@ -110,85 +110,22 @@ app/tasks/celery_tasks/ # → Move to respective modules
|
|||||||
- Observability docs
|
- Observability docs
|
||||||
- Creating modules developer guide
|
- Creating modules developer guide
|
||||||
|
|
||||||
---
|
### ✅ Phase 4: Celery Task Infrastructure
|
||||||
|
- Added `ScheduledTask` dataclass to `ModuleDefinition`
|
||||||
|
- Added `tasks_path` and `scheduled_tasks` fields
|
||||||
|
- Created `app/modules/task_base.py` with `ModuleTask` base class
|
||||||
|
- Created `app/modules/tasks.py` with discovery utilities
|
||||||
|
- Updated `celery_config.py` to use module discovery
|
||||||
|
|
||||||
## Infrastructure Work (Remaining)
|
### ✅ Phase 5: Billing Module Migration
|
||||||
|
- Created `app/modules/billing/services/` with subscription, stripe, admin services
|
||||||
### Phase 4: Celery Task Infrastructure
|
- Created `app/modules/billing/models/` re-exporting from central location
|
||||||
|
- Created `app/modules/billing/schemas/` re-exporting from central location
|
||||||
Add task support to module system before migrating modules.
|
- Created `app/modules/billing/tasks/` with 4 scheduled tasks
|
||||||
|
- Created `app/modules/billing/exceptions.py`
|
||||||
#### 4.1 Update ModuleDefinition
|
- Updated `definition.py` with self-contained configuration
|
||||||
|
- Created backward-compatible re-exports in `app/services/`
|
||||||
```python
|
- Updated legacy celery_config.py to not duplicate scheduled tasks
|
||||||
# app/modules/base.py - Add fields
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ScheduledTask:
|
|
||||||
"""Celery Beat scheduled task definition."""
|
|
||||||
name: str # e.g., "billing.reset_counters"
|
|
||||||
task: str # Full task path
|
|
||||||
schedule: str | dict # Cron string or crontab dict
|
|
||||||
args: tuple = ()
|
|
||||||
kwargs: dict = field(default_factory=dict)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModuleDefinition:
|
|
||||||
# ... existing fields ...
|
|
||||||
|
|
||||||
# Task configuration (NEW)
|
|
||||||
tasks_path: str | None = None
|
|
||||||
scheduled_tasks: list[ScheduledTask] = field(default_factory=list)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4.2 Create Task Discovery
|
|
||||||
|
|
||||||
```python
|
|
||||||
# app/modules/tasks.py (NEW)
|
|
||||||
|
|
||||||
def discover_module_tasks() -> list[str]:
|
|
||||||
"""Discover task modules from all registered modules."""
|
|
||||||
...
|
|
||||||
|
|
||||||
def build_beat_schedule() -> dict:
|
|
||||||
"""Build Celery Beat schedule from module definitions."""
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4.3 Create Module Task Base
|
|
||||||
|
|
||||||
```python
|
|
||||||
# app/modules/task_base.py (NEW)
|
|
||||||
|
|
||||||
class ModuleTask(Task):
|
|
||||||
"""Base Celery task with DB session management."""
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def get_db(self):
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4.4 Update Celery Config
|
|
||||||
|
|
||||||
```python
|
|
||||||
# app/core/celery_config.py
|
|
||||||
|
|
||||||
from app.modules.tasks import discover_module_tasks, build_beat_schedule
|
|
||||||
|
|
||||||
# Auto-discover from modules
|
|
||||||
celery_app.autodiscover_tasks(discover_module_tasks())
|
|
||||||
|
|
||||||
# Build schedule from modules
|
|
||||||
celery_app.conf.beat_schedule = build_beat_schedule()
|
|
||||||
```
|
|
||||||
|
|
||||||
**Files to create:**
|
|
||||||
- `app/modules/tasks.py`
|
|
||||||
- `app/modules/task_base.py`
|
|
||||||
|
|
||||||
**Files to modify:**
|
|
||||||
- `app/modules/base.py`
|
|
||||||
- `app/core/celery_config.py`
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user