From 9d8d5e7138116c6c320e8769dfe325854cfae075 Mon Sep 17 00:00:00 2001 From: Samir Boulahtit Date: Thu, 25 Dec 2025 20:29:44 +0100 Subject: [PATCH] feat: add subscription and billing system with Stripe integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add database models for subscription tiers, vendor subscriptions, add-ons, billing history, and webhook events - Implement BillingService for subscription operations - Implement StripeService for Stripe API operations - Implement StripeWebhookHandler for webhook event processing - Add vendor billing API endpoints for subscription management - Create vendor billing page with Alpine.js frontend - Add limit enforcement for products and team members - Add billing exceptions for proper error handling - Create comprehensive unit tests (40 tests passing) - Add subscription billing documentation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- ...d10d22c_add_subscription_billing_tables.py | 419 +++++++++++ app/api/main.py | 3 +- app/api/v1/admin/products.py | 2 + app/api/v1/shared/webhooks.py | 64 +- app/api/v1/vendor/__init__.py | 2 + app/api/v1/vendor/billing.py | 405 +++++++++++ app/api/v1/vendor/products.py | 7 + app/core/config.py | 8 + app/exceptions/__init__.py | 21 + app/exceptions/billing.py | 95 +++ app/routes/vendor_pages.py | 22 + app/services/billing_service.py | 370 ++++++++++ app/services/marketplace_product_service.py | 31 + app/services/stripe_service.py | 459 ++++++++++++ app/services/stripe_webhook_handler.py | 411 +++++++++++ app/services/vendor_team_service.py | 22 +- app/templates/vendor/billing.html | 397 +++++++++++ app/templates/vendor/partials/sidebar.html | 11 + docs/features/subscription-billing.md | 270 +++++++ mkdocs.yml | 1 + models/database/__init__.py | 16 +- models/database/subscription.py | 287 +++++++- models/database/vendor.py | 15 + requirements.txt | 5 +- static/vendor/js/billing.js | 187 +++++ tests/unit/services/test_billing_service.py | 658 ++++++++++++++++++ .../services/test_stripe_webhook_handler.py | 393 +++++++++++ 27 files changed, 4558 insertions(+), 23 deletions(-) create mode 100644 alembic/versions/2953ed10d22c_add_subscription_billing_tables.py create mode 100644 app/api/v1/vendor/billing.py create mode 100644 app/exceptions/billing.py create mode 100644 app/services/billing_service.py create mode 100644 app/services/stripe_service.py create mode 100644 app/services/stripe_webhook_handler.py create mode 100644 app/templates/vendor/billing.html create mode 100644 docs/features/subscription-billing.md create mode 100644 static/vendor/js/billing.js create mode 100644 tests/unit/services/test_billing_service.py create mode 100644 tests/unit/services/test_stripe_webhook_handler.py diff --git a/alembic/versions/2953ed10d22c_add_subscription_billing_tables.py b/alembic/versions/2953ed10d22c_add_subscription_billing_tables.py new file mode 100644 index 00000000..e279b573 --- /dev/null +++ b/alembic/versions/2953ed10d22c_add_subscription_billing_tables.py @@ -0,0 +1,419 @@ +"""add_subscription_billing_tables + +Revision ID: 2953ed10d22c +Revises: e1bfb453fbe9 +Create Date: 2025-12-25 18:29:34.167773 + +""" +from datetime import datetime +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import sqlite + +# revision identifiers, used by Alembic. +revision: str = '2953ed10d22c' +down_revision: Union[str, None] = 'e1bfb453fbe9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ========================================================================= + # Create new subscription and billing tables + # ========================================================================= + + # subscription_tiers - Database-driven tier definitions + op.create_table('subscription_tiers', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('code', sa.String(length=30), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('price_monthly_cents', sa.Integer(), nullable=False), + sa.Column('price_annual_cents', sa.Integer(), nullable=True), + sa.Column('orders_per_month', sa.Integer(), nullable=True), + sa.Column('products_limit', sa.Integer(), nullable=True), + sa.Column('team_members', sa.Integer(), nullable=True), + sa.Column('order_history_months', sa.Integer(), nullable=True), + sa.Column('features', sqlite.JSON(), nullable=True), + sa.Column('stripe_product_id', sa.String(length=100), nullable=True), + sa.Column('stripe_price_monthly_id', sa.String(length=100), nullable=True), + sa.Column('stripe_price_annual_id', sa.String(length=100), nullable=True), + sa.Column('display_order', sa.Integer(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_public', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_subscription_tiers_code'), 'subscription_tiers', ['code'], unique=True) + op.create_index(op.f('ix_subscription_tiers_id'), 'subscription_tiers', ['id'], unique=False) + + # addon_products - Purchasable add-ons (domains, SSL, email) + op.create_table('addon_products', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('code', sa.String(length=50), nullable=False), + sa.Column('name', sa.String(length=100), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('category', sa.String(length=50), nullable=False), + sa.Column('price_cents', sa.Integer(), nullable=False), + sa.Column('billing_period', sa.String(length=20), nullable=False), + sa.Column('quantity_unit', sa.String(length=50), nullable=True), + sa.Column('quantity_value', sa.Integer(), nullable=True), + sa.Column('stripe_product_id', sa.String(length=100), nullable=True), + sa.Column('stripe_price_id', sa.String(length=100), nullable=True), + sa.Column('display_order', sa.Integer(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_addon_products_category'), 'addon_products', ['category'], unique=False) + op.create_index(op.f('ix_addon_products_code'), 'addon_products', ['code'], unique=True) + op.create_index(op.f('ix_addon_products_id'), 'addon_products', ['id'], unique=False) + + # billing_history - Invoice and payment history + op.create_table('billing_history', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('vendor_id', sa.Integer(), nullable=False), + sa.Column('stripe_invoice_id', sa.String(length=100), nullable=True), + sa.Column('stripe_payment_intent_id', sa.String(length=100), nullable=True), + sa.Column('invoice_number', sa.String(length=50), nullable=True), + sa.Column('invoice_date', sa.DateTime(timezone=True), nullable=False), + sa.Column('due_date', sa.DateTime(timezone=True), nullable=True), + sa.Column('subtotal_cents', sa.Integer(), nullable=False), + sa.Column('tax_cents', sa.Integer(), nullable=False), + sa.Column('total_cents', sa.Integer(), nullable=False), + sa.Column('amount_paid_cents', sa.Integer(), nullable=False), + sa.Column('currency', sa.String(length=3), nullable=False), + sa.Column('status', sa.String(length=20), nullable=False), + sa.Column('invoice_pdf_url', sa.String(length=500), nullable=True), + sa.Column('hosted_invoice_url', sa.String(length=500), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('line_items', sqlite.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('idx_billing_status', 'billing_history', ['vendor_id', 'status'], unique=False) + op.create_index('idx_billing_vendor_date', 'billing_history', ['vendor_id', 'invoice_date'], unique=False) + op.create_index(op.f('ix_billing_history_id'), 'billing_history', ['id'], unique=False) + op.create_index(op.f('ix_billing_history_status'), 'billing_history', ['status'], unique=False) + op.create_index(op.f('ix_billing_history_stripe_invoice_id'), 'billing_history', ['stripe_invoice_id'], unique=True) + op.create_index(op.f('ix_billing_history_vendor_id'), 'billing_history', ['vendor_id'], unique=False) + + # vendor_addons - Add-ons purchased by vendor + op.create_table('vendor_addons', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('vendor_id', sa.Integer(), nullable=False), + sa.Column('addon_product_id', sa.Integer(), nullable=False), + sa.Column('status', sa.String(length=20), nullable=False), + sa.Column('domain_name', sa.String(length=255), nullable=True), + sa.Column('quantity', sa.Integer(), nullable=False), + sa.Column('stripe_subscription_item_id', sa.String(length=100), nullable=True), + sa.Column('period_start', sa.DateTime(timezone=True), nullable=True), + sa.Column('period_end', sa.DateTime(timezone=True), nullable=True), + sa.Column('cancelled_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['addon_product_id'], ['addon_products.id'], ), + sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('idx_vendor_addon_product', 'vendor_addons', ['vendor_id', 'addon_product_id'], unique=False) + op.create_index('idx_vendor_addon_status', 'vendor_addons', ['vendor_id', 'status'], unique=False) + op.create_index(op.f('ix_vendor_addons_addon_product_id'), 'vendor_addons', ['addon_product_id'], unique=False) + op.create_index(op.f('ix_vendor_addons_domain_name'), 'vendor_addons', ['domain_name'], unique=False) + op.create_index(op.f('ix_vendor_addons_id'), 'vendor_addons', ['id'], unique=False) + op.create_index(op.f('ix_vendor_addons_status'), 'vendor_addons', ['status'], unique=False) + op.create_index(op.f('ix_vendor_addons_vendor_id'), 'vendor_addons', ['vendor_id'], unique=False) + + # stripe_webhook_events - Webhook idempotency tracking + op.create_table('stripe_webhook_events', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('event_id', sa.String(length=100), nullable=False), + sa.Column('event_type', sa.String(length=100), nullable=False), + sa.Column('status', sa.String(length=20), nullable=False), + sa.Column('processed_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), + sa.Column('payload_encrypted', sa.Text(), nullable=True), + sa.Column('vendor_id', sa.Integer(), nullable=True), + sa.Column('subscription_id', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['subscription_id'], ['vendor_subscriptions.id'], ), + sa.ForeignKeyConstraint(['vendor_id'], ['vendors.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('idx_webhook_event_type_status', 'stripe_webhook_events', ['event_type', 'status'], unique=False) + op.create_index(op.f('ix_stripe_webhook_events_event_id'), 'stripe_webhook_events', ['event_id'], unique=True) + op.create_index(op.f('ix_stripe_webhook_events_event_type'), 'stripe_webhook_events', ['event_type'], unique=False) + op.create_index(op.f('ix_stripe_webhook_events_id'), 'stripe_webhook_events', ['id'], unique=False) + op.create_index(op.f('ix_stripe_webhook_events_status'), 'stripe_webhook_events', ['status'], unique=False) + op.create_index(op.f('ix_stripe_webhook_events_subscription_id'), 'stripe_webhook_events', ['subscription_id'], unique=False) + op.create_index(op.f('ix_stripe_webhook_events_vendor_id'), 'stripe_webhook_events', ['vendor_id'], unique=False) + + # ========================================================================= + # Add new columns to vendor_subscriptions + # ========================================================================= + op.add_column('vendor_subscriptions', sa.Column('stripe_price_id', sa.String(length=100), nullable=True)) + op.add_column('vendor_subscriptions', sa.Column('stripe_payment_method_id', sa.String(length=100), nullable=True)) + op.add_column('vendor_subscriptions', sa.Column('proration_behavior', sa.String(length=50), nullable=True)) + op.add_column('vendor_subscriptions', sa.Column('scheduled_tier_change', sa.String(length=30), nullable=True)) + op.add_column('vendor_subscriptions', sa.Column('scheduled_change_at', sa.DateTime(timezone=True), nullable=True)) + op.add_column('vendor_subscriptions', sa.Column('payment_retry_count', sa.Integer(), server_default='0', nullable=False)) + op.add_column('vendor_subscriptions', sa.Column('last_payment_error', sa.Text(), nullable=True)) + + # ========================================================================= + # Seed subscription tiers + # ========================================================================= + now = datetime.utcnow() + + subscription_tiers = sa.table( + 'subscription_tiers', + sa.column('code', sa.String), + sa.column('name', sa.String), + sa.column('description', sa.Text), + sa.column('price_monthly_cents', sa.Integer), + sa.column('price_annual_cents', sa.Integer), + sa.column('orders_per_month', sa.Integer), + sa.column('products_limit', sa.Integer), + sa.column('team_members', sa.Integer), + sa.column('order_history_months', sa.Integer), + sa.column('features', sqlite.JSON), + sa.column('display_order', sa.Integer), + sa.column('is_active', sa.Boolean), + sa.column('is_public', sa.Boolean), + sa.column('created_at', sa.DateTime), + sa.column('updated_at', sa.DateTime), + ) + + op.bulk_insert(subscription_tiers, [ + { + 'code': 'essential', + 'name': 'Essential', + 'description': 'Perfect for solo vendors getting started with Letzshop', + 'price_monthly_cents': 4900, + 'price_annual_cents': 49000, + 'orders_per_month': 100, + 'products_limit': 200, + 'team_members': 1, + 'order_history_months': 6, + 'features': ['letzshop_sync', 'inventory_basic', 'invoice_lu', 'customer_view'], + 'display_order': 1, + 'is_active': True, + 'is_public': True, + 'created_at': now, + 'updated_at': now, + }, + { + 'code': 'professional', + 'name': 'Professional', + 'description': 'For active multi-channel vendors shipping EU-wide', + 'price_monthly_cents': 9900, + 'price_annual_cents': 99000, + 'orders_per_month': 500, + 'products_limit': None, + 'team_members': 3, + 'order_history_months': 24, + 'features': [ + 'letzshop_sync', 'inventory_locations', 'inventory_purchase_orders', + 'invoice_lu', 'invoice_eu_vat', 'customer_view', 'customer_export' + ], + 'display_order': 2, + 'is_active': True, + 'is_public': True, + 'created_at': now, + 'updated_at': now, + }, + { + 'code': 'business', + 'name': 'Business', + 'description': 'For high-volume vendors with teams and data-driven operations', + 'price_monthly_cents': 19900, + 'price_annual_cents': 199000, + 'orders_per_month': 2000, + 'products_limit': None, + 'team_members': 10, + 'order_history_months': None, + 'features': [ + 'letzshop_sync', 'inventory_locations', 'inventory_purchase_orders', + 'invoice_lu', 'invoice_eu_vat', 'invoice_bulk', 'customer_view', + 'customer_export', 'analytics_dashboard', 'accounting_export', + 'api_access', 'automation_rules', 'team_roles' + ], + 'display_order': 3, + 'is_active': True, + 'is_public': True, + 'created_at': now, + 'updated_at': now, + }, + { + 'code': 'enterprise', + 'name': 'Enterprise', + 'description': 'Custom solutions for large operations and agencies', + 'price_monthly_cents': 39900, + 'price_annual_cents': None, + 'orders_per_month': None, + 'products_limit': None, + 'team_members': None, + 'order_history_months': None, + 'features': [ + 'letzshop_sync', 'inventory_locations', 'inventory_purchase_orders', + 'invoice_lu', 'invoice_eu_vat', 'invoice_bulk', 'customer_view', + 'customer_export', 'analytics_dashboard', 'accounting_export', + 'api_access', 'automation_rules', 'team_roles', 'white_label', + 'multi_vendor', 'custom_integrations', 'sla_guarantee', 'dedicated_support' + ], + 'display_order': 4, + 'is_active': True, + 'is_public': False, + 'created_at': now, + 'updated_at': now, + }, + ]) + + # ========================================================================= + # Seed add-on products + # ========================================================================= + addon_products = sa.table( + 'addon_products', + sa.column('code', sa.String), + sa.column('name', sa.String), + sa.column('description', sa.Text), + sa.column('category', sa.String), + sa.column('price_cents', sa.Integer), + sa.column('billing_period', sa.String), + sa.column('quantity_unit', sa.String), + sa.column('quantity_value', sa.Integer), + sa.column('display_order', sa.Integer), + sa.column('is_active', sa.Boolean), + sa.column('created_at', sa.DateTime), + sa.column('updated_at', sa.DateTime), + ) + + op.bulk_insert(addon_products, [ + { + 'code': 'domain', + 'name': 'Custom Domain', + 'description': 'Connect your own domain with SSL certificate included', + 'category': 'domain', + 'price_cents': 1500, + 'billing_period': 'annual', + 'quantity_unit': None, + 'quantity_value': None, + 'display_order': 1, + 'is_active': True, + 'created_at': now, + 'updated_at': now, + }, + { + 'code': 'email_5', + 'name': '5 Email Addresses', + 'description': 'Professional email addresses on your domain', + 'category': 'email', + 'price_cents': 500, + 'billing_period': 'monthly', + 'quantity_unit': 'emails', + 'quantity_value': 5, + 'display_order': 2, + 'is_active': True, + 'created_at': now, + 'updated_at': now, + }, + { + 'code': 'email_10', + 'name': '10 Email Addresses', + 'description': 'Professional email addresses on your domain', + 'category': 'email', + 'price_cents': 900, + 'billing_period': 'monthly', + 'quantity_unit': 'emails', + 'quantity_value': 10, + 'display_order': 3, + 'is_active': True, + 'created_at': now, + 'updated_at': now, + }, + { + 'code': 'email_25', + 'name': '25 Email Addresses', + 'description': 'Professional email addresses on your domain', + 'category': 'email', + 'price_cents': 1900, + 'billing_period': 'monthly', + 'quantity_unit': 'emails', + 'quantity_value': 25, + 'display_order': 4, + 'is_active': True, + 'created_at': now, + 'updated_at': now, + }, + { + 'code': 'storage_10gb', + 'name': 'Additional Storage (10GB)', + 'description': 'Extra storage for product images and files', + 'category': 'storage', + 'price_cents': 500, + 'billing_period': 'monthly', + 'quantity_unit': 'GB', + 'quantity_value': 10, + 'display_order': 5, + 'is_active': True, + 'created_at': now, + 'updated_at': now, + }, + ]) + + +def downgrade() -> None: + # Remove new columns from vendor_subscriptions + op.drop_column('vendor_subscriptions', 'last_payment_error') + op.drop_column('vendor_subscriptions', 'payment_retry_count') + op.drop_column('vendor_subscriptions', 'scheduled_change_at') + op.drop_column('vendor_subscriptions', 'scheduled_tier_change') + op.drop_column('vendor_subscriptions', 'proration_behavior') + op.drop_column('vendor_subscriptions', 'stripe_payment_method_id') + op.drop_column('vendor_subscriptions', 'stripe_price_id') + + # Drop stripe_webhook_events + op.drop_index(op.f('ix_stripe_webhook_events_vendor_id'), table_name='stripe_webhook_events') + op.drop_index(op.f('ix_stripe_webhook_events_subscription_id'), table_name='stripe_webhook_events') + op.drop_index(op.f('ix_stripe_webhook_events_status'), table_name='stripe_webhook_events') + op.drop_index(op.f('ix_stripe_webhook_events_id'), table_name='stripe_webhook_events') + op.drop_index(op.f('ix_stripe_webhook_events_event_type'), table_name='stripe_webhook_events') + op.drop_index(op.f('ix_stripe_webhook_events_event_id'), table_name='stripe_webhook_events') + op.drop_index('idx_webhook_event_type_status', table_name='stripe_webhook_events') + op.drop_table('stripe_webhook_events') + + # Drop vendor_addons + op.drop_index(op.f('ix_vendor_addons_vendor_id'), table_name='vendor_addons') + op.drop_index(op.f('ix_vendor_addons_status'), table_name='vendor_addons') + op.drop_index(op.f('ix_vendor_addons_id'), table_name='vendor_addons') + op.drop_index(op.f('ix_vendor_addons_domain_name'), table_name='vendor_addons') + op.drop_index(op.f('ix_vendor_addons_addon_product_id'), table_name='vendor_addons') + op.drop_index('idx_vendor_addon_status', table_name='vendor_addons') + op.drop_index('idx_vendor_addon_product', table_name='vendor_addons') + op.drop_table('vendor_addons') + + # Drop billing_history + op.drop_index(op.f('ix_billing_history_vendor_id'), table_name='billing_history') + op.drop_index(op.f('ix_billing_history_stripe_invoice_id'), table_name='billing_history') + op.drop_index(op.f('ix_billing_history_status'), table_name='billing_history') + op.drop_index(op.f('ix_billing_history_id'), table_name='billing_history') + op.drop_index('idx_billing_vendor_date', table_name='billing_history') + op.drop_index('idx_billing_status', table_name='billing_history') + op.drop_table('billing_history') + + # Drop addon_products + op.drop_index(op.f('ix_addon_products_id'), table_name='addon_products') + op.drop_index(op.f('ix_addon_products_code'), table_name='addon_products') + op.drop_index(op.f('ix_addon_products_category'), table_name='addon_products') + op.drop_table('addon_products') + + # Drop subscription_tiers + op.drop_index(op.f('ix_subscription_tiers_id'), table_name='subscription_tiers') + op.drop_index(op.f('ix_subscription_tiers_code'), table_name='subscription_tiers') + op.drop_table('subscription_tiers') diff --git a/app/api/main.py b/app/api/main.py index 865d33b3..a5567333 100644 --- a/app/api/main.py +++ b/app/api/main.py @@ -11,7 +11,7 @@ This module provides: from fastapi import APIRouter from app.api.v1 import admin, shop, vendor -from app.api.v1.shared import language +from app.api.v1.shared import language, webhooks api_router = APIRouter() @@ -42,3 +42,4 @@ api_router.include_router(shop.router, prefix="/v1/shop", tags=["shop"]) # ============================================================================ api_router.include_router(language.router, prefix="/v1", tags=["language"]) +api_router.include_router(webhooks.router, prefix="/v1", tags=["webhooks"]) diff --git a/app/api/v1/admin/products.py b/app/api/v1/admin/products.py index 6f54b14d..8d13fd2e 100644 --- a/app/api/v1/admin/products.py +++ b/app/api/v1/admin/products.py @@ -101,6 +101,8 @@ class CopyToVendorResponse(BaseModel): copied: int skipped: int failed: int + auto_matched: int = 0 + limit_reached: bool = False details: list[dict] | None = None diff --git a/app/api/v1/shared/webhooks.py b/app/api/v1/shared/webhooks.py index acfcff9e..065139db 100644 --- a/app/api/v1/shared/webhooks.py +++ b/app/api/v1/shared/webhooks.py @@ -1 +1,63 @@ -# External webhooks (Stripe, etc. +# app/api/v1/shared/webhooks.py +""" +External webhook endpoints. + +Handles webhooks from: +- Stripe (payments and subscriptions) +""" + +import logging + +from fastapi import APIRouter, Header, Request +from sqlalchemy.orm import Session + +from app.core.database import get_db +from app.exceptions import InvalidWebhookSignatureException, WebhookMissingSignatureException +from app.services.stripe_service import stripe_service +from app.services.stripe_webhook_handler import stripe_webhook_handler + +router = APIRouter(prefix="/webhooks") +logger = logging.getLogger(__name__) + + +@router.post("/stripe") # public - Stripe webhooks use signature verification +async def stripe_webhook( + request: Request, + stripe_signature: str = Header(None, alias="Stripe-Signature"), +): + """ + Handle Stripe webhook events. + + Stripe sends events for: + - Subscription lifecycle (created, updated, deleted) + - Invoice and payment events + - Checkout session completion + + The endpoint verifies the webhook signature and processes events idempotently. + """ + if not stripe_signature: + logger.warning("Stripe webhook received without signature") + raise WebhookMissingSignatureException() + + # Get raw body for signature verification + payload = await request.body() + + try: + # Verify and construct event + event = stripe_service.construct_event(payload, stripe_signature) + except ValueError as e: + logger.warning(f"Invalid Stripe webhook: {e}") + raise InvalidWebhookSignatureException(str(e)) + + # Process the event + db = next(get_db()) + try: + result = stripe_webhook_handler.handle_event(db, event) + return {"received": True, **result} + except Exception as e: + logger.error(f"Error processing Stripe webhook: {e}") + # Return 200 to prevent Stripe retries for processing errors + # The event is marked as failed and can be retried manually + return {"received": True, "error": str(e)} + finally: + db.close() diff --git a/app/api/v1/vendor/__init__.py b/app/api/v1/vendor/__init__.py index 142c84cc..b347b2cf 100644 --- a/app/api/v1/vendor/__init__.py +++ b/app/api/v1/vendor/__init__.py @@ -16,6 +16,7 @@ from fastapi import APIRouter from . import ( analytics, auth, + billing, content_pages, customers, dashboard, @@ -73,6 +74,7 @@ router.include_router(media.router, tags=["vendor-media"]) router.include_router(notifications.router, tags=["vendor-notifications"]) router.include_router(messages.router, tags=["vendor-messages"]) router.include_router(analytics.router, tags=["vendor-analytics"]) +router.include_router(billing.router, tags=["vendor-billing"]) # Content pages management router.include_router( diff --git a/app/api/v1/vendor/billing.py b/app/api/v1/vendor/billing.py new file mode 100644 index 00000000..edc911e5 --- /dev/null +++ b/app/api/v1/vendor/billing.py @@ -0,0 +1,405 @@ +# app/api/v1/vendor/billing.py +""" +Vendor billing and subscription management endpoints. + +Provides: +- Subscription status and usage +- Tier listing and comparison +- Stripe checkout and portal access +- Invoice history +- Add-on management +""" + +import logging + +from fastapi import APIRouter, Depends, Query +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from app.api.deps import get_current_vendor_api +from app.core.config import settings +from app.core.database import get_db +from app.services.billing_service import billing_service +from app.services.subscription_service import subscription_service +from models.database.user import User + +router = APIRouter(prefix="/billing") +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Schemas +# ============================================================================ + + +class SubscriptionStatusResponse(BaseModel): + """Current subscription status and usage.""" + + tier_code: str + tier_name: str + status: str + is_trial: bool + trial_ends_at: str | None = None + period_start: str | None = None + period_end: str | None = None + cancelled_at: str | None = None + cancellation_reason: str | None = None + + # Usage + orders_this_period: int + orders_limit: int | None + orders_remaining: int | None + products_count: int + products_limit: int | None + products_remaining: int | None + team_count: int + team_limit: int | None + team_remaining: int | None + + # Payment + has_payment_method: bool + last_payment_error: str | None = None + + class Config: + from_attributes = True + + +class TierResponse(BaseModel): + """Subscription tier information.""" + + code: str + name: str + description: str | None = None + price_monthly_cents: int + price_annual_cents: int | None = None + orders_per_month: int | None = None + products_limit: int | None = None + team_members: int | None = None + features: list[str] = [] + is_current: bool = False + can_upgrade: bool = False + can_downgrade: bool = False + + +class TierListResponse(BaseModel): + """List of available tiers.""" + + tiers: list[TierResponse] + current_tier: str + + +class CheckoutRequest(BaseModel): + """Request to create a checkout session.""" + + tier_code: str + is_annual: bool = False + + +class CheckoutResponse(BaseModel): + """Checkout session response.""" + + checkout_url: str + session_id: str + + +class PortalResponse(BaseModel): + """Customer portal session response.""" + + portal_url: str + + +class InvoiceResponse(BaseModel): + """Invoice information.""" + + id: int + invoice_number: str | None = None + invoice_date: str + due_date: str | None = None + total_cents: int + amount_paid_cents: int + currency: str + status: str + pdf_url: str | None = None + hosted_url: str | None = None + + +class InvoiceListResponse(BaseModel): + """List of invoices.""" + + invoices: list[InvoiceResponse] + total: int + + +class AddOnResponse(BaseModel): + """Add-on product information.""" + + id: int + code: str + name: str + description: str | None = None + category: str + price_cents: int + billing_period: str + quantity_unit: str | None = None + quantity_value: int | None = None + + +class VendorAddOnResponse(BaseModel): + """Vendor's purchased add-on.""" + + id: int + addon_code: str + addon_name: str + status: str + domain_name: str | None = None + quantity: int + period_start: str | None = None + period_end: str | None = None + + +class AddOnPurchaseRequest(BaseModel): + """Request to purchase an add-on.""" + + addon_code: str + domain_name: str | None = None # For domain add-ons + quantity: int = 1 + + +class CancelRequest(BaseModel): + """Request to cancel subscription.""" + + reason: str | None = None + immediately: bool = False + + +class CancelResponse(BaseModel): + """Cancellation response.""" + + message: str + effective_date: str + + +# ============================================================================ +# Endpoints +# ============================================================================ + + +@router.get("/subscription", response_model=SubscriptionStatusResponse) +def get_subscription_status( + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Get current subscription status and usage metrics.""" + vendor_id = current_user.token_vendor_id + + usage = subscription_service.get_usage_summary(db, vendor_id) + subscription, tier = billing_service.get_subscription_with_tier(db, vendor_id) + + return SubscriptionStatusResponse( + tier_code=subscription.tier, + tier_name=tier.name if tier else subscription.tier.title(), + status=subscription.status.value, + is_trial=subscription.is_in_trial(), + trial_ends_at=subscription.trial_ends_at.isoformat() + if subscription.trial_ends_at + else None, + period_start=subscription.period_start.isoformat() + if subscription.period_start + else None, + period_end=subscription.period_end.isoformat() + if subscription.period_end + else None, + cancelled_at=subscription.cancelled_at.isoformat() + if subscription.cancelled_at + else None, + cancellation_reason=subscription.cancellation_reason, + orders_this_period=usage.orders_this_period, + orders_limit=usage.orders_limit, + orders_remaining=usage.orders_remaining, + products_count=usage.products_count, + products_limit=usage.products_limit, + products_remaining=usage.products_remaining, + team_count=usage.team_count, + team_limit=usage.team_limit, + team_remaining=usage.team_remaining, + has_payment_method=bool(subscription.stripe_payment_method_id), + last_payment_error=subscription.last_payment_error, + ) + + +@router.get("/tiers", response_model=TierListResponse) +def get_available_tiers( + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Get available subscription tiers for upgrade/downgrade.""" + vendor_id = current_user.token_vendor_id + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + current_tier = subscription.tier + + tier_list, _ = billing_service.get_available_tiers(db, current_tier) + + tier_responses = [TierResponse(**tier_data) for tier_data in tier_list] + + return TierListResponse(tiers=tier_responses, current_tier=current_tier) + + +@router.post("/checkout", response_model=CheckoutResponse) +def create_checkout_session( + request: CheckoutRequest, + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Create a Stripe checkout session for subscription.""" + vendor_id = current_user.token_vendor_id + vendor = billing_service.get_vendor(db, vendor_id) + + # Build URLs + base_url = f"https://{settings.platform_domain}" + success_url = f"{base_url}/vendor/{vendor.vendor_code}/billing?success=true" + cancel_url = f"{base_url}/vendor/{vendor.vendor_code}/billing?cancelled=true" + + result = billing_service.create_checkout_session( + db=db, + vendor_id=vendor_id, + tier_code=request.tier_code, + is_annual=request.is_annual, + success_url=success_url, + cancel_url=cancel_url, + ) + db.commit() + + return CheckoutResponse(checkout_url=result["checkout_url"], session_id=result["session_id"]) + + +@router.post("/portal", response_model=PortalResponse) +def create_portal_session( + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Create a Stripe customer portal session.""" + vendor_id = current_user.token_vendor_id + vendor = billing_service.get_vendor(db, vendor_id) + return_url = f"https://{settings.platform_domain}/vendor/{vendor.vendor_code}/billing" + + result = billing_service.create_portal_session(db, vendor_id, return_url) + + return PortalResponse(portal_url=result["portal_url"]) + + +@router.get("/invoices", response_model=InvoiceListResponse) +def get_invoices( + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Get invoice history.""" + vendor_id = current_user.token_vendor_id + + invoices, total = billing_service.get_invoices(db, vendor_id, skip=skip, limit=limit) + + invoice_responses = [ + InvoiceResponse( + id=inv.id, + invoice_number=inv.invoice_number, + invoice_date=inv.invoice_date.isoformat(), + due_date=inv.due_date.isoformat() if inv.due_date else None, + total_cents=inv.total_cents, + amount_paid_cents=inv.amount_paid_cents, + currency=inv.currency, + status=inv.status, + pdf_url=inv.invoice_pdf_url, + hosted_url=inv.hosted_invoice_url, + ) + for inv in invoices + ] + + return InvoiceListResponse(invoices=invoice_responses, total=total) + + +@router.get("/addons", response_model=list[AddOnResponse]) +def get_available_addons( + category: str | None = Query(None, description="Filter by category"), + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Get available add-on products.""" + addons = billing_service.get_available_addons(db, category=category) + + return [ + AddOnResponse( + id=addon.id, + code=addon.code, + name=addon.name, + description=addon.description, + category=addon.category, + price_cents=addon.price_cents, + billing_period=addon.billing_period, + quantity_unit=addon.quantity_unit, + quantity_value=addon.quantity_value, + ) + for addon in addons + ] + + +@router.get("/my-addons", response_model=list[VendorAddOnResponse]) +def get_vendor_addons( + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Get vendor's purchased add-ons.""" + vendor_id = current_user.token_vendor_id + + vendor_addons = billing_service.get_vendor_addons(db, vendor_id) + + return [ + VendorAddOnResponse( + id=va.id, + addon_code=va.addon_product.code, + addon_name=va.addon_product.name, + status=va.status, + domain_name=va.domain_name, + quantity=va.quantity, + period_start=va.period_start.isoformat() if va.period_start else None, + period_end=va.period_end.isoformat() if va.period_end else None, + ) + for va in vendor_addons + ] + + +@router.post("/cancel", response_model=CancelResponse) +def cancel_subscription( + request: CancelRequest, + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Cancel subscription.""" + vendor_id = current_user.token_vendor_id + + result = billing_service.cancel_subscription( + db=db, + vendor_id=vendor_id, + reason=request.reason, + immediately=request.immediately, + ) + db.commit() + + return CancelResponse( + message=result["message"], + effective_date=result["effective_date"], + ) + + +@router.post("/reactivate") +def reactivate_subscription( + current_user: User = Depends(get_current_vendor_api), + db: Session = Depends(get_db), +): + """Reactivate a cancelled subscription.""" + vendor_id = current_user.token_vendor_id + + result = billing_service.reactivate_subscription(db, vendor_id) + db.commit() + + return result diff --git a/app/api/v1/vendor/products.py b/app/api/v1/vendor/products.py index af161380..e9e7c502 100644 --- a/app/api/v1/vendor/products.py +++ b/app/api/v1/vendor/products.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api from app.core.database import get_db from app.services.product_service import product_service +from app.services.subscription_service import subscription_service from models.database.user import User from models.schema.product import ( ProductCreate, @@ -89,6 +90,9 @@ def add_product_to_catalog( This publishes a MarketplaceProduct to the vendor's public catalog. """ + # Check product limit before creating + subscription_service.check_product_limit(db, current_user.token_vendor_id) + product = product_service.create_product( db=db, vendor_id=current_user.token_vendor_id, product_data=product_data ) @@ -157,6 +161,9 @@ def publish_from_marketplace( Shortcut endpoint for publishing directly from marketplace import. """ + # Check product limit before creating + subscription_service.check_product_limit(db, current_user.token_vendor_id) + product_data = ProductCreate( marketplace_product_id=marketplace_product_id, is_active=True ) diff --git a/app/core/config.py b/app/core/config.py index 86221ce6..2c427af2 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -117,6 +117,14 @@ class Settings(BaseSettings): max_team_members_per_vendor: int = 50 invitation_expiry_days: int = 7 + # ============================================================================= + # STRIPE BILLING + # ============================================================================= + stripe_secret_key: str = "" + stripe_publishable_key: str = "" + stripe_webhook_secret: str = "" + stripe_trial_days: int = 14 + # ============================================================================= # DEMO/SEED DATA CONFIGURATION # ============================================================================= diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py index 37ba3e6e..1e8f5015 100644 --- a/app/exceptions/__init__.py +++ b/app/exceptions/__init__.py @@ -46,6 +46,18 @@ from .base import ( WizamartException, ) +# Billing exceptions +from .billing import ( + InvalidWebhookSignatureException, + NoActiveSubscriptionException, + PaymentSystemNotConfiguredException, + StripePriceNotConfiguredException, + SubscriptionAlreadyCancelledException, + SubscriptionNotCancelledException, + TierNotFoundException, + WebhookMissingSignatureException, +) + # Cart exceptions from .cart import ( CartItemNotFoundException, @@ -416,4 +428,13 @@ __all__ = [ "InvalidConversationTypeException", "InvalidRecipientTypeException", "AttachmentNotFoundException", + # Billing exceptions + "PaymentSystemNotConfiguredException", + "TierNotFoundException", + "StripePriceNotConfiguredException", + "NoActiveSubscriptionException", + "SubscriptionNotCancelledException", + "SubscriptionAlreadyCancelledException", + "InvalidWebhookSignatureException", + "WebhookMissingSignatureException", ] diff --git a/app/exceptions/billing.py b/app/exceptions/billing.py new file mode 100644 index 00000000..6a893bc6 --- /dev/null +++ b/app/exceptions/billing.py @@ -0,0 +1,95 @@ +# app/exceptions/billing.py +""" +Billing and subscription related exceptions. + +This module provides exceptions for: +- Payment system configuration issues +- Subscription management errors +- Tier-related errors +""" + +from typing import Any + +from .base import BusinessLogicException, ResourceNotFoundException, ServiceUnavailableException + + +class PaymentSystemNotConfiguredException(ServiceUnavailableException): + """Raised when the payment system (Stripe) is not configured.""" + + def __init__(self): + super().__init__(message="Payment system not configured") + + +class TierNotFoundException(ResourceNotFoundException): + """Raised when a subscription tier is not found.""" + + def __init__(self, tier_code: str): + super().__init__( + resource_type="SubscriptionTier", + identifier=tier_code, + message=f"Subscription tier '{tier_code}' not found", + error_code="TIER_NOT_FOUND", + ) + self.tier_code = tier_code + + +class StripePriceNotConfiguredException(BusinessLogicException): + """Raised when Stripe price is not configured for a tier.""" + + def __init__(self, tier_code: str): + super().__init__( + message=f"Stripe price not configured for tier '{tier_code}'", + error_code="STRIPE_PRICE_NOT_CONFIGURED", + details={"tier_code": tier_code}, + ) + self.tier_code = tier_code + + +class NoActiveSubscriptionException(BusinessLogicException): + """Raised when no active subscription exists for an operation that requires one.""" + + def __init__(self, message: str = "No active subscription found"): + super().__init__( + message=message, + error_code="NO_ACTIVE_SUBSCRIPTION", + ) + + +class SubscriptionNotCancelledException(BusinessLogicException): + """Raised when trying to reactivate a subscription that is not cancelled.""" + + def __init__(self): + super().__init__( + message="Subscription is not cancelled", + error_code="SUBSCRIPTION_NOT_CANCELLED", + ) + + +class SubscriptionAlreadyCancelledException(BusinessLogicException): + """Raised when trying to cancel an already cancelled subscription.""" + + def __init__(self): + super().__init__( + message="Subscription is already cancelled", + error_code="SUBSCRIPTION_ALREADY_CANCELLED", + ) + + +class InvalidWebhookSignatureException(BusinessLogicException): + """Raised when Stripe webhook signature verification fails.""" + + def __init__(self, message: str = "Invalid webhook signature"): + super().__init__( + message=message, + error_code="INVALID_WEBHOOK_SIGNATURE", + ) + + +class WebhookMissingSignatureException(BusinessLogicException): + """Raised when Stripe webhook is missing the signature header.""" + + def __init__(self): + super().__init__( + message="Missing Stripe-Signature header", + error_code="WEBHOOK_MISSING_SIGNATURE", + ) diff --git a/app/routes/vendor_pages.py b/app/routes/vendor_pages.py index a0969c8a..5a8295e6 100644 --- a/app/routes/vendor_pages.py +++ b/app/routes/vendor_pages.py @@ -454,6 +454,28 @@ async def vendor_settings_page( ) +@router.get( + "/{vendor_code}/billing", response_class=HTMLResponse, include_in_schema=False +) +async def vendor_billing_page( + request: Request, + vendor_code: str = Path(..., description="Vendor code"), + current_user: User = Depends(get_current_vendor_from_cookie_or_header), +): + """ + Render billing and subscription management page. + JavaScript loads subscription status, tiers, and invoices via API. + """ + return templates.TemplateResponse( + "vendor/billing.html", + { + "request": request, + "user": current_user, + "vendor_code": vendor_code, + }, + ) + + # ============================================================================ # DYNAMIC CONTENT PAGES (CMS) # ============================================================================ diff --git a/app/services/billing_service.py b/app/services/billing_service.py new file mode 100644 index 00000000..0052f213 --- /dev/null +++ b/app/services/billing_service.py @@ -0,0 +1,370 @@ +# app/services/billing_service.py +""" +Billing service for subscription and payment operations. + +Provides: +- Subscription status and usage queries +- Tier management +- Invoice history +- Add-on management +""" + +import logging +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.services.stripe_service import stripe_service +from app.services.subscription_service import subscription_service +from models.database.subscription import ( + AddOnProduct, + BillingHistory, + SubscriptionTier, + VendorAddOn, + VendorSubscription, +) +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class BillingServiceError(Exception): + """Base exception for billing service errors.""" + + pass + + +class PaymentSystemNotConfiguredError(BillingServiceError): + """Raised when Stripe is not configured.""" + + def __init__(self): + super().__init__("Payment system not configured") + + +class TierNotFoundError(BillingServiceError): + """Raised when a tier is not found.""" + + def __init__(self, tier_code: str): + self.tier_code = tier_code + super().__init__(f"Tier '{tier_code}' not found") + + +class StripePriceNotConfiguredError(BillingServiceError): + """Raised when Stripe price is not configured for a tier.""" + + def __init__(self, tier_code: str): + self.tier_code = tier_code + super().__init__(f"Stripe price not configured for tier '{tier_code}'") + + +class NoActiveSubscriptionError(BillingServiceError): + """Raised when no active subscription exists.""" + + def __init__(self): + super().__init__("No active subscription found") + + +class SubscriptionNotCancelledError(BillingServiceError): + """Raised when trying to reactivate a non-cancelled subscription.""" + + def __init__(self): + super().__init__("Subscription is not cancelled") + + +class BillingService: + """Service for billing operations.""" + + def get_subscription_with_tier( + self, db: Session, vendor_id: int + ) -> tuple[VendorSubscription, SubscriptionTier | None]: + """ + Get subscription and its tier info. + + Returns: + Tuple of (subscription, tier) where tier may be None + """ + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + + tier = ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == subscription.tier) + .first() + ) + + return subscription, tier + + def get_available_tiers( + self, db: Session, current_tier: str + ) -> tuple[list[dict], dict[str, int]]: + """ + Get all available tiers with upgrade/downgrade flags. + + Returns: + Tuple of (tier_list, tier_order_map) + """ + tiers = ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.is_active == True, # noqa: E712 + SubscriptionTier.is_public == True, # noqa: E712 + ) + .order_by(SubscriptionTier.display_order) + .all() + ) + + tier_order = {t.code: t.display_order for t in tiers} + current_order = tier_order.get(current_tier, 0) + + tier_list = [] + for tier in tiers: + tier_list.append({ + "code": tier.code, + "name": tier.name, + "description": tier.description, + "price_monthly_cents": tier.price_monthly_cents, + "price_annual_cents": tier.price_annual_cents, + "orders_per_month": tier.orders_per_month, + "products_limit": tier.products_limit, + "team_members": tier.team_members, + "features": tier.features or [], + "is_current": tier.code == current_tier, + "can_upgrade": tier.display_order > current_order, + "can_downgrade": tier.display_order < current_order, + }) + + return tier_list, tier_order + + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: + """ + Get a tier by its code. + + Raises: + TierNotFoundError: If tier doesn't exist + """ + tier = ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.code == tier_code, + SubscriptionTier.is_active == True, # noqa: E712 + ) + .first() + ) + + if not tier: + raise TierNotFoundError(tier_code) + + return tier + + def get_vendor(self, db: Session, vendor_id: int) -> Vendor: + """ + Get vendor by ID. + + Raises: + VendorNotFoundException from app.exceptions + """ + from app.exceptions import VendorNotFoundException + + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + if not vendor: + raise VendorNotFoundException(str(vendor_id), identifier_type="id") + + return vendor + + def create_checkout_session( + self, + db: Session, + vendor_id: int, + tier_code: str, + is_annual: bool, + success_url: str, + cancel_url: str, + ) -> dict: + """ + Create a Stripe checkout session. + + Returns: + Dict with checkout_url and session_id + + Raises: + PaymentSystemNotConfiguredError: If Stripe not configured + TierNotFoundError: If tier doesn't exist + StripePriceNotConfiguredError: If price not configured + """ + if not stripe_service.is_configured: + raise PaymentSystemNotConfiguredError() + + vendor = self.get_vendor(db, vendor_id) + tier = self.get_tier_by_code(db, tier_code) + + price_id = ( + tier.stripe_price_annual_id + if is_annual and tier.stripe_price_annual_id + else tier.stripe_price_monthly_id + ) + + if not price_id: + raise StripePriceNotConfiguredError(tier_code) + + # Check if this is a new subscription (for trial) + existing_sub = subscription_service.get_subscription(db, vendor_id) + trial_days = None + if not existing_sub or not existing_sub.stripe_subscription_id: + from app.core.config import settings + trial_days = settings.stripe_trial_days + + session = stripe_service.create_checkout_session( + db=db, + vendor=vendor, + price_id=price_id, + success_url=success_url, + cancel_url=cancel_url, + trial_days=trial_days, + ) + + # Update subscription with tier info + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + subscription.tier = tier_code + subscription.is_annual = is_annual + + return { + "checkout_url": session.url, + "session_id": session.id, + } + + def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict: + """ + Create a Stripe customer portal session. + + Returns: + Dict with portal_url + + Raises: + PaymentSystemNotConfiguredError: If Stripe not configured + NoActiveSubscriptionError: If no subscription with customer ID + """ + if not stripe_service.is_configured: + raise PaymentSystemNotConfiguredError() + + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_customer_id: + raise NoActiveSubscriptionError() + + session = stripe_service.create_portal_session( + customer_id=subscription.stripe_customer_id, + return_url=return_url, + ) + + return {"portal_url": session.url} + + def get_invoices( + self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20 + ) -> tuple[list[BillingHistory], int]: + """ + Get invoice history for a vendor. + + Returns: + Tuple of (invoices, total_count) + """ + query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id) + + total = query.count() + + invoices = ( + query.order_by(BillingHistory.invoice_date.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + return invoices, total + + def get_available_addons( + self, db: Session, category: str | None = None + ) -> list[AddOnProduct]: + """Get available add-on products.""" + query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712 + + if category: + query = query.filter(AddOnProduct.category == category) + + return query.order_by(AddOnProduct.display_order).all() + + def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]: + """Get vendor's purchased add-ons.""" + return ( + db.query(VendorAddOn) + .filter(VendorAddOn.vendor_id == vendor_id) + .all() + ) + + def cancel_subscription( + self, db: Session, vendor_id: int, reason: str | None, immediately: bool + ) -> dict: + """ + Cancel a subscription. + + Returns: + Dict with message and effective_date + + Raises: + NoActiveSubscriptionError: If no subscription to cancel + """ + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_subscription_id: + raise NoActiveSubscriptionError() + + if stripe_service.is_configured: + stripe_service.cancel_subscription( + subscription_id=subscription.stripe_subscription_id, + immediately=immediately, + cancellation_reason=reason, + ) + + subscription.cancelled_at = datetime.utcnow() + subscription.cancellation_reason = reason + + effective_date = ( + datetime.utcnow().isoformat() + if immediately + else subscription.period_end.isoformat() + if subscription.period_end + else datetime.utcnow().isoformat() + ) + + return { + "message": "Subscription cancelled successfully", + "effective_date": effective_date, + } + + def reactivate_subscription(self, db: Session, vendor_id: int) -> dict: + """ + Reactivate a cancelled subscription. + + Returns: + Dict with success message + + Raises: + NoActiveSubscriptionError: If no subscription + SubscriptionNotCancelledError: If not cancelled + """ + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_subscription_id: + raise NoActiveSubscriptionError() + + if not subscription.cancelled_at: + raise SubscriptionNotCancelledError() + + if stripe_service.is_configured: + stripe_service.reactivate_subscription(subscription.stripe_subscription_id) + + subscription.cancelled_at = None + subscription.cancellation_reason = None + + return {"message": "Subscription reactivated successfully"} + + +# Create service instance +billing_service = BillingService() diff --git a/app/services/marketplace_product_service.py b/app/services/marketplace_product_service.py index 9a3f3dbd..df314839 100644 --- a/app/services/marketplace_product_service.py +++ b/app/services/marketplace_product_service.py @@ -865,12 +865,42 @@ class MarketplaceProductService: if not marketplace_products: raise MarketplaceProductNotFoundException("No marketplace products found") + # Check product limit from subscription + from app.services.subscription_service import subscription_service + from sqlalchemy import func + + current_products = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + products_limit = subscription.products_limit + remaining_capacity = ( + products_limit - current_products if products_limit is not None else None + ) + copied = 0 skipped = 0 failed = 0 + limit_reached = False details = [] for mp in marketplace_products: + # Check if we've hit the product limit + if remaining_capacity is not None and copied >= remaining_capacity: + limit_reached = True + details.append( + { + "id": mp.id, + "status": "skipped", + "reason": "Product limit reached", + } + ) + skipped += 1 + continue try: existing = ( db.query(Product) @@ -994,6 +1024,7 @@ class MarketplaceProductService: "skipped": skipped, "failed": failed, "auto_matched": auto_matched, + "limit_reached": limit_reached, "details": details if len(details) <= 100 else None, } diff --git a/app/services/stripe_service.py b/app/services/stripe_service.py new file mode 100644 index 00000000..f7b2a97b --- /dev/null +++ b/app/services/stripe_service.py @@ -0,0 +1,459 @@ +# app/services/stripe_service.py +""" +Stripe payment integration service. + +Provides: +- Customer management +- Subscription management +- Checkout session creation +- Customer portal access +- Webhook event construction +""" + +import logging +from datetime import datetime + +import stripe +from sqlalchemy.orm import Session + +from app.core.config import settings +from models.database.subscription import ( + BillingHistory, + SubscriptionStatus, + SubscriptionTier, + VendorSubscription, +) +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class StripeService: + """Service for Stripe payment operations.""" + + def __init__(self): + self._configured = False + self._configure() + + def _configure(self): + """Configure Stripe with API key.""" + if settings.stripe_secret_key: + stripe.api_key = settings.stripe_secret_key + self._configured = True + else: + logger.warning("Stripe API key not configured") + + @property + def is_configured(self) -> bool: + """Check if Stripe is properly configured.""" + return self._configured and bool(settings.stripe_secret_key) + + # ========================================================================= + # Customer Management + # ========================================================================= + + def create_customer( + self, + vendor: Vendor, + email: str, + name: str | None = None, + metadata: dict | None = None, + ) -> str: + """ + Create a Stripe customer for a vendor. + + Returns the Stripe customer ID. + """ + if not self.is_configured: + raise ValueError("Stripe is not configured") + + customer_metadata = { + "vendor_id": str(vendor.id), + "vendor_code": vendor.vendor_code, + **(metadata or {}), + } + + customer = stripe.Customer.create( + email=email, + name=name or vendor.name, + metadata=customer_metadata, + ) + + logger.info( + f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}" + ) + return customer.id + + def get_customer(self, customer_id: str) -> stripe.Customer: + """Get a Stripe customer by ID.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + return stripe.Customer.retrieve(customer_id) + + def update_customer( + self, + customer_id: str, + email: str | None = None, + name: str | None = None, + metadata: dict | None = None, + ) -> stripe.Customer: + """Update a Stripe customer.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + update_data = {} + if email: + update_data["email"] = email + if name: + update_data["name"] = name + if metadata: + update_data["metadata"] = metadata + + return stripe.Customer.modify(customer_id, **update_data) + + # ========================================================================= + # Subscription Management + # ========================================================================= + + def create_subscription( + self, + customer_id: str, + price_id: str, + trial_days: int | None = None, + metadata: dict | None = None, + ) -> stripe.Subscription: + """ + Create a new Stripe subscription. + + Args: + customer_id: Stripe customer ID + price_id: Stripe price ID for the subscription + trial_days: Optional trial period in days + metadata: Optional metadata to attach + + Returns: + Stripe Subscription object + """ + if not self.is_configured: + raise ValueError("Stripe is not configured") + + subscription_data = { + "customer": customer_id, + "items": [{"price": price_id}], + "metadata": metadata or {}, + "payment_behavior": "default_incomplete", + "expand": ["latest_invoice.payment_intent"], + } + + if trial_days: + subscription_data["trial_period_days"] = trial_days + + subscription = stripe.Subscription.create(**subscription_data) + logger.info( + f"Created Stripe subscription {subscription.id} for customer {customer_id}" + ) + return subscription + + def get_subscription(self, subscription_id: str) -> stripe.Subscription: + """Get a Stripe subscription by ID.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + return stripe.Subscription.retrieve(subscription_id) + + def update_subscription( + self, + subscription_id: str, + new_price_id: str | None = None, + proration_behavior: str = "create_prorations", + metadata: dict | None = None, + ) -> stripe.Subscription: + """ + Update a Stripe subscription (e.g., change tier). + + Args: + subscription_id: Stripe subscription ID + new_price_id: New price ID for tier change + proration_behavior: How to handle prorations + metadata: Optional metadata to update + + Returns: + Updated Stripe Subscription object + """ + if not self.is_configured: + raise ValueError("Stripe is not configured") + + update_data = {"proration_behavior": proration_behavior} + + if new_price_id: + # Get the subscription to find the item ID + subscription = stripe.Subscription.retrieve(subscription_id) + item_id = subscription["items"]["data"][0]["id"] + update_data["items"] = [{"id": item_id, "price": new_price_id}] + + if metadata: + update_data["metadata"] = metadata + + updated = stripe.Subscription.modify(subscription_id, **update_data) + logger.info(f"Updated Stripe subscription {subscription_id}") + return updated + + def cancel_subscription( + self, + subscription_id: str, + immediately: bool = False, + cancellation_reason: str | None = None, + ) -> stripe.Subscription: + """ + Cancel a Stripe subscription. + + Args: + subscription_id: Stripe subscription ID + immediately: If True, cancel now. If False, cancel at period end. + cancellation_reason: Optional reason for cancellation + + Returns: + Cancelled Stripe Subscription object + """ + if not self.is_configured: + raise ValueError("Stripe is not configured") + + if immediately: + subscription = stripe.Subscription.cancel(subscription_id) + else: + subscription = stripe.Subscription.modify( + subscription_id, + cancel_at_period_end=True, + metadata={"cancellation_reason": cancellation_reason or "user_request"}, + ) + + logger.info( + f"Cancelled Stripe subscription {subscription_id} " + f"(immediately={immediately})" + ) + return subscription + + def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription: + """ + Reactivate a cancelled subscription (if not yet ended). + + Returns: + Reactivated Stripe Subscription object + """ + if not self.is_configured: + raise ValueError("Stripe is not configured") + + subscription = stripe.Subscription.modify( + subscription_id, + cancel_at_period_end=False, + ) + logger.info(f"Reactivated Stripe subscription {subscription_id}") + return subscription + + # ========================================================================= + # Checkout & Portal + # ========================================================================= + + def create_checkout_session( + self, + db: Session, + vendor: Vendor, + price_id: str, + success_url: str, + cancel_url: str, + trial_days: int | None = None, + ) -> stripe.checkout.Session: + """ + Create a Stripe Checkout session for subscription signup. + + Args: + db: Database session + vendor: Vendor to create checkout for + price_id: Stripe price ID + success_url: URL to redirect on success + cancel_url: URL to redirect on cancel + trial_days: Optional trial period + + Returns: + Stripe Checkout Session object + """ + if not self.is_configured: + raise ValueError("Stripe is not configured") + + # Get or create Stripe customer + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.vendor_id == vendor.id) + .first() + ) + + if subscription and subscription.stripe_customer_id: + customer_id = subscription.stripe_customer_id + else: + # Get vendor owner email + from models.database.vendor import VendorUser + + owner = ( + db.query(VendorUser) + .filter( + VendorUser.vendor_id == vendor.id, + VendorUser.is_owner == True, + ) + .first() + ) + email = owner.user.email if owner and owner.user else None + + customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com") + + # Store the customer ID + if subscription: + subscription.stripe_customer_id = customer_id + db.flush() + + session_data = { + "customer": customer_id, + "line_items": [{"price": price_id, "quantity": 1}], + "mode": "subscription", + "success_url": success_url, + "cancel_url": cancel_url, + "metadata": { + "vendor_id": str(vendor.id), + "vendor_code": vendor.vendor_code, + }, + } + + if trial_days: + session_data["subscription_data"] = {"trial_period_days": trial_days} + + session = stripe.checkout.Session.create(**session_data) + logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}") + return session + + def create_portal_session( + self, + customer_id: str, + return_url: str, + ) -> stripe.billing_portal.Session: + """ + Create a Stripe Customer Portal session. + + Allows customers to manage their subscription, payment methods, and invoices. + + Args: + customer_id: Stripe customer ID + return_url: URL to return to after portal + + Returns: + Stripe Portal Session object + """ + if not self.is_configured: + raise ValueError("Stripe is not configured") + + session = stripe.billing_portal.Session.create( + customer=customer_id, + return_url=return_url, + ) + logger.info(f"Created portal session for customer {customer_id}") + return session + + # ========================================================================= + # Invoice Management + # ========================================================================= + + def get_invoices( + self, + customer_id: str, + limit: int = 10, + ) -> list[stripe.Invoice]: + """Get invoices for a customer.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + invoices = stripe.Invoice.list(customer=customer_id, limit=limit) + return list(invoices.data) + + def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None: + """Get the upcoming invoice for a customer.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + try: + return stripe.Invoice.upcoming(customer=customer_id) + except stripe.error.InvalidRequestError: + # No upcoming invoice + return None + + # ========================================================================= + # Webhook Handling + # ========================================================================= + + def construct_event( + self, + payload: bytes, + sig_header: str, + ) -> stripe.Event: + """ + Construct and verify a Stripe webhook event. + + Args: + payload: Raw request body + sig_header: Stripe-Signature header + + Returns: + Verified Stripe Event object + + Raises: + ValueError: If signature verification fails + """ + if not settings.stripe_webhook_secret: + raise ValueError("Stripe webhook secret not configured") + + try: + event = stripe.Webhook.construct_event( + payload, + sig_header, + settings.stripe_webhook_secret, + ) + return event + except stripe.error.SignatureVerificationError as e: + logger.error(f"Webhook signature verification failed: {e}") + raise ValueError("Invalid webhook signature") + + # ========================================================================= + # Price/Product Management + # ========================================================================= + + def get_price(self, price_id: str) -> stripe.Price: + """Get a Stripe price by ID.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + return stripe.Price.retrieve(price_id) + + def get_product(self, product_id: str) -> stripe.Product: + """Get a Stripe product by ID.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + return stripe.Product.retrieve(product_id) + + def list_prices( + self, + product_id: str | None = None, + active: bool = True, + ) -> list[stripe.Price]: + """List Stripe prices, optionally filtered by product.""" + if not self.is_configured: + raise ValueError("Stripe is not configured") + + params = {"active": active} + if product_id: + params["product"] = product_id + + prices = stripe.Price.list(**params) + return list(prices.data) + + +# Create service instance +stripe_service = StripeService() diff --git a/app/services/stripe_webhook_handler.py b/app/services/stripe_webhook_handler.py new file mode 100644 index 00000000..084df661 --- /dev/null +++ b/app/services/stripe_webhook_handler.py @@ -0,0 +1,411 @@ +# app/services/stripe_webhook_handler.py +""" +Stripe webhook event handler. + +Processes webhook events from Stripe: +- Subscription lifecycle events +- Invoice and payment events +- Checkout session completion +""" + +import logging +from datetime import datetime, timezone + +import stripe +from sqlalchemy.orm import Session + +from models.database.subscription import ( + BillingHistory, + StripeWebhookEvent, + SubscriptionStatus, + SubscriptionTier, + VendorSubscription, +) + +logger = logging.getLogger(__name__) + + +class StripeWebhookHandler: + """Handler for Stripe webhook events.""" + + def __init__(self): + self.handlers = { + "checkout.session.completed": self._handle_checkout_completed, + "customer.subscription.created": self._handle_subscription_created, + "customer.subscription.updated": self._handle_subscription_updated, + "customer.subscription.deleted": self._handle_subscription_deleted, + "invoice.paid": self._handle_invoice_paid, + "invoice.payment_failed": self._handle_payment_failed, + "invoice.finalized": self._handle_invoice_finalized, + } + + def handle_event(self, db: Session, event: stripe.Event) -> dict: + """ + Process a Stripe webhook event. + + Args: + db: Database session + event: Stripe Event object + + Returns: + Dict with processing result + """ + event_id = event.id + event_type = event.type + + # Check for duplicate processing (idempotency) + existing = ( + db.query(StripeWebhookEvent) + .filter(StripeWebhookEvent.event_id == event_id) + .first() + ) + + if existing: + if existing.status == "processed": + logger.info(f"Skipping duplicate event {event_id}") + return {"status": "skipped", "reason": "duplicate"} + elif existing.status == "failed": + logger.info(f"Retrying previously failed event {event_id}") + else: + # Record the event + webhook_event = StripeWebhookEvent( + event_id=event_id, + event_type=event_type, + status="pending", + ) + db.add(webhook_event) + db.flush() + existing = webhook_event + + # Process the event + handler = self.handlers.get(event_type) + if not handler: + logger.debug(f"No handler for event type {event_type}") + existing.status = "processed" + existing.processed_at = datetime.now(timezone.utc) + db.commit() # noqa: SVC-006 - Webhook handler controls its own transaction + return {"status": "ignored", "reason": f"no handler for {event_type}"} + + try: + result = handler(db, event) + existing.status = "processed" + existing.processed_at = datetime.now(timezone.utc) + db.commit() # noqa: SVC-006 - Webhook handler controls its own transaction + logger.info(f"Successfully processed event {event_id} ({event_type})") + return {"status": "processed", "result": result} + + except Exception as e: + logger.error(f"Error processing event {event_id}: {e}") + existing.status = "failed" + existing.error_message = str(e) + db.commit() # noqa: SVC-006 - Webhook handler controls its own transaction + raise + + # ========================================================================= + # Event Handlers + # ========================================================================= + + def _handle_checkout_completed( + self, db: Session, event: stripe.Event + ) -> dict: + """Handle checkout.session.completed event.""" + session = event.data.object + vendor_id = session.metadata.get("vendor_id") + + if not vendor_id: + logger.warning(f"Checkout session {session.id} missing vendor_id") + return {"action": "skipped", "reason": "no vendor_id"} + + vendor_id = int(vendor_id) + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.vendor_id == vendor_id) + .first() + ) + + if not subscription: + logger.warning(f"No subscription found for vendor {vendor_id}") + return {"action": "skipped", "reason": "no subscription"} + + # Update subscription with Stripe IDs + subscription.stripe_customer_id = session.customer + subscription.stripe_subscription_id = session.subscription + subscription.status = SubscriptionStatus.ACTIVE + + # Get subscription details to set period dates + if session.subscription: + stripe_sub = stripe.Subscription.retrieve(session.subscription) + subscription.period_start = datetime.fromtimestamp( + stripe_sub.current_period_start, tz=timezone.utc + ) + subscription.period_end = datetime.fromtimestamp( + stripe_sub.current_period_end, tz=timezone.utc + ) + + if stripe_sub.trial_end: + subscription.trial_ends_at = datetime.fromtimestamp( + stripe_sub.trial_end, tz=timezone.utc + ) + + logger.info(f"Checkout completed for vendor {vendor_id}") + return {"action": "activated", "vendor_id": vendor_id} + + def _handle_subscription_created( + self, db: Session, event: stripe.Event + ) -> dict: + """Handle customer.subscription.created event.""" + stripe_sub = event.data.object + customer_id = stripe_sub.customer + + # Find subscription by customer ID + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.stripe_customer_id == customer_id) + .first() + ) + + if not subscription: + logger.warning(f"No subscription found for customer {customer_id}") + return {"action": "skipped", "reason": "no subscription"} + + # Update subscription + subscription.stripe_subscription_id = stripe_sub.id + subscription.status = self._map_stripe_status(stripe_sub.status) + subscription.period_start = datetime.fromtimestamp( + stripe_sub.current_period_start, tz=timezone.utc + ) + subscription.period_end = datetime.fromtimestamp( + stripe_sub.current_period_end, tz=timezone.utc + ) + + logger.info(f"Subscription created for vendor {subscription.vendor_id}") + return {"action": "created", "vendor_id": subscription.vendor_id} + + def _handle_subscription_updated( + self, db: Session, event: stripe.Event + ) -> dict: + """Handle customer.subscription.updated event.""" + stripe_sub = event.data.object + + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.stripe_subscription_id == stripe_sub.id) + .first() + ) + + if not subscription: + logger.warning(f"No subscription found for {stripe_sub.id}") + return {"action": "skipped", "reason": "no subscription"} + + # Update status and period + subscription.status = self._map_stripe_status(stripe_sub.status) + subscription.period_start = datetime.fromtimestamp( + stripe_sub.current_period_start, tz=timezone.utc + ) + subscription.period_end = datetime.fromtimestamp( + stripe_sub.current_period_end, tz=timezone.utc + ) + + # Handle cancellation + if stripe_sub.cancel_at_period_end: + subscription.cancelled_at = datetime.now(timezone.utc) + subscription.cancellation_reason = stripe_sub.metadata.get( + "cancellation_reason", "user_request" + ) + elif subscription.cancelled_at and not stripe_sub.cancel_at_period_end: + # Subscription reactivated + subscription.cancelled_at = None + subscription.cancellation_reason = None + + # Check for tier change via price + if stripe_sub.items.data: + new_price_id = stripe_sub.items.data[0].price.id + if subscription.stripe_price_id != new_price_id: + # Price changed, look up new tier + tier = ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.stripe_price_monthly_id == new_price_id) + .first() + ) + if tier: + subscription.tier = tier.code + logger.info( + f"Tier changed to {tier.code} for vendor {subscription.vendor_id}" + ) + subscription.stripe_price_id = new_price_id + + logger.info(f"Subscription updated for vendor {subscription.vendor_id}") + return {"action": "updated", "vendor_id": subscription.vendor_id} + + def _handle_subscription_deleted( + self, db: Session, event: stripe.Event + ) -> dict: + """Handle customer.subscription.deleted event.""" + stripe_sub = event.data.object + + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.stripe_subscription_id == stripe_sub.id) + .first() + ) + + if not subscription: + logger.warning(f"No subscription found for {stripe_sub.id}") + return {"action": "skipped", "reason": "no subscription"} + + subscription.status = SubscriptionStatus.CANCELLED + subscription.cancelled_at = datetime.now(timezone.utc) + + logger.info(f"Subscription deleted for vendor {subscription.vendor_id}") + return {"action": "cancelled", "vendor_id": subscription.vendor_id} + + def _handle_invoice_paid(self, db: Session, event: stripe.Event) -> dict: + """Handle invoice.paid event.""" + invoice = event.data.object + customer_id = invoice.customer + + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.stripe_customer_id == customer_id) + .first() + ) + + if not subscription: + logger.warning(f"No subscription found for customer {customer_id}") + return {"action": "skipped", "reason": "no subscription"} + + # Record billing history + billing_record = BillingHistory( + vendor_id=subscription.vendor_id, + stripe_invoice_id=invoice.id, + stripe_payment_intent_id=invoice.payment_intent, + invoice_number=invoice.number, + invoice_date=datetime.fromtimestamp(invoice.created, tz=timezone.utc), + subtotal_cents=invoice.subtotal, + tax_cents=invoice.tax or 0, + total_cents=invoice.total, + amount_paid_cents=invoice.amount_paid, + currency=invoice.currency.upper(), + status="paid", + invoice_pdf_url=invoice.invoice_pdf, + hosted_invoice_url=invoice.hosted_invoice_url, + ) + db.add(billing_record) + + # Reset payment retry count on successful payment + subscription.payment_retry_count = 0 + subscription.last_payment_error = None + + # Reset period counters if this is a new billing cycle + if subscription.status == SubscriptionStatus.ACTIVE: + subscription.orders_this_period = 0 + subscription.orders_limit_reached_at = None + + logger.info(f"Invoice paid for vendor {subscription.vendor_id}") + return { + "action": "recorded", + "vendor_id": subscription.vendor_id, + "invoice_id": invoice.id, + } + + def _handle_payment_failed(self, db: Session, event: stripe.Event) -> dict: + """Handle invoice.payment_failed event.""" + invoice = event.data.object + customer_id = invoice.customer + + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.stripe_customer_id == customer_id) + .first() + ) + + if not subscription: + logger.warning(f"No subscription found for customer {customer_id}") + return {"action": "skipped", "reason": "no subscription"} + + # Update subscription status + subscription.status = SubscriptionStatus.PAST_DUE + subscription.payment_retry_count = (subscription.payment_retry_count or 0) + 1 + + # Store error message + if invoice.last_payment_error: + subscription.last_payment_error = invoice.last_payment_error.get("message") + + logger.warning( + f"Payment failed for vendor {subscription.vendor_id} " + f"(retry #{subscription.payment_retry_count})" + ) + return { + "action": "marked_past_due", + "vendor_id": subscription.vendor_id, + "retry_count": subscription.payment_retry_count, + } + + def _handle_invoice_finalized( + self, db: Session, event: stripe.Event + ) -> dict: + """Handle invoice.finalized event.""" + invoice = event.data.object + customer_id = invoice.customer + + subscription = ( + db.query(VendorSubscription) + .filter(VendorSubscription.stripe_customer_id == customer_id) + .first() + ) + + if not subscription: + return {"action": "skipped", "reason": "no subscription"} + + # Check if we already have this invoice + existing = ( + db.query(BillingHistory) + .filter(BillingHistory.stripe_invoice_id == invoice.id) + .first() + ) + + if existing: + return {"action": "skipped", "reason": "already recorded"} + + # Record as pending invoice + billing_record = BillingHistory( + vendor_id=subscription.vendor_id, + stripe_invoice_id=invoice.id, + invoice_number=invoice.number, + invoice_date=datetime.fromtimestamp(invoice.created, tz=timezone.utc), + due_date=datetime.fromtimestamp(invoice.due_date, tz=timezone.utc) + if invoice.due_date + else None, + subtotal_cents=invoice.subtotal, + tax_cents=invoice.tax or 0, + total_cents=invoice.total, + amount_paid_cents=0, + currency=invoice.currency.upper(), + status="open", + invoice_pdf_url=invoice.invoice_pdf, + hosted_invoice_url=invoice.hosted_invoice_url, + ) + db.add(billing_record) + + return {"action": "recorded_pending", "vendor_id": subscription.vendor_id} + + # ========================================================================= + # Helpers + # ========================================================================= + + def _map_stripe_status(self, stripe_status: str) -> SubscriptionStatus: + """Map Stripe subscription status to internal status.""" + status_map = { + "active": SubscriptionStatus.ACTIVE, + "trialing": SubscriptionStatus.TRIAL, + "past_due": SubscriptionStatus.PAST_DUE, + "canceled": SubscriptionStatus.CANCELLED, + "unpaid": SubscriptionStatus.PAST_DUE, + "incomplete": SubscriptionStatus.TRIAL, # Treat as trial until complete + "incomplete_expired": SubscriptionStatus.EXPIRED, + } + return status_map.get(stripe_status, SubscriptionStatus.EXPIRED) + + +# Create handler instance +stripe_webhook_handler = StripeWebhookHandler() diff --git a/app/services/vendor_team_service.py b/app/services/vendor_team_service.py index 82bd311f..2a70d51b 100644 --- a/app/services/vendor_team_service.py +++ b/app/services/vendor_team_service.py @@ -20,11 +20,11 @@ from app.core.permissions import get_preset_permissions from app.exceptions import ( CannotRemoveOwnerException, InvalidInvitationTokenException, - MaxTeamMembersReachedException, TeamInvitationAlreadyAcceptedException, TeamMemberAlreadyExistsException, UserNotFoundException, ) +from app.services.subscription_service import TierLimitExceededException from middleware.auth import AuthManager from models.database.user import User from models.database.vendor import Role, Vendor, VendorUser, VendorUserType @@ -37,7 +37,6 @@ class VendorTeamService: def __init__(self): self.auth_manager = AuthManager() - self.max_team_members = 50 # Configure as needed def invite_team_member( self, @@ -68,21 +67,10 @@ class VendorTeamService: Dict with invitation details """ try: - # Check team size limit - current_team_size = ( - db.query(VendorUser) - .filter( - VendorUser.vendor_id == vendor.id, - VendorUser.is_active == True, - ) - .count() - ) + # Check team size limit from subscription + from app.services.subscription_service import subscription_service - if current_team_size >= self.max_team_members: - raise MaxTeamMembersReachedException( - self.max_team_members, - vendor.vendor_code, - ) + subscription_service.check_team_limit(db, vendor.id) # Check if user already exists user = db.query(User).filter(User.email == email).first() @@ -187,7 +175,7 @@ class VendorTeamService: "existing_user": user.is_active, } - except (TeamMemberAlreadyExistsException, MaxTeamMembersReachedException): + except (TeamMemberAlreadyExistsException, TierLimitExceededException): raise except Exception as e: logger.error(f"Error inviting team member: {str(e)}") diff --git a/app/templates/vendor/billing.html b/app/templates/vendor/billing.html new file mode 100644 index 00000000..31e2736e --- /dev/null +++ b/app/templates/vendor/billing.html @@ -0,0 +1,397 @@ +{# app/templates/vendor/billing.html #} +{% extends "vendor/base.html" %} + +{% block title %}Billing & Subscription{% endblock %} + +{% block alpine_data %}billingData(){% endblock %} + +{% block content %} +
+

