From 4f379b472b019a5c9a1498894156a05a08fb084c Mon Sep 17 00:00:00 2001 From: Samir Boulahtit Date: Tue, 27 Jan 2026 23:06:23 +0100 Subject: [PATCH] feat: complete billing module migration (Phase 5) Migrates billing module to self-contained structure: - Create app/modules/billing/services/ with subscription, stripe, admin services - Create app/modules/billing/models/ re-exporting from central location - Create app/modules/billing/schemas/ re-exporting from central location - Create app/modules/billing/tasks/ with 4 scheduled Celery tasks - Create app/modules/billing/exceptions.py with module-specific exceptions - Update definition.py with is_self_contained=True and scheduled_tasks Celery task migration: - reset_period_counters -> billing module - check_trial_expirations -> billing module - sync_stripe_status -> billing module - cleanup_stale_subscriptions -> billing module - capture_capacity_snapshot remains in legacy (will go to monitoring) Backward compatibility: - Create re-exports in app/services/ for subscription, stripe, admin services - Old import paths continue to work - Update celery_config.py to use module-defined schedules Co-Authored-By: Claude Opus 4.5 --- app/core/celery_config.py | 31 +- app/modules/billing/__init__.py | 14 +- app/modules/billing/definition.py | 42 +- app/modules/billing/exceptions.py | 83 +++ app/modules/billing/models/__init__.py | 52 ++ app/modules/billing/schemas/__init__.py | 56 ++ app/modules/billing/services/__init__.py | 28 + .../services/admin_subscription_service.py | 352 ++++++++++ .../billing/services/stripe_service.py | 582 +++++++++++++++ .../billing/services/subscription_service.py | 631 +++++++++++++++++ app/modules/billing/tasks/__init__.py | 26 + app/modules/billing/tasks/subscription.py | 255 +++++++ app/services/admin_subscription_service.py | 358 +--------- app/services/stripe_service.py | 599 +--------------- app/services/subscription_service.py | 663 +----------------- app/tasks/celery_tasks/subscription.py | 262 +------ docs/proposals/module-migration-plan.md | 95 +-- 17 files changed, 2198 insertions(+), 1931 deletions(-) create mode 100644 app/modules/billing/exceptions.py create mode 100644 app/modules/billing/models/__init__.py create mode 100644 app/modules/billing/schemas/__init__.py create mode 100644 app/modules/billing/services/__init__.py create mode 100644 app/modules/billing/services/admin_subscription_service.py create mode 100644 app/modules/billing/services/stripe_service.py create mode 100644 app/modules/billing/services/subscription_service.py create mode 100644 app/modules/billing/tasks/__init__.py create mode 100644 app/modules/billing/tasks/subscription.py diff --git a/app/core/celery_config.py b/app/core/celery_config.py index 3fb3d916..178fb911 100644 --- a/app/core/celery_config.py +++ b/app/core/celery_config.py @@ -49,10 +49,12 @@ if SENTRY_DSN: # TASK DISCOVERY # ============================================================================= # Legacy tasks (will be migrated to modules over time) +# NOTE: Most subscription tasks have been migrated to app.modules.billing.tasks +# The subscription module is kept for capture_capacity_snapshot (will move to monitoring) LEGACY_TASK_MODULES = [ "app.tasks.celery_tasks.marketplace", "app.tasks.celery_tasks.letzshop", - "app.tasks.celery_tasks.subscription", + "app.tasks.celery_tasks.subscription", # Kept for capture_capacity_snapshot only "app.tasks.celery_tasks.export", "app.tasks.celery_tasks.code_quality", "app.tasks.celery_tasks.test_runner", @@ -141,32 +143,9 @@ celery_app.conf.task_routes = { # ============================================================================= # Legacy scheduled tasks (will be migrated to module definitions) +# NOTE: Subscription tasks have been migrated to billing module (see definition.py) LEGACY_BEAT_SCHEDULE = { - # Reset usage counters at start of each period - "reset-period-counters-daily": { - "task": "app.tasks.celery_tasks.subscription.reset_period_counters", - "schedule": crontab(hour=0, minute=5), # 00:05 daily - "options": {"queue": "scheduled"}, - }, - # Check for expiring trials and send notifications - "check-trial-expirations-daily": { - "task": "app.tasks.celery_tasks.subscription.check_trial_expirations", - "schedule": crontab(hour=1, minute=0), # 01:00 daily - "options": {"queue": "scheduled"}, - }, - # Sync subscription status with Stripe - "sync-stripe-status-hourly": { - "task": "app.tasks.celery_tasks.subscription.sync_stripe_status", - "schedule": crontab(minute=30), # Every hour at :30 - "options": {"queue": "scheduled"}, - }, - # Clean up stale/orphaned subscriptions - "cleanup-stale-subscriptions-weekly": { - "task": "app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions", - "schedule": crontab(hour=3, minute=0, day_of_week=0), # Sunday 03:00 - "options": {"queue": "scheduled"}, - }, - # Capture daily capacity snapshot for analytics + # Capacity snapshot - will be migrated to monitoring module "capture-capacity-snapshot-daily": { "task": "app.tasks.celery_tasks.subscription.capture_capacity_snapshot", "schedule": crontab(hour=0, minute=0), # Midnight daily diff --git a/app/modules/billing/__init__.py b/app/modules/billing/__init__.py index 34b1710f..47f37d32 100644 --- a/app/modules/billing/__init__.py +++ b/app/modules/billing/__init__.py @@ -7,6 +7,7 @@ This module provides: - Vendor subscription CRUD - Billing history and invoices - Stripe integration +- Scheduled tasks for subscription lifecycle Routes: - Admin: /api/v1/admin/subscriptions/* @@ -15,8 +16,17 @@ Routes: Menu Items: - Admin: subscription-tiers, subscriptions, billing-history - Vendor: billing, invoices + +Usage: + from app.modules.billing import billing_module + from app.modules.billing.services import subscription_service, stripe_service + from app.modules.billing.models import VendorSubscription, SubscriptionTier + from app.modules.billing.exceptions import TierLimitExceededException """ -from app.modules.billing.definition import billing_module +from app.modules.billing.definition import billing_module, get_billing_module_with_routers -__all__ = ["billing_module"] +__all__ = [ + "billing_module", + "get_billing_module_with_routers", +] diff --git a/app/modules/billing/definition.py b/app/modules/billing/definition.py index ca7ceff4..9b1440a5 100644 --- a/app/modules/billing/definition.py +++ b/app/modules/billing/definition.py @@ -3,10 +3,10 @@ Billing module definition. Defines the billing module including its features, menu items, -and route configurations. +route configurations, and scheduled tasks. """ -from app.modules.base import ModuleDefinition +from app.modules.base import ModuleDefinition, ScheduledTask from models.database.admin_menu_config import FrontendType @@ -54,6 +54,44 @@ billing_module = ModuleDefinition( ], }, is_core=False, # Billing can be disabled (e.g., internal platforms) + # ========================================================================= + # Self-Contained Module Configuration + # ========================================================================= + is_self_contained=True, + services_path="app.modules.billing.services", + models_path="app.modules.billing.models", + schemas_path="app.modules.billing.schemas", + exceptions_path="app.modules.billing.exceptions", + tasks_path="app.modules.billing.tasks", + # ========================================================================= + # Scheduled Tasks + # ========================================================================= + scheduled_tasks=[ + ScheduledTask( + name="billing.reset_period_counters", + task="app.modules.billing.tasks.subscription.reset_period_counters", + schedule="5 0 * * *", # Daily at 00:05 + options={"queue": "scheduled"}, + ), + ScheduledTask( + name="billing.check_trial_expirations", + task="app.modules.billing.tasks.subscription.check_trial_expirations", + schedule="0 1 * * *", # Daily at 01:00 + options={"queue": "scheduled"}, + ), + ScheduledTask( + name="billing.sync_stripe_status", + task="app.modules.billing.tasks.subscription.sync_stripe_status", + schedule="30 * * * *", # Hourly at :30 + options={"queue": "scheduled"}, + ), + ScheduledTask( + name="billing.cleanup_stale_subscriptions", + task="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions", + schedule="0 3 * * 0", # Weekly on Sunday at 03:00 + options={"queue": "scheduled"}, + ), + ], ) diff --git a/app/modules/billing/exceptions.py b/app/modules/billing/exceptions.py new file mode 100644 index 00000000..b81cd2c9 --- /dev/null +++ b/app/modules/billing/exceptions.py @@ -0,0 +1,83 @@ +# app/modules/billing/exceptions.py +""" +Billing module exceptions. + +Custom exceptions for subscription, billing, and payment operations. +""" + +from app.exceptions import BusinessLogicException, ResourceNotFoundException + + +class BillingException(BusinessLogicException): + """Base exception for billing module errors.""" + + pass + + +class SubscriptionNotFoundException(ResourceNotFoundException): + """Raised when a subscription is not found.""" + + def __init__(self, vendor_id: int): + super().__init__("Subscription", str(vendor_id)) + + +class TierNotFoundException(ResourceNotFoundException): + """Raised when a subscription tier is not found.""" + + def __init__(self, tier_code: str): + super().__init__("SubscriptionTier", tier_code) + + +class TierLimitExceededException(BillingException): + """Raised when a tier limit is exceeded.""" + + def __init__(self, message: str, limit_type: str, current: int, limit: int): + super().__init__(message) + self.limit_type = limit_type + self.current = current + self.limit = limit + + +class FeatureNotAvailableException(BillingException): + """Raised when a feature is not available in current tier.""" + + def __init__(self, feature: str, current_tier: str, required_tier: str): + message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})" + super().__init__(message) + self.feature = feature + self.current_tier = current_tier + self.required_tier = required_tier + + +class StripeNotConfiguredException(BillingException): + """Raised when Stripe is not configured.""" + + def __init__(self): + super().__init__("Stripe is not configured") + + +class PaymentFailedException(BillingException): + """Raised when a payment fails.""" + + def __init__(self, message: str, stripe_error: str | None = None): + super().__init__(message) + self.stripe_error = stripe_error + + +class WebhookVerificationException(BillingException): + """Raised when webhook signature verification fails.""" + + def __init__(self, message: str = "Invalid webhook signature"): + super().__init__(message) + + +__all__ = [ + "BillingException", + "SubscriptionNotFoundException", + "TierNotFoundException", + "TierLimitExceededException", + "FeatureNotAvailableException", + "StripeNotConfiguredException", + "PaymentFailedException", + "WebhookVerificationException", +] diff --git a/app/modules/billing/models/__init__.py b/app/modules/billing/models/__init__.py new file mode 100644 index 00000000..f450bb1a --- /dev/null +++ b/app/modules/billing/models/__init__.py @@ -0,0 +1,52 @@ +# app/modules/billing/models/__init__.py +""" +Billing module models. + +Re-exports subscription models from the central models location. +Models remain in models/database/ for now to avoid breaking existing imports +across the codebase. This provides a module-local import path. + +Usage: + from app.modules.billing.models import ( + VendorSubscription, + SubscriptionTier, + SubscriptionStatus, + TierCode, + ) +""" + +from models.database.subscription import ( + # Enums + TierCode, + SubscriptionStatus, + AddOnCategory, + BillingPeriod, + # Models + SubscriptionTier, + AddOnProduct, + VendorAddOn, + StripeWebhookEvent, + BillingHistory, + VendorSubscription, + CapacitySnapshot, + # Legacy constants + TIER_LIMITS, +) + +__all__ = [ + # Enums + "TierCode", + "SubscriptionStatus", + "AddOnCategory", + "BillingPeriod", + # Models + "SubscriptionTier", + "AddOnProduct", + "VendorAddOn", + "StripeWebhookEvent", + "BillingHistory", + "VendorSubscription", + "CapacitySnapshot", + # Legacy constants + "TIER_LIMITS", +] diff --git a/app/modules/billing/schemas/__init__.py b/app/modules/billing/schemas/__init__.py new file mode 100644 index 00000000..3537780f --- /dev/null +++ b/app/modules/billing/schemas/__init__.py @@ -0,0 +1,56 @@ +# app/modules/billing/schemas/__init__.py +""" +Billing module Pydantic schemas. + +Re-exports subscription schemas from the central schemas location. +Provides a module-local import path while maintaining backwards compatibility. + +Usage: + from app.modules.billing.schemas import ( + SubscriptionCreate, + SubscriptionResponse, + TierInfo, + ) +""" + +from models.schema.subscription import ( + # Tier schemas + TierFeatures, + TierLimits, + TierInfo, + # Subscription CRUD schemas + SubscriptionCreate, + SubscriptionUpdate, + SubscriptionResponse, + # Usage schemas + SubscriptionUsage, + UsageSummary, + SubscriptionStatusResponse, + # Limit check schemas + LimitCheckResult, + CanCreateOrderResponse, + CanAddProductResponse, + CanAddTeamMemberResponse, + FeatureCheckResponse, +) + +__all__ = [ + # Tier schemas + "TierFeatures", + "TierLimits", + "TierInfo", + # Subscription CRUD schemas + "SubscriptionCreate", + "SubscriptionUpdate", + "SubscriptionResponse", + # Usage schemas + "SubscriptionUsage", + "UsageSummary", + "SubscriptionStatusResponse", + # Limit check schemas + "LimitCheckResult", + "CanCreateOrderResponse", + "CanAddProductResponse", + "CanAddTeamMemberResponse", + "FeatureCheckResponse", +] diff --git a/app/modules/billing/services/__init__.py b/app/modules/billing/services/__init__.py new file mode 100644 index 00000000..fbee9824 --- /dev/null +++ b/app/modules/billing/services/__init__.py @@ -0,0 +1,28 @@ +# app/modules/billing/services/__init__.py +""" +Billing module services. + +Provides subscription management, Stripe integration, and admin operations. +""" + +from app.modules.billing.services.subscription_service import ( + SubscriptionService, + subscription_service, +) +from app.modules.billing.services.stripe_service import ( + StripeService, + stripe_service, +) +from app.modules.billing.services.admin_subscription_service import ( + AdminSubscriptionService, + admin_subscription_service, +) + +__all__ = [ + "SubscriptionService", + "subscription_service", + "StripeService", + "stripe_service", + "AdminSubscriptionService", + "admin_subscription_service", +] diff --git a/app/modules/billing/services/admin_subscription_service.py b/app/modules/billing/services/admin_subscription_service.py new file mode 100644 index 00000000..509dfc0d --- /dev/null +++ b/app/modules/billing/services/admin_subscription_service.py @@ -0,0 +1,352 @@ +# app/modules/billing/services/admin_subscription_service.py +""" +Admin Subscription Service. + +Handles subscription management operations for platform administrators: +- Subscription tier CRUD +- Vendor subscription management +- Billing history queries +- Subscription analytics +""" + +import logging +from math import ceil + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.exceptions import ( + BusinessLogicException, + ConflictException, + ResourceNotFoundException, +) +from app.modules.billing.exceptions import TierNotFoundException +from app.modules.billing.models import ( + BillingHistory, + SubscriptionStatus, + SubscriptionTier, + VendorSubscription, +) +from models.database.product import Product +from models.database.vendor import Vendor, VendorUser + +logger = logging.getLogger(__name__) + + +class AdminSubscriptionService: + """Service for admin subscription management operations.""" + + # ========================================================================= + # Subscription Tiers + # ========================================================================= + + def get_tiers( + self, db: Session, include_inactive: bool = False + ) -> list[SubscriptionTier]: + """Get all subscription tiers.""" + query = db.query(SubscriptionTier) + + if not include_inactive: + query = query.filter(SubscriptionTier.is_active == True) # noqa: E712 + + return query.order_by(SubscriptionTier.display_order).all() + + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: + """Get a subscription tier by code.""" + tier = ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == tier_code) + .first() + ) + + if not tier: + raise TierNotFoundException(tier_code) + + return tier + + def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier: + """Create a new subscription tier.""" + # Check for duplicate code + existing = ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == tier_data["code"]) + .first() + ) + if existing: + raise ConflictException( + f"Tier with code '{tier_data['code']}' already exists" + ) + + tier = SubscriptionTier(**tier_data) + db.add(tier) + + logger.info(f"Created subscription tier: {tier.code}") + return tier + + def update_tier( + self, db: Session, tier_code: str, update_data: dict + ) -> SubscriptionTier: + """Update a subscription tier.""" + tier = self.get_tier_by_code(db, tier_code) + + for field, value in update_data.items(): + setattr(tier, field, value) + + logger.info(f"Updated subscription tier: {tier.code}") + return tier + + def deactivate_tier(self, db: Session, tier_code: str) -> None: + """Soft-delete a subscription tier.""" + tier = self.get_tier_by_code(db, tier_code) + + # Check if any active subscriptions use this tier + active_subs = ( + db.query(VendorSubscription) + .filter( + VendorSubscription.tier == tier_code, + VendorSubscription.status.in_([ + SubscriptionStatus.ACTIVE.value, + SubscriptionStatus.TRIAL.value, + ]), + ) + .count() + ) + + if active_subs > 0: + raise BusinessLogicException( + f"Cannot delete tier: {active_subs} active subscriptions are using it" + ) + + tier.is_active = False + + logger.info(f"Soft-deleted subscription tier: {tier.code}") + + # ========================================================================= + # Vendor Subscriptions + # ========================================================================= + + def list_subscriptions( + self, + db: Session, + page: int = 1, + per_page: int = 20, + status: str | None = None, + tier: str | None = None, + search: str | None = None, + ) -> dict: + """List vendor subscriptions with filtering and pagination.""" + query = ( + db.query(VendorSubscription, Vendor) + .join(Vendor, VendorSubscription.vendor_id == Vendor.id) + ) + + # Apply filters + if status: + query = query.filter(VendorSubscription.status == status) + if tier: + query = query.filter(VendorSubscription.tier == tier) + if search: + query = query.filter(Vendor.name.ilike(f"%{search}%")) + + # Count total + total = query.count() + + # Paginate + offset = (page - 1) * per_page + results = ( + query.order_by(VendorSubscription.created_at.desc()) + .offset(offset) + .limit(per_page) + .all() + ) + + return { + "results": results, + "total": total, + "page": page, + "per_page": per_page, + "pages": ceil(total / per_page) if total > 0 else 0, + } + + def get_subscription(self, db: Session, vendor_id: int) -> tuple: + """Get subscription for a specific vendor.""" + result = ( + db.query(VendorSubscription, Vendor) + .join(Vendor, VendorSubscription.vendor_id == Vendor.id) + .filter(VendorSubscription.vendor_id == vendor_id) + .first() + ) + + if not result: + raise ResourceNotFoundException("Subscription", str(vendor_id)) + + return result + + def update_subscription( + self, db: Session, vendor_id: int, update_data: dict + ) -> tuple: + """Update a vendor's subscription.""" + result = self.get_subscription(db, vendor_id) + sub, vendor = result + + for field, value in update_data.items(): + setattr(sub, field, value) + + logger.info( + f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}" + ) + + return sub, vendor + + def get_vendor(self, db: Session, vendor_id: int) -> Vendor: + """Get a vendor by ID.""" + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + + if not vendor: + raise ResourceNotFoundException("Vendor", str(vendor_id)) + + return vendor + + def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict: + """Get usage counts (products and team members) for a vendor.""" + products_count = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + team_count = ( + db.query(func.count(VendorUser.id)) + .filter( + VendorUser.vendor_id == vendor_id, + VendorUser.is_active == True, # noqa: E712 + ) + .scalar() + or 0 + ) + + return { + "products_count": products_count, + "team_count": team_count, + } + + # ========================================================================= + # Billing History + # ========================================================================= + + def list_billing_history( + self, + db: Session, + page: int = 1, + per_page: int = 20, + vendor_id: int | None = None, + status: str | None = None, + ) -> dict: + """List billing history across all vendors.""" + query = ( + db.query(BillingHistory, Vendor) + .join(Vendor, BillingHistory.vendor_id == Vendor.id) + ) + + if vendor_id: + query = query.filter(BillingHistory.vendor_id == vendor_id) + if status: + query = query.filter(BillingHistory.status == status) + + total = query.count() + + offset = (page - 1) * per_page + results = ( + query.order_by(BillingHistory.invoice_date.desc()) + .offset(offset) + .limit(per_page) + .all() + ) + + return { + "results": results, + "total": total, + "page": page, + "per_page": per_page, + "pages": ceil(total / per_page) if total > 0 else 0, + } + + # ========================================================================= + # Statistics + # ========================================================================= + + def get_stats(self, db: Session) -> dict: + """Get subscription statistics for admin dashboard.""" + # Count by status + status_counts = ( + db.query(VendorSubscription.status, func.count(VendorSubscription.id)) + .group_by(VendorSubscription.status) + .all() + ) + + stats = { + "total_subscriptions": 0, + "active_count": 0, + "trial_count": 0, + "past_due_count": 0, + "cancelled_count": 0, + "expired_count": 0, + } + + for status, count in status_counts: + stats["total_subscriptions"] += count + if status == SubscriptionStatus.ACTIVE.value: + stats["active_count"] = count + elif status == SubscriptionStatus.TRIAL.value: + stats["trial_count"] = count + elif status == SubscriptionStatus.PAST_DUE.value: + stats["past_due_count"] = count + elif status == SubscriptionStatus.CANCELLED.value: + stats["cancelled_count"] = count + elif status == SubscriptionStatus.EXPIRED.value: + stats["expired_count"] = count + + # Count by tier + tier_counts = ( + db.query(VendorSubscription.tier, func.count(VendorSubscription.id)) + .filter( + VendorSubscription.status.in_([ + SubscriptionStatus.ACTIVE.value, + SubscriptionStatus.TRIAL.value, + ]) + ) + .group_by(VendorSubscription.tier) + .all() + ) + + tier_distribution = {tier: count for tier, count in tier_counts} + + # Calculate MRR (Monthly Recurring Revenue) + mrr_cents = 0 + arr_cents = 0 + + active_subs = ( + db.query(VendorSubscription, SubscriptionTier) + .join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code) + .filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value) + .all() + ) + + for sub, tier in active_subs: + if sub.is_annual and tier.price_annual_cents: + mrr_cents += tier.price_annual_cents // 12 + arr_cents += tier.price_annual_cents + else: + mrr_cents += tier.price_monthly_cents + arr_cents += tier.price_monthly_cents * 12 + + stats["tier_distribution"] = tier_distribution + stats["mrr_cents"] = mrr_cents + stats["arr_cents"] = arr_cents + + return stats + + +# Singleton instance +admin_subscription_service = AdminSubscriptionService() diff --git a/app/modules/billing/services/stripe_service.py b/app/modules/billing/services/stripe_service.py new file mode 100644 index 00000000..343a2eca --- /dev/null +++ b/app/modules/billing/services/stripe_service.py @@ -0,0 +1,582 @@ +# app/modules/billing/services/stripe_service.py +""" +Stripe payment integration service. + +Provides: +- Customer management +- Subscription management +- Checkout session creation +- Customer portal access +- Webhook event construction +""" + +import logging +from datetime import datetime + +import stripe +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.modules.billing.exceptions import ( + StripeNotConfiguredException, + WebhookVerificationException, +) +from app.modules.billing.models import ( + BillingHistory, + SubscriptionStatus, + SubscriptionTier, + VendorSubscription, +) +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class StripeService: + """Service for Stripe payment operations.""" + + def __init__(self): + self._configured = False + self._configure() + + def _configure(self): + """Configure Stripe with API key.""" + if settings.stripe_secret_key: + stripe.api_key = settings.stripe_secret_key + self._configured = True + else: + logger.warning("Stripe API key not configured") + + @property + def is_configured(self) -> bool: + """Check if Stripe is properly configured.""" + return self._configured and bool(settings.stripe_secret_key) + + def _check_configured(self) -> None: + """Raise exception if Stripe is not configured.""" + if not self.is_configured: + raise StripeNotConfiguredException() + + # ========================================================================= + # Customer Management + # ========================================================================= + + def create_customer( + self, + vendor: Vendor, + email: str, + name: str | None = None, + metadata: dict | None = None, + ) -> str: + """ + Create a Stripe customer for a vendor. + + Returns the Stripe customer ID. + """ + self._check_configured() + + customer_metadata = { + "vendor_id": str(vendor.id), + "vendor_code": vendor.vendor_code, + **(metadata or {}), + } + + customer = stripe.Customer.create( + email=email, + name=name or vendor.name, + metadata=customer_metadata, + ) + + logger.info( + f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}" + ) + return customer.id + + def get_customer(self, customer_id: str) -> stripe.Customer: + """Get a Stripe customer by ID.""" + self._check_configured() + + return stripe.Customer.retrieve(customer_id) + + def update_customer( + self, + customer_id: str, + email: str | None = None, + name: str | None = None, + metadata: dict | None = None, + ) -> stripe.Customer: + """Update a Stripe customer.""" + self._check_configured() + + update_data = {} + if email: + update_data["email"] = email + if name: + update_data["name"] = name + if metadata: + update_data["metadata"] = metadata + + return stripe.Customer.modify(customer_id, **update_data) + + # ========================================================================= + # Subscription Management + # ========================================================================= + + def create_subscription( + self, + customer_id: str, + price_id: str, + trial_days: int | None = None, + metadata: dict | None = None, + ) -> stripe.Subscription: + """ + Create a new Stripe subscription. + + Args: + customer_id: Stripe customer ID + price_id: Stripe price ID for the subscription + trial_days: Optional trial period in days + metadata: Optional metadata to attach + + Returns: + Stripe Subscription object + """ + self._check_configured() + + subscription_data = { + "customer": customer_id, + "items": [{"price": price_id}], + "metadata": metadata or {}, + "payment_behavior": "default_incomplete", + "expand": ["latest_invoice.payment_intent"], + } + + if trial_days: + subscription_data["trial_period_days"] = trial_days + + subscription = stripe.Subscription.create(**subscription_data) + logger.info( + f"Created Stripe subscription {subscription.id} for customer {customer_id}" + ) + return subscription + + def get_subscription(self, subscription_id: str) -> stripe.Subscription: + """Get a Stripe subscription by ID.""" + self._check_configured() + + return stripe.Subscription.retrieve(subscription_id) + + def update_subscription( + self, + subscription_id: str, + new_price_id: str | None = None, + proration_behavior: str = "create_prorations", + metadata: dict | None = None, + ) -> stripe.Subscription: + """ + Update a Stripe subscription (e.g., change tier). + + Args: + subscription_id: Stripe subscription ID + new_price_id: New price ID for tier change + proration_behavior: How to handle prorations + metadata: Optional metadata to update + + Returns: + Updated Stripe Subscription object + """ + self._check_configured() + + update_data = {"proration_behavior": proration_behavior} + + if new_price_id: + # Get the subscription to find the item ID + subscription = stripe.Subscription.retrieve(subscription_id) + item_id = subscription["items"]["data"][0]["id"] + update_data["items"] = [{"id": item_id, "price": new_price_id}] + + if metadata: + update_data["metadata"] = metadata + + updated = stripe.Subscription.modify(subscription_id, **update_data) + logger.info(f"Updated Stripe subscription {subscription_id}") + return updated + + def cancel_subscription( + self, + subscription_id: str, + immediately: bool = False, + cancellation_reason: str | None = None, + ) -> stripe.Subscription: + """ + Cancel a Stripe subscription. + + Args: + subscription_id: Stripe subscription ID + immediately: If True, cancel now. If False, cancel at period end. + cancellation_reason: Optional reason for cancellation + + Returns: + Cancelled Stripe Subscription object + """ + self._check_configured() + + if immediately: + subscription = stripe.Subscription.cancel(subscription_id) + else: + subscription = stripe.Subscription.modify( + subscription_id, + cancel_at_period_end=True, + metadata={"cancellation_reason": cancellation_reason or "user_request"}, + ) + + logger.info( + f"Cancelled Stripe subscription {subscription_id} " + f"(immediately={immediately})" + ) + return subscription + + def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription: + """ + Reactivate a cancelled subscription (if not yet ended). + + Returns: + Reactivated Stripe Subscription object + """ + self._check_configured() + + subscription = stripe.Subscription.modify( + subscription_id, + cancel_at_period_end=False, + ) + logger.info(f"Reactivated Stripe subscription {subscription_id}") + return subscription + + def cancel_subscription_item(self, subscription_item_id: str) -> None: + """ + Cancel a subscription item (used for add-ons). + + Args: + subscription_item_id: Stripe subscription item ID + """ + self._check_configured() + + stripe.SubscriptionItem.delete(subscription_item_id) + logger.info(f"Cancelled Stripe subscription item {subscription_item_id}") + + # ========================================================================= + # Checkout & Portal + # ========================================================================= + + def create_checkout_session( + self, + db: Session, + vendor: Vendor, + price_id: str, + success_url: str, + cancel_url: str, + trial_days: int | None = None, + quantity: int = 1, + metadata: dict | None = None, + ) -> stripe.checkout.Session: + """ + Create a Stripe Checkout session for subscription signup. + + Args: + db: Database session + vendor: Vendor to create checkout for + price_id: Stripe price ID + success_url: URL to redirect on success + cancel_url: URL to redirect on cancel + trial_days: Optional trial period + quantity: Number of items (default 1) + metadata: Additional metadata to store + + Returns: + Stripe Checkout Session object + """ + self._check_configured() + + # Get or create Stripe customer + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.vendor_id == vendor.id) + .first() + ) + + if subscription and subscription.stripe_customer_id: + customer_id = subscription.stripe_customer_id + else: + # Get vendor owner email + from models.database.vendor import VendorUser + + owner = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor.id, + VendorUser.is_owner == True, + ) + .first() + ) + email = owner.user.email if owner and owner.user else None + + customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com") + + # Store the customer ID + if subscription: + subscription.stripe_customer_id = customer_id + db.flush() + + # Build metadata + session_metadata = { + "vendor_id": str(vendor.id), + "vendor_code": vendor.vendor_code, + } + if metadata: + session_metadata.update(metadata) + + session_data = { + "customer": customer_id, + "line_items": [{"price": price_id, "quantity": quantity}], + "mode": "subscription", + "success_url": success_url, + "cancel_url": cancel_url, + "metadata": session_metadata, + } + + if trial_days: + session_data["subscription_data"] = {"trial_period_days": trial_days} + + session = stripe.checkout.Session.create(**session_data) + logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}") + return session + + def create_portal_session( + self, + customer_id: str, + return_url: str, + ) -> stripe.billing_portal.Session: + """ + Create a Stripe Customer Portal session. + + Allows customers to manage their subscription, payment methods, and invoices. + + Args: + customer_id: Stripe customer ID + return_url: URL to return to after portal + + Returns: + Stripe Portal Session object + """ + self._check_configured() + + session = stripe.billing_portal.Session.create( + customer=customer_id, + return_url=return_url, + ) + logger.info(f"Created portal session for customer {customer_id}") + return session + + # ========================================================================= + # Invoice Management + # ========================================================================= + + def get_invoices( + self, + customer_id: str, + limit: int = 10, + ) -> list[stripe.Invoice]: + """Get invoices for a customer.""" + self._check_configured() + + invoices = stripe.Invoice.list(customer=customer_id, limit=limit) + return list(invoices.data) + + def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None: + """Get the upcoming invoice for a customer.""" + self._check_configured() + + try: + return stripe.Invoice.upcoming(customer=customer_id) + except stripe.error.InvalidRequestError: + # No upcoming invoice + return None + + # ========================================================================= + # Webhook Handling + # ========================================================================= + + def construct_event( + self, + payload: bytes, + sig_header: str, + ) -> stripe.Event: + """ + Construct and verify a Stripe webhook event. + + Args: + payload: Raw request body + sig_header: Stripe-Signature header + + Returns: + Verified Stripe Event object + + Raises: + WebhookVerificationException: If signature verification fails + """ + if not settings.stripe_webhook_secret: + raise WebhookVerificationException("Stripe webhook secret not configured") + + try: + event = stripe.Webhook.construct_event( + payload, + sig_header, + settings.stripe_webhook_secret, + ) + return event + except stripe.error.SignatureVerificationError as e: + logger.error(f"Webhook signature verification failed: {e}") + raise WebhookVerificationException("Invalid webhook signature") + + # ========================================================================= + # SetupIntent & Payment Method Management + # ========================================================================= + + def create_setup_intent( + self, + customer_id: str, + metadata: dict | None = None, + ) -> stripe.SetupIntent: + """ + Create a SetupIntent to collect card without charging. + + Used for trial signups where we collect card upfront + but don't charge until trial ends. + + Args: + customer_id: Stripe customer ID + metadata: Optional metadata to attach + + Returns: + Stripe SetupIntent object with client_secret for frontend + """ + self._check_configured() + + setup_intent = stripe.SetupIntent.create( + customer=customer_id, + payment_method_types=["card"], + metadata=metadata or {}, + ) + + logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}") + return setup_intent + + def attach_payment_method_to_customer( + self, + customer_id: str, + payment_method_id: str, + set_as_default: bool = True, + ) -> None: + """ + Attach a payment method to customer and optionally set as default. + + Args: + customer_id: Stripe customer ID + payment_method_id: Payment method ID from confirmed SetupIntent + set_as_default: Whether to set as default payment method + """ + self._check_configured() + + # Attach the payment method to the customer + stripe.PaymentMethod.attach(payment_method_id, customer=customer_id) + + if set_as_default: + stripe.Customer.modify( + customer_id, + invoice_settings={"default_payment_method": payment_method_id}, + ) + + logger.info( + f"Attached payment method {payment_method_id} to customer {customer_id} " + f"(default={set_as_default})" + ) + + def create_subscription_with_trial( + self, + customer_id: str, + price_id: str, + trial_days: int = 30, + metadata: dict | None = None, + ) -> stripe.Subscription: + """ + Create subscription with trial period. + + Customer must have a default payment method attached. + Card will be charged automatically after trial ends. + + Args: + customer_id: Stripe customer ID (must have default payment method) + price_id: Stripe price ID for the subscription tier + trial_days: Number of trial days (default 30) + metadata: Optional metadata to attach + + Returns: + Stripe Subscription object + """ + self._check_configured() + + subscription = stripe.Subscription.create( + customer=customer_id, + items=[{"price": price_id}], + trial_period_days=trial_days, + metadata=metadata or {}, + # Use default payment method for future charges + default_payment_method=None, # Uses customer's default + ) + + logger.info( + f"Created subscription {subscription.id} with {trial_days}-day trial " + f"for customer {customer_id}" + ) + return subscription + + def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent: + """Get a SetupIntent by ID.""" + self._check_configured() + + return stripe.SetupIntent.retrieve(setup_intent_id) + + # ========================================================================= + # Price/Product Management + # ========================================================================= + + def get_price(self, price_id: str) -> stripe.Price: + """Get a Stripe price by ID.""" + self._check_configured() + + return stripe.Price.retrieve(price_id) + + def get_product(self, product_id: str) -> stripe.Product: + """Get a Stripe product by ID.""" + self._check_configured() + + return stripe.Product.retrieve(product_id) + + def list_prices( + self, + product_id: str | None = None, + active: bool = True, + ) -> list[stripe.Price]: + """List Stripe prices, optionally filtered by product.""" + self._check_configured() + + params = {"active": active} + if product_id: + params["product"] = product_id + + prices = stripe.Price.list(**params) + return list(prices.data) + + +# Create service instance +stripe_service = StripeService() diff --git a/app/modules/billing/services/subscription_service.py b/app/modules/billing/services/subscription_service.py new file mode 100644 index 00000000..51eb48ae --- /dev/null +++ b/app/modules/billing/services/subscription_service.py @@ -0,0 +1,631 @@ +# app/modules/billing/services/subscription_service.py +""" +Subscription service for tier-based access control. + +Handles: +- Subscription creation and management +- Tier limit enforcement +- Usage tracking +- Feature gating + +Usage: + from app.modules.billing.services import subscription_service + + # Check if vendor can create an order + can_create, message = subscription_service.can_create_order(db, vendor_id) + + # Increment order counter after successful order + subscription_service.increment_order_count(db, vendor_id) +""" + +import logging +from datetime import UTC, datetime, timedelta +from typing import Any + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.modules.billing.exceptions import ( + FeatureNotAvailableException, + SubscriptionNotFoundException, + TierLimitExceededException, +) +from app.modules.billing.models import ( + SubscriptionStatus, + SubscriptionTier, + TIER_LIMITS, + TierCode, + VendorSubscription, +) +from app.modules.billing.schemas import ( + SubscriptionCreate, + SubscriptionUpdate, + SubscriptionUsage, + TierInfo, + TierLimits, + UsageSummary, +) +from models.database.product import Product +from models.database.vendor import Vendor, VendorUser + +logger = logging.getLogger(__name__) + + +class SubscriptionService: + """Service for subscription and tier limit operations.""" + + # ========================================================================= + # Tier Information + # ========================================================================= + + def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo: + """ + Get full tier information. + + Queries database if db session provided, otherwise falls back to TIER_LIMITS. + """ + # Try database first if session provided + if db is not None: + db_tier = self.get_tier_by_code(db, tier_code) + if db_tier: + return TierInfo( + code=db_tier.code, + name=db_tier.name, + price_monthly_cents=db_tier.price_monthly_cents, + price_annual_cents=db_tier.price_annual_cents, + limits=TierLimits( + orders_per_month=db_tier.orders_per_month, + products_limit=db_tier.products_limit, + team_members=db_tier.team_members, + order_history_months=db_tier.order_history_months, + ), + features=db_tier.features or [], + ) + + # Fallback to hardcoded TIER_LIMITS + return self._get_tier_from_legacy(tier_code) + + def _get_tier_from_legacy(self, tier_code: str) -> TierInfo: + """Get tier info from hardcoded TIER_LIMITS (fallback).""" + try: + tier = TierCode(tier_code) + except ValueError: + tier = TierCode.ESSENTIAL + + limits = TIER_LIMITS[tier] + return TierInfo( + code=tier.value, + name=limits["name"], + price_monthly_cents=limits["price_monthly_cents"], + price_annual_cents=limits.get("price_annual_cents"), + limits=TierLimits( + orders_per_month=limits.get("orders_per_month"), + products_limit=limits.get("products_limit"), + team_members=limits.get("team_members"), + order_history_months=limits.get("order_history_months"), + ), + features=limits.get("features", []), + ) + + def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]: + """ + Get information for all tiers. + + Queries database if db session provided, otherwise falls back to TIER_LIMITS. + """ + if db is not None: + db_tiers = ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.is_active == True, # noqa: E712 + SubscriptionTier.is_public == True, # noqa: E712 + ) + .order_by(SubscriptionTier.display_order) + .all() + ) + if db_tiers: + return [ + TierInfo( + code=t.code, + name=t.name, + price_monthly_cents=t.price_monthly_cents, + price_annual_cents=t.price_annual_cents, + limits=TierLimits( + orders_per_month=t.orders_per_month, + products_limit=t.products_limit, + team_members=t.team_members, + order_history_months=t.order_history_months, + ), + features=t.features or [], + ) + for t in db_tiers + ] + + # Fallback to hardcoded + return [ + self._get_tier_from_legacy(tier.value) + for tier in TierCode + ] + + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: + """Get subscription tier by code.""" + return ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == tier_code) + .first() + ) + + def get_tier_id(self, db: Session, tier_code: str) -> int | None: + """Get tier ID from tier code. Returns None if tier not found.""" + tier = self.get_tier_by_code(db, tier_code) + return tier.id if tier else None + + # ========================================================================= + # Subscription CRUD + # ========================================================================= + + def get_subscription( + self, db: Session, vendor_id: int + ) -> VendorSubscription | None: + """Get vendor subscription.""" + return ( + db.query(VendorSubscription) + .filter(VendorSubscription.vendor_id == vendor_id) + .first() + ) + + def get_subscription_or_raise( + self, db: Session, vendor_id: int + ) -> VendorSubscription: + """Get vendor subscription or raise exception.""" + subscription = self.get_subscription(db, vendor_id) + if not subscription: + raise SubscriptionNotFoundException(vendor_id) + return subscription + + def get_current_tier( + self, db: Session, vendor_id: int + ) -> TierCode | None: + """Get vendor's current subscription tier code.""" + subscription = self.get_subscription(db, vendor_id) + if subscription: + try: + return TierCode(subscription.tier) + except ValueError: + return None + return None + + def get_or_create_subscription( + self, + db: Session, + vendor_id: int, + tier: str = TierCode.ESSENTIAL.value, + trial_days: int = 14, + ) -> VendorSubscription: + """ + Get existing subscription or create a new trial subscription. + + Used when a vendor first accesses the system. + """ + subscription = self.get_subscription(db, vendor_id) + if subscription: + return subscription + + # Create new trial subscription + now = datetime.now(UTC) + trial_end = now + timedelta(days=trial_days) + + # Lookup tier_id from tier code + tier_id = self.get_tier_id(db, tier) + + subscription = VendorSubscription( + vendor_id=vendor_id, + tier=tier, + tier_id=tier_id, + status=SubscriptionStatus.TRIAL.value, + period_start=now, + period_end=trial_end, + trial_ends_at=trial_end, + is_annual=False, + ) + + db.add(subscription) + db.flush() + db.refresh(subscription) + + logger.info( + f"Created trial subscription for vendor {vendor_id} " + f"(tier={tier}, trial_ends={trial_end})" + ) + + return subscription + + def create_subscription( + self, + db: Session, + vendor_id: int, + data: SubscriptionCreate, + ) -> VendorSubscription: + """Create a subscription for a vendor.""" + # Check if subscription exists + existing = self.get_subscription(db, vendor_id) + if existing: + raise ValueError("Vendor already has a subscription") + + now = datetime.now(UTC) + + # Calculate period end based on billing cycle + if data.is_annual: + period_end = now + timedelta(days=365) + else: + period_end = now + timedelta(days=30) + + # Handle trial + trial_ends_at = None + status = SubscriptionStatus.ACTIVE.value + if data.trial_days > 0: + trial_ends_at = now + timedelta(days=data.trial_days) + status = SubscriptionStatus.TRIAL.value + period_end = trial_ends_at + + # Lookup tier_id from tier code + tier_id = self.get_tier_id(db, data.tier) + + subscription = VendorSubscription( + vendor_id=vendor_id, + tier=data.tier, + tier_id=tier_id, + status=status, + period_start=now, + period_end=period_end, + trial_ends_at=trial_ends_at, + is_annual=data.is_annual, + ) + + db.add(subscription) + db.flush() + db.refresh(subscription) + + logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}") + return subscription + + def update_subscription( + self, + db: Session, + vendor_id: int, + data: SubscriptionUpdate, + ) -> VendorSubscription: + """Update a vendor subscription.""" + subscription = self.get_subscription_or_raise(db, vendor_id) + + update_data = data.model_dump(exclude_unset=True) + + # If tier is being updated, also update tier_id + if "tier" in update_data: + tier_id = self.get_tier_id(db, update_data["tier"]) + update_data["tier_id"] = tier_id + + for key, value in update_data.items(): + setattr(subscription, key, value) + + subscription.updated_at = datetime.now(UTC) + db.flush() + db.refresh(subscription) + + logger.info(f"Updated subscription for vendor {vendor_id}") + return subscription + + def upgrade_tier( + self, + db: Session, + vendor_id: int, + new_tier: str, + ) -> VendorSubscription: + """Upgrade vendor to a new tier.""" + subscription = self.get_subscription_or_raise(db, vendor_id) + + old_tier = subscription.tier + subscription.tier = new_tier + subscription.tier_id = self.get_tier_id(db, new_tier) + subscription.updated_at = datetime.now(UTC) + + # If upgrading from trial, mark as active + if subscription.status == SubscriptionStatus.TRIAL.value: + subscription.status = SubscriptionStatus.ACTIVE.value + + db.flush() + db.refresh(subscription) + + logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}") + return subscription + + def cancel_subscription( + self, + db: Session, + vendor_id: int, + reason: str | None = None, + ) -> VendorSubscription: + """Cancel a vendor subscription (access until period end).""" + subscription = self.get_subscription_or_raise(db, vendor_id) + + subscription.status = SubscriptionStatus.CANCELLED.value + subscription.cancelled_at = datetime.now(UTC) + subscription.cancellation_reason = reason + subscription.updated_at = datetime.now(UTC) + + db.flush() + db.refresh(subscription) + + logger.info(f"Cancelled subscription for vendor {vendor_id}") + return subscription + + # ========================================================================= + # Usage Tracking + # ========================================================================= + + def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage: + """Get current subscription usage statistics.""" + subscription = self.get_or_create_subscription(db, vendor_id) + + # Get actual counts + products_count = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + team_count = ( + db.query(func.count(VendorUser.id)) + .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) + .scalar() + or 0 + ) + + # Calculate usage stats + orders_limit = subscription.orders_limit + products_limit = subscription.products_limit + team_limit = subscription.team_members_limit + + def calc_remaining(current: int, limit: int | None) -> int | None: + if limit is None: + return None + return max(0, limit - current) + + def calc_percent(current: int, limit: int | None) -> float | None: + if limit is None or limit == 0: + return None + return min(100.0, (current / limit) * 100) + + return SubscriptionUsage( + orders_used=subscription.orders_this_period, + orders_limit=orders_limit, + orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit), + orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit), + products_used=products_count, + products_limit=products_limit, + products_remaining=calc_remaining(products_count, products_limit), + products_percent_used=calc_percent(products_count, products_limit), + team_members_used=team_count, + team_members_limit=team_limit, + team_members_remaining=calc_remaining(team_count, team_limit), + team_members_percent_used=calc_percent(team_count, team_limit), + ) + + def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary: + """Get usage summary for billing page display.""" + subscription = self.get_or_create_subscription(db, vendor_id) + + # Get actual counts + products_count = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + team_count = ( + db.query(func.count(VendorUser.id)) + .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) + .scalar() + or 0 + ) + + # Get limits + orders_limit = subscription.orders_limit + products_limit = subscription.products_limit + team_limit = subscription.team_members_limit + + def calc_remaining(current: int, limit: int | None) -> int | None: + if limit is None: + return None + return max(0, limit - current) + + return UsageSummary( + orders_this_period=subscription.orders_this_period, + orders_limit=orders_limit, + orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit), + products_count=products_count, + products_limit=products_limit, + products_remaining=calc_remaining(products_count, products_limit), + team_count=team_count, + team_limit=team_limit, + team_remaining=calc_remaining(team_count, team_limit), + ) + + def increment_order_count(self, db: Session, vendor_id: int) -> None: + """ + Increment the order counter for the current period. + + Call this after successfully creating/importing an order. + """ + subscription = self.get_or_create_subscription(db, vendor_id) + subscription.increment_order_count() + db.flush() + + def reset_period_counters(self, db: Session, vendor_id: int) -> None: + """Reset counters for a new billing period.""" + subscription = self.get_subscription_or_raise(db, vendor_id) + subscription.reset_period_counters() + db.flush() + logger.info(f"Reset period counters for vendor {vendor_id}") + + # ========================================================================= + # Limit Checks + # ========================================================================= + + def can_create_order( + self, db: Session, vendor_id: int + ) -> tuple[bool, str | None]: + """ + Check if vendor can create/import another order. + + Returns: (allowed, error_message) + """ + subscription = self.get_or_create_subscription(db, vendor_id) + return subscription.can_create_order() + + def check_order_limit(self, db: Session, vendor_id: int) -> None: + """ + Check order limit and raise exception if exceeded. + + Use this in order creation flows. + """ + can_create, message = self.can_create_order(db, vendor_id) + if not can_create: + subscription = self.get_subscription(db, vendor_id) + raise TierLimitExceededException( + message=message or "Order limit exceeded", + limit_type="orders", + current=subscription.orders_this_period if subscription else 0, + limit=subscription.orders_limit if subscription else 0, + ) + + def can_add_product( + self, db: Session, vendor_id: int + ) -> tuple[bool, str | None]: + """ + Check if vendor can add another product. + + Returns: (allowed, error_message) + """ + subscription = self.get_or_create_subscription(db, vendor_id) + + products_count = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + return subscription.can_add_product(products_count) + + def check_product_limit(self, db: Session, vendor_id: int) -> None: + """ + Check product limit and raise exception if exceeded. + + Use this in product creation flows. + """ + can_add, message = self.can_add_product(db, vendor_id) + if not can_add: + subscription = self.get_subscription(db, vendor_id) + products_count = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + raise TierLimitExceededException( + message=message or "Product limit exceeded", + limit_type="products", + current=products_count, + limit=subscription.products_limit if subscription else 0, + ) + + def can_add_team_member( + self, db: Session, vendor_id: int + ) -> tuple[bool, str | None]: + """ + Check if vendor can add another team member. + + Returns: (allowed, error_message) + """ + subscription = self.get_or_create_subscription(db, vendor_id) + + team_count = ( + db.query(func.count(VendorUser.id)) + .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) + .scalar() + or 0 + ) + + return subscription.can_add_team_member(team_count) + + def check_team_limit(self, db: Session, vendor_id: int) -> None: + """ + Check team member limit and raise exception if exceeded. + + Use this in team member invitation flows. + """ + can_add, message = self.can_add_team_member(db, vendor_id) + if not can_add: + subscription = self.get_subscription(db, vendor_id) + team_count = ( + db.query(func.count(VendorUser.id)) + .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) + .scalar() + or 0 + ) + raise TierLimitExceededException( + message=message or "Team member limit exceeded", + limit_type="team_members", + current=team_count, + limit=subscription.team_members_limit if subscription else 0, + ) + + # ========================================================================= + # Feature Gating + # ========================================================================= + + def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool: + """Check if vendor has access to a feature.""" + subscription = self.get_or_create_subscription(db, vendor_id) + return subscription.has_feature(feature) + + def check_feature(self, db: Session, vendor_id: int, feature: str) -> None: + """ + Check feature access and raise exception if not available. + + Use this to gate premium features. + """ + if not self.has_feature(db, vendor_id, feature): + subscription = self.get_or_create_subscription(db, vendor_id) + + # Find which tier has this feature + required_tier = None + for tier_code, limits in TIER_LIMITS.items(): + if feature in limits.get("features", []): + required_tier = limits["name"] + break + + raise FeatureNotAvailableException( + feature=feature, + current_tier=subscription.tier, + required_tier=required_tier or "higher", + ) + + def get_feature_tier(self, feature: str) -> str | None: + """Get the minimum tier required for a feature.""" + for tier_code in [ + TierCode.ESSENTIAL, + TierCode.PROFESSIONAL, + TierCode.BUSINESS, + TierCode.ENTERPRISE, + ]: + if feature in TIER_LIMITS[tier_code].get("features", []): + return tier_code.value + return None + + +# Singleton instance +subscription_service = SubscriptionService() diff --git a/app/modules/billing/tasks/__init__.py b/app/modules/billing/tasks/__init__.py new file mode 100644 index 00000000..4efa7ef3 --- /dev/null +++ b/app/modules/billing/tasks/__init__.py @@ -0,0 +1,26 @@ +# app/modules/billing/tasks/__init__.py +""" +Billing module Celery tasks. + +Scheduled tasks for: +- Resetting period counters +- Checking trial expirations +- Syncing with Stripe +- Cleaning up stale subscriptions + +Note: capture_capacity_snapshot moved to monitoring module. +""" + +from app.modules.billing.tasks.subscription import ( + reset_period_counters, + check_trial_expirations, + sync_stripe_status, + cleanup_stale_subscriptions, +) + +__all__ = [ + "reset_period_counters", + "check_trial_expirations", + "sync_stripe_status", + "cleanup_stale_subscriptions", +] diff --git a/app/modules/billing/tasks/subscription.py b/app/modules/billing/tasks/subscription.py new file mode 100644 index 00000000..878c2d00 --- /dev/null +++ b/app/modules/billing/tasks/subscription.py @@ -0,0 +1,255 @@ +# app/modules/billing/tasks/subscription.py +""" +Celery tasks for subscription management. + +Scheduled tasks for: +- Resetting period counters +- Checking trial expirations +- Syncing with Stripe +- Cleaning up stale subscriptions +""" + +import logging +from datetime import UTC, datetime, timedelta + +from app.core.celery_config import celery_app +from app.modules.billing.models import SubscriptionStatus, VendorSubscription +from app.modules.billing.services import stripe_service +from app.modules.task_base import ModuleTask + +logger = logging.getLogger(__name__) + + +@celery_app.task( + bind=True, + base=ModuleTask, + name="app.modules.billing.tasks.subscription.reset_period_counters", +) +def reset_period_counters(self): + """ + Reset order counters for subscriptions whose billing period has ended. + + Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates. + """ + now = datetime.now(UTC) + reset_count = 0 + + with self.get_db() as db: + # Find subscriptions where period has ended + expired_periods = ( + db.query(VendorSubscription) + .filter( + VendorSubscription.period_end <= now, + VendorSubscription.status.in_(["active", "trial"]), + ) + .all() + ) + + for subscription in expired_periods: + old_period_end = subscription.period_end + + # Reset counters + subscription.orders_this_period = 0 + subscription.orders_limit_reached_at = None + + # Set new period dates + if subscription.is_annual: + subscription.period_start = now + subscription.period_end = now + timedelta(days=365) + else: + subscription.period_start = now + subscription.period_end = now + timedelta(days=30) + + subscription.updated_at = now + reset_count += 1 + + logger.info( + f"Reset period counters for vendor {subscription.vendor_id}: " + f"old_period_end={old_period_end}, new_period_end={subscription.period_end}" + ) + + logger.info(f"Reset period counters for {reset_count} subscriptions") + + return {"reset_count": reset_count} + + +@celery_app.task( + bind=True, + base=ModuleTask, + name="app.modules.billing.tasks.subscription.check_trial_expirations", +) +def check_trial_expirations(self): + """ + Check for expired trials and update their status. + + Runs daily at 01:00. + - Trials without payment method -> expired + - Trials with payment method -> active + """ + now = datetime.now(UTC) + expired_count = 0 + activated_count = 0 + + with self.get_db() as db: + # Find expired trials + expired_trials = ( + db.query(VendorSubscription) + .filter( + VendorSubscription.status == SubscriptionStatus.TRIAL.value, + VendorSubscription.trial_ends_at <= now, + ) + .all() + ) + + for subscription in expired_trials: + if subscription.stripe_payment_method_id: + # Has payment method - activate + subscription.status = SubscriptionStatus.ACTIVE.value + activated_count += 1 + logger.info( + f"Activated subscription for vendor {subscription.vendor_id} " + f"(trial ended with payment method)" + ) + else: + # No payment method - expire + subscription.status = SubscriptionStatus.EXPIRED.value + expired_count += 1 + logger.info( + f"Expired trial for vendor {subscription.vendor_id} " + f"(no payment method)" + ) + + subscription.updated_at = now + + logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated") + + return {"expired_count": expired_count, "activated_count": activated_count} + + +@celery_app.task( + bind=True, + base=ModuleTask, + name="app.modules.billing.tasks.subscription.sync_stripe_status", + max_retries=3, + default_retry_delay=300, +) +def sync_stripe_status(self): + """ + Sync subscription status with Stripe. + + Runs hourly at :30. Fetches current status from Stripe and updates local records. + """ + if not stripe_service.is_configured: + logger.warning("Stripe not configured, skipping sync") + return {"synced": 0, "skipped": True} + + synced_count = 0 + error_count = 0 + + with self.get_db() as db: + # Find subscriptions with Stripe IDs + subscriptions = ( + db.query(VendorSubscription) + .filter(VendorSubscription.stripe_subscription_id.isnot(None)) + .all() + ) + + for subscription in subscriptions: + try: + # Fetch from Stripe + stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id) + + if not stripe_sub: + logger.warning( + f"Stripe subscription {subscription.stripe_subscription_id} " + f"not found for vendor {subscription.vendor_id}" + ) + continue + + # Map Stripe status to local status + status_map = { + "active": SubscriptionStatus.ACTIVE.value, + "trialing": SubscriptionStatus.TRIAL.value, + "past_due": SubscriptionStatus.PAST_DUE.value, + "canceled": SubscriptionStatus.CANCELLED.value, + "unpaid": SubscriptionStatus.PAST_DUE.value, + "incomplete": SubscriptionStatus.TRIAL.value, + "incomplete_expired": SubscriptionStatus.EXPIRED.value, + } + + new_status = status_map.get(stripe_sub.status) + if new_status and new_status != subscription.status: + old_status = subscription.status + subscription.status = new_status + subscription.updated_at = datetime.now(UTC) + logger.info( + f"Updated vendor {subscription.vendor_id} status: " + f"{old_status} -> {new_status} (from Stripe)" + ) + + # Update period dates from Stripe + if stripe_sub.current_period_start: + subscription.period_start = datetime.fromtimestamp( + stripe_sub.current_period_start, tz=UTC + ) + if stripe_sub.current_period_end: + subscription.period_end = datetime.fromtimestamp( + stripe_sub.current_period_end, tz=UTC + ) + + # Update payment method + if stripe_sub.default_payment_method: + subscription.stripe_payment_method_id = ( + stripe_sub.default_payment_method + if isinstance(stripe_sub.default_payment_method, str) + else stripe_sub.default_payment_method.id + ) + + synced_count += 1 + + except Exception as e: + logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}") + error_count += 1 + + logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors") + + return {"synced_count": synced_count, "error_count": error_count} + + +@celery_app.task( + bind=True, + base=ModuleTask, + name="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions", +) +def cleanup_stale_subscriptions(self): + """ + Clean up subscriptions in inconsistent states. + + Runs weekly on Sunday at 03:00. + """ + now = datetime.now(UTC) + cleaned_count = 0 + + with self.get_db() as db: + # Find cancelled subscriptions past their period end + stale_cancelled = ( + db.query(VendorSubscription) + .filter( + VendorSubscription.status == SubscriptionStatus.CANCELLED.value, + VendorSubscription.period_end < now - timedelta(days=30), + ) + .all() + ) + + for subscription in stale_cancelled: + # Mark as expired (fully terminated) + subscription.status = SubscriptionStatus.EXPIRED.value + subscription.updated_at = now + cleaned_count += 1 + logger.info( + f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}" + ) + + logger.info(f"Cleaned up {cleaned_count} stale subscriptions") + + return {"cleaned_count": cleaned_count} diff --git a/app/services/admin_subscription_service.py b/app/services/admin_subscription_service.py index 9e5d37fc..8dc32c31 100644 --- a/app/services/admin_subscription_service.py +++ b/app/services/admin_subscription_service.py @@ -2,351 +2,21 @@ """ Admin Subscription Service. -Handles subscription management operations for platform administrators: -- Subscription tier CRUD -- Vendor subscription management -- Billing history queries -- Subscription analytics +DEPRECATED: This file is maintained for backward compatibility. +Import from app.modules.billing.services instead: + + from app.modules.billing.services import admin_subscription_service + +This file re-exports the service from its new location in the billing module. """ -import logging -from math import ceil - -from sqlalchemy import func -from sqlalchemy.orm import Session - -from app.exceptions import ( - BusinessLogicException, - ConflictException, - ResourceNotFoundException, - TierNotFoundException, +# Re-export from new location for backward compatibility +from app.modules.billing.services.admin_subscription_service import ( + AdminSubscriptionService, + admin_subscription_service, ) -from models.database.product import Product -from models.database.subscription import ( - BillingHistory, - SubscriptionStatus, - SubscriptionTier, - VendorSubscription, -) -from models.database.vendor import Vendor, VendorUser -logger = logging.getLogger(__name__) - - -class AdminSubscriptionService: - """Service for admin subscription management operations.""" - - # ========================================================================= - # Subscription Tiers - # ========================================================================= - - def get_tiers( - self, db: Session, include_inactive: bool = False - ) -> list[SubscriptionTier]: - """Get all subscription tiers.""" - query = db.query(SubscriptionTier) - - if not include_inactive: - query = query.filter(SubscriptionTier.is_active == True) # noqa: E712 - - return query.order_by(SubscriptionTier.display_order).all() - - def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: - """Get a subscription tier by code.""" - tier = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == tier_code) - .first() - ) - - if not tier: - raise TierNotFoundException(tier_code) - - return tier - - def create_tier(self, db: Session, tier_data: dict) -> SubscriptionTier: - """Create a new subscription tier.""" - # Check for duplicate code - existing = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == tier_data["code"]) - .first() - ) - if existing: - raise ConflictException( - f"Tier with code '{tier_data['code']}' already exists" - ) - - tier = SubscriptionTier(**tier_data) - db.add(tier) - - logger.info(f"Created subscription tier: {tier.code}") - return tier - - def update_tier( - self, db: Session, tier_code: str, update_data: dict - ) -> SubscriptionTier: - """Update a subscription tier.""" - tier = self.get_tier_by_code(db, tier_code) - - for field, value in update_data.items(): - setattr(tier, field, value) - - logger.info(f"Updated subscription tier: {tier.code}") - return tier - - def deactivate_tier(self, db: Session, tier_code: str) -> None: - """Soft-delete a subscription tier.""" - tier = self.get_tier_by_code(db, tier_code) - - # Check if any active subscriptions use this tier - active_subs = ( - db.query(VendorSubscription) - .filter( - VendorSubscription.tier == tier_code, - VendorSubscription.status.in_([ - SubscriptionStatus.ACTIVE.value, - SubscriptionStatus.TRIAL.value, - ]), - ) - .count() - ) - - if active_subs > 0: - raise BusinessLogicException( - f"Cannot delete tier: {active_subs} active subscriptions are using it" - ) - - tier.is_active = False - - logger.info(f"Soft-deleted subscription tier: {tier.code}") - - # ========================================================================= - # Vendor Subscriptions - # ========================================================================= - - def list_subscriptions( - self, - db: Session, - page: int = 1, - per_page: int = 20, - status: str | None = None, - tier: str | None = None, - search: str | None = None, - ) -> dict: - """List vendor subscriptions with filtering and pagination.""" - query = ( - db.query(VendorSubscription, Vendor) - .join(Vendor, VendorSubscription.vendor_id == Vendor.id) - ) - - # Apply filters - if status: - query = query.filter(VendorSubscription.status == status) - if tier: - query = query.filter(VendorSubscription.tier == tier) - if search: - query = query.filter(Vendor.name.ilike(f"%{search}%")) - - # Count total - total = query.count() - - # Paginate - offset = (page - 1) * per_page - results = ( - query.order_by(VendorSubscription.created_at.desc()) - .offset(offset) - .limit(per_page) - .all() - ) - - return { - "results": results, - "total": total, - "page": page, - "per_page": per_page, - "pages": ceil(total / per_page) if total > 0 else 0, - } - - def get_subscription(self, db: Session, vendor_id: int) -> tuple: - """Get subscription for a specific vendor.""" - result = ( - db.query(VendorSubscription, Vendor) - .join(Vendor, VendorSubscription.vendor_id == Vendor.id) - .filter(VendorSubscription.vendor_id == vendor_id) - .first() - ) - - if not result: - raise ResourceNotFoundException("Subscription", str(vendor_id)) - - return result - - def update_subscription( - self, db: Session, vendor_id: int, update_data: dict - ) -> tuple: - """Update a vendor's subscription.""" - result = self.get_subscription(db, vendor_id) - sub, vendor = result - - for field, value in update_data.items(): - setattr(sub, field, value) - - logger.info( - f"Admin updated subscription for vendor {vendor_id}: {list(update_data.keys())}" - ) - - return sub, vendor - - def get_vendor(self, db: Session, vendor_id: int) -> Vendor: - """Get a vendor by ID.""" - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - - if not vendor: - raise ResourceNotFoundException("Vendor", str(vendor_id)) - - return vendor - - def get_vendor_usage_counts(self, db: Session, vendor_id: int) -> dict: - """Get usage counts (products and team members) for a vendor.""" - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - team_count = ( - db.query(func.count(VendorUser.id)) - .filter( - VendorUser.vendor_id == vendor_id, - VendorUser.is_active == True, # noqa: E712 - ) - .scalar() - or 0 - ) - - return { - "products_count": products_count, - "team_count": team_count, - } - - # ========================================================================= - # Billing History - # ========================================================================= - - def list_billing_history( - self, - db: Session, - page: int = 1, - per_page: int = 20, - vendor_id: int | None = None, - status: str | None = None, - ) -> dict: - """List billing history across all vendors.""" - query = ( - db.query(BillingHistory, Vendor) - .join(Vendor, BillingHistory.vendor_id == Vendor.id) - ) - - if vendor_id: - query = query.filter(BillingHistory.vendor_id == vendor_id) - if status: - query = query.filter(BillingHistory.status == status) - - total = query.count() - - offset = (page - 1) * per_page - results = ( - query.order_by(BillingHistory.invoice_date.desc()) - .offset(offset) - .limit(per_page) - .all() - ) - - return { - "results": results, - "total": total, - "page": page, - "per_page": per_page, - "pages": ceil(total / per_page) if total > 0 else 0, - } - - # ========================================================================= - # Statistics - # ========================================================================= - - def get_stats(self, db: Session) -> dict: - """Get subscription statistics for admin dashboard.""" - # Count by status - status_counts = ( - db.query(VendorSubscription.status, func.count(VendorSubscription.id)) - .group_by(VendorSubscription.status) - .all() - ) - - stats = { - "total_subscriptions": 0, - "active_count": 0, - "trial_count": 0, - "past_due_count": 0, - "cancelled_count": 0, - "expired_count": 0, - } - - for status, count in status_counts: - stats["total_subscriptions"] += count - if status == SubscriptionStatus.ACTIVE.value: - stats["active_count"] = count - elif status == SubscriptionStatus.TRIAL.value: - stats["trial_count"] = count - elif status == SubscriptionStatus.PAST_DUE.value: - stats["past_due_count"] = count - elif status == SubscriptionStatus.CANCELLED.value: - stats["cancelled_count"] = count - elif status == SubscriptionStatus.EXPIRED.value: - stats["expired_count"] = count - - # Count by tier - tier_counts = ( - db.query(VendorSubscription.tier, func.count(VendorSubscription.id)) - .filter( - VendorSubscription.status.in_([ - SubscriptionStatus.ACTIVE.value, - SubscriptionStatus.TRIAL.value, - ]) - ) - .group_by(VendorSubscription.tier) - .all() - ) - - tier_distribution = {tier: count for tier, count in tier_counts} - - # Calculate MRR (Monthly Recurring Revenue) - mrr_cents = 0 - arr_cents = 0 - - active_subs = ( - db.query(VendorSubscription, SubscriptionTier) - .join(SubscriptionTier, VendorSubscription.tier == SubscriptionTier.code) - .filter(VendorSubscription.status == SubscriptionStatus.ACTIVE.value) - .all() - ) - - for sub, tier in active_subs: - if sub.is_annual and tier.price_annual_cents: - mrr_cents += tier.price_annual_cents // 12 - arr_cents += tier.price_annual_cents - else: - mrr_cents += tier.price_monthly_cents - arr_cents += tier.price_monthly_cents * 12 - - stats["tier_distribution"] = tier_distribution - stats["mrr_cents"] = mrr_cents - stats["arr_cents"] = arr_cents - - return stats - - -# Singleton instance -admin_subscription_service = AdminSubscriptionService() +__all__ = [ + "AdminSubscriptionService", + "admin_subscription_service", +] diff --git a/app/services/stripe_service.py b/app/services/stripe_service.py index 60e02e36..81e13f0f 100644 --- a/app/services/stripe_service.py +++ b/app/services/stripe_service.py @@ -2,592 +2,21 @@ """ Stripe payment integration service. -Provides: -- Customer management -- Subscription management -- Checkout session creation -- Customer portal access -- Webhook event construction +DEPRECATED: This file is maintained for backward compatibility. +Import from app.modules.billing.services instead: + + from app.modules.billing.services import stripe_service + +This file re-exports the service from its new location in the billing module. """ -import logging -from datetime import datetime - -import stripe -from sqlalchemy.orm import Session - -from app.core.config import settings -from models.database.subscription import ( - BillingHistory, - SubscriptionStatus, - SubscriptionTier, - VendorSubscription, +# Re-export from new location for backward compatibility +from app.modules.billing.services.stripe_service import ( + StripeService, + stripe_service, ) -from models.database.vendor import Vendor -logger = logging.getLogger(__name__) - - -class StripeService: - """Service for Stripe payment operations.""" - - def __init__(self): - self._configured = False - self._configure() - - def _configure(self): - """Configure Stripe with API key.""" - if settings.stripe_secret_key: - stripe.api_key = settings.stripe_secret_key - self._configured = True - else: - logger.warning("Stripe API key not configured") - - @property - def is_configured(self) -> bool: - """Check if Stripe is properly configured.""" - return self._configured and bool(settings.stripe_secret_key) - - # ========================================================================= - # Customer Management - # ========================================================================= - - def create_customer( - self, - vendor: Vendor, - email: str, - name: str | None = None, - metadata: dict | None = None, - ) -> str: - """ - Create a Stripe customer for a vendor. - - Returns the Stripe customer ID. - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - customer_metadata = { - "vendor_id": str(vendor.id), - "vendor_code": vendor.vendor_code, - **(metadata or {}), - } - - customer = stripe.Customer.create( - email=email, - name=name or vendor.name, - metadata=customer_metadata, - ) - - logger.info( - f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}" - ) - return customer.id - - def get_customer(self, customer_id: str) -> stripe.Customer: - """Get a Stripe customer by ID.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - return stripe.Customer.retrieve(customer_id) - - def update_customer( - self, - customer_id: str, - email: str | None = None, - name: str | None = None, - metadata: dict | None = None, - ) -> stripe.Customer: - """Update a Stripe customer.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - update_data = {} - if email: - update_data["email"] = email - if name: - update_data["name"] = name - if metadata: - update_data["metadata"] = metadata - - return stripe.Customer.modify(customer_id, **update_data) - - # ========================================================================= - # Subscription Management - # ========================================================================= - - def create_subscription( - self, - customer_id: str, - price_id: str, - trial_days: int | None = None, - metadata: dict | None = None, - ) -> stripe.Subscription: - """ - Create a new Stripe subscription. - - Args: - customer_id: Stripe customer ID - price_id: Stripe price ID for the subscription - trial_days: Optional trial period in days - metadata: Optional metadata to attach - - Returns: - Stripe Subscription object - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - subscription_data = { - "customer": customer_id, - "items": [{"price": price_id}], - "metadata": metadata or {}, - "payment_behavior": "default_incomplete", - "expand": ["latest_invoice.payment_intent"], - } - - if trial_days: - subscription_data["trial_period_days"] = trial_days - - subscription = stripe.Subscription.create(**subscription_data) - logger.info( - f"Created Stripe subscription {subscription.id} for customer {customer_id}" - ) - return subscription - - def get_subscription(self, subscription_id: str) -> stripe.Subscription: - """Get a Stripe subscription by ID.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - return stripe.Subscription.retrieve(subscription_id) - - def update_subscription( - self, - subscription_id: str, - new_price_id: str | None = None, - proration_behavior: str = "create_prorations", - metadata: dict | None = None, - ) -> stripe.Subscription: - """ - Update a Stripe subscription (e.g., change tier). - - Args: - subscription_id: Stripe subscription ID - new_price_id: New price ID for tier change - proration_behavior: How to handle prorations - metadata: Optional metadata to update - - Returns: - Updated Stripe Subscription object - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - update_data = {"proration_behavior": proration_behavior} - - if new_price_id: - # Get the subscription to find the item ID - subscription = stripe.Subscription.retrieve(subscription_id) - item_id = subscription["items"]["data"][0]["id"] - update_data["items"] = [{"id": item_id, "price": new_price_id}] - - if metadata: - update_data["metadata"] = metadata - - updated = stripe.Subscription.modify(subscription_id, **update_data) - logger.info(f"Updated Stripe subscription {subscription_id}") - return updated - - def cancel_subscription( - self, - subscription_id: str, - immediately: bool = False, - cancellation_reason: str | None = None, - ) -> stripe.Subscription: - """ - Cancel a Stripe subscription. - - Args: - subscription_id: Stripe subscription ID - immediately: If True, cancel now. If False, cancel at period end. - cancellation_reason: Optional reason for cancellation - - Returns: - Cancelled Stripe Subscription object - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - if immediately: - subscription = stripe.Subscription.cancel(subscription_id) - else: - subscription = stripe.Subscription.modify( - subscription_id, - cancel_at_period_end=True, - metadata={"cancellation_reason": cancellation_reason or "user_request"}, - ) - - logger.info( - f"Cancelled Stripe subscription {subscription_id} " - f"(immediately={immediately})" - ) - return subscription - - def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription: - """ - Reactivate a cancelled subscription (if not yet ended). - - Returns: - Reactivated Stripe Subscription object - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - subscription = stripe.Subscription.modify( - subscription_id, - cancel_at_period_end=False, - ) - logger.info(f"Reactivated Stripe subscription {subscription_id}") - return subscription - - def cancel_subscription_item(self, subscription_item_id: str) -> None: - """ - Cancel a subscription item (used for add-ons). - - Args: - subscription_item_id: Stripe subscription item ID - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - stripe.SubscriptionItem.delete(subscription_item_id) - logger.info(f"Cancelled Stripe subscription item {subscription_item_id}") - - # ========================================================================= - # Checkout & Portal - # ========================================================================= - - def create_checkout_session( - self, - db: Session, - vendor: Vendor, - price_id: str, - success_url: str, - cancel_url: str, - trial_days: int | None = None, - quantity: int = 1, - metadata: dict | None = None, - ) -> stripe.checkout.Session: - """ - Create a Stripe Checkout session for subscription signup. - - Args: - db: Database session - vendor: Vendor to create checkout for - price_id: Stripe price ID - success_url: URL to redirect on success - cancel_url: URL to redirect on cancel - trial_days: Optional trial period - quantity: Number of items (default 1) - metadata: Additional metadata to store - - Returns: - Stripe Checkout Session object - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - # Get or create Stripe customer - subscription = ( - db.query(VendorSubscription) - .filter(VendorSubscription.vendor_id == vendor.id) - .first() - ) - - if subscription and subscription.stripe_customer_id: - customer_id = subscription.stripe_customer_id - else: - # Get vendor owner email - from models.database.vendor import VendorUser - - owner = ( - db.query(VendorUser) - .filter( - VendorUser.vendor_id == vendor.id, - VendorUser.is_owner == True, - ) - .first() - ) - email = owner.user.email if owner and owner.user else None - - customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com") - - # Store the customer ID - if subscription: - subscription.stripe_customer_id = customer_id - db.flush() - - # Build metadata - session_metadata = { - "vendor_id": str(vendor.id), - "vendor_code": vendor.vendor_code, - } - if metadata: - session_metadata.update(metadata) - - session_data = { - "customer": customer_id, - "line_items": [{"price": price_id, "quantity": quantity}], - "mode": "subscription", - "success_url": success_url, - "cancel_url": cancel_url, - "metadata": session_metadata, - } - - if trial_days: - session_data["subscription_data"] = {"trial_period_days": trial_days} - - session = stripe.checkout.Session.create(**session_data) - logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}") - return session - - def create_portal_session( - self, - customer_id: str, - return_url: str, - ) -> stripe.billing_portal.Session: - """ - Create a Stripe Customer Portal session. - - Allows customers to manage their subscription, payment methods, and invoices. - - Args: - customer_id: Stripe customer ID - return_url: URL to return to after portal - - Returns: - Stripe Portal Session object - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - session = stripe.billing_portal.Session.create( - customer=customer_id, - return_url=return_url, - ) - logger.info(f"Created portal session for customer {customer_id}") - return session - - # ========================================================================= - # Invoice Management - # ========================================================================= - - def get_invoices( - self, - customer_id: str, - limit: int = 10, - ) -> list[stripe.Invoice]: - """Get invoices for a customer.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - invoices = stripe.Invoice.list(customer=customer_id, limit=limit) - return list(invoices.data) - - def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None: - """Get the upcoming invoice for a customer.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - try: - return stripe.Invoice.upcoming(customer=customer_id) - except stripe.error.InvalidRequestError: - # No upcoming invoice - return None - - # ========================================================================= - # Webhook Handling - # ========================================================================= - - def construct_event( - self, - payload: bytes, - sig_header: str, - ) -> stripe.Event: - """ - Construct and verify a Stripe webhook event. - - Args: - payload: Raw request body - sig_header: Stripe-Signature header - - Returns: - Verified Stripe Event object - - Raises: - ValueError: If signature verification fails - """ - if not settings.stripe_webhook_secret: - raise ValueError("Stripe webhook secret not configured") - - try: - event = stripe.Webhook.construct_event( - payload, - sig_header, - settings.stripe_webhook_secret, - ) - return event - except stripe.error.SignatureVerificationError as e: - logger.error(f"Webhook signature verification failed: {e}") - raise ValueError("Invalid webhook signature") - - # ========================================================================= - # SetupIntent & Payment Method Management - # ========================================================================= - - def create_setup_intent( - self, - customer_id: str, - metadata: dict | None = None, - ) -> stripe.SetupIntent: - """ - Create a SetupIntent to collect card without charging. - - Used for trial signups where we collect card upfront - but don't charge until trial ends. - - Args: - customer_id: Stripe customer ID - metadata: Optional metadata to attach - - Returns: - Stripe SetupIntent object with client_secret for frontend - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - setup_intent = stripe.SetupIntent.create( - customer=customer_id, - payment_method_types=["card"], - metadata=metadata or {}, - ) - - logger.info(f"Created SetupIntent {setup_intent.id} for customer {customer_id}") - return setup_intent - - def attach_payment_method_to_customer( - self, - customer_id: str, - payment_method_id: str, - set_as_default: bool = True, - ) -> None: - """ - Attach a payment method to customer and optionally set as default. - - Args: - customer_id: Stripe customer ID - payment_method_id: Payment method ID from confirmed SetupIntent - set_as_default: Whether to set as default payment method - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - # Attach the payment method to the customer - stripe.PaymentMethod.attach(payment_method_id, customer=customer_id) - - if set_as_default: - stripe.Customer.modify( - customer_id, - invoice_settings={"default_payment_method": payment_method_id}, - ) - - logger.info( - f"Attached payment method {payment_method_id} to customer {customer_id} " - f"(default={set_as_default})" - ) - - def create_subscription_with_trial( - self, - customer_id: str, - price_id: str, - trial_days: int = 30, - metadata: dict | None = None, - ) -> stripe.Subscription: - """ - Create subscription with trial period. - - Customer must have a default payment method attached. - Card will be charged automatically after trial ends. - - Args: - customer_id: Stripe customer ID (must have default payment method) - price_id: Stripe price ID for the subscription tier - trial_days: Number of trial days (default 30) - metadata: Optional metadata to attach - - Returns: - Stripe Subscription object - """ - if not self.is_configured: - raise ValueError("Stripe is not configured") - - subscription = stripe.Subscription.create( - customer=customer_id, - items=[{"price": price_id}], - trial_period_days=trial_days, - metadata=metadata or {}, - # Use default payment method for future charges - default_payment_method=None, # Uses customer's default - ) - - logger.info( - f"Created subscription {subscription.id} with {trial_days}-day trial " - f"for customer {customer_id}" - ) - return subscription - - def get_setup_intent(self, setup_intent_id: str) -> stripe.SetupIntent: - """Get a SetupIntent by ID.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - return stripe.SetupIntent.retrieve(setup_intent_id) - - # ========================================================================= - # Price/Product Management - # ========================================================================= - - def get_price(self, price_id: str) -> stripe.Price: - """Get a Stripe price by ID.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - return stripe.Price.retrieve(price_id) - - def get_product(self, product_id: str) -> stripe.Product: - """Get a Stripe product by ID.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - return stripe.Product.retrieve(product_id) - - def list_prices( - self, - product_id: str | None = None, - active: bool = True, - ) -> list[stripe.Price]: - """List Stripe prices, optionally filtered by product.""" - if not self.is_configured: - raise ValueError("Stripe is not configured") - - params = {"active": active} - if product_id: - params["product"] = product_id - - prices = stripe.Price.list(**params) - return list(prices.data) - - -# Create service instance -stripe_service = StripeService() +__all__ = [ + "StripeService", + "stripe_service", +] diff --git a/app/services/subscription_service.py b/app/services/subscription_service.py index 83c08226..c31e8547 100644 --- a/app/services/subscription_service.py +++ b/app/services/subscription_service.py @@ -2,654 +2,29 @@ """ Subscription service for tier-based access control. -Handles: -- Subscription creation and management -- Tier limit enforcement -- Usage tracking -- Feature gating +DEPRECATED: This file is maintained for backward compatibility. +Import from app.modules.billing.services instead: -Usage: - from app.services.subscription_service import subscription_service + from app.modules.billing.services import subscription_service - # Check if vendor can create an order - can_create, message = subscription_service.can_create_order(db, vendor_id) - - # Increment order counter after successful order - subscription_service.increment_order_count(db, vendor_id) +This file re-exports the service from its new location in the billing module. """ -import logging -from datetime import UTC, datetime, timedelta -from typing import Any - -from sqlalchemy import func -from sqlalchemy.orm import Session - -from models.database.product import Product -from models.database.subscription import ( - SubscriptionStatus, - SubscriptionTier, - TIER_LIMITS, - TierCode, - VendorSubscription, +# Re-export from new location for backward compatibility +from app.modules.billing.services.subscription_service import ( + SubscriptionService, + subscription_service, ) -from models.database.vendor import Vendor, VendorUser -from models.schema.subscription import ( - SubscriptionCreate, - SubscriptionUpdate, - SubscriptionUsage, - TierInfo, - TierLimits, - UsageSummary, +from app.modules.billing.exceptions import ( + SubscriptionNotFoundException, + TierLimitExceededException, + FeatureNotAvailableException, ) -logger = logging.getLogger(__name__) - - -class SubscriptionNotFoundException(Exception): - """Raised when subscription not found.""" - - pass - - -class TierLimitExceededException(Exception): - """Raised when a tier limit is exceeded.""" - - def __init__(self, message: str, limit_type: str, current: int, limit: int): - super().__init__(message) - self.limit_type = limit_type - self.current = current - self.limit = limit - - -class FeatureNotAvailableException(Exception): - """Raised when a feature is not available in current tier.""" - - def __init__(self, feature: str, current_tier: str, required_tier: str): - message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})" - super().__init__(message) - self.feature = feature - self.current_tier = current_tier - self.required_tier = required_tier - - -class SubscriptionService: - """Service for subscription and tier limit operations.""" - - # ========================================================================= - # Tier Information - # ========================================================================= - - def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo: - """ - Get full tier information. - - Queries database if db session provided, otherwise falls back to TIER_LIMITS. - """ - # Try database first if session provided - if db is not None: - db_tier = self.get_tier_by_code(db, tier_code) - if db_tier: - return TierInfo( - code=db_tier.code, - name=db_tier.name, - price_monthly_cents=db_tier.price_monthly_cents, - price_annual_cents=db_tier.price_annual_cents, - limits=TierLimits( - orders_per_month=db_tier.orders_per_month, - products_limit=db_tier.products_limit, - team_members=db_tier.team_members, - order_history_months=db_tier.order_history_months, - ), - features=db_tier.features or [], - ) - - # Fallback to hardcoded TIER_LIMITS - return self._get_tier_from_legacy(tier_code) - - def _get_tier_from_legacy(self, tier_code: str) -> TierInfo: - """Get tier info from hardcoded TIER_LIMITS (fallback).""" - try: - tier = TierCode(tier_code) - except ValueError: - tier = TierCode.ESSENTIAL - - limits = TIER_LIMITS[tier] - return TierInfo( - code=tier.value, - name=limits["name"], - price_monthly_cents=limits["price_monthly_cents"], - price_annual_cents=limits.get("price_annual_cents"), - limits=TierLimits( - orders_per_month=limits.get("orders_per_month"), - products_limit=limits.get("products_limit"), - team_members=limits.get("team_members"), - order_history_months=limits.get("order_history_months"), - ), - features=limits.get("features", []), - ) - - def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]: - """ - Get information for all tiers. - - Queries database if db session provided, otherwise falls back to TIER_LIMITS. - """ - if db is not None: - db_tiers = ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.is_active == True, # noqa: E712 - SubscriptionTier.is_public == True, # noqa: E712 - ) - .order_by(SubscriptionTier.display_order) - .all() - ) - if db_tiers: - return [ - TierInfo( - code=t.code, - name=t.name, - price_monthly_cents=t.price_monthly_cents, - price_annual_cents=t.price_annual_cents, - limits=TierLimits( - orders_per_month=t.orders_per_month, - products_limit=t.products_limit, - team_members=t.team_members, - order_history_months=t.order_history_months, - ), - features=t.features or [], - ) - for t in db_tiers - ] - - # Fallback to hardcoded - return [ - self._get_tier_from_legacy(tier.value) - for tier in TierCode - ] - - def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None: - """Get subscription tier by code.""" - return ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == tier_code) - .first() - ) - - def get_tier_id(self, db: Session, tier_code: str) -> int | None: - """Get tier ID from tier code. Returns None if tier not found.""" - tier = self.get_tier_by_code(db, tier_code) - return tier.id if tier else None - - # ========================================================================= - # Subscription CRUD - # ========================================================================= - - def get_subscription( - self, db: Session, vendor_id: int - ) -> VendorSubscription | None: - """Get vendor subscription.""" - return ( - db.query(VendorSubscription) - .filter(VendorSubscription.vendor_id == vendor_id) - .first() - ) - - def get_subscription_or_raise( - self, db: Session, vendor_id: int - ) -> VendorSubscription: - """Get vendor subscription or raise exception.""" - subscription = self.get_subscription(db, vendor_id) - if not subscription: - raise SubscriptionNotFoundException( - f"No subscription found for vendor {vendor_id}" - ) - return subscription - - def get_current_tier( - self, db: Session, vendor_id: int - ) -> TierCode | None: - """Get vendor's current subscription tier code.""" - subscription = self.get_subscription(db, vendor_id) - if subscription: - try: - return TierCode(subscription.tier) - except ValueError: - return None - return None - - def get_or_create_subscription( - self, - db: Session, - vendor_id: int, - tier: str = TierCode.ESSENTIAL.value, - trial_days: int = 14, - ) -> VendorSubscription: - """ - Get existing subscription or create a new trial subscription. - - Used when a vendor first accesses the system. - """ - subscription = self.get_subscription(db, vendor_id) - if subscription: - return subscription - - # Create new trial subscription - now = datetime.now(UTC) - trial_end = now + timedelta(days=trial_days) - - # Lookup tier_id from tier code - tier_id = self.get_tier_id(db, tier) - - subscription = VendorSubscription( - vendor_id=vendor_id, - tier=tier, - tier_id=tier_id, - status=SubscriptionStatus.TRIAL.value, - period_start=now, - period_end=trial_end, - trial_ends_at=trial_end, - is_annual=False, - ) - - db.add(subscription) - db.flush() - db.refresh(subscription) - - logger.info( - f"Created trial subscription for vendor {vendor_id} " - f"(tier={tier}, trial_ends={trial_end})" - ) - - return subscription - - def create_subscription( - self, - db: Session, - vendor_id: int, - data: SubscriptionCreate, - ) -> VendorSubscription: - """Create a subscription for a vendor.""" - # Check if subscription exists - existing = self.get_subscription(db, vendor_id) - if existing: - raise ValueError("Vendor already has a subscription") - - now = datetime.now(UTC) - - # Calculate period end based on billing cycle - if data.is_annual: - period_end = now + timedelta(days=365) - else: - period_end = now + timedelta(days=30) - - # Handle trial - trial_ends_at = None - status = SubscriptionStatus.ACTIVE.value - if data.trial_days > 0: - trial_ends_at = now + timedelta(days=data.trial_days) - status = SubscriptionStatus.TRIAL.value - period_end = trial_ends_at - - # Lookup tier_id from tier code - tier_id = self.get_tier_id(db, data.tier) - - subscription = VendorSubscription( - vendor_id=vendor_id, - tier=data.tier, - tier_id=tier_id, - status=status, - period_start=now, - period_end=period_end, - trial_ends_at=trial_ends_at, - is_annual=data.is_annual, - ) - - db.add(subscription) - db.flush() - db.refresh(subscription) - - logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}") - return subscription - - def update_subscription( - self, - db: Session, - vendor_id: int, - data: SubscriptionUpdate, - ) -> VendorSubscription: - """Update a vendor subscription.""" - subscription = self.get_subscription_or_raise(db, vendor_id) - - update_data = data.model_dump(exclude_unset=True) - - # If tier is being updated, also update tier_id - if "tier" in update_data: - tier_id = self.get_tier_id(db, update_data["tier"]) - update_data["tier_id"] = tier_id - - for key, value in update_data.items(): - setattr(subscription, key, value) - - subscription.updated_at = datetime.now(UTC) - db.flush() - db.refresh(subscription) - - logger.info(f"Updated subscription for vendor {vendor_id}") - return subscription - - def upgrade_tier( - self, - db: Session, - vendor_id: int, - new_tier: str, - ) -> VendorSubscription: - """Upgrade vendor to a new tier.""" - subscription = self.get_subscription_or_raise(db, vendor_id) - - old_tier = subscription.tier - subscription.tier = new_tier - subscription.tier_id = self.get_tier_id(db, new_tier) - subscription.updated_at = datetime.now(UTC) - - # If upgrading from trial, mark as active - if subscription.status == SubscriptionStatus.TRIAL.value: - subscription.status = SubscriptionStatus.ACTIVE.value - - db.flush() - db.refresh(subscription) - - logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}") - return subscription - - def cancel_subscription( - self, - db: Session, - vendor_id: int, - reason: str | None = None, - ) -> VendorSubscription: - """Cancel a vendor subscription (access until period end).""" - subscription = self.get_subscription_or_raise(db, vendor_id) - - subscription.status = SubscriptionStatus.CANCELLED.value - subscription.cancelled_at = datetime.now(UTC) - subscription.cancellation_reason = reason - subscription.updated_at = datetime.now(UTC) - - db.flush() - db.refresh(subscription) - - logger.info(f"Cancelled subscription for vendor {vendor_id}") - return subscription - - # ========================================================================= - # Usage Tracking - # ========================================================================= - - def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage: - """Get current subscription usage statistics.""" - subscription = self.get_or_create_subscription(db, vendor_id) - - # Get actual counts - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - team_count = ( - db.query(func.count(VendorUser.id)) - .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) - .scalar() - or 0 - ) - - # Calculate usage stats - orders_limit = subscription.orders_limit - products_limit = subscription.products_limit - team_limit = subscription.team_members_limit - - def calc_remaining(current: int, limit: int | None) -> int | None: - if limit is None: - return None - return max(0, limit - current) - - def calc_percent(current: int, limit: int | None) -> float | None: - if limit is None or limit == 0: - return None - return min(100.0, (current / limit) * 100) - - return SubscriptionUsage( - orders_used=subscription.orders_this_period, - orders_limit=orders_limit, - orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit), - orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit), - products_used=products_count, - products_limit=products_limit, - products_remaining=calc_remaining(products_count, products_limit), - products_percent_used=calc_percent(products_count, products_limit), - team_members_used=team_count, - team_members_limit=team_limit, - team_members_remaining=calc_remaining(team_count, team_limit), - team_members_percent_used=calc_percent(team_count, team_limit), - ) - - def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary: - """Get usage summary for billing page display.""" - subscription = self.get_or_create_subscription(db, vendor_id) - - # Get actual counts - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - team_count = ( - db.query(func.count(VendorUser.id)) - .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) - .scalar() - or 0 - ) - - # Get limits - orders_limit = subscription.orders_limit - products_limit = subscription.products_limit - team_limit = subscription.team_members_limit - - def calc_remaining(current: int, limit: int | None) -> int | None: - if limit is None: - return None - return max(0, limit - current) - - return UsageSummary( - orders_this_period=subscription.orders_this_period, - orders_limit=orders_limit, - orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit), - products_count=products_count, - products_limit=products_limit, - products_remaining=calc_remaining(products_count, products_limit), - team_count=team_count, - team_limit=team_limit, - team_remaining=calc_remaining(team_count, team_limit), - ) - - def increment_order_count(self, db: Session, vendor_id: int) -> None: - """ - Increment the order counter for the current period. - - Call this after successfully creating/importing an order. - """ - subscription = self.get_or_create_subscription(db, vendor_id) - subscription.increment_order_count() - db.flush() - - def reset_period_counters(self, db: Session, vendor_id: int) -> None: - """Reset counters for a new billing period.""" - subscription = self.get_subscription_or_raise(db, vendor_id) - subscription.reset_period_counters() - db.flush() - logger.info(f"Reset period counters for vendor {vendor_id}") - - # ========================================================================= - # Limit Checks - # ========================================================================= - - def can_create_order( - self, db: Session, vendor_id: int - ) -> tuple[bool, str | None]: - """ - Check if vendor can create/import another order. - - Returns: (allowed, error_message) - """ - subscription = self.get_or_create_subscription(db, vendor_id) - return subscription.can_create_order() - - def check_order_limit(self, db: Session, vendor_id: int) -> None: - """ - Check order limit and raise exception if exceeded. - - Use this in order creation flows. - """ - can_create, message = self.can_create_order(db, vendor_id) - if not can_create: - subscription = self.get_subscription(db, vendor_id) - raise TierLimitExceededException( - message=message or "Order limit exceeded", - limit_type="orders", - current=subscription.orders_this_period if subscription else 0, - limit=subscription.orders_limit if subscription else 0, - ) - - def can_add_product( - self, db: Session, vendor_id: int - ) -> tuple[bool, str | None]: - """ - Check if vendor can add another product. - - Returns: (allowed, error_message) - """ - subscription = self.get_or_create_subscription(db, vendor_id) - - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - return subscription.can_add_product(products_count) - - def check_product_limit(self, db: Session, vendor_id: int) -> None: - """ - Check product limit and raise exception if exceeded. - - Use this in product creation flows. - """ - can_add, message = self.can_add_product(db, vendor_id) - if not can_add: - subscription = self.get_subscription(db, vendor_id) - products_count = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - raise TierLimitExceededException( - message=message or "Product limit exceeded", - limit_type="products", - current=products_count, - limit=subscription.products_limit if subscription else 0, - ) - - def can_add_team_member( - self, db: Session, vendor_id: int - ) -> tuple[bool, str | None]: - """ - Check if vendor can add another team member. - - Returns: (allowed, error_message) - """ - subscription = self.get_or_create_subscription(db, vendor_id) - - team_count = ( - db.query(func.count(VendorUser.id)) - .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) - .scalar() - or 0 - ) - - return subscription.can_add_team_member(team_count) - - def check_team_limit(self, db: Session, vendor_id: int) -> None: - """ - Check team member limit and raise exception if exceeded. - - Use this in team member invitation flows. - """ - can_add, message = self.can_add_team_member(db, vendor_id) - if not can_add: - subscription = self.get_subscription(db, vendor_id) - team_count = ( - db.query(func.count(VendorUser.id)) - .filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True) - .scalar() - or 0 - ) - raise TierLimitExceededException( - message=message or "Team member limit exceeded", - limit_type="team_members", - current=team_count, - limit=subscription.team_members_limit if subscription else 0, - ) - - # ========================================================================= - # Feature Gating - # ========================================================================= - - def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool: - """Check if vendor has access to a feature.""" - subscription = self.get_or_create_subscription(db, vendor_id) - return subscription.has_feature(feature) - - def check_feature(self, db: Session, vendor_id: int, feature: str) -> None: - """ - Check feature access and raise exception if not available. - - Use this to gate premium features. - """ - if not self.has_feature(db, vendor_id, feature): - subscription = self.get_or_create_subscription(db, vendor_id) - - # Find which tier has this feature - required_tier = None - for tier_code, limits in TIER_LIMITS.items(): - if feature in limits.get("features", []): - required_tier = limits["name"] - break - - raise FeatureNotAvailableException( - feature=feature, - current_tier=subscription.tier, - required_tier=required_tier or "higher", - ) - - def get_feature_tier(self, feature: str) -> str | None: - """Get the minimum tier required for a feature.""" - for tier_code in [ - TierCode.ESSENTIAL, - TierCode.PROFESSIONAL, - TierCode.BUSINESS, - TierCode.ENTERPRISE, - ]: - if feature in TIER_LIMITS[tier_code].get("features", []): - return tier_code.value - return None - - -# Singleton instance -subscription_service = SubscriptionService() +__all__ = [ + "SubscriptionService", + "subscription_service", + "SubscriptionNotFoundException", + "TierLimitExceededException", + "FeatureNotAvailableException", +] diff --git a/app/tasks/celery_tasks/subscription.py b/app/tasks/celery_tasks/subscription.py index 73227c68..ed231cbf 100644 --- a/app/tasks/celery_tasks/subscription.py +++ b/app/tasks/celery_tasks/subscription.py @@ -1,265 +1,27 @@ # app/tasks/celery_tasks/subscription.py """ -Celery tasks for subscription management. +Legacy subscription tasks. -Scheduled tasks for: -- Resetting period counters -- Checking trial expirations -- Syncing with Stripe -- Cleaning up stale subscriptions -- Capturing capacity snapshots +MOSTLY MIGRATED: Most tasks have been migrated to app.modules.billing.tasks. + +The following tasks now live in the billing module: +- reset_period_counters -> app.modules.billing.tasks.subscription +- check_trial_expirations -> app.modules.billing.tasks.subscription +- sync_stripe_status -> app.modules.billing.tasks.subscription +- cleanup_stale_subscriptions -> app.modules.billing.tasks.subscription + +Remaining task (to be migrated to monitoring module): +- capture_capacity_snapshot """ import logging -from datetime import UTC, datetime, timedelta from app.core.celery_config import celery_app -from app.services.stripe_service import stripe_service from app.tasks.celery_tasks.base import DatabaseTask -from models.database.subscription import SubscriptionStatus, VendorSubscription logger = logging.getLogger(__name__) -@celery_app.task( - bind=True, - base=DatabaseTask, - name="app.tasks.celery_tasks.subscription.reset_period_counters", -) -def reset_period_counters(self): - """ - Reset order counters for subscriptions whose billing period has ended. - - Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates. - """ - now = datetime.now(UTC) - reset_count = 0 - - with self.get_db() as db: - # Find subscriptions where period has ended - expired_periods = ( - db.query(VendorSubscription) - .filter( - VendorSubscription.period_end <= now, - VendorSubscription.status.in_(["active", "trial"]), - ) - .all() - ) - - for subscription in expired_periods: - old_period_end = subscription.period_end - - # Reset counters - subscription.orders_this_period = 0 - subscription.orders_limit_reached_at = None - - # Set new period dates - if subscription.is_annual: - subscription.period_start = now - subscription.period_end = now + timedelta(days=365) - else: - subscription.period_start = now - subscription.period_end = now + timedelta(days=30) - - subscription.updated_at = now - reset_count += 1 - - logger.info( - f"Reset period counters for vendor {subscription.vendor_id}: " - f"old_period_end={old_period_end}, new_period_end={subscription.period_end}" - ) - - db.commit() - logger.info(f"Reset period counters for {reset_count} subscriptions") - - return {"reset_count": reset_count} - - -@celery_app.task( - bind=True, - base=DatabaseTask, - name="app.tasks.celery_tasks.subscription.check_trial_expirations", -) -def check_trial_expirations(self): - """ - Check for expired trials and update their status. - - Runs daily at 01:00. - - Trials without payment method -> expired - - Trials with payment method -> active - """ - now = datetime.now(UTC) - expired_count = 0 - activated_count = 0 - - with self.get_db() as db: - # Find expired trials - expired_trials = ( - db.query(VendorSubscription) - .filter( - VendorSubscription.status == SubscriptionStatus.TRIAL.value, - VendorSubscription.trial_ends_at <= now, - ) - .all() - ) - - for subscription in expired_trials: - if subscription.stripe_payment_method_id: - # Has payment method - activate - subscription.status = SubscriptionStatus.ACTIVE.value - activated_count += 1 - logger.info( - f"Activated subscription for vendor {subscription.vendor_id} " - f"(trial ended with payment method)" - ) - else: - # No payment method - expire - subscription.status = SubscriptionStatus.EXPIRED.value - expired_count += 1 - logger.info( - f"Expired trial for vendor {subscription.vendor_id} " - f"(no payment method)" - ) - - subscription.updated_at = now - - db.commit() - logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated") - - return {"expired_count": expired_count, "activated_count": activated_count} - - -@celery_app.task( - bind=True, - base=DatabaseTask, - name="app.tasks.celery_tasks.subscription.sync_stripe_status", - max_retries=3, - default_retry_delay=300, -) -def sync_stripe_status(self): - """ - Sync subscription status with Stripe. - - Runs hourly at :30. Fetches current status from Stripe and updates local records. - """ - if not stripe_service.is_configured: - logger.warning("Stripe not configured, skipping sync") - return {"synced": 0, "skipped": True} - - synced_count = 0 - error_count = 0 - - with self.get_db() as db: - # Find subscriptions with Stripe IDs - subscriptions = ( - db.query(VendorSubscription) - .filter(VendorSubscription.stripe_subscription_id.isnot(None)) - .all() - ) - - for subscription in subscriptions: - try: - # Fetch from Stripe - stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id) - - if not stripe_sub: - logger.warning( - f"Stripe subscription {subscription.stripe_subscription_id} " - f"not found for vendor {subscription.vendor_id}" - ) - continue - - # Map Stripe status to local status - status_map = { - "active": SubscriptionStatus.ACTIVE.value, - "trialing": SubscriptionStatus.TRIAL.value, - "past_due": SubscriptionStatus.PAST_DUE.value, - "canceled": SubscriptionStatus.CANCELLED.value, - "unpaid": SubscriptionStatus.PAST_DUE.value, - "incomplete": SubscriptionStatus.TRIAL.value, - "incomplete_expired": SubscriptionStatus.EXPIRED.value, - } - - new_status = status_map.get(stripe_sub.status) - if new_status and new_status != subscription.status: - old_status = subscription.status - subscription.status = new_status - subscription.updated_at = datetime.now(UTC) - logger.info( - f"Updated vendor {subscription.vendor_id} status: " - f"{old_status} -> {new_status} (from Stripe)" - ) - - # Update period dates from Stripe - if stripe_sub.current_period_start: - subscription.period_start = datetime.fromtimestamp( - stripe_sub.current_period_start, tz=UTC - ) - if stripe_sub.current_period_end: - subscription.period_end = datetime.fromtimestamp( - stripe_sub.current_period_end, tz=UTC - ) - - # Update payment method - if stripe_sub.default_payment_method: - subscription.stripe_payment_method_id = ( - stripe_sub.default_payment_method - if isinstance(stripe_sub.default_payment_method, str) - else stripe_sub.default_payment_method.id - ) - - synced_count += 1 - - except Exception as e: - logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}") - error_count += 1 - - db.commit() - logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors") - - return {"synced_count": synced_count, "error_count": error_count} - - -@celery_app.task( - bind=True, - base=DatabaseTask, - name="app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions", -) -def cleanup_stale_subscriptions(self): - """ - Clean up subscriptions in inconsistent states. - - Runs weekly on Sunday at 03:00. - """ - now = datetime.now(UTC) - cleaned_count = 0 - - with self.get_db() as db: - # Find cancelled subscriptions past their period end - stale_cancelled = ( - db.query(VendorSubscription) - .filter( - VendorSubscription.status == SubscriptionStatus.CANCELLED.value, - VendorSubscription.period_end < now - timedelta(days=30), - ) - .all() - ) - - for subscription in stale_cancelled: - # Mark as expired (fully terminated) - subscription.status = SubscriptionStatus.EXPIRED.value - subscription.updated_at = now - cleaned_count += 1 - logger.info( - f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}" - ) - - db.commit() - logger.info(f"Cleaned up {cleaned_count} stale subscriptions") - - return {"cleaned_count": cleaned_count} - - @celery_app.task( bind=True, base=DatabaseTask, @@ -270,6 +32,8 @@ def capture_capacity_snapshot(self): Capture a daily snapshot of platform capacity metrics. Runs daily at midnight. + + TODO: Migrate to app.modules.monitoring.tasks """ from app.services.capacity_forecast_service import capacity_forecast_service diff --git a/docs/proposals/module-migration-plan.md b/docs/proposals/module-migration-plan.md index d9b94fa6..d12e0809 100644 --- a/docs/proposals/module-migration-plan.md +++ b/docs/proposals/module-migration-plan.md @@ -29,7 +29,7 @@ Transform the platform from a monolithic structure to self-contained modules whe |--------|---------------|---------------|----------|--------|-------|--------| | `cms` | Core | ✅ **Complete** | ✅ | ✅ | - | Done | | `payments` | Optional | 🟡 Partial | ✅ | ✅ | - | Done | -| `billing` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full | +| `billing` | Optional | ✅ **Complete** | ✅ | ✅ | ✅ | Done | | `marketplace` | Optional | 🔴 Shell | ❌ | ❌ | ❌ | Full | | `orders` | Optional | 🔴 Shell | ❌ | ❌ | - | Full | | `inventory` | Optional | 🔴 Shell | ❌ | ❌ | - | Full | @@ -110,85 +110,22 @@ app/tasks/celery_tasks/ # → Move to respective modules - Observability docs - Creating modules developer guide ---- +### ✅ Phase 4: Celery Task Infrastructure +- Added `ScheduledTask` dataclass to `ModuleDefinition` +- Added `tasks_path` and `scheduled_tasks` fields +- Created `app/modules/task_base.py` with `ModuleTask` base class +- Created `app/modules/tasks.py` with discovery utilities +- Updated `celery_config.py` to use module discovery -## Infrastructure Work (Remaining) - -### Phase 4: Celery Task Infrastructure - -Add task support to module system before migrating modules. - -#### 4.1 Update ModuleDefinition - -```python -# app/modules/base.py - Add fields - -@dataclass -class ScheduledTask: - """Celery Beat scheduled task definition.""" - name: str # e.g., "billing.reset_counters" - task: str # Full task path - schedule: str | dict # Cron string or crontab dict - args: tuple = () - kwargs: dict = field(default_factory=dict) - -@dataclass -class ModuleDefinition: - # ... existing fields ... - - # Task configuration (NEW) - tasks_path: str | None = None - scheduled_tasks: list[ScheduledTask] = field(default_factory=list) -``` - -#### 4.2 Create Task Discovery - -```python -# app/modules/tasks.py (NEW) - -def discover_module_tasks() -> list[str]: - """Discover task modules from all registered modules.""" - ... - -def build_beat_schedule() -> dict: - """Build Celery Beat schedule from module definitions.""" - ... -``` - -#### 4.3 Create Module Task Base - -```python -# app/modules/task_base.py (NEW) - -class ModuleTask(Task): - """Base Celery task with DB session management.""" - - @contextmanager - def get_db(self): - ... -``` - -#### 4.4 Update Celery Config - -```python -# app/core/celery_config.py - -from app.modules.tasks import discover_module_tasks, build_beat_schedule - -# Auto-discover from modules -celery_app.autodiscover_tasks(discover_module_tasks()) - -# Build schedule from modules -celery_app.conf.beat_schedule = build_beat_schedule() -``` - -**Files to create:** -- `app/modules/tasks.py` -- `app/modules/task_base.py` - -**Files to modify:** -- `app/modules/base.py` -- `app/core/celery_config.py` +### ✅ Phase 5: Billing Module Migration +- Created `app/modules/billing/services/` with subscription, stripe, admin services +- Created `app/modules/billing/models/` re-exporting from central location +- Created `app/modules/billing/schemas/` re-exporting from central location +- Created `app/modules/billing/tasks/` with 4 scheduled tasks +- Created `app/modules/billing/exceptions.py` +- Updated `definition.py` with self-contained configuration +- Created backward-compatible re-exports in `app/services/` +- Updated legacy celery_config.py to not duplicate scheduled tasks ---