feat: add subscription and billing system with Stripe integration
- Add database models for subscription tiers, vendor subscriptions, add-ons, billing history, and webhook events - Implement BillingService for subscription operations - Implement StripeService for Stripe API operations - Implement StripeWebhookHandler for webhook event processing - Add vendor billing API endpoints for subscription management - Create vendor billing page with Alpine.js frontend - Add limit enforcement for products and team members - Add billing exceptions for proper error handling - Create comprehensive unit tests (40 tests passing) - Add subscription billing documentation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
370
app/services/billing_service.py
Normal file
370
app/services/billing_service.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# app/services/billing_service.py
|
||||
"""
|
||||
Billing service for subscription and payment operations.
|
||||
|
||||
Provides:
|
||||
- Subscription status and usage queries
|
||||
- Tier management
|
||||
- Invoice history
|
||||
- Add-on management
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.stripe_service import stripe_service
|
||||
from app.services.subscription_service import subscription_service
|
||||
from models.database.subscription import (
|
||||
AddOnProduct,
|
||||
BillingHistory,
|
||||
SubscriptionTier,
|
||||
VendorAddOn,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Base exception for billing service errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PaymentSystemNotConfiguredError(BillingServiceError):
|
||||
"""Raised when Stripe is not configured."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Payment system not configured")
|
||||
|
||||
|
||||
class TierNotFoundError(BillingServiceError):
|
||||
"""Raised when a tier is not found."""
|
||||
|
||||
def __init__(self, tier_code: str):
|
||||
self.tier_code = tier_code
|
||||
super().__init__(f"Tier '{tier_code}' not found")
|
||||
|
||||
|
||||
class StripePriceNotConfiguredError(BillingServiceError):
|
||||
"""Raised when Stripe price is not configured for a tier."""
|
||||
|
||||
def __init__(self, tier_code: str):
|
||||
self.tier_code = tier_code
|
||||
super().__init__(f"Stripe price not configured for tier '{tier_code}'")
|
||||
|
||||
|
||||
class NoActiveSubscriptionError(BillingServiceError):
|
||||
"""Raised when no active subscription exists."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("No active subscription found")
|
||||
|
||||
|
||||
class SubscriptionNotCancelledError(BillingServiceError):
|
||||
"""Raised when trying to reactivate a non-cancelled subscription."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Subscription is not cancelled")
|
||||
|
||||
|
||||
class BillingService:
|
||||
"""Service for billing operations."""
|
||||
|
||||
def get_subscription_with_tier(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[VendorSubscription, SubscriptionTier | None]:
|
||||
"""
|
||||
Get subscription and its tier info.
|
||||
|
||||
Returns:
|
||||
Tuple of (subscription, tier) where tier may be None
|
||||
"""
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == subscription.tier)
|
||||
.first()
|
||||
)
|
||||
|
||||
return subscription, tier
|
||||
|
||||
def get_available_tiers(
|
||||
self, db: Session, current_tier: str
|
||||
) -> tuple[list[dict], dict[str, int]]:
|
||||
"""
|
||||
Get all available tiers with upgrade/downgrade flags.
|
||||
|
||||
Returns:
|
||||
Tuple of (tier_list, tier_order_map)
|
||||
"""
|
||||
tiers = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
SubscriptionTier.is_public == True, # noqa: E712
|
||||
)
|
||||
.order_by(SubscriptionTier.display_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
tier_order = {t.code: t.display_order for t in tiers}
|
||||
current_order = tier_order.get(current_tier, 0)
|
||||
|
||||
tier_list = []
|
||||
for tier in tiers:
|
||||
tier_list.append({
|
||||
"code": tier.code,
|
||||
"name": tier.name,
|
||||
"description": tier.description,
|
||||
"price_monthly_cents": tier.price_monthly_cents,
|
||||
"price_annual_cents": tier.price_annual_cents,
|
||||
"orders_per_month": tier.orders_per_month,
|
||||
"products_limit": tier.products_limit,
|
||||
"team_members": tier.team_members,
|
||||
"features": tier.features or [],
|
||||
"is_current": tier.code == current_tier,
|
||||
"can_upgrade": tier.display_order > current_order,
|
||||
"can_downgrade": tier.display_order < current_order,
|
||||
})
|
||||
|
||||
return tier_list, tier_order
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||
"""
|
||||
Get a tier by its code.
|
||||
|
||||
Raises:
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
"""
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.code == tier_code,
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tier:
|
||||
raise TierNotFoundError(tier_code)
|
||||
|
||||
return tier
|
||||
|
||||
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""
|
||||
Get vendor by ID.
|
||||
|
||||
Raises:
|
||||
VendorNotFoundException from app.exceptions
|
||||
"""
|
||||
from app.exceptions import VendorNotFoundException
|
||||
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
|
||||
|
||||
return vendor
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
tier_code: str,
|
||||
is_annual: bool,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe checkout session.
|
||||
|
||||
Returns:
|
||||
Dict with checkout_url and session_id
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
StripePriceNotConfiguredError: If price not configured
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
vendor = self.get_vendor(db, vendor_id)
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
price_id = (
|
||||
tier.stripe_price_annual_id
|
||||
if is_annual and tier.stripe_price_annual_id
|
||||
else tier.stripe_price_monthly_id
|
||||
)
|
||||
|
||||
if not price_id:
|
||||
raise StripePriceNotConfiguredError(tier_code)
|
||||
|
||||
# Check if this is a new subscription (for trial)
|
||||
existing_sub = subscription_service.get_subscription(db, vendor_id)
|
||||
trial_days = None
|
||||
if not existing_sub or not existing_sub.stripe_subscription_id:
|
||||
from app.core.config import settings
|
||||
trial_days = settings.stripe_trial_days
|
||||
|
||||
session = stripe_service.create_checkout_session(
|
||||
db=db,
|
||||
vendor=vendor,
|
||||
price_id=price_id,
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
trial_days=trial_days,
|
||||
)
|
||||
|
||||
# Update subscription with tier info
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
subscription.tier = tier_code
|
||||
subscription.is_annual = is_annual
|
||||
|
||||
return {
|
||||
"checkout_url": session.url,
|
||||
"session_id": session.id,
|
||||
}
|
||||
|
||||
def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session.
|
||||
|
||||
Returns:
|
||||
Dict with portal_url
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
NoActiveSubscriptionError: If no subscription with customer ID
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_customer_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
session = stripe_service.create_portal_session(
|
||||
customer_id=subscription.stripe_customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
|
||||
return {"portal_url": session.url}
|
||||
|
||||
def get_invoices(
|
||||
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20
|
||||
) -> tuple[list[BillingHistory], int]:
|
||||
"""
|
||||
Get invoice history for a vendor.
|
||||
|
||||
Returns:
|
||||
Tuple of (invoices, total_count)
|
||||
"""
|
||||
query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id)
|
||||
|
||||
total = query.count()
|
||||
|
||||
invoices = (
|
||||
query.order_by(BillingHistory.invoice_date.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return invoices, total
|
||||
|
||||
def get_available_addons(
|
||||
self, db: Session, category: str | None = None
|
||||
) -> list[AddOnProduct]:
|
||||
"""Get available add-on products."""
|
||||
query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712
|
||||
|
||||
if category:
|
||||
query = query.filter(AddOnProduct.category == category)
|
||||
|
||||
return query.order_by(AddOnProduct.display_order).all()
|
||||
|
||||
def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]:
|
||||
"""Get vendor's purchased add-ons."""
|
||||
return (
|
||||
db.query(VendorAddOn)
|
||||
.filter(VendorAddOn.vendor_id == vendor_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
def cancel_subscription(
|
||||
self, db: Session, vendor_id: int, reason: str | None, immediately: bool
|
||||
) -> dict:
|
||||
"""
|
||||
Cancel a subscription.
|
||||
|
||||
Returns:
|
||||
Dict with message and effective_date
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription to cancel
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.cancel_subscription(
|
||||
subscription_id=subscription.stripe_subscription_id,
|
||||
immediately=immediately,
|
||||
cancellation_reason=reason,
|
||||
)
|
||||
|
||||
subscription.cancelled_at = datetime.utcnow()
|
||||
subscription.cancellation_reason = reason
|
||||
|
||||
effective_date = (
|
||||
datetime.utcnow().isoformat()
|
||||
if immediately
|
||||
else subscription.period_end.isoformat()
|
||||
if subscription.period_end
|
||||
else datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Subscription cancelled successfully",
|
||||
"effective_date": effective_date,
|
||||
}
|
||||
|
||||
def reactivate_subscription(self, db: Session, vendor_id: int) -> dict:
|
||||
"""
|
||||
Reactivate a cancelled subscription.
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription
|
||||
SubscriptionNotCancelledError: If not cancelled
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if not subscription.cancelled_at:
|
||||
raise SubscriptionNotCancelledError()
|
||||
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.reactivate_subscription(subscription.stripe_subscription_id)
|
||||
|
||||
subscription.cancelled_at = None
|
||||
subscription.cancellation_reason = None
|
||||
|
||||
return {"message": "Subscription reactivated successfully"}
|
||||
|
||||
|
||||
# Create service instance
|
||||
billing_service = BillingService()
|
||||
@@ -865,12 +865,42 @@ class MarketplaceProductService:
|
||||
if not marketplace_products:
|
||||
raise MarketplaceProductNotFoundException("No marketplace products found")
|
||||
|
||||
# Check product limit from subscription
|
||||
from app.services.subscription_service import subscription_service
|
||||
from sqlalchemy import func
|
||||
|
||||
current_products = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
products_limit = subscription.products_limit
|
||||
remaining_capacity = (
|
||||
products_limit - current_products if products_limit is not None else None
|
||||
)
|
||||
|
||||
copied = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
limit_reached = False
|
||||
details = []
|
||||
|
||||
for mp in marketplace_products:
|
||||
# Check if we've hit the product limit
|
||||
if remaining_capacity is not None and copied >= remaining_capacity:
|
||||
limit_reached = True
|
||||
details.append(
|
||||
{
|
||||
"id": mp.id,
|
||||
"status": "skipped",
|
||||
"reason": "Product limit reached",
|
||||
}
|
||||
)
|
||||
skipped += 1
|
||||
continue
|
||||
try:
|
||||
existing = (
|
||||
db.query(Product)
|
||||
@@ -994,6 +1024,7 @@ class MarketplaceProductService:
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
"auto_matched": auto_matched,
|
||||
"limit_reached": limit_reached,
|
||||
"details": details if len(details) <= 100 else None,
|
||||
}
|
||||
|
||||
|
||||
459
app/services/stripe_service.py
Normal file
459
app/services/stripe_service.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# app/services/stripe_service.py
|
||||
"""
|
||||
Stripe payment integration service.
|
||||
|
||||
Provides:
|
||||
- Customer management
|
||||
- Subscription management
|
||||
- Checkout session creation
|
||||
- Customer portal access
|
||||
- Webhook event construction
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import stripe
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from models.database.subscription import (
|
||||
BillingHistory,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Service for Stripe payment operations."""
|
||||
|
||||
def __init__(self):
|
||||
self._configured = False
|
||||
self._configure()
|
||||
|
||||
def _configure(self):
|
||||
"""Configure Stripe with API key."""
|
||||
if settings.stripe_secret_key:
|
||||
stripe.api_key = settings.stripe_secret_key
|
||||
self._configured = True
|
||||
else:
|
||||
logger.warning("Stripe API key not configured")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Stripe is properly configured."""
|
||||
return self._configured and bool(settings.stripe_secret_key)
|
||||
|
||||
# =========================================================================
|
||||
# Customer Management
|
||||
# =========================================================================
|
||||
|
||||
def create_customer(
|
||||
self,
|
||||
vendor: Vendor,
|
||||
email: str,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Stripe customer for a vendor.
|
||||
|
||||
Returns the Stripe customer ID.
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
customer_metadata = {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
email=email,
|
||||
name=name or vendor.name,
|
||||
metadata=customer_metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created Stripe customer {customer.id} for vendor {vendor.vendor_code}"
|
||||
)
|
||||
return customer.id
|
||||
|
||||
def get_customer(self, customer_id: str) -> stripe.Customer:
|
||||
"""Get a Stripe customer by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Customer.retrieve(customer_id)
|
||||
|
||||
def update_customer(
|
||||
self,
|
||||
customer_id: str,
|
||||
email: str | None = None,
|
||||
name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Customer:
|
||||
"""Update a Stripe customer."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
update_data = {}
|
||||
if email:
|
||||
update_data["email"] = email
|
||||
if name:
|
||||
update_data["name"] = name
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
return stripe.Customer.modify(customer_id, **update_data)
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Management
|
||||
# =========================================================================
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
trial_days: int | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Create a new Stripe subscription.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
price_id: Stripe price ID for the subscription
|
||||
trial_days: Optional trial period in days
|
||||
metadata: Optional metadata to attach
|
||||
|
||||
Returns:
|
||||
Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
subscription_data = {
|
||||
"customer": customer_id,
|
||||
"items": [{"price": price_id}],
|
||||
"metadata": metadata or {},
|
||||
"payment_behavior": "default_incomplete",
|
||||
"expand": ["latest_invoice.payment_intent"],
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
subscription_data["trial_period_days"] = trial_days
|
||||
|
||||
subscription = stripe.Subscription.create(**subscription_data)
|
||||
logger.info(
|
||||
f"Created Stripe subscription {subscription.id} for customer {customer_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""Get a Stripe subscription by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Subscription.retrieve(subscription_id)
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
new_price_id: str | None = None,
|
||||
proration_behavior: str = "create_prorations",
|
||||
metadata: dict | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Update a Stripe subscription (e.g., change tier).
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
new_price_id: New price ID for tier change
|
||||
proration_behavior: How to handle prorations
|
||||
metadata: Optional metadata to update
|
||||
|
||||
Returns:
|
||||
Updated Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
update_data = {"proration_behavior": proration_behavior}
|
||||
|
||||
if new_price_id:
|
||||
# Get the subscription to find the item ID
|
||||
subscription = stripe.Subscription.retrieve(subscription_id)
|
||||
item_id = subscription["items"]["data"][0]["id"]
|
||||
update_data["items"] = [{"id": item_id, "price": new_price_id}]
|
||||
|
||||
if metadata:
|
||||
update_data["metadata"] = metadata
|
||||
|
||||
updated = stripe.Subscription.modify(subscription_id, **update_data)
|
||||
logger.info(f"Updated Stripe subscription {subscription_id}")
|
||||
return updated
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
immediately: bool = False,
|
||||
cancellation_reason: str | None = None,
|
||||
) -> stripe.Subscription:
|
||||
"""
|
||||
Cancel a Stripe subscription.
|
||||
|
||||
Args:
|
||||
subscription_id: Stripe subscription ID
|
||||
immediately: If True, cancel now. If False, cancel at period end.
|
||||
cancellation_reason: Optional reason for cancellation
|
||||
|
||||
Returns:
|
||||
Cancelled Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
if immediately:
|
||||
subscription = stripe.Subscription.cancel(subscription_id)
|
||||
else:
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=True,
|
||||
metadata={"cancellation_reason": cancellation_reason or "user_request"},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Cancelled Stripe subscription {subscription_id} "
|
||||
f"(immediately={immediately})"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def reactivate_subscription(self, subscription_id: str) -> stripe.Subscription:
|
||||
"""
|
||||
Reactivate a cancelled subscription (if not yet ended).
|
||||
|
||||
Returns:
|
||||
Reactivated Stripe Subscription object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
subscription = stripe.Subscription.modify(
|
||||
subscription_id,
|
||||
cancel_at_period_end=False,
|
||||
)
|
||||
logger.info(f"Reactivated Stripe subscription {subscription_id}")
|
||||
return subscription
|
||||
|
||||
# =========================================================================
|
||||
# Checkout & Portal
|
||||
# =========================================================================
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor,
|
||||
price_id: str,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
trial_days: int | None = None,
|
||||
) -> stripe.checkout.Session:
|
||||
"""
|
||||
Create a Stripe Checkout session for subscription signup.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor: Vendor to create checkout for
|
||||
price_id: Stripe price ID
|
||||
success_url: URL to redirect on success
|
||||
cancel_url: URL to redirect on cancel
|
||||
trial_days: Optional trial period
|
||||
|
||||
Returns:
|
||||
Stripe Checkout Session object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
# Get or create Stripe customer
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription and subscription.stripe_customer_id:
|
||||
customer_id = subscription.stripe_customer_id
|
||||
else:
|
||||
# Get vendor owner email
|
||||
from models.database.vendor import VendorUser
|
||||
|
||||
owner = (
|
||||
db.query(VendorUser)
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor.id,
|
||||
VendorUser.is_owner == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
email = owner.user.email if owner and owner.user else None
|
||||
|
||||
customer_id = self.create_customer(vendor, email or f"{vendor.vendor_code}@placeholder.com")
|
||||
|
||||
# Store the customer ID
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer_id
|
||||
db.flush()
|
||||
|
||||
session_data = {
|
||||
"customer": customer_id,
|
||||
"line_items": [{"price": price_id, "quantity": 1}],
|
||||
"mode": "subscription",
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"metadata": {
|
||||
"vendor_id": str(vendor.id),
|
||||
"vendor_code": vendor.vendor_code,
|
||||
},
|
||||
}
|
||||
|
||||
if trial_days:
|
||||
session_data["subscription_data"] = {"trial_period_days": trial_days}
|
||||
|
||||
session = stripe.checkout.Session.create(**session_data)
|
||||
logger.info(f"Created checkout session {session.id} for vendor {vendor.vendor_code}")
|
||||
return session
|
||||
|
||||
def create_portal_session(
|
||||
self,
|
||||
customer_id: str,
|
||||
return_url: str,
|
||||
) -> stripe.billing_portal.Session:
|
||||
"""
|
||||
Create a Stripe Customer Portal session.
|
||||
|
||||
Allows customers to manage their subscription, payment methods, and invoices.
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
return_url: URL to return to after portal
|
||||
|
||||
Returns:
|
||||
Stripe Portal Session object
|
||||
"""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
session = stripe.billing_portal.Session.create(
|
||||
customer=customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
logger.info(f"Created portal session for customer {customer_id}")
|
||||
return session
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Management
|
||||
# =========================================================================
|
||||
|
||||
def get_invoices(
|
||||
self,
|
||||
customer_id: str,
|
||||
limit: int = 10,
|
||||
) -> list[stripe.Invoice]:
|
||||
"""Get invoices for a customer."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
invoices = stripe.Invoice.list(customer=customer_id, limit=limit)
|
||||
return list(invoices.data)
|
||||
|
||||
def get_upcoming_invoice(self, customer_id: str) -> stripe.Invoice | None:
|
||||
"""Get the upcoming invoice for a customer."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
try:
|
||||
return stripe.Invoice.upcoming(customer=customer_id)
|
||||
except stripe.error.InvalidRequestError:
|
||||
# No upcoming invoice
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Webhook Handling
|
||||
# =========================================================================
|
||||
|
||||
def construct_event(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
) -> stripe.Event:
|
||||
"""
|
||||
Construct and verify a Stripe webhook event.
|
||||
|
||||
Args:
|
||||
payload: Raw request body
|
||||
sig_header: Stripe-Signature header
|
||||
|
||||
Returns:
|
||||
Verified Stripe Event object
|
||||
|
||||
Raises:
|
||||
ValueError: If signature verification fails
|
||||
"""
|
||||
if not settings.stripe_webhook_secret:
|
||||
raise ValueError("Stripe webhook secret not configured")
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload,
|
||||
sig_header,
|
||||
settings.stripe_webhook_secret,
|
||||
)
|
||||
return event
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
logger.error(f"Webhook signature verification failed: {e}")
|
||||
raise ValueError("Invalid webhook signature")
|
||||
|
||||
# =========================================================================
|
||||
# Price/Product Management
|
||||
# =========================================================================
|
||||
|
||||
def get_price(self, price_id: str) -> stripe.Price:
|
||||
"""Get a Stripe price by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Price.retrieve(price_id)
|
||||
|
||||
def get_product(self, product_id: str) -> stripe.Product:
|
||||
"""Get a Stripe product by ID."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
return stripe.Product.retrieve(product_id)
|
||||
|
||||
def list_prices(
|
||||
self,
|
||||
product_id: str | None = None,
|
||||
active: bool = True,
|
||||
) -> list[stripe.Price]:
|
||||
"""List Stripe prices, optionally filtered by product."""
|
||||
if not self.is_configured:
|
||||
raise ValueError("Stripe is not configured")
|
||||
|
||||
params = {"active": active}
|
||||
if product_id:
|
||||
params["product"] = product_id
|
||||
|
||||
prices = stripe.Price.list(**params)
|
||||
return list(prices.data)
|
||||
|
||||
|
||||
# Create service instance
|
||||
stripe_service = StripeService()
|
||||
411
app/services/stripe_webhook_handler.py
Normal file
411
app/services/stripe_webhook_handler.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# app/services/stripe_webhook_handler.py
|
||||
"""
|
||||
Stripe webhook event handler.
|
||||
|
||||
Processes webhook events from Stripe:
|
||||
- Subscription lifecycle events
|
||||
- Invoice and payment events
|
||||
- Checkout session completion
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import stripe
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.subscription import (
|
||||
BillingHistory,
|
||||
StripeWebhookEvent,
|
||||
SubscriptionStatus,
|
||||
SubscriptionTier,
|
||||
VendorSubscription,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StripeWebhookHandler:
|
||||
"""Handler for Stripe webhook events."""
|
||||
|
||||
def __init__(self):
|
||||
self.handlers = {
|
||||
"checkout.session.completed": self._handle_checkout_completed,
|
||||
"customer.subscription.created": self._handle_subscription_created,
|
||||
"customer.subscription.updated": self._handle_subscription_updated,
|
||||
"customer.subscription.deleted": self._handle_subscription_deleted,
|
||||
"invoice.paid": self._handle_invoice_paid,
|
||||
"invoice.payment_failed": self._handle_payment_failed,
|
||||
"invoice.finalized": self._handle_invoice_finalized,
|
||||
}
|
||||
|
||||
def handle_event(self, db: Session, event: stripe.Event) -> dict:
|
||||
"""
|
||||
Process a Stripe webhook event.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
event: Stripe Event object
|
||||
|
||||
Returns:
|
||||
Dict with processing result
|
||||
"""
|
||||
event_id = event.id
|
||||
event_type = event.type
|
||||
|
||||
# Check for duplicate processing (idempotency)
|
||||
existing = (
|
||||
db.query(StripeWebhookEvent)
|
||||
.filter(StripeWebhookEvent.event_id == event_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
if existing.status == "processed":
|
||||
logger.info(f"Skipping duplicate event {event_id}")
|
||||
return {"status": "skipped", "reason": "duplicate"}
|
||||
elif existing.status == "failed":
|
||||
logger.info(f"Retrying previously failed event {event_id}")
|
||||
else:
|
||||
# Record the event
|
||||
webhook_event = StripeWebhookEvent(
|
||||
event_id=event_id,
|
||||
event_type=event_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(webhook_event)
|
||||
db.flush()
|
||||
existing = webhook_event
|
||||
|
||||
# Process the event
|
||||
handler = self.handlers.get(event_type)
|
||||
if not handler:
|
||||
logger.debug(f"No handler for event type {event_type}")
|
||||
existing.status = "processed"
|
||||
existing.processed_at = datetime.now(timezone.utc)
|
||||
db.commit() # noqa: SVC-006 - Webhook handler controls its own transaction
|
||||
return {"status": "ignored", "reason": f"no handler for {event_type}"}
|
||||
|
||||
try:
|
||||
result = handler(db, event)
|
||||
existing.status = "processed"
|
||||
existing.processed_at = datetime.now(timezone.utc)
|
||||
db.commit() # noqa: SVC-006 - Webhook handler controls its own transaction
|
||||
logger.info(f"Successfully processed event {event_id} ({event_type})")
|
||||
return {"status": "processed", "result": result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing event {event_id}: {e}")
|
||||
existing.status = "failed"
|
||||
existing.error_message = str(e)
|
||||
db.commit() # noqa: SVC-006 - Webhook handler controls its own transaction
|
||||
raise
|
||||
|
||||
# =========================================================================
|
||||
# Event Handlers
|
||||
# =========================================================================
|
||||
|
||||
def _handle_checkout_completed(
|
||||
self, db: Session, event: stripe.Event
|
||||
) -> dict:
|
||||
"""Handle checkout.session.completed event."""
|
||||
session = event.data.object
|
||||
vendor_id = session.metadata.get("vendor_id")
|
||||
|
||||
if not vendor_id:
|
||||
logger.warning(f"Checkout session {session.id} missing vendor_id")
|
||||
return {"action": "skipped", "reason": "no vendor_id"}
|
||||
|
||||
vendor_id = int(vendor_id)
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
logger.warning(f"No subscription found for vendor {vendor_id}")
|
||||
return {"action": "skipped", "reason": "no subscription"}
|
||||
|
||||
# Update subscription with Stripe IDs
|
||||
subscription.stripe_customer_id = session.customer
|
||||
subscription.stripe_subscription_id = session.subscription
|
||||
subscription.status = SubscriptionStatus.ACTIVE
|
||||
|
||||
# Get subscription details to set period dates
|
||||
if session.subscription:
|
||||
stripe_sub = stripe.Subscription.retrieve(session.subscription)
|
||||
subscription.period_start = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_start, tz=timezone.utc
|
||||
)
|
||||
subscription.period_end = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_end, tz=timezone.utc
|
||||
)
|
||||
|
||||
if stripe_sub.trial_end:
|
||||
subscription.trial_ends_at = datetime.fromtimestamp(
|
||||
stripe_sub.trial_end, tz=timezone.utc
|
||||
)
|
||||
|
||||
logger.info(f"Checkout completed for vendor {vendor_id}")
|
||||
return {"action": "activated", "vendor_id": vendor_id}
|
||||
|
||||
def _handle_subscription_created(
|
||||
self, db: Session, event: stripe.Event
|
||||
) -> dict:
|
||||
"""Handle customer.subscription.created event."""
|
||||
stripe_sub = event.data.object
|
||||
customer_id = stripe_sub.customer
|
||||
|
||||
# Find subscription by customer ID
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_customer_id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
logger.warning(f"No subscription found for customer {customer_id}")
|
||||
return {"action": "skipped", "reason": "no subscription"}
|
||||
|
||||
# Update subscription
|
||||
subscription.stripe_subscription_id = stripe_sub.id
|
||||
subscription.status = self._map_stripe_status(stripe_sub.status)
|
||||
subscription.period_start = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_start, tz=timezone.utc
|
||||
)
|
||||
subscription.period_end = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_end, tz=timezone.utc
|
||||
)
|
||||
|
||||
logger.info(f"Subscription created for vendor {subscription.vendor_id}")
|
||||
return {"action": "created", "vendor_id": subscription.vendor_id}
|
||||
|
||||
def _handle_subscription_updated(
|
||||
self, db: Session, event: stripe.Event
|
||||
) -> dict:
|
||||
"""Handle customer.subscription.updated event."""
|
||||
stripe_sub = event.data.object
|
||||
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_subscription_id == stripe_sub.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
logger.warning(f"No subscription found for {stripe_sub.id}")
|
||||
return {"action": "skipped", "reason": "no subscription"}
|
||||
|
||||
# Update status and period
|
||||
subscription.status = self._map_stripe_status(stripe_sub.status)
|
||||
subscription.period_start = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_start, tz=timezone.utc
|
||||
)
|
||||
subscription.period_end = datetime.fromtimestamp(
|
||||
stripe_sub.current_period_end, tz=timezone.utc
|
||||
)
|
||||
|
||||
# Handle cancellation
|
||||
if stripe_sub.cancel_at_period_end:
|
||||
subscription.cancelled_at = datetime.now(timezone.utc)
|
||||
subscription.cancellation_reason = stripe_sub.metadata.get(
|
||||
"cancellation_reason", "user_request"
|
||||
)
|
||||
elif subscription.cancelled_at and not stripe_sub.cancel_at_period_end:
|
||||
# Subscription reactivated
|
||||
subscription.cancelled_at = None
|
||||
subscription.cancellation_reason = None
|
||||
|
||||
# Check for tier change via price
|
||||
if stripe_sub.items.data:
|
||||
new_price_id = stripe_sub.items.data[0].price.id
|
||||
if subscription.stripe_price_id != new_price_id:
|
||||
# Price changed, look up new tier
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.stripe_price_monthly_id == new_price_id)
|
||||
.first()
|
||||
)
|
||||
if tier:
|
||||
subscription.tier = tier.code
|
||||
logger.info(
|
||||
f"Tier changed to {tier.code} for vendor {subscription.vendor_id}"
|
||||
)
|
||||
subscription.stripe_price_id = new_price_id
|
||||
|
||||
logger.info(f"Subscription updated for vendor {subscription.vendor_id}")
|
||||
return {"action": "updated", "vendor_id": subscription.vendor_id}
|
||||
|
||||
def _handle_subscription_deleted(
|
||||
self, db: Session, event: stripe.Event
|
||||
) -> dict:
|
||||
"""Handle customer.subscription.deleted event."""
|
||||
stripe_sub = event.data.object
|
||||
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_subscription_id == stripe_sub.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
logger.warning(f"No subscription found for {stripe_sub.id}")
|
||||
return {"action": "skipped", "reason": "no subscription"}
|
||||
|
||||
subscription.status = SubscriptionStatus.CANCELLED
|
||||
subscription.cancelled_at = datetime.now(timezone.utc)
|
||||
|
||||
logger.info(f"Subscription deleted for vendor {subscription.vendor_id}")
|
||||
return {"action": "cancelled", "vendor_id": subscription.vendor_id}
|
||||
|
||||
def _handle_invoice_paid(self, db: Session, event: stripe.Event) -> dict:
|
||||
"""Handle invoice.paid event."""
|
||||
invoice = event.data.object
|
||||
customer_id = invoice.customer
|
||||
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_customer_id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
logger.warning(f"No subscription found for customer {customer_id}")
|
||||
return {"action": "skipped", "reason": "no subscription"}
|
||||
|
||||
# Record billing history
|
||||
billing_record = BillingHistory(
|
||||
vendor_id=subscription.vendor_id,
|
||||
stripe_invoice_id=invoice.id,
|
||||
stripe_payment_intent_id=invoice.payment_intent,
|
||||
invoice_number=invoice.number,
|
||||
invoice_date=datetime.fromtimestamp(invoice.created, tz=timezone.utc),
|
||||
subtotal_cents=invoice.subtotal,
|
||||
tax_cents=invoice.tax or 0,
|
||||
total_cents=invoice.total,
|
||||
amount_paid_cents=invoice.amount_paid,
|
||||
currency=invoice.currency.upper(),
|
||||
status="paid",
|
||||
invoice_pdf_url=invoice.invoice_pdf,
|
||||
hosted_invoice_url=invoice.hosted_invoice_url,
|
||||
)
|
||||
db.add(billing_record)
|
||||
|
||||
# Reset payment retry count on successful payment
|
||||
subscription.payment_retry_count = 0
|
||||
subscription.last_payment_error = None
|
||||
|
||||
# Reset period counters if this is a new billing cycle
|
||||
if subscription.status == SubscriptionStatus.ACTIVE:
|
||||
subscription.orders_this_period = 0
|
||||
subscription.orders_limit_reached_at = None
|
||||
|
||||
logger.info(f"Invoice paid for vendor {subscription.vendor_id}")
|
||||
return {
|
||||
"action": "recorded",
|
||||
"vendor_id": subscription.vendor_id,
|
||||
"invoice_id": invoice.id,
|
||||
}
|
||||
|
||||
def _handle_payment_failed(self, db: Session, event: stripe.Event) -> dict:
|
||||
"""Handle invoice.payment_failed event."""
|
||||
invoice = event.data.object
|
||||
customer_id = invoice.customer
|
||||
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_customer_id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
logger.warning(f"No subscription found for customer {customer_id}")
|
||||
return {"action": "skipped", "reason": "no subscription"}
|
||||
|
||||
# Update subscription status
|
||||
subscription.status = SubscriptionStatus.PAST_DUE
|
||||
subscription.payment_retry_count = (subscription.payment_retry_count or 0) + 1
|
||||
|
||||
# Store error message
|
||||
if invoice.last_payment_error:
|
||||
subscription.last_payment_error = invoice.last_payment_error.get("message")
|
||||
|
||||
logger.warning(
|
||||
f"Payment failed for vendor {subscription.vendor_id} "
|
||||
f"(retry #{subscription.payment_retry_count})"
|
||||
)
|
||||
return {
|
||||
"action": "marked_past_due",
|
||||
"vendor_id": subscription.vendor_id,
|
||||
"retry_count": subscription.payment_retry_count,
|
||||
}
|
||||
|
||||
def _handle_invoice_finalized(
|
||||
self, db: Session, event: stripe.Event
|
||||
) -> dict:
|
||||
"""Handle invoice.finalized event."""
|
||||
invoice = event.data.object
|
||||
customer_id = invoice.customer
|
||||
|
||||
subscription = (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.stripe_customer_id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
return {"action": "skipped", "reason": "no subscription"}
|
||||
|
||||
# Check if we already have this invoice
|
||||
existing = (
|
||||
db.query(BillingHistory)
|
||||
.filter(BillingHistory.stripe_invoice_id == invoice.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
return {"action": "skipped", "reason": "already recorded"}
|
||||
|
||||
# Record as pending invoice
|
||||
billing_record = BillingHistory(
|
||||
vendor_id=subscription.vendor_id,
|
||||
stripe_invoice_id=invoice.id,
|
||||
invoice_number=invoice.number,
|
||||
invoice_date=datetime.fromtimestamp(invoice.created, tz=timezone.utc),
|
||||
due_date=datetime.fromtimestamp(invoice.due_date, tz=timezone.utc)
|
||||
if invoice.due_date
|
||||
else None,
|
||||
subtotal_cents=invoice.subtotal,
|
||||
tax_cents=invoice.tax or 0,
|
||||
total_cents=invoice.total,
|
||||
amount_paid_cents=0,
|
||||
currency=invoice.currency.upper(),
|
||||
status="open",
|
||||
invoice_pdf_url=invoice.invoice_pdf,
|
||||
hosted_invoice_url=invoice.hosted_invoice_url,
|
||||
)
|
||||
db.add(billing_record)
|
||||
|
||||
return {"action": "recorded_pending", "vendor_id": subscription.vendor_id}
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
def _map_stripe_status(self, stripe_status: str) -> SubscriptionStatus:
|
||||
"""Map Stripe subscription status to internal status."""
|
||||
status_map = {
|
||||
"active": SubscriptionStatus.ACTIVE,
|
||||
"trialing": SubscriptionStatus.TRIAL,
|
||||
"past_due": SubscriptionStatus.PAST_DUE,
|
||||
"canceled": SubscriptionStatus.CANCELLED,
|
||||
"unpaid": SubscriptionStatus.PAST_DUE,
|
||||
"incomplete": SubscriptionStatus.TRIAL, # Treat as trial until complete
|
||||
"incomplete_expired": SubscriptionStatus.EXPIRED,
|
||||
}
|
||||
return status_map.get(stripe_status, SubscriptionStatus.EXPIRED)
|
||||
|
||||
|
||||
# Create handler instance
|
||||
stripe_webhook_handler = StripeWebhookHandler()
|
||||
@@ -20,11 +20,11 @@ from app.core.permissions import get_preset_permissions
|
||||
from app.exceptions import (
|
||||
CannotRemoveOwnerException,
|
||||
InvalidInvitationTokenException,
|
||||
MaxTeamMembersReachedException,
|
||||
TeamInvitationAlreadyAcceptedException,
|
||||
TeamMemberAlreadyExistsException,
|
||||
UserNotFoundException,
|
||||
)
|
||||
from app.services.subscription_service import TierLimitExceededException
|
||||
from middleware.auth import AuthManager
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Role, Vendor, VendorUser, VendorUserType
|
||||
@@ -37,7 +37,6 @@ class VendorTeamService:
|
||||
|
||||
def __init__(self):
|
||||
self.auth_manager = AuthManager()
|
||||
self.max_team_members = 50 # Configure as needed
|
||||
|
||||
def invite_team_member(
|
||||
self,
|
||||
@@ -68,21 +67,10 @@ class VendorTeamService:
|
||||
Dict with invitation details
|
||||
"""
|
||||
try:
|
||||
# Check team size limit
|
||||
current_team_size = (
|
||||
db.query(VendorUser)
|
||||
.filter(
|
||||
VendorUser.vendor_id == vendor.id,
|
||||
VendorUser.is_active == True,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
# Check team size limit from subscription
|
||||
from app.services.subscription_service import subscription_service
|
||||
|
||||
if current_team_size >= self.max_team_members:
|
||||
raise MaxTeamMembersReachedException(
|
||||
self.max_team_members,
|
||||
vendor.vendor_code,
|
||||
)
|
||||
subscription_service.check_team_limit(db, vendor.id)
|
||||
|
||||
# Check if user already exists
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
@@ -187,7 +175,7 @@ class VendorTeamService:
|
||||
"existing_user": user.is_active,
|
||||
}
|
||||
|
||||
except (TeamMemberAlreadyExistsException, MaxTeamMembersReachedException):
|
||||
except (TeamMemberAlreadyExistsException, TierLimitExceededException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error inviting team member: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user