feat: complete billing module migration (Phase 5)

Migrates billing module to self-contained structure:
- Create app/modules/billing/services/ with subscription, stripe, admin services
- Create app/modules/billing/models/ re-exporting from central location
- Create app/modules/billing/schemas/ re-exporting from central location
- Create app/modules/billing/tasks/ with 4 scheduled Celery tasks
- Create app/modules/billing/exceptions.py with module-specific exceptions
- Update definition.py with is_self_contained=True and scheduled_tasks

Celery task migration:
- reset_period_counters -> billing module
- check_trial_expirations -> billing module
- sync_stripe_status -> billing module
- cleanup_stale_subscriptions -> billing module
- capture_capacity_snapshot remains in legacy (will go to monitoring)

Backward compatibility:
- Create re-exports in app/services/ for subscription, stripe, admin services
- Old import paths continue to work
- Update celery_config.py to use module-defined schedules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-27 23:06:23 +01:00
parent f1f91abe51
commit 4f379b472b
17 changed files with 2198 additions and 1931 deletions

View File

@@ -49,10 +49,12 @@ if SENTRY_DSN:
# TASK DISCOVERY
# =============================================================================
# Legacy tasks (will be migrated to modules over time)
# NOTE: Most subscription tasks have been migrated to app.modules.billing.tasks
# The subscription module is kept for capture_capacity_snapshot (will move to monitoring)
LEGACY_TASK_MODULES = [
"app.tasks.celery_tasks.marketplace",
"app.tasks.celery_tasks.letzshop",
"app.tasks.celery_tasks.subscription",
"app.tasks.celery_tasks.subscription", # Kept for capture_capacity_snapshot only
"app.tasks.celery_tasks.export",
"app.tasks.celery_tasks.code_quality",
"app.tasks.celery_tasks.test_runner",
@@ -141,32 +143,9 @@ celery_app.conf.task_routes = {
# =============================================================================
# Legacy scheduled tasks (will be migrated to module definitions)
# NOTE: Subscription tasks have been migrated to billing module (see definition.py)
LEGACY_BEAT_SCHEDULE = {
# Reset usage counters at start of each period
"reset-period-counters-daily": {
"task": "app.tasks.celery_tasks.subscription.reset_period_counters",
"schedule": crontab(hour=0, minute=5), # 00:05 daily
"options": {"queue": "scheduled"},
},
# Check for expiring trials and send notifications
"check-trial-expirations-daily": {
"task": "app.tasks.celery_tasks.subscription.check_trial_expirations",
"schedule": crontab(hour=1, minute=0), # 01:00 daily
"options": {"queue": "scheduled"},
},
# Sync subscription status with Stripe
"sync-stripe-status-hourly": {
"task": "app.tasks.celery_tasks.subscription.sync_stripe_status",
"schedule": crontab(minute=30), # Every hour at :30
"options": {"queue": "scheduled"},
},
# Clean up stale/orphaned subscriptions
"cleanup-stale-subscriptions-weekly": {
"task": "app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions",
"schedule": crontab(hour=3, minute=0, day_of_week=0), # Sunday 03:00
"options": {"queue": "scheduled"},
},
# Capture daily capacity snapshot for analytics
# Capacity snapshot - will be migrated to monitoring module
"capture-capacity-snapshot-daily": {
"task": "app.tasks.celery_tasks.subscription.capture_capacity_snapshot",
"schedule": crontab(hour=0, minute=0), # Midnight daily

View File

@@ -7,6 +7,7 @@ This module provides:
- Vendor subscription CRUD
- Billing history and invoices
- Stripe integration
- Scheduled tasks for subscription lifecycle
Routes:
- Admin: /api/v1/admin/subscriptions/*
@@ -15,8 +16,17 @@ Routes:
Menu Items:
- Admin: subscription-tiers, subscriptions, billing-history
- Vendor: billing, invoices
Usage:
from app.modules.billing import billing_module
from app.modules.billing.services import subscription_service, stripe_service
from app.modules.billing.models import VendorSubscription, SubscriptionTier
from app.modules.billing.exceptions import TierLimitExceededException
"""
from app.modules.billing.definition import billing_module
from app.modules.billing.definition import billing_module, get_billing_module_with_routers
__all__ = ["billing_module"]
__all__ = [
"billing_module",
"get_billing_module_with_routers",
]

View File

@@ -3,10 +3,10 @@
Billing module definition.
Defines the billing module including its features, menu items,
and route configurations.
route configurations, and scheduled tasks.
"""
from app.modules.base import ModuleDefinition
from app.modules.base import ModuleDefinition, ScheduledTask
from models.database.admin_menu_config import FrontendType
@@ -54,6 +54,44 @@ billing_module = ModuleDefinition(
],
},
is_core=False, # Billing can be disabled (e.g., internal platforms)
# =========================================================================
# Self-Contained Module Configuration
# =========================================================================
is_self_contained=True,
services_path="app.modules.billing.services",
models_path="app.modules.billing.models",
schemas_path="app.modules.billing.schemas",
exceptions_path="app.modules.billing.exceptions",
tasks_path="app.modules.billing.tasks",
# =========================================================================
# Scheduled Tasks
# =========================================================================
scheduled_tasks=[
ScheduledTask(
name="billing.reset_period_counters",
task="app.modules.billing.tasks.subscription.reset_period_counters",
schedule="5 0 * * *", # Daily at 00:05
options={"queue": "scheduled"},
),
ScheduledTask(
name="billing.check_trial_expirations",
task="app.modules.billing.tasks.subscription.check_trial_expirations",
schedule="0 1 * * *", # Daily at 01:00
options={"queue": "scheduled"},
),
ScheduledTask(
name="billing.sync_stripe_status",
task="app.modules.billing.tasks.subscription.sync_stripe_status",
schedule="30 * * * *", # Hourly at :30
options={"queue": "scheduled"},
),
ScheduledTask(
name="billing.cleanup_stale_subscriptions",
task="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions",
schedule="0 3 * * 0", # Weekly on Sunday at 03:00
options={"queue": "scheduled"},
),
],
)

View File

@@ -0,0 +1,83 @@
# app/modules/billing/exceptions.py
"""
Billing module exceptions.
Custom exceptions for subscription, billing, and payment operations.
"""
from app.exceptions import BusinessLogicException, ResourceNotFoundException
class BillingException(BusinessLogicException):
"""Base exception for billing module errors."""
pass
class SubscriptionNotFoundException(ResourceNotFoundException):
"""Raised when a subscription is not found."""
def __init__(self, vendor_id: int):
super().__init__("Subscription", str(vendor_id))
class TierNotFoundException(ResourceNotFoundException):
"""Raised when a subscription tier is not found."""
def __init__(self, tier_code: str):
super().__init__("SubscriptionTier", tier_code)
class TierLimitExceededException(BillingException):
"""Raised when a tier limit is exceeded."""
def __init__(self, message: str, limit_type: str, current: int, limit: int):
super().__init__(message)
self.limit_type = limit_type
self.current = current
self.limit = limit
class FeatureNotAvailableException(BillingException):
"""Raised when a feature is not available in current tier."""
def __init__(self, feature: str, current_tier: str, required_tier: str):
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
super().__init__(message)
self.feature = feature
self.current_tier = current_tier
self.required_tier = required_tier
class StripeNotConfiguredException(BillingException):
"""Raised when Stripe is not configured."""
def __init__(self):
super().__init__("Stripe is not configured")
class PaymentFailedException(BillingException):
"""Raised when a payment fails."""
def __init__(self, message: str, stripe_error: str | None = None):
super().__init__(message)
self.stripe_error = stripe_error
class WebhookVerificationException(BillingException):
"""Raised when webhook signature verification fails."""
def __init__(self, message: str = "Invalid webhook signature"):
super().__init__(message)
__all__ = [
"BillingException",
"SubscriptionNotFoundException",
"TierNotFoundException",
"TierLimitExceededException",
"FeatureNotAvailableException",
"StripeNotConfiguredException",
"PaymentFailedException",
"WebhookVerificationException",
]

View File

@@ -0,0 +1,52 @@
# app/modules/billing/models/__init__.py
"""
Billing module models.
Re-exports subscription models from the central models location.
Models remain in models/database/ for now to avoid breaking existing imports
across the codebase. This provides a module-local import path.
Usage:
from app.modules.billing.models import (
VendorSubscription,
SubscriptionTier,
SubscriptionStatus,
TierCode,
)
"""
from models.database.subscription import (
# Enums
TierCode,
SubscriptionStatus,
AddOnCategory,
BillingPeriod,
# Models
SubscriptionTier,
AddOnProduct,
VendorAddOn,
StripeWebhookEvent,
BillingHistory,
VendorSubscription,
CapacitySnapshot,
# Legacy constants
TIER_LIMITS,
)
__all__ = [
# Enums
"TierCode",
"SubscriptionStatus",
"AddOnCategory",
"BillingPeriod",
# Models
"SubscriptionTier",
"AddOnProduct",
"VendorAddOn",
"StripeWebhookEvent",
"BillingHistory",
"VendorSubscription",
"CapacitySnapshot",
# Legacy constants
"TIER_LIMITS",
]

View File

@@ -0,0 +1,56 @@
# app/modules/billing/schemas/__init__.py
"""
Billing module Pydantic schemas.
Re-exports subscription schemas from the central schemas location.
Provides a module-local import path while maintaining backwards compatibility.
Usage:
from app.modules.billing.schemas import (
SubscriptionCreate,
SubscriptionResponse,
TierInfo,
)
"""
from models.schema.subscription import (
# Tier schemas
TierFeatures,
TierLimits,
TierInfo,
# Subscription CRUD schemas
SubscriptionCreate,
SubscriptionUpdate,
SubscriptionResponse,
# Usage schemas
SubscriptionUsage,
UsageSummary,
SubscriptionStatusResponse,
# Limit check schemas
LimitCheckResult,
CanCreateOrderResponse,
CanAddProductResponse,
CanAddTeamMemberResponse,
FeatureCheckResponse,
)
__all__ = [
# Tier schemas
"TierFeatures",
"TierLimits",
"TierInfo",
# Subscription CRUD schemas
"SubscriptionCreate",
"SubscriptionUpdate",
"SubscriptionResponse",
# Usage schemas
"SubscriptionUsage",
"UsageSummary",
"SubscriptionStatusResponse",
# Limit check schemas
"LimitCheckResult",
"CanCreateOrderResponse",
"CanAddProductResponse",
"CanAddTeamMemberResponse",
"FeatureCheckResponse",
]

View File

@@ -0,0 +1,28 @@
# app/modules/billing/services/__init__.py
"""
Billing module services.
Provides subscription management, Stripe integration, and admin operations.
"""
from app.modules.billing.services.subscription_service import (
SubscriptionService,
subscription_service,
)
from app.modules.billing.services.stripe_service import (
StripeService,
stripe_service,
)
from app.modules.billing.services.admin_subscription_service import (
AdminSubscriptionService,
admin_subscription_service,
)
__all__ = [
"SubscriptionService",
"subscription_service",
"StripeService",
"stripe_service",
"AdminSubscriptionService",
"admin_subscription_service",
]

View File

@@ -0,0 +1,352 @@
# app/modules/billing/services/admin_subscription_service.py
"""
Admin Subscription Service.
Handles subscription management operations for platform administrators:
- Subscription tier CRUD
- Vendor subscription management
- Billing history queries
- Subscription analytics
"""
import logging
from math import ceil
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
BusinessLogicException,
ConflictException,
ResourceNotFoundException,
)
from app.modules.billing.exceptions import TierNotFoundException
from app.modules.billing.models import (
BillingHistory,
SubscriptionStatus,
SubscriptionTier,
VendorSubscription,
)
from models.database.product import Product
from models.database.vendor import Vendor, VendorUser
logger = logging.getLogger(__name__)
class AdminSubscriptionService:
"""Service for admin subscription management operations."""
# =========================================================================
# Subscription Tiers
# =========================================================================
def get_tiers(
self, db: Session, include_inactive: bool = False
) -> list[SubscriptionTier]:
"""Get all subscription tiers."""
query = db.query(SubscriptionTier)
if not include_inactive:
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
return query.order_by(SubscriptionTier.display_order).all()
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
"""Get a subscription tier by code."""
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
if not tier:
raise TierNotFoundException(tier_code)
return tier
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
"""Create a new subscription tier."""
# Check for duplicate code
existing = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_data["code"])
.first()
)
if existing:
raise ConflictException(
f"Tier with code '{tier_data['code']}' already exists"
)
tier = SubscriptionTier(**tier_data)
db.add(tier)
logger.info(f"Created subscription tier: {tier.code}")
return tier
def update_tier(
self, db: Session, tier_code: str, update_data: dict
) -> SubscriptionTier:
"""Update a subscription tier."""
tier = self.get_tier_by_code(db, tier_code)
for field, value in update_data.items():
setattr(tier, field, value)
logger.info(f"Updated subscription tier: {tier.code}")
return tier
def deactivate_tier(self, db: Session, tier_code: str) -> None:
"""Soft-delete a subscription tier."""
tier = self.get_tier_by_code(db, tier_code)
# Check if any active subscriptions use this tier
active_subs = (
db.query(VendorSubscription)
.filter(
VendorSubscription.tier == tier_code,
VendorSubscription.status.in_([
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
]),
)
.count()
)
if active_subs > 0:
raise BusinessLogicException(
f"Cannot delete tier: {active_subs} active subscriptions are using it"
)
tier.is_active = False
logger.info(f"Soft-deleted subscription tier: {tier.code}")
# =========================================================================
# Vendor Subscriptions
# =========================================================================
def list_subscriptions(
self,
db: Session,
page: int = 1,
per_page: int = 20,
status: str | None = None,
tier: str | None = None,
search: str | None = None,
) -> dict:
"""List vendor subscriptions with filtering and pagination."""
query = (
db.query(VendorSubscription, Vendor)
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
)
# Apply filters
if status:
query = query.filter(VendorSubscription.status == status)
if tier:
query = query.filter(VendorSubscription.tier == tier)
if search:
query = query.filter(Vendor.name.ilike(f"%{search}%"))
# Count total
total = query.count()
# Paginate
offset = (page - 1) * per_page
results = (
query.order_by(VendorSubscription.created_at.desc())
.offset(offset)
.limit(per_page)
.all()
)
return {
"results": results,
"total": total,
"page": page,
"per_page": per_page,
"pages": ceil(total / per_page) if total > 0 else 0,
}
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
"""Get subscription for a specific vendor."""
result = (
db.query(VendorSubscription, Vendor)
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
if not result:
raise ResourceNotFoundException("Subscription", str(vendor_id))
return result
def update_subscription(
self, db: Session, vendor_id: int, update_data: dict
) -> tuple:
"""Update a vendor's subscription."""
result = self.get_subscription(db, vendor_id)
sub, vendor = result
for field, value in update_data.items():
setattr(sub, field, value)
logger.info(
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
)
return sub, vendor
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
"""Get a vendor by ID."""
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise ResourceNotFoundException("Vendor", str(vendor_id))
return vendor
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
"""Get usage counts (products and team members) for a vendor."""
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(
VendorUser.vendor_id == vendor_id,
VendorUser.is_active == True, # noqa: E712
)
.scalar()
or 0
)
return {
"products_count": products_count,
"team_count": team_count,
}
# =========================================================================
# Billing History
# =========================================================================
def list_billing_history(
self,
db: Session,
page: int = 1,
per_page: int = 20,
vendor_id: int | None = None,
status: str | None = None,
) -> dict:
"""List billing history across all vendors."""
query = (
db.query(BillingHistory, Vendor)
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
)
if vendor_id:
query = query.filter(BillingHistory.vendor_id == vendor_id)
if status:
query = query.filter(BillingHistory.status == status)
total = query.count()
offset = (page - 1) * per_page
results = (
query.order_by(BillingHistory.invoice_date.desc())
.offset(offset)
.limit(per_page)
.all()
)
return {
"results": results,
"total": total,
"page": page,
"per_page": per_page,
"pages": ceil(total / per_page) if total > 0 else 0,
}
# =========================================================================
# Statistics
# =========================================================================
def get_stats(self, db: Session) -> dict:
"""Get subscription statistics for admin dashboard."""
# Count by status
status_counts = (
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
.group_by(VendorSubscription.status)
.all()
)
stats = {
"total_subscriptions": 0,
"active_count": 0,
"trial_count": 0,
"past_due_count": 0,
"cancelled_count": 0,
"expired_count": 0,
}
for status, count in status_counts:
stats["total_subscriptions"] += count
if status == SubscriptionStatus.ACTIVE.value:
stats["active_count"] = count
elif status == SubscriptionStatus.TRIAL.value:
stats["trial_count"] = count
elif status == SubscriptionStatus.PAST_DUE.value:
stats["past_due_count"] = count
elif status == SubscriptionStatus.CANCELLED.value:
stats["cancelled_count"] = count
elif status == SubscriptionStatus.EXPIRED.value:
stats["expired_count"] = count
# Count by tier
tier_counts = (
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
.filter(
VendorSubscription.status.in_([
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
])
)
.group_by(VendorSubscription.tier)
.all()
)
tier_distribution = {tier: count for tier, count in tier_counts}
# Calculate MRR (Monthly Recurring Revenue)
mrr_cents = 0
arr_cents = 0
active_subs = (
db.query(VendorSubscription, SubscriptionTier)
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
.all()
)
for sub, tier in active_subs:
if sub.is_annual and tier.price_annual_cents:
mrr_cents += tier.price_annual_cents // 12
arr_cents += tier.price_annual_cents
else:
mrr_cents += tier.price_monthly_cents
arr_cents += tier.price_monthly_cents * 12
stats["tier_distribution"] = tier_distribution
stats["mrr_cents"] = mrr_cents
stats["arr_cents"] = arr_cents
return stats
# Singleton instance
admin_subscription_service = AdminSubscriptionService()

View File

@@ -0,0 +1,582 @@
# app/modules/billing/services/stripe_service.py
"""
Stripe payment integration service.
Provides:
- Customer management
- Subscription management
- Checkout session creation
- Customer portal access
- Webhook event construction
"""
import logging
from datetime import datetime
import stripe
from sqlalchemy.orm import Session
from app.core.config import settings
from app.modules.billing.exceptions import (
StripeNotConfiguredException,
WebhookVerificationException,
)
from app.modules.billing.models import (
BillingHistory,
SubscriptionStatus,
SubscriptionTier,
VendorSubscription,
)
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class StripeService:
"""Service for Stripe payment operations."""
def __init__(self):
self._configured = False
self._configure()
def _configure(self):
"""Configure Stripe with API key."""
if settings.stripe_secret_key:
stripe.api_key = settings.stripe_secret_key
self._configured = True
else:
logger.warning("Stripe API key not configured")
@property
def is_configured(self) -> bool:
"""Check if Stripe is properly configured."""
return self._configured and bool(settings.stripe_secret_key)
def _check_configured(self) -> None:
"""Raise exception if Stripe is not configured."""
if not self.is_configured:
raise StripeNotConfiguredException()
# =========================================================================
# Customer Management
# =========================================================================
def create_customer(
self,
vendor: Vendor,
email: str,
name: str | None = None,
metadata: dict | None = None,
) -> str:
"""
Create a Stripe customer for a vendor.
Returns the Stripe customer ID.
"""
self._check_configured()
customer_metadata = {
"vendor_id": str(vendor.id),
"vendor_code": vendor.vendor_code,
**(metadata or {}),
}
customer = stripe.Customer.create(
email=email,
name=name or vendor.name,
metadata=customer_metadata,
)
logger.info(
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
)
return customer.id
def get_customer(self, customer_id: str) -> stripe.Customer:
"""Get a Stripe customer by ID."""
self._check_configured()
return stripe.Customer.retrieve(customer_id)
def update_customer(
self,
customer_id: str,
email: str | None = None,
name: str | None = None,
metadata: dict | None = None,
) -> stripe.Customer:
"""Update a Stripe customer."""
self._check_configured()
update_data = {}
if email:
update_data["email"] = email
if name:
update_data["name"] = name
if metadata:
update_data["metadata"] = metadata
return stripe.Customer.modify(customer_id, **update_data)
# =========================================================================
# Subscription Management
# =========================================================================
def create_subscription(
self,
customer_id: str,
price_id: str,
trial_days: int | None = None,
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Create a new Stripe subscription.
Args:
customer_id: Stripe customer ID
price_id: Stripe price ID for the subscription
trial_days: Optional trial period in days
metadata: Optional metadata to attach
Returns:
Stripe Subscription object
"""
self._check_configured()
subscription_data = {
"customer": customer_id,
"items": [{"price": price_id}],
"metadata": metadata or {},
"payment_behavior": "default_incomplete",
"expand": ["latest_invoice.payment_intent"],
}
if trial_days:
subscription_data["trial_period_days"] = trial_days
subscription = stripe.Subscription.create(**subscription_data)
logger.info(
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
)
return subscription
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
"""Get a Stripe subscription by ID."""
self._check_configured()
return stripe.Subscription.retrieve(subscription_id)
def update_subscription(
self,
subscription_id: str,
new_price_id: str | None = None,
proration_behavior: str = "create_prorations",
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Update a Stripe subscription (e.g., change tier).
Args:
subscription_id: Stripe subscription ID
new_price_id: New price ID for tier change
proration_behavior: How to handle prorations
metadata: Optional metadata to update
Returns:
Updated Stripe Subscription object
"""
self._check_configured()
update_data = {"proration_behavior": proration_behavior}
if new_price_id:
# Get the subscription to find the item ID
subscription = stripe.Subscription.retrieve(subscription_id)
item_id = subscription["items"]["data"][0]["id"]
update_data["items"] = [{"id": item_id, "price": new_price_id}]
if metadata:
update_data["metadata"] = metadata
updated = stripe.Subscription.modify(subscription_id, **update_data)
logger.info(f"Updated Stripe subscription {subscription_id}")
return updated
def cancel_subscription(
self,
subscription_id: str,
immediately: bool = False,
cancellation_reason: str | None = None,
) -> stripe.Subscription:
"""
Cancel a Stripe subscription.
Args:
subscription_id: Stripe subscription ID
immediately: If True, cancel now. If False, cancel at period end.
cancellation_reason: Optional reason for cancellation
Returns:
Cancelled Stripe Subscription object
"""
self._check_configured()
if immediately:
subscription = stripe.Subscription.cancel(subscription_id)
else:
subscription = stripe.Subscription.modify(
subscription_id,
cancel_at_period_end=True,
metadata={"cancellation_reason": cancellation_reason or "user_request"},
)
logger.info(
f"Cancelled Stripe subscription {subscription_id} "
f"(immediately={immediately})"
)
return subscription
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
"""
Reactivate a cancelled subscription (if not yet ended).
Returns:
Reactivated Stripe Subscription object
"""
self._check_configured()
subscription = stripe.Subscription.modify(
subscription_id,
cancel_at_period_end=False,
)
logger.info(f"Reactivated Stripe subscription {subscription_id}")
return subscription
def cancel_subscription_item(self, subscription_item_id: str) -> None:
"""
Cancel a subscription item (used for add-ons).
Args:
subscription_item_id: Stripe subscription item ID
"""
self._check_configured()
stripe.SubscriptionItem.delete(subscription_item_id)
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
# =========================================================================
# Checkout & Portal
# =========================================================================
def create_checkout_session(
self,
db: Session,
vendor: Vendor,
price_id: str,
success_url: str,
cancel_url: str,
trial_days: int | None = None,
quantity: int = 1,
metadata: dict | None = None,
) -> stripe.checkout.Session:
"""
Create a Stripe Checkout session for subscription signup.
Args:
db: Database session
vendor: Vendor to create checkout for
price_id: Stripe price ID
success_url: URL to redirect on success
cancel_url: URL to redirect on cancel
trial_days: Optional trial period
quantity: Number of items (default 1)
metadata: Additional metadata to store
Returns:
Stripe Checkout Session object
"""
self._check_configured()
# Get or create Stripe customer
subscription = (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor.id)
.first()
)
if subscription and subscription.stripe_customer_id:
customer_id = subscription.stripe_customer_id
else:
# Get vendor owner email
from models.database.vendor import VendorUser
owner = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.is_owner == True,
)
.first()
)
email = owner.user.email if owner and owner.user else None
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
# Store the customer ID
if subscription:
subscription.stripe_customer_id = customer_id
db.flush()
# Build metadata
session_metadata = {
"vendor_id": str(vendor.id),
"vendor_code": vendor.vendor_code,
}
if metadata:
session_metadata.update(metadata)
session_data = {
"customer": customer_id,
"line_items": [{"price": price_id, "quantity": quantity}],
"mode": "subscription",
"success_url": success_url,
"cancel_url": cancel_url,
"metadata": session_metadata,
}
if trial_days:
session_data["subscription_data"] = {"trial_period_days": trial_days}
session = stripe.checkout.Session.create(**session_data)
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
return session
def create_portal_session(
self,
customer_id: str,
return_url: str,
) -> stripe.billing_portal.Session:
"""
Create a Stripe Customer Portal session.
Allows customers to manage their subscription, payment methods, and invoices.
Args:
customer_id: Stripe customer ID
return_url: URL to return to after portal
Returns:
Stripe Portal Session object
"""
self._check_configured()
session = stripe.billing_portal.Session.create(
customer=customer_id,
return_url=return_url,
)
logger.info(f"Created portal session for customer {customer_id}")
return session
# =========================================================================
# Invoice Management
# =========================================================================
def get_invoices(
self,
customer_id: str,
limit: int = 10,
) -> list[stripe.Invoice]:
"""Get invoices for a customer."""
self._check_configured()
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
return list(invoices.data)
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
"""Get the upcoming invoice for a customer."""
self._check_configured()
try:
return stripe.Invoice.upcoming(customer=customer_id)
except stripe.error.InvalidRequestError:
# No upcoming invoice
return None
# =========================================================================
# Webhook Handling
# =========================================================================
def construct_event(
self,
payload: bytes,
sig_header: str,
) -> stripe.Event:
"""
Construct and verify a Stripe webhook event.
Args:
payload: Raw request body
sig_header: Stripe-Signature header
Returns:
Verified Stripe Event object
Raises:
WebhookVerificationException: If signature verification fails
"""
if not settings.stripe_webhook_secret:
raise WebhookVerificationException("Stripe webhook secret not configured")
try:
event = stripe.Webhook.construct_event(
payload,
sig_header,
settings.stripe_webhook_secret,
)
return event
except stripe.error.SignatureVerificationError as e:
logger.error(f"Webhook signature verification failed: {e}")
raise WebhookVerificationException("Invalid webhook signature")
# =========================================================================
# SetupIntent & Payment Method Management
# =========================================================================
def create_setup_intent(
self,
customer_id: str,
metadata: dict | None = None,
) -> stripe.SetupIntent:
"""
Create a SetupIntent to collect card without charging.
Used for trial signups where we collect card upfront
but don't charge until trial ends.
Args:
customer_id: Stripe customer ID
metadata: Optional metadata to attach
Returns:
Stripe SetupIntent object with client_secret for frontend
"""
self._check_configured()
setup_intent = stripe.SetupIntent.create(
customer=customer_id,
payment_method_types=["card"],
metadata=metadata or {},
)
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
return setup_intent
def attach_payment_method_to_customer(
self,
customer_id: str,
payment_method_id: str,
set_as_default: bool = True,
) -> None:
"""
Attach a payment method to customer and optionally set as default.
Args:
customer_id: Stripe customer ID
payment_method_id: Payment method ID from confirmed SetupIntent
set_as_default: Whether to set as default payment method
"""
self._check_configured()
# Attach the payment method to the customer
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
if set_as_default:
stripe.Customer.modify(
customer_id,
invoice_settings={"default_payment_method": payment_method_id},
)
logger.info(
f"Attached payment method {payment_method_id} to customer {customer_id} "
f"(default={set_as_default})"
)
def create_subscription_with_trial(
self,
customer_id: str,
price_id: str,
trial_days: int = 30,
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Create subscription with trial period.
Customer must have a default payment method attached.
Card will be charged automatically after trial ends.
Args:
customer_id: Stripe customer ID (must have default payment method)
price_id: Stripe price ID for the subscription tier
trial_days: Number of trial days (default 30)
metadata: Optional metadata to attach
Returns:
Stripe Subscription object
"""
self._check_configured()
subscription = stripe.Subscription.create(
customer=customer_id,
items=[{"price": price_id}],
trial_period_days=trial_days,
metadata=metadata or {},
# Use default payment method for future charges
default_payment_method=None, # Uses customer's default
)
logger.info(
f"Created subscription {subscription.id} with {trial_days}-day trial "
f"for customer {customer_id}"
)
return subscription
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
"""Get a SetupIntent by ID."""
self._check_configured()
return stripe.SetupIntent.retrieve(setup_intent_id)
# =========================================================================
# Price/Product Management
# =========================================================================
def get_price(self, price_id: str) -> stripe.Price:
"""Get a Stripe price by ID."""
self._check_configured()
return stripe.Price.retrieve(price_id)
def get_product(self, product_id: str) -> stripe.Product:
"""Get a Stripe product by ID."""
self._check_configured()
return stripe.Product.retrieve(product_id)
def list_prices(
self,
product_id: str | None = None,
active: bool = True,
) -> list[stripe.Price]:
"""List Stripe prices, optionally filtered by product."""
self._check_configured()
params = {"active": active}
if product_id:
params["product"] = product_id
prices = stripe.Price.list(**params)
return list(prices.data)
# Create service instance
stripe_service = StripeService()

View File

@@ -0,0 +1,631 @@
# app/modules/billing/services/subscription_service.py
"""
Subscription service for tier-based access control.
Handles:
- Subscription creation and management
- Tier limit enforcement
- Usage tracking
- Feature gating
Usage:
from app.modules.billing.services import subscription_service
# Check if vendor can create an order
can_create, message = subscription_service.can_create_order(db, vendor_id)
# Increment order counter after successful order
subscription_service.increment_order_count(db, vendor_id)
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.modules.billing.exceptions import (
FeatureNotAvailableException,
SubscriptionNotFoundException,
TierLimitExceededException,
)
from app.modules.billing.models import (
SubscriptionStatus,
SubscriptionTier,
TIER_LIMITS,
TierCode,
VendorSubscription,
)
from app.modules.billing.schemas import (
SubscriptionCreate,
SubscriptionUpdate,
SubscriptionUsage,
TierInfo,
TierLimits,
UsageSummary,
)
from models.database.product import Product
from models.database.vendor import Vendor, VendorUser
logger = logging.getLogger(__name__)
class SubscriptionService:
"""Service for subscription and tier limit operations."""
# =========================================================================
# Tier Information
# =========================================================================
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
"""
Get full tier information.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
# Try database first if session provided
if db is not None:
db_tier = self.get_tier_by_code(db, tier_code)
if db_tier:
return TierInfo(
code=db_tier.code,
name=db_tier.name,
price_monthly_cents=db_tier.price_monthly_cents,
price_annual_cents=db_tier.price_annual_cents,
limits=TierLimits(
orders_per_month=db_tier.orders_per_month,
products_limit=db_tier.products_limit,
team_members=db_tier.team_members,
order_history_months=db_tier.order_history_months,
),
features=db_tier.features or [],
)
# Fallback to hardcoded TIER_LIMITS
return self._get_tier_from_legacy(tier_code)
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
try:
tier = TierCode(tier_code)
except ValueError:
tier = TierCode.ESSENTIAL
limits = TIER_LIMITS[tier]
return TierInfo(
code=tier.value,
name=limits["name"],
price_monthly_cents=limits["price_monthly_cents"],
price_annual_cents=limits.get("price_annual_cents"),
limits=TierLimits(
orders_per_month=limits.get("orders_per_month"),
products_limit=limits.get("products_limit"),
team_members=limits.get("team_members"),
order_history_months=limits.get("order_history_months"),
),
features=limits.get("features", []),
)
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
"""
Get information for all tiers.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
if db is not None:
db_tiers = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.is_public == True, # noqa: E712
)
.order_by(SubscriptionTier.display_order)
.all()
)
if db_tiers:
return [
TierInfo(
code=t.code,
name=t.name,
price_monthly_cents=t.price_monthly_cents,
price_annual_cents=t.price_annual_cents,
limits=TierLimits(
orders_per_month=t.orders_per_month,
products_limit=t.products_limit,
team_members=t.team_members,
order_history_months=t.order_history_months,
),
features=t.features or [],
)
for t in db_tiers
]
# Fallback to hardcoded
return [
self._get_tier_from_legacy(tier.value)
for tier in TierCode
]
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
"""Get subscription tier by code."""
return (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
"""Get tier ID from tier code. Returns None if tier not found."""
tier = self.get_tier_by_code(db, tier_code)
return tier.id if tier else None
# =========================================================================
# Subscription CRUD
# =========================================================================
def get_subscription(
self, db: Session, vendor_id: int
) -> VendorSubscription | None:
"""Get vendor subscription."""
return (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
def get_subscription_or_raise(
self, db: Session, vendor_id: int
) -> VendorSubscription:
"""Get vendor subscription or raise exception."""
subscription = self.get_subscription(db, vendor_id)
if not subscription:
raise SubscriptionNotFoundException(vendor_id)
return subscription
def get_current_tier(
self, db: Session, vendor_id: int
) -> TierCode | None:
"""Get vendor's current subscription tier code."""
subscription = self.get_subscription(db, vendor_id)
if subscription:
try:
return TierCode(subscription.tier)
except ValueError:
return None
return None
def get_or_create_subscription(
self,
db: Session,
vendor_id: int,
tier: str = TierCode.ESSENTIAL.value,
trial_days: int = 14,
) -> VendorSubscription:
"""
Get existing subscription or create a new trial subscription.
Used when a vendor first accesses the system.
"""
subscription = self.get_subscription(db, vendor_id)
if subscription:
return subscription
# Create new trial subscription
now = datetime.now(UTC)
trial_end = now + timedelta(days=trial_days)
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=tier,
tier_id=tier_id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=trial_end,
trial_ends_at=trial_end,
is_annual=False,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(
f"Created trial subscription for vendor {vendor_id} "
f"(tier={tier}, trial_ends={trial_end})"
)
return subscription
def create_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionCreate,
) -> VendorSubscription:
"""Create a subscription for a vendor."""
# Check if subscription exists
existing = self.get_subscription(db, vendor_id)
if existing:
raise ValueError("Vendor already has a subscription")
now = datetime.now(UTC)
# Calculate period end based on billing cycle
if data.is_annual:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
# Handle trial
trial_ends_at = None
status = SubscriptionStatus.ACTIVE.value
if data.trial_days > 0:
trial_ends_at = now + timedelta(days=data.trial_days)
status = SubscriptionStatus.TRIAL.value
period_end = trial_ends_at
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, data.tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=data.tier,
tier_id=tier_id,
status=status,
period_start=now,
period_end=period_end,
trial_ends_at=trial_ends_at,
is_annual=data.is_annual,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
return subscription
def update_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionUpdate,
) -> VendorSubscription:
"""Update a vendor subscription."""
subscription = self.get_subscription_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
# If tier is being updated, also update tier_id
if "tier" in update_data:
tier_id = self.get_tier_id(db, update_data["tier"])
update_data["tier_id"] = tier_id
for key, value in update_data.items():
setattr(subscription, key, value)
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Updated subscription for vendor {vendor_id}")
return subscription
def upgrade_tier(
self,
db: Session,
vendor_id: int,
new_tier: str,
) -> VendorSubscription:
"""Upgrade vendor to a new tier."""
subscription = self.get_subscription_or_raise(db, vendor_id)
old_tier = subscription.tier
subscription.tier = new_tier
subscription.tier_id = self.get_tier_id(db, new_tier)
subscription.updated_at = datetime.now(UTC)
# If upgrading from trial, mark as active
if subscription.status == SubscriptionStatus.TRIAL.value:
subscription.status = SubscriptionStatus.ACTIVE.value
db.flush()
db.refresh(subscription)
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
return subscription
def cancel_subscription(
self,
db: Session,
vendor_id: int,
reason: str | None = None,
) -> VendorSubscription:
"""Cancel a vendor subscription (access until period end)."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.status = SubscriptionStatus.CANCELLED.value
subscription.cancelled_at = datetime.now(UTC)
subscription.cancellation_reason = reason
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Cancelled subscription for vendor {vendor_id}")
return subscription
# =========================================================================
# Usage Tracking
# =========================================================================
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
"""Get current subscription usage statistics."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Calculate usage stats
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
def calc_percent(current: int, limit: int | None) -> float | None:
if limit is None or limit == 0:
return None
return min(100.0, (current / limit) * 100)
return SubscriptionUsage(
orders_used=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
products_used=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
products_percent_used=calc_percent(products_count, products_limit),
team_members_used=team_count,
team_members_limit=team_limit,
team_members_remaining=calc_remaining(team_count, team_limit),
team_members_percent_used=calc_percent(team_count, team_limit),
)
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
"""Get usage summary for billing page display."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Get limits
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
return UsageSummary(
orders_this_period=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
products_count=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
team_count=team_count,
team_limit=team_limit,
team_remaining=calc_remaining(team_count, team_limit),
)
def increment_order_count(self, db: Session, vendor_id: int) -> None:
"""
Increment the order counter for the current period.
Call this after successfully creating/importing an order.
"""
subscription = self.get_or_create_subscription(db, vendor_id)
subscription.increment_order_count()
db.flush()
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
"""Reset counters for a new billing period."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.reset_period_counters()
db.flush()
logger.info(f"Reset period counters for vendor {vendor_id}")
# =========================================================================
# Limit Checks
# =========================================================================
def can_create_order(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can create/import another order.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.can_create_order()
def check_order_limit(self, db: Session, vendor_id: int) -> None:
"""
Check order limit and raise exception if exceeded.
Use this in order creation flows.
"""
can_create, message = self.can_create_order(db, vendor_id)
if not can_create:
subscription = self.get_subscription(db, vendor_id)
raise TierLimitExceededException(
message=message or "Order limit exceeded",
limit_type="orders",
current=subscription.orders_this_period if subscription else 0,
limit=subscription.orders_limit if subscription else 0,
)
def can_add_product(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another product.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
return subscription.can_add_product(products_count)
def check_product_limit(self, db: Session, vendor_id: int) -> None:
"""
Check product limit and raise exception if exceeded.
Use this in product creation flows.
"""
can_add, message = self.can_add_product(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Product limit exceeded",
limit_type="products",
current=products_count,
limit=subscription.products_limit if subscription else 0,
)
def can_add_team_member(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another team member.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
return subscription.can_add_team_member(team_count)
def check_team_limit(self, db: Session, vendor_id: int) -> None:
"""
Check team member limit and raise exception if exceeded.
Use this in team member invitation flows.
"""
can_add, message = self.can_add_team_member(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Team member limit exceeded",
limit_type="team_members",
current=team_count,
limit=subscription.team_members_limit if subscription else 0,
)
# =========================================================================
# Feature Gating
# =========================================================================
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
"""Check if vendor has access to a feature."""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.has_feature(feature)
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
"""
Check feature access and raise exception if not available.
Use this to gate premium features.
"""
if not self.has_feature(db, vendor_id, feature):
subscription = self.get_or_create_subscription(db, vendor_id)
# Find which tier has this feature
required_tier = None
for tier_code, limits in TIER_LIMITS.items():
if feature in limits.get("features", []):
required_tier = limits["name"]
break
raise FeatureNotAvailableException(
feature=feature,
current_tier=subscription.tier,
required_tier=required_tier or "higher",
)
def get_feature_tier(self, feature: str) -> str | None:
"""Get the minimum tier required for a feature."""
for tier_code in [
TierCode.ESSENTIAL,
TierCode.PROFESSIONAL,
TierCode.BUSINESS,
TierCode.ENTERPRISE,
]:
if feature in TIER_LIMITS[tier_code].get("features", []):
return tier_code.value
return None
# Singleton instance
subscription_service = SubscriptionService()

View File

@@ -0,0 +1,26 @@
# app/modules/billing/tasks/__init__.py
"""
Billing module Celery tasks.
Scheduled tasks for:
- Resetting period counters
- Checking trial expirations
- Syncing with Stripe
- Cleaning up stale subscriptions
Note: capture_capacity_snapshot moved to monitoring module.
"""
from app.modules.billing.tasks.subscription import (
reset_period_counters,
check_trial_expirations,
sync_stripe_status,
cleanup_stale_subscriptions,
)
__all__ = [
"reset_period_counters",
"check_trial_expirations",
"sync_stripe_status",
"cleanup_stale_subscriptions",
]

View File

@@ -0,0 +1,255 @@
# app/modules/billing/tasks/subscription.py
"""
Celery tasks for subscription management.
Scheduled tasks for:
- Resetting period counters
- Checking trial expirations
- Syncing with Stripe
- Cleaning up stale subscriptions
"""
import logging
from datetime import UTC, datetime, timedelta
from app.core.celery_config import celery_app
from app.modules.billing.models import SubscriptionStatus, VendorSubscription
from app.modules.billing.services import stripe_service
from app.modules.task_base import ModuleTask
logger = logging.getLogger(__name__)
@celery_app.task(
bind=True,
base=ModuleTask,
name="app.modules.billing.tasks.subscription.reset_period_counters",
)
def reset_period_counters(self):
"""
Reset order counters for subscriptions whose billing period has ended.
Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates.
"""
now = datetime.now(UTC)
reset_count = 0
with self.get_db() as db:
# Find subscriptions where period has ended
expired_periods = (
db.query(VendorSubscription)
.filter(
VendorSubscription.period_end <= now,
VendorSubscription.status.in_(["active", "trial"]),
)
.all()
)
for subscription in expired_periods:
old_period_end = subscription.period_end
# Reset counters
subscription.orders_this_period = 0
subscription.orders_limit_reached_at = None
# Set new period dates
if subscription.is_annual:
subscription.period_start = now
subscription.period_end = now + timedelta(days=365)
else:
subscription.period_start = now
subscription.period_end = now + timedelta(days=30)
subscription.updated_at = now
reset_count += 1
logger.info(
f"Reset period counters for vendor {subscription.vendor_id}: "
f"old_period_end={old_period_end}, new_period_end={subscription.period_end}"
)
logger.info(f"Reset period counters for {reset_count} subscriptions")
return {"reset_count": reset_count}
@celery_app.task(
bind=True,
base=ModuleTask,
name="app.modules.billing.tasks.subscription.check_trial_expirations",
)
def check_trial_expirations(self):
"""
Check for expired trials and update their status.
Runs daily at 01:00.
- Trials without payment method -> expired
- Trials with payment method -> active
"""
now = datetime.now(UTC)
expired_count = 0
activated_count = 0
with self.get_db() as db:
# Find expired trials
expired_trials = (
db.query(VendorSubscription)
.filter(
VendorSubscription.status == SubscriptionStatus.TRIAL.value,
VendorSubscription.trial_ends_at <= now,
)
.all()
)
for subscription in expired_trials:
if subscription.stripe_payment_method_id:
# Has payment method - activate
subscription.status = SubscriptionStatus.ACTIVE.value
activated_count += 1
logger.info(
f"Activated subscription for vendor {subscription.vendor_id} "
f"(trial ended with payment method)"
)
else:
# No payment method - expire
subscription.status = SubscriptionStatus.EXPIRED.value
expired_count += 1
logger.info(
f"Expired trial for vendor {subscription.vendor_id} "
f"(no payment method)"
)
subscription.updated_at = now
logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated")
return {"expired_count": expired_count, "activated_count": activated_count}
@celery_app.task(
bind=True,
base=ModuleTask,
name="app.modules.billing.tasks.subscription.sync_stripe_status",
max_retries=3,
default_retry_delay=300,
)
def sync_stripe_status(self):
"""
Sync subscription status with Stripe.
Runs hourly at :30. Fetches current status from Stripe and updates local records.
"""
if not stripe_service.is_configured:
logger.warning("Stripe not configured, skipping sync")
return {"synced": 0, "skipped": True}
synced_count = 0
error_count = 0
with self.get_db() as db:
# Find subscriptions with Stripe IDs
subscriptions = (
db.query(VendorSubscription)
.filter(VendorSubscription.stripe_subscription_id.isnot(None))
.all()
)
for subscription in subscriptions:
try:
# Fetch from Stripe
stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id)
if not stripe_sub:
logger.warning(
f"Stripe subscription {subscription.stripe_subscription_id} "
f"not found for vendor {subscription.vendor_id}"
)
continue
# Map Stripe status to local status
status_map = {
"active": SubscriptionStatus.ACTIVE.value,
"trialing": SubscriptionStatus.TRIAL.value,
"past_due": SubscriptionStatus.PAST_DUE.value,
"canceled": SubscriptionStatus.CANCELLED.value,
"unpaid": SubscriptionStatus.PAST_DUE.value,
"incomplete": SubscriptionStatus.TRIAL.value,
"incomplete_expired": SubscriptionStatus.EXPIRED.value,
}
new_status = status_map.get(stripe_sub.status)
if new_status and new_status != subscription.status:
old_status = subscription.status
subscription.status = new_status
subscription.updated_at = datetime.now(UTC)
logger.info(
f"Updated vendor {subscription.vendor_id} status: "
f"{old_status} -> {new_status} (from Stripe)"
)
# Update period dates from Stripe
if stripe_sub.current_period_start:
subscription.period_start = datetime.fromtimestamp(
stripe_sub.current_period_start, tz=UTC
)
if stripe_sub.current_period_end:
subscription.period_end = datetime.fromtimestamp(
stripe_sub.current_period_end, tz=UTC
)
# Update payment method
if stripe_sub.default_payment_method:
subscription.stripe_payment_method_id = (
stripe_sub.default_payment_method
if isinstance(stripe_sub.default_payment_method, str)
else stripe_sub.default_payment_method.id
)
synced_count += 1
except Exception as e:
logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}")
error_count += 1
logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors")
return {"synced_count": synced_count, "error_count": error_count}
@celery_app.task(
bind=True,
base=ModuleTask,
name="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions",
)
def cleanup_stale_subscriptions(self):
"""
Clean up subscriptions in inconsistent states.
Runs weekly on Sunday at 03:00.
"""
now = datetime.now(UTC)
cleaned_count = 0
with self.get_db() as db:
# Find cancelled subscriptions past their period end
stale_cancelled = (
db.query(VendorSubscription)
.filter(
VendorSubscription.status == SubscriptionStatus.CANCELLED.value,
VendorSubscription.period_end < now - timedelta(days=30),
)
.all()
)
for subscription in stale_cancelled:
# Mark as expired (fully terminated)
subscription.status = SubscriptionStatus.EXPIRED.value
subscription.updated_at = now
cleaned_count += 1
logger.info(
f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}"
)
logger.info(f"Cleaned up {cleaned_count} stale subscriptions")
return {"cleaned_count": cleaned_count}

View File

@@ -2,351 +2,21 @@
"""
Admin Subscription Service.
Handles subscription management operations for platform administrators:
- Subscription tier CRUD
- Vendor subscription management
- Billing history queries
- Subscription analytics
DEPRECATED: This file is maintained for backward compatibility.
Import from app.modules.billing.services instead:
from app.modules.billing.services import admin_subscription_service
This file re-exports the service from its new location in the billing module.
"""
import logging
from math import ceil
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
BusinessLogicException,
ConflictException,
ResourceNotFoundException,
TierNotFoundException,
# Re-export from new location for backward compatibility
from app.modules.billing.services.admin_subscription_service import (
AdminSubscriptionService,
admin_subscription_service,
)
from models.database.product import Product
from models.database.subscription import (
BillingHistory,
SubscriptionStatus,
SubscriptionTier,
VendorSubscription,
)
from models.database.vendor import Vendor, VendorUser
logger = logging.getLogger(__name__)
class AdminSubscriptionService:
"""Service for admin subscription management operations."""
# =========================================================================
# Subscription Tiers
# =========================================================================
def get_tiers(
self, db: Session, include_inactive: bool = False
) -> list[SubscriptionTier]:
"""Get all subscription tiers."""
query = db.query(SubscriptionTier)
if not include_inactive:
query = query.filter(SubscriptionTier.is_active == True) # noqa: E712
return query.order_by(SubscriptionTier.display_order).all()
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
"""Get a subscription tier by code."""
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
if not tier:
raise TierNotFoundException(tier_code)
return tier
def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier:
"""Create a new subscription tier."""
# Check for duplicate code
existing = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_data["code"])
.first()
)
if existing:
raise ConflictException(
f"Tier with code '{tier_data['code']}' already exists"
)
tier = SubscriptionTier(**tier_data)
db.add(tier)
logger.info(f"Created subscription tier: {tier.code}")
return tier
def update_tier(
self, db: Session, tier_code: str, update_data: dict
) -> SubscriptionTier:
"""Update a subscription tier."""
tier = self.get_tier_by_code(db, tier_code)
for field, value in update_data.items():
setattr(tier, field, value)
logger.info(f"Updated subscription tier: {tier.code}")
return tier
def deactivate_tier(self, db: Session, tier_code: str) -> None:
"""Soft-delete a subscription tier."""
tier = self.get_tier_by_code(db, tier_code)
# Check if any active subscriptions use this tier
active_subs = (
db.query(VendorSubscription)
.filter(
VendorSubscription.tier == tier_code,
VendorSubscription.status.in_([
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
]),
)
.count()
)
if active_subs > 0:
raise BusinessLogicException(
f"Cannot delete tier: {active_subs} active subscriptions are using it"
)
tier.is_active = False
logger.info(f"Soft-deleted subscription tier: {tier.code}")
# =========================================================================
# Vendor Subscriptions
# =========================================================================
def list_subscriptions(
self,
db: Session,
page: int = 1,
per_page: int = 20,
status: str | None = None,
tier: str | None = None,
search: str | None = None,
) -> dict:
"""List vendor subscriptions with filtering and pagination."""
query = (
db.query(VendorSubscription, Vendor)
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
)
# Apply filters
if status:
query = query.filter(VendorSubscription.status == status)
if tier:
query = query.filter(VendorSubscription.tier == tier)
if search:
query = query.filter(Vendor.name.ilike(f"%{search}%"))
# Count total
total = query.count()
# Paginate
offset = (page - 1) * per_page
results = (
query.order_by(VendorSubscription.created_at.desc())
.offset(offset)
.limit(per_page)
.all()
)
return {
"results": results,
"total": total,
"page": page,
"per_page": per_page,
"pages": ceil(total / per_page) if total > 0 else 0,
}
def get_subscription(self, db: Session, vendor_id: int) -> tuple:
"""Get subscription for a specific vendor."""
result = (
db.query(VendorSubscription, Vendor)
.join(Vendor, VendorSubscription.vendor_id == Vendor.id)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
if not result:
raise ResourceNotFoundException("Subscription", str(vendor_id))
return result
def update_subscription(
self, db: Session, vendor_id: int, update_data: dict
) -> tuple:
"""Update a vendor's subscription."""
result = self.get_subscription(db, vendor_id)
sub, vendor = result
for field, value in update_data.items():
setattr(sub, field, value)
logger.info(
f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}"
)
return sub, vendor
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
"""Get a vendor by ID."""
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise ResourceNotFoundException("Vendor", str(vendor_id))
return vendor
def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict:
"""Get usage counts (products and team members) for a vendor."""
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(
VendorUser.vendor_id == vendor_id,
VendorUser.is_active == True, # noqa: E712
)
.scalar()
or 0
)
return {
"products_count": products_count,
"team_count": team_count,
}
# =========================================================================
# Billing History
# =========================================================================
def list_billing_history(
self,
db: Session,
page: int = 1,
per_page: int = 20,
vendor_id: int | None = None,
status: str | None = None,
) -> dict:
"""List billing history across all vendors."""
query = (
db.query(BillingHistory, Vendor)
.join(Vendor, BillingHistory.vendor_id == Vendor.id)
)
if vendor_id:
query = query.filter(BillingHistory.vendor_id == vendor_id)
if status:
query = query.filter(BillingHistory.status == status)
total = query.count()
offset = (page - 1) * per_page
results = (
query.order_by(BillingHistory.invoice_date.desc())
.offset(offset)
.limit(per_page)
.all()
)
return {
"results": results,
"total": total,
"page": page,
"per_page": per_page,
"pages": ceil(total / per_page) if total > 0 else 0,
}
# =========================================================================
# Statistics
# =========================================================================
def get_stats(self, db: Session) -> dict:
"""Get subscription statistics for admin dashboard."""
# Count by status
status_counts = (
db.query(VendorSubscription.status, func.count(VendorSubscription.id))
.group_by(VendorSubscription.status)
.all()
)
stats = {
"total_subscriptions": 0,
"active_count": 0,
"trial_count": 0,
"past_due_count": 0,
"cancelled_count": 0,
"expired_count": 0,
}
for status, count in status_counts:
stats["total_subscriptions"] += count
if status == SubscriptionStatus.ACTIVE.value:
stats["active_count"] = count
elif status == SubscriptionStatus.TRIAL.value:
stats["trial_count"] = count
elif status == SubscriptionStatus.PAST_DUE.value:
stats["past_due_count"] = count
elif status == SubscriptionStatus.CANCELLED.value:
stats["cancelled_count"] = count
elif status == SubscriptionStatus.EXPIRED.value:
stats["expired_count"] = count
# Count by tier
tier_counts = (
db.query(VendorSubscription.tier, func.count(VendorSubscription.id))
.filter(
VendorSubscription.status.in_([
SubscriptionStatus.ACTIVE.value,
SubscriptionStatus.TRIAL.value,
])
)
.group_by(VendorSubscription.tier)
.all()
)
tier_distribution = {tier: count for tier, count in tier_counts}
# Calculate MRR (Monthly Recurring Revenue)
mrr_cents = 0
arr_cents = 0
active_subs = (
db.query(VendorSubscription, SubscriptionTier)
.join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code)
.filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value)
.all()
)
for sub, tier in active_subs:
if sub.is_annual and tier.price_annual_cents:
mrr_cents += tier.price_annual_cents // 12
arr_cents += tier.price_annual_cents
else:
mrr_cents += tier.price_monthly_cents
arr_cents += tier.price_monthly_cents * 12
stats["tier_distribution"] = tier_distribution
stats["mrr_cents"] = mrr_cents
stats["arr_cents"] = arr_cents
return stats
# Singleton instance
admin_subscription_service = AdminSubscriptionService()
__all__ = [
"AdminSubscriptionService",
"admin_subscription_service",
]

View File

@@ -2,592 +2,21 @@
"""
Stripe payment integration service.
Provides:
- Customer management
- Subscription management
- Checkout session creation
- Customer portal access
- Webhook event construction
DEPRECATED: This file is maintained for backward compatibility.
Import from app.modules.billing.services instead:
from app.modules.billing.services import stripe_service
This file re-exports the service from its new location in the billing module.
"""
import logging
from datetime import datetime
import stripe
from sqlalchemy.orm import Session
from app.core.config import settings
from models.database.subscription import (
BillingHistory,
SubscriptionStatus,
SubscriptionTier,
VendorSubscription,
# Re-export from new location for backward compatibility
from app.modules.billing.services.stripe_service import (
StripeService,
stripe_service,
)
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class StripeService:
"""Service for Stripe payment operations."""
def __init__(self):
self._configured = False
self._configure()
def _configure(self):
"""Configure Stripe with API key."""
if settings.stripe_secret_key:
stripe.api_key = settings.stripe_secret_key
self._configured = True
else:
logger.warning("Stripe API key not configured")
@property
def is_configured(self) -> bool:
"""Check if Stripe is properly configured."""
return self._configured and bool(settings.stripe_secret_key)
# =========================================================================
# Customer Management
# =========================================================================
def create_customer(
self,
vendor: Vendor,
email: str,
name: str | None = None,
metadata: dict | None = None,
) -> str:
"""
Create a Stripe customer for a vendor.
Returns the Stripe customer ID.
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
customer_metadata = {
"vendor_id": str(vendor.id),
"vendor_code": vendor.vendor_code,
**(metadata or {}),
}
customer = stripe.Customer.create(
email=email,
name=name or vendor.name,
metadata=customer_metadata,
)
logger.info(
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
)
return customer.id
def get_customer(self, customer_id: str) -> stripe.Customer:
"""Get a Stripe customer by ID."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
return stripe.Customer.retrieve(customer_id)
def update_customer(
self,
customer_id: str,
email: str | None = None,
name: str | None = None,
metadata: dict | None = None,
) -> stripe.Customer:
"""Update a Stripe customer."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
update_data = {}
if email:
update_data["email"] = email
if name:
update_data["name"] = name
if metadata:
update_data["metadata"] = metadata
return stripe.Customer.modify(customer_id, **update_data)
# =========================================================================
# Subscription Management
# =========================================================================
def create_subscription(
self,
customer_id: str,
price_id: str,
trial_days: int | None = None,
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Create a new Stripe subscription.
Args:
customer_id: Stripe customer ID
price_id: Stripe price ID for the subscription
trial_days: Optional trial period in days
metadata: Optional metadata to attach
Returns:
Stripe Subscription object
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
subscription_data = {
"customer": customer_id,
"items": [{"price": price_id}],
"metadata": metadata or {},
"payment_behavior": "default_incomplete",
"expand": ["latest_invoice.payment_intent"],
}
if trial_days:
subscription_data["trial_period_days"] = trial_days
subscription = stripe.Subscription.create(**subscription_data)
logger.info(
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
)
return subscription
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
"""Get a Stripe subscription by ID."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
return stripe.Subscription.retrieve(subscription_id)
def update_subscription(
self,
subscription_id: str,
new_price_id: str | None = None,
proration_behavior: str = "create_prorations",
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Update a Stripe subscription (e.g., change tier).
Args:
subscription_id: Stripe subscription ID
new_price_id: New price ID for tier change
proration_behavior: How to handle prorations
metadata: Optional metadata to update
Returns:
Updated Stripe Subscription object
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
update_data = {"proration_behavior": proration_behavior}
if new_price_id:
# Get the subscription to find the item ID
subscription = stripe.Subscription.retrieve(subscription_id)
item_id = subscription["items"]["data"][0]["id"]
update_data["items"] = [{"id": item_id, "price": new_price_id}]
if metadata:
update_data["metadata"] = metadata
updated = stripe.Subscription.modify(subscription_id, **update_data)
logger.info(f"Updated Stripe subscription {subscription_id}")
return updated
def cancel_subscription(
self,
subscription_id: str,
immediately: bool = False,
cancellation_reason: str | None = None,
) -> stripe.Subscription:
"""
Cancel a Stripe subscription.
Args:
subscription_id: Stripe subscription ID
immediately: If True, cancel now. If False, cancel at period end.
cancellation_reason: Optional reason for cancellation
Returns:
Cancelled Stripe Subscription object
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
if immediately:
subscription = stripe.Subscription.cancel(subscription_id)
else:
subscription = stripe.Subscription.modify(
subscription_id,
cancel_at_period_end=True,
metadata={"cancellation_reason": cancellation_reason or "user_request"},
)
logger.info(
f"Cancelled Stripe subscription {subscription_id} "
f"(immediately={immediately})"
)
return subscription
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
"""
Reactivate a cancelled subscription (if not yet ended).
Returns:
Reactivated Stripe Subscription object
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
subscription = stripe.Subscription.modify(
subscription_id,
cancel_at_period_end=False,
)
logger.info(f"Reactivated Stripe subscription {subscription_id}")
return subscription
def cancel_subscription_item(self, subscription_item_id: str) -> None:
"""
Cancel a subscription item (used for add-ons).
Args:
subscription_item_id: Stripe subscription item ID
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
stripe.SubscriptionItem.delete(subscription_item_id)
logger.info(f"Cancelled Stripe subscription item {subscription_item_id}")
# =========================================================================
# Checkout & Portal
# =========================================================================
def create_checkout_session(
self,
db: Session,
vendor: Vendor,
price_id: str,
success_url: str,
cancel_url: str,
trial_days: int | None = None,
quantity: int = 1,
metadata: dict | None = None,
) -> stripe.checkout.Session:
"""
Create a Stripe Checkout session for subscription signup.
Args:
db: Database session
vendor: Vendor to create checkout for
price_id: Stripe price ID
success_url: URL to redirect on success
cancel_url: URL to redirect on cancel
trial_days: Optional trial period
quantity: Number of items (default 1)
metadata: Additional metadata to store
Returns:
Stripe Checkout Session object
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
# Get or create Stripe customer
subscription = (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor.id)
.first()
)
if subscription and subscription.stripe_customer_id:
customer_id = subscription.stripe_customer_id
else:
# Get vendor owner email
from models.database.vendor import VendorUser
owner = (
db.query(VendorUser)
.filter(
VendorUser.vendor_id == vendor.id,
VendorUser.is_owner == True,
)
.first()
)
email = owner.user.email if owner and owner.user else None
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
# Store the customer ID
if subscription:
subscription.stripe_customer_id = customer_id
db.flush()
# Build metadata
session_metadata = {
"vendor_id": str(vendor.id),
"vendor_code": vendor.vendor_code,
}
if metadata:
session_metadata.update(metadata)
session_data = {
"customer": customer_id,
"line_items": [{"price": price_id, "quantity": quantity}],
"mode": "subscription",
"success_url": success_url,
"cancel_url": cancel_url,
"metadata": session_metadata,
}
if trial_days:
session_data["subscription_data"] = {"trial_period_days": trial_days}
session = stripe.checkout.Session.create(**session_data)
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
return session
def create_portal_session(
self,
customer_id: str,
return_url: str,
) -> stripe.billing_portal.Session:
"""
Create a Stripe Customer Portal session.
Allows customers to manage their subscription, payment methods, and invoices.
Args:
customer_id: Stripe customer ID
return_url: URL to return to after portal
Returns:
Stripe Portal Session object
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
session = stripe.billing_portal.Session.create(
customer=customer_id,
return_url=return_url,
)
logger.info(f"Created portal session for customer {customer_id}")
return session
# =========================================================================
# Invoice Management
# =========================================================================
def get_invoices(
self,
customer_id: str,
limit: int = 10,
) -> list[stripe.Invoice]:
"""Get invoices for a customer."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
return list(invoices.data)
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
"""Get the upcoming invoice for a customer."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
try:
return stripe.Invoice.upcoming(customer=customer_id)
except stripe.error.InvalidRequestError:
# No upcoming invoice
return None
# =========================================================================
# Webhook Handling
# =========================================================================
def construct_event(
self,
payload: bytes,
sig_header: str,
) -> stripe.Event:
"""
Construct and verify a Stripe webhook event.
Args:
payload: Raw request body
sig_header: Stripe-Signature header
Returns:
Verified Stripe Event object
Raises:
ValueError: If signature verification fails
"""
if not settings.stripe_webhook_secret:
raise ValueError("Stripe webhook secret not configured")
try:
event = stripe.Webhook.construct_event(
payload,
sig_header,
settings.stripe_webhook_secret,
)
return event
except stripe.error.SignatureVerificationError as e:
logger.error(f"Webhook signature verification failed: {e}")
raise ValueError("Invalid webhook signature")
# =========================================================================
# SetupIntent & Payment Method Management
# =========================================================================
def create_setup_intent(
self,
customer_id: str,
metadata: dict | None = None,
) -> stripe.SetupIntent:
"""
Create a SetupIntent to collect card without charging.
Used for trial signups where we collect card upfront
but don't charge until trial ends.
Args:
customer_id: Stripe customer ID
metadata: Optional metadata to attach
Returns:
Stripe SetupIntent object with client_secret for frontend
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
setup_intent = stripe.SetupIntent.create(
customer=customer_id,
payment_method_types=["card"],
metadata=metadata or {},
)
logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}")
return setup_intent
def attach_payment_method_to_customer(
self,
customer_id: str,
payment_method_id: str,
set_as_default: bool = True,
) -> None:
"""
Attach a payment method to customer and optionally set as default.
Args:
customer_id: Stripe customer ID
payment_method_id: Payment method ID from confirmed SetupIntent
set_as_default: Whether to set as default payment method
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
# Attach the payment method to the customer
stripe.PaymentMethod.attach(payment_method_id, customer=customer_id)
if set_as_default:
stripe.Customer.modify(
customer_id,
invoice_settings={"default_payment_method": payment_method_id},
)
logger.info(
f"Attached payment method {payment_method_id} to customer {customer_id} "
f"(default={set_as_default})"
)
def create_subscription_with_trial(
self,
customer_id: str,
price_id: str,
trial_days: int = 30,
metadata: dict | None = None,
) -> stripe.Subscription:
"""
Create subscription with trial period.
Customer must have a default payment method attached.
Card will be charged automatically after trial ends.
Args:
customer_id: Stripe customer ID (must have default payment method)
price_id: Stripe price ID for the subscription tier
trial_days: Number of trial days (default 30)
metadata: Optional metadata to attach
Returns:
Stripe Subscription object
"""
if not self.is_configured:
raise ValueError("Stripe is not configured")
subscription = stripe.Subscription.create(
customer=customer_id,
items=[{"price": price_id}],
trial_period_days=trial_days,
metadata=metadata or {},
# Use default payment method for future charges
default_payment_method=None, # Uses customer's default
)
logger.info(
f"Created subscription {subscription.id} with {trial_days}-day trial "
f"for customer {customer_id}"
)
return subscription
def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent:
"""Get a SetupIntent by ID."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
return stripe.SetupIntent.retrieve(setup_intent_id)
# =========================================================================
# Price/Product Management
# =========================================================================
def get_price(self, price_id: str) -> stripe.Price:
"""Get a Stripe price by ID."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
return stripe.Price.retrieve(price_id)
def get_product(self, product_id: str) -> stripe.Product:
"""Get a Stripe product by ID."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
return stripe.Product.retrieve(product_id)
def list_prices(
self,
product_id: str | None = None,
active: bool = True,
) -> list[stripe.Price]:
"""List Stripe prices, optionally filtered by product."""
if not self.is_configured:
raise ValueError("Stripe is not configured")
params = {"active": active}
if product_id:
params["product"] = product_id
prices = stripe.Price.list(**params)
return list(prices.data)
# Create service instance
stripe_service = StripeService()
__all__ = [
"StripeService",
"stripe_service",
]

View File

@@ -2,654 +2,29 @@
"""
Subscription service for tier-based access control.
Handles:
- Subscription creation and management
- Tier limit enforcement
- Usage tracking
- Feature gating
DEPRECATED: This file is maintained for backward compatibility.
Import from app.modules.billing.services instead:
Usage:
from app.services.subscription_service import subscription_service
from app.modules.billing.services import subscription_service
# Check if vendor can create an order
can_create, message = subscription_service.can_create_order(db, vendor_id)
# Increment order counter after successful order
subscription_service.increment_order_count(db, vendor_id)
This file re-exports the service from its new location in the billing module.
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from models.database.product import Product
from models.database.subscription import (
SubscriptionStatus,
SubscriptionTier,
TIER_LIMITS,
TierCode,
VendorSubscription,
# Re-export from new location for backward compatibility
from app.modules.billing.services.subscription_service import (
SubscriptionService,
subscription_service,
)
from models.database.vendor import Vendor, VendorUser
from models.schema.subscription import (
SubscriptionCreate,
SubscriptionUpdate,
SubscriptionUsage,
TierInfo,
TierLimits,
UsageSummary,
from app.modules.billing.exceptions import (
SubscriptionNotFoundException,
TierLimitExceededException,
FeatureNotAvailableException,
)
logger = logging.getLogger(__name__)
class SubscriptionNotFoundException(Exception):
"""Raised when subscription not found."""
pass
class TierLimitExceededException(Exception):
"""Raised when a tier limit is exceeded."""
def __init__(self, message: str, limit_type: str, current: int, limit: int):
super().__init__(message)
self.limit_type = limit_type
self.current = current
self.limit = limit
class FeatureNotAvailableException(Exception):
"""Raised when a feature is not available in current tier."""
def __init__(self, feature: str, current_tier: str, required_tier: str):
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
super().__init__(message)
self.feature = feature
self.current_tier = current_tier
self.required_tier = required_tier
class SubscriptionService:
"""Service for subscription and tier limit operations."""
# =========================================================================
# Tier Information
# =========================================================================
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
"""
Get full tier information.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
# Try database first if session provided
if db is not None:
db_tier = self.get_tier_by_code(db, tier_code)
if db_tier:
return TierInfo(
code=db_tier.code,
name=db_tier.name,
price_monthly_cents=db_tier.price_monthly_cents,
price_annual_cents=db_tier.price_annual_cents,
limits=TierLimits(
orders_per_month=db_tier.orders_per_month,
products_limit=db_tier.products_limit,
team_members=db_tier.team_members,
order_history_months=db_tier.order_history_months,
),
features=db_tier.features or [],
)
# Fallback to hardcoded TIER_LIMITS
return self._get_tier_from_legacy(tier_code)
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
try:
tier = TierCode(tier_code)
except ValueError:
tier = TierCode.ESSENTIAL
limits = TIER_LIMITS[tier]
return TierInfo(
code=tier.value,
name=limits["name"],
price_monthly_cents=limits["price_monthly_cents"],
price_annual_cents=limits.get("price_annual_cents"),
limits=TierLimits(
orders_per_month=limits.get("orders_per_month"),
products_limit=limits.get("products_limit"),
team_members=limits.get("team_members"),
order_history_months=limits.get("order_history_months"),
),
features=limits.get("features", []),
)
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
"""
Get information for all tiers.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
if db is not None:
db_tiers = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.is_public == True, # noqa: E712
)
.order_by(SubscriptionTier.display_order)
.all()
)
if db_tiers:
return [
TierInfo(
code=t.code,
name=t.name,
price_monthly_cents=t.price_monthly_cents,
price_annual_cents=t.price_annual_cents,
limits=TierLimits(
orders_per_month=t.orders_per_month,
products_limit=t.products_limit,
team_members=t.team_members,
order_history_months=t.order_history_months,
),
features=t.features or [],
)
for t in db_tiers
]
# Fallback to hardcoded
return [
self._get_tier_from_legacy(tier.value)
for tier in TierCode
]
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
"""Get subscription tier by code."""
return (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
"""Get tier ID from tier code. Returns None if tier not found."""
tier = self.get_tier_by_code(db, tier_code)
return tier.id if tier else None
# =========================================================================
# Subscription CRUD
# =========================================================================
def get_subscription(
self, db: Session, vendor_id: int
) -> VendorSubscription | None:
"""Get vendor subscription."""
return (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
def get_subscription_or_raise(
self, db: Session, vendor_id: int
) -> VendorSubscription:
"""Get vendor subscription or raise exception."""
subscription = self.get_subscription(db, vendor_id)
if not subscription:
raise SubscriptionNotFoundException(
f"No subscription found for vendor {vendor_id}"
)
return subscription
def get_current_tier(
self, db: Session, vendor_id: int
) -> TierCode | None:
"""Get vendor's current subscription tier code."""
subscription = self.get_subscription(db, vendor_id)
if subscription:
try:
return TierCode(subscription.tier)
except ValueError:
return None
return None
def get_or_create_subscription(
self,
db: Session,
vendor_id: int,
tier: str = TierCode.ESSENTIAL.value,
trial_days: int = 14,
) -> VendorSubscription:
"""
Get existing subscription or create a new trial subscription.
Used when a vendor first accesses the system.
"""
subscription = self.get_subscription(db, vendor_id)
if subscription:
return subscription
# Create new trial subscription
now = datetime.now(UTC)
trial_end = now + timedelta(days=trial_days)
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=tier,
tier_id=tier_id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=trial_end,
trial_ends_at=trial_end,
is_annual=False,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(
f"Created trial subscription for vendor {vendor_id} "
f"(tier={tier}, trial_ends={trial_end})"
)
return subscription
def create_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionCreate,
) -> VendorSubscription:
"""Create a subscription for a vendor."""
# Check if subscription exists
existing = self.get_subscription(db, vendor_id)
if existing:
raise ValueError("Vendor already has a subscription")
now = datetime.now(UTC)
# Calculate period end based on billing cycle
if data.is_annual:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
# Handle trial
trial_ends_at = None
status = SubscriptionStatus.ACTIVE.value
if data.trial_days > 0:
trial_ends_at = now + timedelta(days=data.trial_days)
status = SubscriptionStatus.TRIAL.value
period_end = trial_ends_at
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, data.tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=data.tier,
tier_id=tier_id,
status=status,
period_start=now,
period_end=period_end,
trial_ends_at=trial_ends_at,
is_annual=data.is_annual,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
return subscription
def update_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionUpdate,
) -> VendorSubscription:
"""Update a vendor subscription."""
subscription = self.get_subscription_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
# If tier is being updated, also update tier_id
if "tier" in update_data:
tier_id = self.get_tier_id(db, update_data["tier"])
update_data["tier_id"] = tier_id
for key, value in update_data.items():
setattr(subscription, key, value)
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Updated subscription for vendor {vendor_id}")
return subscription
def upgrade_tier(
self,
db: Session,
vendor_id: int,
new_tier: str,
) -> VendorSubscription:
"""Upgrade vendor to a new tier."""
subscription = self.get_subscription_or_raise(db, vendor_id)
old_tier = subscription.tier
subscription.tier = new_tier
subscription.tier_id = self.get_tier_id(db, new_tier)
subscription.updated_at = datetime.now(UTC)
# If upgrading from trial, mark as active
if subscription.status == SubscriptionStatus.TRIAL.value:
subscription.status = SubscriptionStatus.ACTIVE.value
db.flush()
db.refresh(subscription)
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
return subscription
def cancel_subscription(
self,
db: Session,
vendor_id: int,
reason: str | None = None,
) -> VendorSubscription:
"""Cancel a vendor subscription (access until period end)."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.status = SubscriptionStatus.CANCELLED.value
subscription.cancelled_at = datetime.now(UTC)
subscription.cancellation_reason = reason
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Cancelled subscription for vendor {vendor_id}")
return subscription
# =========================================================================
# Usage Tracking
# =========================================================================
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
"""Get current subscription usage statistics."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Calculate usage stats
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
def calc_percent(current: int, limit: int | None) -> float | None:
if limit is None or limit == 0:
return None
return min(100.0, (current / limit) * 100)
return SubscriptionUsage(
orders_used=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
products_used=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
products_percent_used=calc_percent(products_count, products_limit),
team_members_used=team_count,
team_members_limit=team_limit,
team_members_remaining=calc_remaining(team_count, team_limit),
team_members_percent_used=calc_percent(team_count, team_limit),
)
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
"""Get usage summary for billing page display."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Get limits
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
return UsageSummary(
orders_this_period=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
products_count=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
team_count=team_count,
team_limit=team_limit,
team_remaining=calc_remaining(team_count, team_limit),
)
def increment_order_count(self, db: Session, vendor_id: int) -> None:
"""
Increment the order counter for the current period.
Call this after successfully creating/importing an order.
"""
subscription = self.get_or_create_subscription(db, vendor_id)
subscription.increment_order_count()
db.flush()
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
"""Reset counters for a new billing period."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.reset_period_counters()
db.flush()
logger.info(f"Reset period counters for vendor {vendor_id}")
# =========================================================================
# Limit Checks
# =========================================================================
def can_create_order(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can create/import another order.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.can_create_order()
def check_order_limit(self, db: Session, vendor_id: int) -> None:
"""
Check order limit and raise exception if exceeded.
Use this in order creation flows.
"""
can_create, message = self.can_create_order(db, vendor_id)
if not can_create:
subscription = self.get_subscription(db, vendor_id)
raise TierLimitExceededException(
message=message or "Order limit exceeded",
limit_type="orders",
current=subscription.orders_this_period if subscription else 0,
limit=subscription.orders_limit if subscription else 0,
)
def can_add_product(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another product.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
return subscription.can_add_product(products_count)
def check_product_limit(self, db: Session, vendor_id: int) -> None:
"""
Check product limit and raise exception if exceeded.
Use this in product creation flows.
"""
can_add, message = self.can_add_product(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Product limit exceeded",
limit_type="products",
current=products_count,
limit=subscription.products_limit if subscription else 0,
)
def can_add_team_member(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another team member.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
return subscription.can_add_team_member(team_count)
def check_team_limit(self, db: Session, vendor_id: int) -> None:
"""
Check team member limit and raise exception if exceeded.
Use this in team member invitation flows.
"""
can_add, message = self.can_add_team_member(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Team member limit exceeded",
limit_type="team_members",
current=team_count,
limit=subscription.team_members_limit if subscription else 0,
)
# =========================================================================
# Feature Gating
# =========================================================================
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
"""Check if vendor has access to a feature."""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.has_feature(feature)
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
"""
Check feature access and raise exception if not available.
Use this to gate premium features.
"""
if not self.has_feature(db, vendor_id, feature):
subscription = self.get_or_create_subscription(db, vendor_id)
# Find which tier has this feature
required_tier = None
for tier_code, limits in TIER_LIMITS.items():
if feature in limits.get("features", []):
required_tier = limits["name"]
break
raise FeatureNotAvailableException(
feature=feature,
current_tier=subscription.tier,
required_tier=required_tier or "higher",
)
def get_feature_tier(self, feature: str) -> str | None:
"""Get the minimum tier required for a feature."""
for tier_code in [
TierCode.ESSENTIAL,
TierCode.PROFESSIONAL,
TierCode.BUSINESS,
TierCode.ENTERPRISE,
]:
if feature in TIER_LIMITS[tier_code].get("features", []):
return tier_code.value
return None
# Singleton instance
subscription_service = SubscriptionService()
__all__ = [
"SubscriptionService",
"subscription_service",
"SubscriptionNotFoundException",
"TierLimitExceededException",
"FeatureNotAvailableException",
]

View File

@@ -1,265 +1,27 @@
# app/tasks/celery_tasks/subscription.py
"""
Celery tasks for subscription management.
Legacy subscription tasks.
Scheduled tasks for:
- Resetting period counters
- Checking trial expirations
- Syncing with Stripe
- Cleaning up stale subscriptions
- Capturing capacity snapshots
MOSTLY MIGRATED: Most tasks have been migrated to app.modules.billing.tasks.
The following tasks now live in the billing module:
- reset_period_counters -> app.modules.billing.tasks.subscription
- check_trial_expirations -> app.modules.billing.tasks.subscription
- sync_stripe_status -> app.modules.billing.tasks.subscription
- cleanup_stale_subscriptions -> app.modules.billing.tasks.subscription
Remaining task (to be migrated to monitoring module):
- capture_capacity_snapshot
"""
import logging
from datetime import UTC, datetime, timedelta
from app.core.celery_config import celery_app
from app.services.stripe_service import stripe_service
from app.tasks.celery_tasks.base import DatabaseTask
from models.database.subscription import SubscriptionStatus, VendorSubscription
logger = logging.getLogger(__name__)
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.reset_period_counters",
)
def reset_period_counters(self):
"""
Reset order counters for subscriptions whose billing period has ended.
Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates.
"""
now = datetime.now(UTC)
reset_count = 0
with self.get_db() as db:
# Find subscriptions where period has ended
expired_periods = (
db.query(VendorSubscription)
.filter(
VendorSubscription.period_end <= now,
VendorSubscription.status.in_(["active", "trial"]),
)
.all()
)
for subscription in expired_periods:
old_period_end = subscription.period_end
# Reset counters
subscription.orders_this_period = 0
subscription.orders_limit_reached_at = None
# Set new period dates
if subscription.is_annual:
subscription.period_start = now
subscription.period_end = now + timedelta(days=365)
else:
subscription.period_start = now
subscription.period_end = now + timedelta(days=30)
subscription.updated_at = now
reset_count += 1
logger.info(
f"Reset period counters for vendor {subscription.vendor_id}: "
f"old_period_end={old_period_end}, new_period_end={subscription.period_end}"
)
db.commit()
logger.info(f"Reset period counters for {reset_count} subscriptions")
return {"reset_count": reset_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.check_trial_expirations",
)
def check_trial_expirations(self):
"""
Check for expired trials and update their status.
Runs daily at 01:00.
- Trials without payment method -> expired
- Trials with payment method -> active
"""
now = datetime.now(UTC)
expired_count = 0
activated_count = 0
with self.get_db() as db:
# Find expired trials
expired_trials = (
db.query(VendorSubscription)
.filter(
VendorSubscription.status == SubscriptionStatus.TRIAL.value,
VendorSubscription.trial_ends_at <= now,
)
.all()
)
for subscription in expired_trials:
if subscription.stripe_payment_method_id:
# Has payment method - activate
subscription.status = SubscriptionStatus.ACTIVE.value
activated_count += 1
logger.info(
f"Activated subscription for vendor {subscription.vendor_id} "
f"(trial ended with payment method)"
)
else:
# No payment method - expire
subscription.status = SubscriptionStatus.EXPIRED.value
expired_count += 1
logger.info(
f"Expired trial for vendor {subscription.vendor_id} "
f"(no payment method)"
)
subscription.updated_at = now
db.commit()
logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated")
return {"expired_count": expired_count, "activated_count": activated_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.sync_stripe_status",
max_retries=3,
default_retry_delay=300,
)
def sync_stripe_status(self):
"""
Sync subscription status with Stripe.
Runs hourly at :30. Fetches current status from Stripe and updates local records.
"""
if not stripe_service.is_configured:
logger.warning("Stripe not configured, skipping sync")
return {"synced": 0, "skipped": True}
synced_count = 0
error_count = 0
with self.get_db() as db:
# Find subscriptions with Stripe IDs
subscriptions = (
db.query(VendorSubscription)
.filter(VendorSubscription.stripe_subscription_id.isnot(None))
.all()
)
for subscription in subscriptions:
try:
# Fetch from Stripe
stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id)
if not stripe_sub:
logger.warning(
f"Stripe subscription {subscription.stripe_subscription_id} "
f"not found for vendor {subscription.vendor_id}"
)
continue
# Map Stripe status to local status
status_map = {
"active": SubscriptionStatus.ACTIVE.value,
"trialing": SubscriptionStatus.TRIAL.value,
"past_due": SubscriptionStatus.PAST_DUE.value,
"canceled": SubscriptionStatus.CANCELLED.value,
"unpaid": SubscriptionStatus.PAST_DUE.value,
"incomplete": SubscriptionStatus.TRIAL.value,
"incomplete_expired": SubscriptionStatus.EXPIRED.value,
}
new_status = status_map.get(stripe_sub.status)
if new_status and new_status != subscription.status:
old_status = subscription.status
subscription.status = new_status
subscription.updated_at = datetime.now(UTC)
logger.info(
f"Updated vendor {subscription.vendor_id} status: "
f"{old_status} -> {new_status} (from Stripe)"
)
# Update period dates from Stripe
if stripe_sub.current_period_start:
subscription.period_start = datetime.fromtimestamp(
stripe_sub.current_period_start, tz=UTC
)
if stripe_sub.current_period_end:
subscription.period_end = datetime.fromtimestamp(
stripe_sub.current_period_end, tz=UTC
)
# Update payment method
if stripe_sub.default_payment_method:
subscription.stripe_payment_method_id = (
stripe_sub.default_payment_method
if isinstance(stripe_sub.default_payment_method, str)
else stripe_sub.default_payment_method.id
)
synced_count += 1
except Exception as e:
logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}")
error_count += 1
db.commit()
logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors")
return {"synced_count": synced_count, "error_count": error_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions",
)
def cleanup_stale_subscriptions(self):
"""
Clean up subscriptions in inconsistent states.
Runs weekly on Sunday at 03:00.
"""
now = datetime.now(UTC)
cleaned_count = 0
with self.get_db() as db:
# Find cancelled subscriptions past their period end
stale_cancelled = (
db.query(VendorSubscription)
.filter(
VendorSubscription.status == SubscriptionStatus.CANCELLED.value,
VendorSubscription.period_end < now - timedelta(days=30),
)
.all()
)
for subscription in stale_cancelled:
# Mark as expired (fully terminated)
subscription.status = SubscriptionStatus.EXPIRED.value
subscription.updated_at = now
cleaned_count += 1
logger.info(
f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}"
)
db.commit()
logger.info(f"Cleaned up {cleaned_count} stale subscriptions")
return {"cleaned_count": cleaned_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
@@ -270,6 +32,8 @@ def capture_capacity_snapshot(self):
Capture a daily snapshot of platform capacity metrics.
Runs daily at midnight.
TODO: Migrate to app.modules.monitoring.tasks
"""
from app.services.capacity_forecast_service import capacity_forecast_service

View File

@@ -29,7 +29,7 @@ Transform the platform from a monolithic structure to self-contained modules whe
|--------|---------------|---------------|----------|--------|-------|--------|
| `cms` | Core | ✅ **Complete** | ✅ | ✅ | - | Done |
| `payments` | Optional | 🟡 Partial | ✅ | ✅ | - | Done |
| `billing` | Optional | 🔴 Shell | | | | Full |
| `billing` | Optional | **Complete** | | | | Done |
| `marketplace` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full |
| `orders` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
| `inventory` | Optional | 🔴 Shell | ❌ | ❌ | - | Full |
@@ -110,85 +110,22 @@ app/tasks/celery_tasks/ # → Move to respective modules
- Observability docs
- Creating modules developer guide
---
### ✅ Phase 4: Celery Task Infrastructure
- Added `ScheduledTask` dataclass to `ModuleDefinition`
- Added `tasks_path` and `scheduled_tasks` fields
- Created `app/modules/task_base.py` with `ModuleTask` base class
- Created `app/modules/tasks.py` with discovery utilities
- Updated `celery_config.py` to use module discovery
## Infrastructure Work (Remaining)
### Phase 4: Celery Task Infrastructure
Add task support to module system before migrating modules.
#### 4.1 Update ModuleDefinition
```python
# app/modules/base.py - Add fields
@dataclass
class ScheduledTask:
"""Celery Beat scheduled task definition."""
name: str # e.g., "billing.reset_counters"
task: str # Full task path
schedule: str | dict # Cron string or crontab dict
args: tuple = ()
kwargs: dict = field(default_factory=dict)
@dataclass
class ModuleDefinition:
# ... existing fields ...
# Task configuration (NEW)
tasks_path: str | None = None
scheduled_tasks: list[ScheduledTask] = field(default_factory=list)
```
#### 4.2 Create Task Discovery
```python
# app/modules/tasks.py (NEW)
def discover_module_tasks() -> list[str]:
"""Discover task modules from all registered modules."""
...
def build_beat_schedule() -> dict:
"""Build Celery Beat schedule from module definitions."""
...
```
#### 4.3 Create Module Task Base
```python
# app/modules/task_base.py (NEW)
class ModuleTask(Task):
"""Base Celery task with DB session management."""
@contextmanager
def get_db(self):
...
```
#### 4.4 Update Celery Config
```python
# app/core/celery_config.py
from app.modules.tasks import discover_module_tasks, build_beat_schedule
# Auto-discover from modules
celery_app.autodiscover_tasks(discover_module_tasks())
# Build schedule from modules
celery_app.conf.beat_schedule = build_beat_schedule()
```
**Files to create:**
- `app/modules/tasks.py`
- `app/modules/task_base.py`
**Files to modify:**
- `app/modules/base.py`
- `app/core/celery_config.py`
### ✅ Phase 5: Billing Module Migration
- Created `app/modules/billing/services/` with subscription, stripe, admin services
- Created `app/modules/billing/models/` re-exporting from central location
- Created `app/modules/billing/schemas/` re-exporting from central location
- Created `app/modules/billing/tasks/` with 4 scheduled tasks
- Created `app/modules/billing/exceptions.py`
- Updated `definition.py` with self-contained configuration
- Created backward-compatible re-exports in `app/services/`
- Updated legacy celery_config.py to not duplicate scheduled tasks
---