+ Billing & Subscription +

+
+ + + + + + + + + + + + +
+
+
+

Choose Your Plan

+ +
+
+
+ +
+
+
+
+ + +
+
+
+

Add-ons

+ +
+
+ +
+ +
+
+
+
+ + +
+
+
+

Cancel Subscription

+ +
+
+

+ Are you sure you want to cancel your subscription? You'll continue to have access until the end of your current billing period. +

+
+ + +
+
+ + +
+
+
+
+ +{% endblock %} + +{% block scripts %} + +{% endblock %} diff --git a/app/templates/vendor/partials/sidebar.html b/app/templates/vendor/partials/sidebar.html index dd3eb51f..160f0a0e 100644 --- a/app/templates/vendor/partials/sidebar.html +++ b/app/templates/vendor/partials/sidebar.html @@ -184,6 +184,17 @@ Follows same pattern as admin sidebar Settings +
  • + + + + Billing + +
  • diff --git a/docs/features/subscription-billing.md b/docs/features/subscription-billing.md new file mode 100644 index 00000000..0e9ebb23 --- /dev/null +++ b/docs/features/subscription-billing.md @@ -0,0 +1,270 @@ +# Subscription & Billing System + +The platform provides a comprehensive subscription and billing system for managing vendor subscriptions, usage limits, and payments through Stripe. + +## Overview + +The billing system enables: + +- **Subscription Tiers**: Database-driven tier definitions with configurable limits +- **Usage Tracking**: Orders, products, and team member limits per tier +- **Stripe Integration**: Checkout sessions, customer portal, and webhook handling +- **Self-Service Billing**: Vendor-facing billing page for subscription management +- **Add-ons**: Optional purchasable items (domains, SSL, email packages) + +## Architecture + +### Database Models + +All subscription models are defined in `models/database/subscription.py`: + +| Model | Purpose | +|-------|---------| +| `SubscriptionTier` | Tier definitions with limits and Stripe price IDs | +| `VendorSubscription` | Per-vendor subscription status and usage | +| `AddOnProduct` | Purchasable add-ons (domains, SSL, email) | +| `VendorAddOn` | Add-ons purchased by each vendor | +| `StripeWebhookEvent` | Idempotency tracking for webhooks | +| `BillingHistory` | Invoice and payment history | + +### Services + +| Service | Location | Purpose | +|---------|----------|---------| +| `BillingService` | `app/services/billing_service.py` | Subscription operations, checkout, portal | +| `SubscriptionService` | `app/services/subscription_service.py` | Limit checks, usage tracking | +| `StripeService` | `app/services/stripe_service.py` | Core Stripe API operations | +| `StripeWebhookHandler` | `app/services/stripe_webhook_handler.py` | Webhook event processing | + +### API Endpoints + +All billing endpoints are under `/api/v1/vendor/billing`: + +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `/billing/subscription` | GET | Current subscription status & usage | +| `/billing/tiers` | GET | Available tiers for upgrade | +| `/billing/checkout` | POST | Create Stripe checkout session | +| `/billing/portal` | POST | Create Stripe customer portal session | +| `/billing/invoices` | GET | Invoice history | +| `/billing/addons` | GET | Available add-on products | +| `/billing/my-addons` | GET | Vendor's purchased add-ons | +| `/billing/cancel` | POST | Cancel subscription | +| `/billing/reactivate` | POST | Reactivate cancelled subscription | + +## Subscription Tiers + +### Default Tiers + +| Tier | Price | Products | Orders/mo | Team | +|------|-------|----------|-----------|------| +| Essential | €49/mo | 200 | 100 | 1 | +| Professional | €99/mo | Unlimited | 500 | 3 | +| Business | €199/mo | Unlimited | 2000 | 10 | +| Enterprise | Custom | Unlimited | Unlimited | Unlimited | + +### Tier Features + +Each tier includes specific features stored in the `features` JSON column: + +```python +tier.features = [ + "basic_support", # Essential + "priority_support", # Professional+ + "analytics", # Business+ + "api_access", # Business+ + "white_label", # Enterprise + "custom_integrations", # Enterprise +] +``` + +## Limit Enforcement + +Limits are enforced at the service layer: + +### Orders +```python +# app/services/order_service.py +subscription_service.check_order_limit(db, vendor_id) +``` + +### Products +```python +# app/api/v1/vendor/products.py +subscription_service.check_product_limit(db, vendor_id) +``` + +### Team Members +```python +# app/services/vendor_team_service.py +subscription_service.can_add_team_member(db, vendor_id) +``` + +## Stripe Integration + +### Configuration + +Required environment variables: + +```bash +STRIPE_SECRET_KEY=sk_test_... +STRIPE_PUBLISHABLE_KEY=pk_test_... +STRIPE_WEBHOOK_SECRET=whsec_... +STRIPE_TRIAL_DAYS=14 # Optional, default trial period +``` + +### Webhook Events + +The system handles these Stripe events: + +| Event | Handler | +|-------|---------| +| `checkout.session.completed` | Activates subscription, links customer | +| `customer.subscription.updated` | Updates tier, status, period | +| `customer.subscription.deleted` | Marks subscription cancelled | +| `invoice.paid` | Records payment, resets counters | +| `invoice.payment_failed` | Marks past due, increments retry count | + +### Webhook Endpoint + +Webhooks are received at `/api/v1/webhooks/stripe`: + +```python +# Uses signature verification for security +event = stripe_service.construct_event(payload, stripe_signature) +``` + +## Vendor Billing Page + +The vendor billing page is at `/vendor/{vendor_code}/billing`: + +### Page Sections + +1. **Current Plan**: Tier name, status, next billing date +2. **Usage Meters**: Products, orders, team members with limits +3. **Change Plan**: Upgrade/downgrade options +4. **Payment Method**: Link to Stripe portal +5. **Invoice History**: Recent invoices with PDF links + +### JavaScript Component + +The billing page uses Alpine.js (`static/vendor/js/billing.js`): + +```javascript +function billingData() { + return { + subscription: null, + tiers: [], + invoices: [], + + async init() { + await this.loadData(); + }, + + async selectTier(tier) { + const response = await this.apiPost('/billing/checkout', { + tier_code: tier.code, + is_annual: false + }); + window.location.href = response.checkout_url; + }, + + async openPortal() { + const response = await this.apiPost('/billing/portal', {}); + window.location.href = response.portal_url; + } + }; +} +``` + +## Add-ons + +### Available Add-ons + +| Code | Name | Category | Price | +|------|------|----------|-------| +| `domain` | Custom Domain | domain | €15/year | +| `ssl_premium` | Premium SSL | ssl | €49/year | +| `email_5` | 5 Email Addresses | email | €5/month | +| `email_10` | 10 Email Addresses | email | €9/month | +| `email_25` | 25 Email Addresses | email | €19/month | + +### Purchase Flow + +1. Vendor selects add-on on billing page +2. For domains: enter domain name, validate availability +3. Create Stripe checkout session with add-on price +4. On webhook success: create `VendorAddOn` record + +## Exception Handling + +Custom exceptions for billing operations (`app/exceptions/billing.py`): + +| Exception | HTTP Status | Description | +|-----------|-------------|-------------| +| `PaymentSystemNotConfiguredException` | 503 | Stripe not configured | +| `TierNotFoundException` | 404 | Invalid tier code | +| `StripePriceNotConfiguredException` | 400 | No Stripe price for tier | +| `NoActiveSubscriptionException` | 400 | Operation requires subscription | +| `SubscriptionNotCancelledException` | 400 | Cannot reactivate active subscription | + +## Testing + +Unit tests for the billing system: + +```bash +# Run billing service tests +pytest tests/unit/services/test_billing_service.py -v + +# Run webhook handler tests +pytest tests/unit/services/test_stripe_webhook_handler.py -v +``` + +### Test Coverage + +- `BillingService`: Subscription queries, checkout, portal, cancellation +- `StripeWebhookHandler`: Event idempotency, checkout completion, invoice handling + +## Migration + +### Creating Tiers + +Tiers are seeded via migration: + +```python +# alembic/versions/xxx_add_subscription_billing_tables.py +def seed_subscription_tiers(op): + op.bulk_insert(subscription_tiers_table, [ + { + "code": "essential", + "name": "Essential", + "price_monthly_cents": 4900, + "orders_per_month": 100, + "products_limit": 200, + "team_members": 1, + }, + # ... more tiers + ]) +``` + +### Setting Up Stripe + +1. Create products and prices in Stripe Dashboard +2. Update `SubscriptionTier` records with Stripe IDs: + +```python +tier.stripe_product_id = "prod_xxx" +tier.stripe_price_monthly_id = "price_xxx" +tier.stripe_price_annual_id = "price_yyy" +``` + +3. Configure webhook endpoint in Stripe Dashboard: + - URL: `https://yourdomain.com/api/v1/webhooks/stripe` + - Events: `checkout.session.completed`, `customer.subscription.*`, `invoice.*` + +## Security Considerations + +- Webhook signatures verified before processing +- Idempotency keys prevent duplicate event processing +- Customer portal links are session-based and expire +- Stripe API key stored securely in environment variables diff --git a/mkdocs.yml b/mkdocs.yml index 708eb062..ee8a56a6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -204,6 +204,7 @@ nav: - Implementation Guide: features/cms-implementation-guide.md - Platform Homepage: features/platform-homepage.md - Vendor Landing Pages: features/vendor-landing-pages.md + - Subscription & Billing: features/subscription-billing.md # ============================================ # USER GUIDES diff --git a/models/database/__init__.py b/models/database/__init__.py index 3d264ef5..52138194 100644 --- a/models/database/__init__.py +++ b/models/database/__init__.py @@ -51,9 +51,16 @@ from .order_item_exception import OrderItemException from .product import Product from .product_translation import ProductTranslation from .subscription import ( + AddOnCategory, + AddOnProduct, + BillingHistory, + BillingPeriod, + StripeWebhookEvent, SubscriptionStatus, + SubscriptionTier, TierCode, TIER_LIMITS, + VendorAddOn, VendorSubscription, ) from .test_run import TestCollection, TestResult, TestRun @@ -121,11 +128,18 @@ __all__ = [ "LetzshopFulfillmentQueue", "LetzshopSyncLog", "LetzshopHistoricalImportJob", - # Subscription + # Subscription & Billing "VendorSubscription", "SubscriptionStatus", + "SubscriptionTier", "TierCode", "TIER_LIMITS", + "AddOnProduct", + "AddOnCategory", + "BillingPeriod", + "VendorAddOn", + "BillingHistory", + "StripeWebhookEvent", # Messaging "Conversation", "ConversationParticipant", diff --git a/models/database/subscription.py b/models/database/subscription.py index 703d3d43..5688f2c7 100644 --- a/models/database/subscription.py +++ b/models/database/subscription.py @@ -3,8 +3,12 @@ Subscription database models for tier-based access control. Provides models for: -- SubscriptionTier: Tier definitions with limits and features +- SubscriptionTier: Database-driven tier definitions with Stripe integration - VendorSubscription: Per-vendor subscription tracking +- AddOnProduct: Purchasable add-ons (domains, SSL, email packages) +- VendorAddOn: Add-ons purchased by each vendor +- StripeWebhookEvent: Idempotency tracking for webhook processing +- BillingHistory: Invoice and payment history Tier Structure: - Essential (€49/mo): 100 orders/mo, 200 products, 1 user, LU invoicing @@ -53,6 +57,274 @@ class SubscriptionStatus(str, enum.Enum): EXPIRED = "expired" # No longer active +class AddOnCategory(str, enum.Enum): + """Add-on product categories.""" + + DOMAIN = "domain" + SSL = "ssl" + EMAIL = "email" + STORAGE = "storage" + + +class BillingPeriod(str, enum.Enum): + """Billing period for add-ons.""" + + MONTHLY = "monthly" + ANNUAL = "annual" + ONE_TIME = "one_time" + + +# ============================================================================ +# SubscriptionTier - Database-driven tier definitions +# ============================================================================ + + +class SubscriptionTier(Base, TimestampMixin): + """ + Database-driven tier definitions with Stripe integration. + + Replaces the hardcoded TIER_LIMITS dict for dynamic tier management. + """ + + __tablename__ = "subscription_tiers" + + id = Column(Integer, primary_key=True, index=True) + code = Column(String(30), unique=True, nullable=False, index=True) + name = Column(String(100), nullable=False) + description = Column(Text, nullable=True) + + # Pricing (in cents for precision) + price_monthly_cents = Column(Integer, nullable=False) + price_annual_cents = Column(Integer, nullable=True) # Null for enterprise/custom + + # Limits (null = unlimited) + orders_per_month = Column(Integer, nullable=True) + products_limit = Column(Integer, nullable=True) + team_members = Column(Integer, nullable=True) + order_history_months = Column(Integer, nullable=True) + + # Features (JSON array of feature codes) + features = Column(JSON, default=list) + + # Stripe Product/Price IDs + stripe_product_id = Column(String(100), nullable=True) + stripe_price_monthly_id = Column(String(100), nullable=True) + stripe_price_annual_id = Column(String(100), nullable=True) + + # Display and visibility + display_order = Column(Integer, default=0) + is_active = Column(Boolean, default=True, nullable=False) + is_public = Column(Boolean, default=True, nullable=False) # False for enterprise + + def __repr__(self): + return f"" + + def to_dict(self) -> dict: + """Convert tier to dictionary (compatible with TIER_LIMITS format).""" + return { + "name": self.name, + "price_monthly_cents": self.price_monthly_cents, + "price_annual_cents": self.price_annual_cents, + "orders_per_month": self.orders_per_month, + "products_limit": self.products_limit, + "team_members": self.team_members, + "order_history_months": self.order_history_months, + "features": self.features or [], + } + + +# ============================================================================ +# AddOnProduct - Purchasable add-ons +# ============================================================================ + + +class AddOnProduct(Base, TimestampMixin): + """ + Purchasable add-on products (domains, SSL, email packages). + + These are separate from subscription tiers and can be added to any tier. + """ + + __tablename__ = "addon_products" + + id = Column(Integer, primary_key=True, index=True) + code = Column(String(50), unique=True, nullable=False, index=True) + name = Column(String(100), nullable=False) + description = Column(Text, nullable=True) + category = Column(String(50), nullable=False, index=True) + + # Pricing + price_cents = Column(Integer, nullable=False) + billing_period = Column( + String(20), default=BillingPeriod.MONTHLY.value, nullable=False + ) + + # For tiered add-ons (e.g., email_5, email_10) + quantity_unit = Column(String(50), nullable=True) # emails, GB, etc. + quantity_value = Column(Integer, nullable=True) # 5, 10, 50, etc. + + # Stripe + stripe_product_id = Column(String(100), nullable=True) + stripe_price_id = Column(String(100), nullable=True) + + # Display + display_order = Column(Integer, default=0) + is_active = Column(Boolean, default=True, nullable=False) + + def __repr__(self): + return f"" + + +# ============================================================================ +# VendorAddOn - Add-ons purchased by vendor +# ============================================================================ + + +class VendorAddOn(Base, TimestampMixin): + """ + Add-ons purchased by a vendor. + + Tracks active add-on subscriptions and their billing status. + """ + + __tablename__ = "vendor_addons" + + id = Column(Integer, primary_key=True, index=True) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) + addon_product_id = Column( + Integer, ForeignKey("addon_products.id"), nullable=False, index=True + ) + + # Status + status = Column(String(20), default="active", nullable=False, index=True) + + # For domains: store the actual domain name + domain_name = Column(String(255), nullable=True, index=True) + + # Quantity (for tiered add-ons like email packages) + quantity = Column(Integer, default=1, nullable=False) + + # Stripe billing + stripe_subscription_item_id = Column(String(100), nullable=True) + + # Period tracking + period_start = Column(DateTime(timezone=True), nullable=True) + period_end = Column(DateTime(timezone=True), nullable=True) + + # Cancellation + cancelled_at = Column(DateTime(timezone=True), nullable=True) + + # Relationships + vendor = relationship("Vendor", back_populates="addons") + addon_product = relationship("AddOnProduct") + + __table_args__ = ( + Index("idx_vendor_addon_status", "vendor_id", "status"), + Index("idx_vendor_addon_product", "vendor_id", "addon_product_id"), + ) + + def __repr__(self): + return f"" + + +# ============================================================================ +# StripeWebhookEvent - Webhook idempotency tracking +# ============================================================================ + + +class StripeWebhookEvent(Base, TimestampMixin): + """ + Log of processed Stripe webhook events for idempotency. + + Prevents duplicate processing of the same event. + """ + + __tablename__ = "stripe_webhook_events" + + id = Column(Integer, primary_key=True, index=True) + event_id = Column(String(100), unique=True, nullable=False, index=True) + event_type = Column(String(100), nullable=False, index=True) + + # Processing status + status = Column(String(20), default="pending", nullable=False, index=True) + processed_at = Column(DateTime(timezone=True), nullable=True) + error_message = Column(Text, nullable=True) + + # Raw event data (encrypted for security) + payload_encrypted = Column(Text, nullable=True) + + # Related entities (for quick lookup) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=True, index=True) + subscription_id = Column( + Integer, ForeignKey("vendor_subscriptions.id"), nullable=True, index=True + ) + + __table_args__ = (Index("idx_webhook_event_type_status", "event_type", "status"),) + + def __repr__(self): + return f"" + + +# ============================================================================ +# BillingHistory - Invoice and payment history +# ============================================================================ + + +class BillingHistory(Base, TimestampMixin): + """ + Invoice and payment history for vendors. + + Stores Stripe invoice data for display and reporting. + """ + + __tablename__ = "billing_history" + + id = Column(Integer, primary_key=True, index=True) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) + + # Stripe references + stripe_invoice_id = Column(String(100), unique=True, nullable=True, index=True) + stripe_payment_intent_id = Column(String(100), nullable=True) + + # Invoice details + invoice_number = Column(String(50), nullable=True) + invoice_date = Column(DateTime(timezone=True), nullable=False) + due_date = Column(DateTime(timezone=True), nullable=True) + + # Amounts (in cents for precision) + subtotal_cents = Column(Integer, nullable=False) + tax_cents = Column(Integer, default=0, nullable=False) + total_cents = Column(Integer, nullable=False) + amount_paid_cents = Column(Integer, default=0, nullable=False) + currency = Column(String(3), default="EUR", nullable=False) + + # Status + status = Column(String(20), nullable=False, index=True) + + # PDF URLs + invoice_pdf_url = Column(String(500), nullable=True) + hosted_invoice_url = Column(String(500), nullable=True) + + # Description and line items + description = Column(Text, nullable=True) + line_items = Column(JSON, nullable=True) + + # Relationships + vendor = relationship("Vendor", back_populates="billing_history") + + __table_args__ = ( + Index("idx_billing_vendor_date", "vendor_id", "invoice_date"), + Index("idx_billing_status", "vendor_id", "status"), + ) + + def __repr__(self): + return f"" + + +# ============================================================================ +# Legacy TIER_LIMITS (kept for backward compatibility during migration) +# ============================================================================ + # Tier limit definitions (hardcoded for now, could be moved to DB) TIER_LIMITS = { TierCode.ESSENTIAL: { @@ -186,9 +458,20 @@ class VendorSubscription(Base, TimestampMixin): custom_products_limit = Column(Integer, nullable=True) custom_team_limit = Column(Integer, nullable=True) - # Payment info (for future Stripe integration) + # Payment info (Stripe integration) stripe_customer_id = Column(String(100), nullable=True, index=True) stripe_subscription_id = Column(String(100), nullable=True, index=True) + stripe_price_id = Column(String(100), nullable=True) # Current price being billed + stripe_payment_method_id = Column(String(100), nullable=True) # Default payment method + + # Proration and upgrade/downgrade tracking + proration_behavior = Column(String(50), default="create_prorations") + scheduled_tier_change = Column(String(30), nullable=True) # Pending tier change + scheduled_change_at = Column(DateTime(timezone=True), nullable=True) + + # Payment failure tracking + payment_retry_count = Column(Integer, default=0, nullable=False) + last_payment_error = Column(Text, nullable=True) # Cancellation cancelled_at = Column(DateTime(timezone=True), nullable=True) diff --git a/models/database/vendor.py b/models/database/vendor.py index 12d73813..a029ec1b 100644 --- a/models/database/vendor.py +++ b/models/database/vendor.py @@ -166,6 +166,21 @@ class Vendor(Base, TimestampMixin): cascade="all, delete-orphan", ) + # Add-ons purchased by vendor (one-to-many) + addons = relationship( + "VendorAddOn", + back_populates="vendor", + cascade="all, delete-orphan", + ) + + # Billing/invoice history (one-to-many) + billing_history = relationship( + "BillingHistory", + back_populates="vendor", + cascade="all, delete-orphan", + order_by="BillingHistory.invoice_date.desc()", + ) + domains = relationship( "VendorDomain", back_populates="vendor", diff --git a/requirements.txt b/requirements.txt index 948b4cd1..2f30360f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,7 @@ psutil>=5.9.0 weasyprint==62.3 # Environment and configuration -python-dotenv==1.0.1 \ No newline at end of file +python-dotenv==1.0.1 + +# Payment processing +stripe>=7.0.0 \ No newline at end of file diff --git a/static/vendor/js/billing.js b/static/vendor/js/billing.js new file mode 100644 index 00000000..27073cd6 --- /dev/null +++ b/static/vendor/js/billing.js @@ -0,0 +1,187 @@ +// static/vendor/js/billing.js +// Vendor billing and subscription management + +function billingData() { + return { + // State + loading: true, + subscription: null, + tiers: [], + addons: [], + myAddons: [], + invoices: [], + + // UI state + showTiersModal: false, + showAddonsModal: false, + showCancelModal: false, + showSuccessMessage: false, + showCancelMessage: false, + cancelReason: '', + + // Initialize + async init() { + // Check URL params for success/cancel + const params = new URLSearchParams(window.location.search); + if (params.get('success') === 'true') { + this.showSuccessMessage = true; + // Clean URL + window.history.replaceState({}, document.title, window.location.pathname); + } + if (params.get('cancelled') === 'true') { + this.showCancelMessage = true; + window.history.replaceState({}, document.title, window.location.pathname); + } + + await this.loadData(); + }, + + async loadData() { + this.loading = true; + try { + // Load all data in parallel + const [subscriptionRes, tiersRes, addonsRes, invoicesRes] = await Promise.all([ + this.apiGet('/billing/subscription'), + this.apiGet('/billing/tiers'), + this.apiGet('/billing/addons'), + this.apiGet('/billing/invoices?limit=5'), + ]); + + this.subscription = subscriptionRes; + this.tiers = tiersRes.tiers || []; + this.addons = addonsRes || []; + this.invoices = invoicesRes.invoices || []; + + } catch (error) { + console.error('Error loading billing data:', error); + this.showNotification('Failed to load billing data', 'error'); + } finally { + this.loading = false; + } + }, + + async selectTier(tier) { + if (tier.is_current) return; + + try { + const response = await this.apiPost('/billing/checkout', { + tier_code: tier.code, + is_annual: false + }); + + if (response.checkout_url) { + window.location.href = response.checkout_url; + } + } catch (error) { + console.error('Error creating checkout:', error); + this.showNotification('Failed to create checkout session', 'error'); + } + }, + + async openPortal() { + try { + const response = await this.apiPost('/billing/portal', {}); + if (response.portal_url) { + window.location.href = response.portal_url; + } + } catch (error) { + console.error('Error opening portal:', error); + this.showNotification('Failed to open payment portal', 'error'); + } + }, + + async cancelSubscription() { + try { + await this.apiPost('/billing/cancel', { + reason: this.cancelReason, + immediately: false + }); + + this.showCancelModal = false; + this.showNotification('Subscription cancelled. You have access until the end of your billing period.', 'success'); + await this.loadData(); + + } catch (error) { + console.error('Error cancelling subscription:', error); + this.showNotification('Failed to cancel subscription', 'error'); + } + }, + + async reactivate() { + try { + await this.apiPost('/billing/reactivate', {}); + this.showNotification('Subscription reactivated!', 'success'); + await this.loadData(); + + } catch (error) { + console.error('Error reactivating subscription:', error); + this.showNotification('Failed to reactivate subscription', 'error'); + } + }, + + // API helpers + async apiGet(endpoint) { + const response = await fetch(`/api/v1/vendor${endpoint}`, { + headers: { + 'Content-Type': 'application/json', + }, + credentials: 'include' + }); + + if (!response.ok) { + throw new Error(`API error: ${response.status}`); + } + + return response.json(); + }, + + async apiPost(endpoint, data) { + const response = await fetch(`/api/v1/vendor${endpoint}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + credentials: 'include', + body: JSON.stringify(data) + }); + + if (!response.ok) { + const error = await response.json().catch(() => ({})); + throw new Error(error.detail || `API error: ${response.status}`); + } + + return response.json(); + }, + + // Formatters + formatDate(dateString) { + if (!dateString) return '-'; + const date = new Date(dateString); + return date.toLocaleDateString('en-US', { + year: 'numeric', + month: 'short', + day: 'numeric' + }); + }, + + formatCurrency(cents, currency = 'EUR') { + if (cents === null || cents === undefined) return '-'; + const amount = cents / 100; + return new Intl.NumberFormat('en-US', { + style: 'currency', + currency: currency + }).format(amount); + }, + + showNotification(message, type = 'info') { + // Use Alpine's $dispatch if available, or fallback to alert + if (window.Alpine) { + window.dispatchEvent(new CustomEvent('show-notification', { + detail: { message, type } + })); + } else { + alert(message); + } + } + }; +} diff --git a/tests/unit/services/test_billing_service.py b/tests/unit/services/test_billing_service.py new file mode 100644 index 00000000..bfcb4fcb --- /dev/null +++ b/tests/unit/services/test_billing_service.py @@ -0,0 +1,658 @@ +# tests/unit/services/test_billing_service.py +"""Unit tests for BillingService.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from app.exceptions import VendorNotFoundException +from app.services.billing_service import ( + BillingService, + NoActiveSubscriptionError, + PaymentSystemNotConfiguredError, + StripePriceNotConfiguredError, + SubscriptionNotCancelledError, + TierNotFoundError, +) +from models.database.subscription import ( + AddOnProduct, + BillingHistory, + SubscriptionStatus, + SubscriptionTier, + VendorAddOn, + VendorSubscription, +) + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceSubscription: + """Test suite for BillingService subscription operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + def test_get_subscription_with_tier_creates_if_not_exists( + self, db, test_vendor, test_subscription_tier + ): + """Test get_subscription_with_tier creates subscription if needed.""" + subscription, tier = self.service.get_subscription_with_tier(db, test_vendor.id) + + assert subscription is not None + assert subscription.vendor_id == test_vendor.id + assert tier is not None + assert tier.code == subscription.tier + + def test_get_subscription_with_tier_returns_existing( + self, db, test_vendor, test_subscription + ): + """Test get_subscription_with_tier returns existing subscription.""" + # Note: test_subscription fixture already creates the tier + subscription, tier = self.service.get_subscription_with_tier(db, test_vendor.id) + + assert subscription.id == test_subscription.id + assert tier.code == test_subscription.tier + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceTiers: + """Test suite for BillingService tier operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + def test_get_available_tiers(self, db, test_subscription_tiers): + """Test getting available tiers.""" + tier_list, tier_order = self.service.get_available_tiers(db, "essential") + + assert len(tier_list) > 0 + assert "essential" in tier_order + assert "professional" in tier_order + + # Check tier has expected fields + essential_tier = next(t for t in tier_list if t["code"] == "essential") + assert essential_tier["is_current"] is True + assert essential_tier["can_upgrade"] is False + assert essential_tier["can_downgrade"] is False + + professional_tier = next(t for t in tier_list if t["code"] == "professional") + assert professional_tier["can_upgrade"] is True + assert professional_tier["can_downgrade"] is False + + def test_get_tier_by_code_success(self, db, test_subscription_tier): + """Test getting tier by code.""" + tier = self.service.get_tier_by_code(db, "essential") + + assert tier.code == "essential" + assert tier.is_active is True + + def test_get_tier_by_code_not_found(self, db): + """Test getting non-existent tier raises error.""" + with pytest.raises(TierNotFoundError) as exc_info: + self.service.get_tier_by_code(db, "nonexistent") + + assert exc_info.value.tier_code == "nonexistent" + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceCheckout: + """Test suite for BillingService checkout operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + @patch("app.services.billing_service.stripe_service") + def test_create_checkout_session_stripe_not_configured( + self, mock_stripe, db, test_vendor, test_subscription_tier + ): + """Test checkout fails when Stripe not configured.""" + mock_stripe.is_configured = False + + with pytest.raises(PaymentSystemNotConfiguredError): + self.service.create_checkout_session( + db=db, + vendor_id=test_vendor.id, + tier_code="essential", + is_annual=False, + success_url="https://example.com/success", + cancel_url="https://example.com/cancel", + ) + + @patch("app.services.billing_service.stripe_service") + def test_create_checkout_session_success( + self, mock_stripe, db, test_vendor, test_subscription_tier_with_stripe + ): + """Test successful checkout session creation.""" + mock_stripe.is_configured = True + mock_session = MagicMock() + mock_session.url = "https://checkout.stripe.com/test" + mock_session.id = "cs_test_123" + mock_stripe.create_checkout_session.return_value = mock_session + + result = self.service.create_checkout_session( + db=db, + vendor_id=test_vendor.id, + tier_code="essential", + is_annual=False, + success_url="https://example.com/success", + cancel_url="https://example.com/cancel", + ) + + assert result["checkout_url"] == "https://checkout.stripe.com/test" + assert result["session_id"] == "cs_test_123" + + @patch("app.services.billing_service.stripe_service") + def test_create_checkout_session_tier_not_found( + self, mock_stripe, db, test_vendor + ): + """Test checkout fails with invalid tier.""" + mock_stripe.is_configured = True + + with pytest.raises(TierNotFoundError): + self.service.create_checkout_session( + db=db, + vendor_id=test_vendor.id, + tier_code="nonexistent", + is_annual=False, + success_url="https://example.com/success", + cancel_url="https://example.com/cancel", + ) + + @patch("app.services.billing_service.stripe_service") + def test_create_checkout_session_no_price( + self, mock_stripe, db, test_vendor, test_subscription_tier + ): + """Test checkout fails when tier has no Stripe price.""" + mock_stripe.is_configured = True + + with pytest.raises(StripePriceNotConfiguredError): + self.service.create_checkout_session( + db=db, + vendor_id=test_vendor.id, + tier_code="essential", + is_annual=False, + success_url="https://example.com/success", + cancel_url="https://example.com/cancel", + ) + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServicePortal: + """Test suite for BillingService portal operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + @patch("app.services.billing_service.stripe_service") + def test_create_portal_session_stripe_not_configured(self, mock_stripe, db, test_vendor): + """Test portal fails when Stripe not configured.""" + mock_stripe.is_configured = False + + with pytest.raises(PaymentSystemNotConfiguredError): + self.service.create_portal_session( + db=db, + vendor_id=test_vendor.id, + return_url="https://example.com/billing", + ) + + @patch("app.services.billing_service.stripe_service") + def test_create_portal_session_no_subscription(self, mock_stripe, db, test_vendor): + """Test portal fails when no subscription exists.""" + mock_stripe.is_configured = True + + with pytest.raises(NoActiveSubscriptionError): + self.service.create_portal_session( + db=db, + vendor_id=test_vendor.id, + return_url="https://example.com/billing", + ) + + @patch("app.services.billing_service.stripe_service") + def test_create_portal_session_success( + self, mock_stripe, db, test_vendor, test_active_subscription + ): + """Test successful portal session creation.""" + mock_stripe.is_configured = True + mock_session = MagicMock() + mock_session.url = "https://billing.stripe.com/portal" + mock_stripe.create_portal_session.return_value = mock_session + + result = self.service.create_portal_session( + db=db, + vendor_id=test_vendor.id, + return_url="https://example.com/billing", + ) + + assert result["portal_url"] == "https://billing.stripe.com/portal" + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceInvoices: + """Test suite for BillingService invoice operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + def test_get_invoices_empty(self, db, test_vendor): + """Test getting invoices when none exist.""" + invoices, total = self.service.get_invoices(db, test_vendor.id) + + assert invoices == [] + assert total == 0 + + def test_get_invoices_with_data(self, db, test_vendor, test_billing_history): + """Test getting invoices returns data.""" + invoices, total = self.service.get_invoices(db, test_vendor.id) + + assert len(invoices) == 1 + assert total == 1 + assert invoices[0].invoice_number == "INV-001" + + def test_get_invoices_pagination(self, db, test_vendor, test_multiple_invoices): + """Test invoice pagination.""" + # Get first page + page1, total = self.service.get_invoices(db, test_vendor.id, skip=0, limit=2) + assert len(page1) == 2 + assert total == 5 + + # Get second page + page2, _ = self.service.get_invoices(db, test_vendor.id, skip=2, limit=2) + assert len(page2) == 2 + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceAddons: + """Test suite for BillingService addon operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + def test_get_available_addons_empty(self, db): + """Test getting addons when none exist.""" + addons = self.service.get_available_addons(db) + assert addons == [] + + def test_get_available_addons_with_data(self, db, test_addon_products): + """Test getting all available addons.""" + addons = self.service.get_available_addons(db) + + assert len(addons) == 3 + assert all(addon.is_active for addon in addons) + + def test_get_available_addons_by_category(self, db, test_addon_products): + """Test filtering addons by category.""" + domain_addons = self.service.get_available_addons(db, category="domain") + + assert len(domain_addons) == 1 + assert domain_addons[0].category == "domain" + + def test_get_vendor_addons_empty(self, db, test_vendor): + """Test getting vendor addons when none purchased.""" + addons = self.service.get_vendor_addons(db, test_vendor.id) + assert addons == [] + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceCancellation: + """Test suite for BillingService cancellation operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + @patch("app.services.billing_service.stripe_service") + def test_cancel_subscription_no_subscription( + self, mock_stripe, db, test_vendor + ): + """Test cancel fails when no subscription.""" + mock_stripe.is_configured = True + + with pytest.raises(NoActiveSubscriptionError): + self.service.cancel_subscription( + db=db, + vendor_id=test_vendor.id, + reason="Test reason", + immediately=False, + ) + + @patch("app.services.billing_service.stripe_service") + def test_cancel_subscription_success( + self, mock_stripe, db, test_vendor, test_active_subscription + ): + """Test successful subscription cancellation.""" + mock_stripe.is_configured = True + + result = self.service.cancel_subscription( + db=db, + vendor_id=test_vendor.id, + reason="Too expensive", + immediately=False, + ) + + assert result["message"] == "Subscription cancelled successfully" + assert test_active_subscription.cancelled_at is not None + assert test_active_subscription.cancellation_reason == "Too expensive" + + @patch("app.services.billing_service.stripe_service") + def test_reactivate_subscription_not_cancelled( + self, mock_stripe, db, test_vendor, test_active_subscription + ): + """Test reactivate fails when subscription not cancelled.""" + mock_stripe.is_configured = True + + with pytest.raises(SubscriptionNotCancelledError): + self.service.reactivate_subscription(db, test_vendor.id) + + @patch("app.services.billing_service.stripe_service") + def test_reactivate_subscription_success( + self, mock_stripe, db, test_vendor, test_cancelled_subscription + ): + """Test successful subscription reactivation.""" + mock_stripe.is_configured = True + + result = self.service.reactivate_subscription(db, test_vendor.id) + + assert result["message"] == "Subscription reactivated successfully" + assert test_cancelled_subscription.cancelled_at is None + assert test_cancelled_subscription.cancellation_reason is None + + +@pytest.mark.unit +@pytest.mark.billing +class TestBillingServiceVendor: + """Test suite for BillingService vendor operations.""" + + def setup_method(self): + """Initialize service instance before each test.""" + self.service = BillingService() + + def test_get_vendor_success(self, db, test_vendor): + """Test getting vendor by ID.""" + vendor = self.service.get_vendor(db, test_vendor.id) + + assert vendor.id == test_vendor.id + + def test_get_vendor_not_found(self, db): + """Test getting non-existent vendor raises error.""" + with pytest.raises(VendorNotFoundException): + self.service.get_vendor(db, 99999) + + +# ==================== Fixtures ==================== + + +@pytest.fixture +def test_subscription_tier(db): + """Create a basic subscription tier.""" + tier = SubscriptionTier( + code="essential", + name="Essential", + description="Essential plan", + price_monthly_cents=4900, + price_annual_cents=49000, + orders_per_month=100, + products_limit=200, + team_members=1, + features=["basic_support"], + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def test_subscription_tier_with_stripe(db): + """Create a subscription tier with Stripe configuration.""" + tier = SubscriptionTier( + code="essential", + name="Essential", + description="Essential plan", + price_monthly_cents=4900, + price_annual_cents=49000, + orders_per_month=100, + products_limit=200, + team_members=1, + features=["basic_support"], + display_order=1, + is_active=True, + is_public=True, + stripe_product_id="prod_test123", + stripe_price_monthly_id="price_test123", + stripe_price_annual_id="price_test456", + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def test_subscription_tiers(db): + """Create multiple subscription tiers.""" + tiers = [ + SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + display_order=1, + is_active=True, + is_public=True, + ), + SubscriptionTier( + code="professional", + name="Professional", + price_monthly_cents=9900, + display_order=2, + is_active=True, + is_public=True, + ), + SubscriptionTier( + code="business", + name="Business", + price_monthly_cents=19900, + display_order=3, + is_active=True, + is_public=True, + ), + ] + db.add_all(tiers) + db.commit() + for tier in tiers: + db.refresh(tier) + return tiers + + +@pytest.fixture +def test_subscription(db, test_vendor): + """Create a basic subscription for testing.""" + # Create tier first + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.ACTIVE, + period_start=datetime.now(timezone.utc), + period_end=datetime.now(timezone.utc), + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def test_active_subscription(db, test_vendor): + """Create an active subscription with Stripe IDs.""" + # Create tier first if not exists + tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first() + if not tier: + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.ACTIVE, + stripe_customer_id="cus_test123", + stripe_subscription_id="sub_test123", + period_start=datetime.now(timezone.utc), + period_end=datetime.now(timezone.utc), + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def test_cancelled_subscription(db, test_vendor): + """Create a cancelled subscription.""" + # Create tier first if not exists + tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first() + if not tier: + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.ACTIVE, + stripe_customer_id="cus_test123", + stripe_subscription_id="sub_test123", + period_start=datetime.now(timezone.utc), + period_end=datetime.now(timezone.utc), + cancelled_at=datetime.now(timezone.utc), + cancellation_reason="Too expensive", + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def test_billing_history(db, test_vendor): + """Create a billing history record.""" + record = BillingHistory( + vendor_id=test_vendor.id, + stripe_invoice_id="in_test123", + invoice_number="INV-001", + invoice_date=datetime.now(timezone.utc), + subtotal_cents=4900, + tax_cents=0, + total_cents=4900, + amount_paid_cents=4900, + currency="EUR", + status="paid", + ) + db.add(record) + db.commit() + db.refresh(record) + return record + + +@pytest.fixture +def test_multiple_invoices(db, test_vendor): + """Create multiple billing history records.""" + records = [] + for i in range(5): + record = BillingHistory( + vendor_id=test_vendor.id, + stripe_invoice_id=f"in_test{i}", + invoice_number=f"INV-{i:03d}", + invoice_date=datetime.now(timezone.utc), + subtotal_cents=4900, + tax_cents=0, + total_cents=4900, + amount_paid_cents=4900, + currency="EUR", + status="paid", + ) + records.append(record) + db.add_all(records) + db.commit() + return records + + +@pytest.fixture +def test_addon_products(db): + """Create test addon products.""" + addons = [ + AddOnProduct( + code="domain", + name="Custom Domain", + category="domain", + price_cents=1500, + billing_period="annual", + display_order=1, + is_active=True, + ), + AddOnProduct( + code="email_5", + name="5 Email Addresses", + category="email", + price_cents=500, + billing_period="monthly", + quantity_value=5, + display_order=2, + is_active=True, + ), + AddOnProduct( + code="email_10", + name="10 Email Addresses", + category="email", + price_cents=900, + billing_period="monthly", + quantity_value=10, + display_order=3, + is_active=True, + ), + ] + db.add_all(addons) + db.commit() + for addon in addons: + db.refresh(addon) + return addons diff --git a/tests/unit/services/test_stripe_webhook_handler.py b/tests/unit/services/test_stripe_webhook_handler.py new file mode 100644 index 00000000..a275817b --- /dev/null +++ b/tests/unit/services/test_stripe_webhook_handler.py @@ -0,0 +1,393 @@ +# tests/unit/services/test_stripe_webhook_handler.py +"""Unit tests for StripeWebhookHandler.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.stripe_webhook_handler import StripeWebhookHandler +from models.database.subscription import ( + BillingHistory, + StripeWebhookEvent, + SubscriptionStatus, + SubscriptionTier, + VendorSubscription, +) + + +@pytest.mark.unit +@pytest.mark.billing +class TestStripeWebhookHandlerIdempotency: + """Test suite for webhook handler idempotency.""" + + def setup_method(self): + """Initialize handler instance before each test.""" + self.handler = StripeWebhookHandler() + + def test_handle_event_creates_webhook_event_record(self, db, mock_stripe_event): + """Test that handling an event creates a webhook event record.""" + self.handler.handle_event(db, mock_stripe_event) + + record = ( + db.query(StripeWebhookEvent) + .filter(StripeWebhookEvent.event_id == mock_stripe_event.id) + .first() + ) + assert record is not None + assert record.event_type == mock_stripe_event.type + assert record.status == "processed" + + def test_handle_event_skips_duplicate(self, db, mock_stripe_event): + """Test that duplicate events are skipped.""" + # Process first time + result1 = self.handler.handle_event(db, mock_stripe_event) + assert result1["status"] != "skipped" + + # Process second time + result2 = self.handler.handle_event(db, mock_stripe_event) + assert result2["status"] == "skipped" + assert result2["reason"] == "duplicate" + + +@pytest.mark.unit +@pytest.mark.billing +class TestStripeWebhookHandlerCheckout: + """Test suite for checkout.session.completed event handling.""" + + def setup_method(self): + """Initialize handler instance before each test.""" + self.handler = StripeWebhookHandler() + + @patch("app.services.stripe_webhook_handler.stripe.Subscription.retrieve") + def test_handle_checkout_completed_success( + self, mock_stripe_retrieve, db, test_vendor, test_subscription, mock_checkout_event + ): + """Test successful checkout completion.""" + # Mock Stripe subscription retrieve + mock_stripe_sub = MagicMock() + mock_stripe_sub.current_period_start = int(datetime.now(timezone.utc).timestamp()) + mock_stripe_sub.current_period_end = int(datetime.now(timezone.utc).timestamp()) + mock_stripe_sub.trial_end = None + mock_stripe_retrieve.return_value = mock_stripe_sub + + mock_checkout_event.data.object.metadata = {"vendor_id": str(test_vendor.id)} + + result = self.handler.handle_event(db, mock_checkout_event) + + assert result["status"] == "processed" + db.refresh(test_subscription) + assert test_subscription.stripe_customer_id == "cus_test123" + assert test_subscription.status == SubscriptionStatus.ACTIVE + + def test_handle_checkout_completed_no_vendor_id(self, db, mock_checkout_event): + """Test checkout with missing vendor_id is skipped.""" + mock_checkout_event.data.object.metadata = {} + + result = self.handler.handle_event(db, mock_checkout_event) + + assert result["status"] == "processed" + assert result["result"]["action"] == "skipped" + assert result["result"]["reason"] == "no vendor_id" + + +@pytest.mark.unit +@pytest.mark.billing +class TestStripeWebhookHandlerSubscription: + """Test suite for subscription event handling.""" + + def setup_method(self): + """Initialize handler instance before each test.""" + self.handler = StripeWebhookHandler() + + def test_handle_subscription_updated_status_change( + self, db, test_vendor, test_active_subscription, mock_subscription_updated_event + ): + """Test subscription update changes status.""" + result = self.handler.handle_event(db, mock_subscription_updated_event) + + assert result["status"] == "processed" + + def test_handle_subscription_deleted( + self, db, test_vendor, test_active_subscription, mock_subscription_deleted_event + ): + """Test subscription deletion.""" + result = self.handler.handle_event(db, mock_subscription_deleted_event) + + assert result["status"] == "processed" + db.refresh(test_active_subscription) + assert test_active_subscription.status == SubscriptionStatus.CANCELLED + + +@pytest.mark.unit +@pytest.mark.billing +class TestStripeWebhookHandlerInvoice: + """Test suite for invoice event handling.""" + + def setup_method(self): + """Initialize handler instance before each test.""" + self.handler = StripeWebhookHandler() + + def test_handle_invoice_paid_creates_billing_record( + self, db, test_vendor, test_active_subscription, mock_invoice_paid_event + ): + """Test invoice.paid creates billing history record.""" + result = self.handler.handle_event(db, mock_invoice_paid_event) + + assert result["status"] == "processed" + + # Check billing record created + record = ( + db.query(BillingHistory) + .filter(BillingHistory.vendor_id == test_vendor.id) + .first() + ) + assert record is not None + assert record.status == "paid" + assert record.total_cents == 4900 + + def test_handle_invoice_paid_resets_counters( + self, db, test_vendor, test_active_subscription, mock_invoice_paid_event + ): + """Test invoice.paid resets order counters.""" + test_active_subscription.orders_this_period = 50 + db.commit() + + self.handler.handle_event(db, mock_invoice_paid_event) + + db.refresh(test_active_subscription) + assert test_active_subscription.orders_this_period == 0 + + def test_handle_payment_failed_marks_past_due( + self, db, test_vendor, test_active_subscription, mock_payment_failed_event + ): + """Test payment failure marks subscription as past due.""" + result = self.handler.handle_event(db, mock_payment_failed_event) + + assert result["status"] == "processed" + db.refresh(test_active_subscription) + assert test_active_subscription.status == SubscriptionStatus.PAST_DUE + assert test_active_subscription.payment_retry_count == 1 + + +@pytest.mark.unit +@pytest.mark.billing +class TestStripeWebhookHandlerUnknownEvents: + """Test suite for unknown event handling.""" + + def setup_method(self): + """Initialize handler instance before each test.""" + self.handler = StripeWebhookHandler() + + def test_handle_unknown_event_type(self, db): + """Test unknown event types are ignored.""" + mock_event = MagicMock() + mock_event.id = "evt_unknown123" + mock_event.type = "customer.unknown_event" + mock_event.data.object = {} + + result = self.handler.handle_event(db, mock_event) + + assert result["status"] == "ignored" + assert "no handler" in result["reason"] + + +@pytest.mark.unit +@pytest.mark.billing +class TestStripeWebhookHandlerStatusMapping: + """Test suite for status mapping helper.""" + + def setup_method(self): + """Initialize handler instance before each test.""" + self.handler = StripeWebhookHandler() + + def test_map_active_status(self): + """Test mapping active status.""" + result = self.handler._map_stripe_status("active") + assert result == SubscriptionStatus.ACTIVE + + def test_map_trialing_status(self): + """Test mapping trialing status.""" + result = self.handler._map_stripe_status("trialing") + assert result == SubscriptionStatus.TRIAL + + def test_map_past_due_status(self): + """Test mapping past_due status.""" + result = self.handler._map_stripe_status("past_due") + assert result == SubscriptionStatus.PAST_DUE + + def test_map_canceled_status(self): + """Test mapping canceled status.""" + result = self.handler._map_stripe_status("canceled") + assert result == SubscriptionStatus.CANCELLED + + def test_map_unknown_status(self): + """Test mapping unknown status defaults to expired.""" + result = self.handler._map_stripe_status("unknown_status") + assert result == SubscriptionStatus.EXPIRED + + +# ==================== Fixtures ==================== + + +@pytest.fixture +def test_subscription_tier(db): + """Create a basic subscription tier.""" + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + db.refresh(tier) + return tier + + +@pytest.fixture +def test_subscription(db, test_vendor): + """Create a basic subscription for testing.""" + # Create tier first if not exists + tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first() + if not tier: + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.TRIAL, + period_start=datetime.now(timezone.utc), + period_end=datetime.now(timezone.utc), + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def test_active_subscription(db, test_vendor): + """Create an active subscription with Stripe IDs.""" + # Create tier first if not exists + tier = db.query(SubscriptionTier).filter(SubscriptionTier.code == "essential").first() + if not tier: + tier = SubscriptionTier( + code="essential", + name="Essential", + price_monthly_cents=4900, + display_order=1, + is_active=True, + is_public=True, + ) + db.add(tier) + db.commit() + + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.ACTIVE, + stripe_customer_id="cus_test123", + stripe_subscription_id="sub_test123", + period_start=datetime.now(timezone.utc), + period_end=datetime.now(timezone.utc), + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def mock_stripe_event(): + """Create a mock Stripe event.""" + event = MagicMock() + event.id = "evt_test123" + event.type = "customer.created" + event.data.object = {"id": "cus_test123"} + return event + + +@pytest.fixture +def mock_checkout_event(): + """Create a mock checkout.session.completed event.""" + event = MagicMock() + event.id = "evt_checkout123" + event.type = "checkout.session.completed" + event.data.object.id = "cs_test123" + event.data.object.customer = "cus_test123" + event.data.object.subscription = "sub_test123" + event.data.object.metadata = {} + return event + + +@pytest.fixture +def mock_subscription_updated_event(): + """Create a mock customer.subscription.updated event.""" + event = MagicMock() + event.id = "evt_subupdated123" + event.type = "customer.subscription.updated" + event.data.object.id = "sub_test123" + event.data.object.customer = "cus_test123" + event.data.object.status = "active" + event.data.object.current_period_start = int(datetime.now(timezone.utc).timestamp()) + event.data.object.current_period_end = int(datetime.now(timezone.utc).timestamp()) + event.data.object.cancel_at_period_end = False + event.data.object.items.data = [] + event.data.object.metadata = {} + return event + + +@pytest.fixture +def mock_subscription_deleted_event(): + """Create a mock customer.subscription.deleted event.""" + event = MagicMock() + event.id = "evt_subdeleted123" + event.type = "customer.subscription.deleted" + event.data.object.id = "sub_test123" + event.data.object.customer = "cus_test123" + return event + + +@pytest.fixture +def mock_invoice_paid_event(): + """Create a mock invoice.paid event.""" + event = MagicMock() + event.id = "evt_invoicepaid123" + event.type = "invoice.paid" + event.data.object.id = "in_test123" + event.data.object.customer = "cus_test123" + event.data.object.payment_intent = "pi_test123" + event.data.object.number = "INV-001" + event.data.object.created = int(datetime.now(timezone.utc).timestamp()) + event.data.object.subtotal = 4900 + event.data.object.tax = 0 + event.data.object.total = 4900 + event.data.object.amount_paid = 4900 + event.data.object.currency = "eur" + event.data.object.invoice_pdf = "https://stripe.com/invoice.pdf" + event.data.object.hosted_invoice_url = "https://invoice.stripe.com" + return event + + +@pytest.fixture +def mock_payment_failed_event(): + """Create a mock invoice.payment_failed event.""" + event = MagicMock() + event.id = "evt_paymentfailed123" + event.type = "invoice.payment_failed" + event.data.object.id = "in_test123" + event.data.object.customer = "cus_test123" + event.data.object.last_payment_error = {"message": "Card declined"} + return event