Files
orion/app/services/subscription_service.py
Samir Boulahtit 3614d448e4 chore: PostgreSQL migration compatibility and infrastructure improvements
Database & Migrations:
- Update all Alembic migrations for PostgreSQL compatibility
- Remove SQLite-specific syntax (AUTOINCREMENT, etc.)
- Add database utility helpers for PostgreSQL operations
- Fix services to use PostgreSQL-compatible queries

Documentation:
- Add comprehensive Docker deployment guide
- Add production deployment documentation
- Add infrastructure architecture documentation
- Update database setup guide for PostgreSQL-only
- Expand troubleshooting guide

Architecture & Validation:
- Add migration.yaml rules for SQL compatibility checking
- Enhance validate_architecture.py with migration validation
- Update architecture rules to validate Alembic migrations

Development:
- Fix duplicate install-all target in Makefile
- Add Celery/Redis validation to install.py script
- Add docker-compose.test.yml for CI testing
- Add squash_migrations.py utility script
- Update tests for PostgreSQL compatibility
- Improve test fixtures in conftest.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-11 17:52:28 +01:00

656 lines
22 KiB
Python

# app/services/subscription_service.py
"""
Subscription service for tier-based access control.
Handles:
- Subscription creation and management
- Tier limit enforcement
- Usage tracking
- Feature gating
Usage:
from app.services.subscription_service import subscription_service
# Check if vendor can create an order
can_create, message = subscription_service.can_create_order(db, vendor_id)
# Increment order counter after successful order
subscription_service.increment_order_count(db, vendor_id)
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from models.database.product import Product
from models.database.subscription import (
SubscriptionStatus,
SubscriptionTier,
TIER_LIMITS,
TierCode,
VendorSubscription,
)
from models.database.vendor import Vendor, VendorUser
from models.schema.subscription import (
SubscriptionCreate,
SubscriptionUpdate,
SubscriptionUsage,
TierInfo,
TierLimits,
UsageSummary,
)
logger = logging.getLogger(__name__)
class SubscriptionNotFoundException(Exception):
"""Raised when subscription not found."""
pass
class TierLimitExceededException(Exception):
"""Raised when a tier limit is exceeded."""
def __init__(self, message: str, limit_type: str, current: int, limit: int):
super().__init__(message)
self.limit_type = limit_type
self.current = current
self.limit = limit
class FeatureNotAvailableException(Exception):
"""Raised when a feature is not available in current tier."""
def __init__(self, feature: str, current_tier: str, required_tier: str):
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
super().__init__(message)
self.feature = feature
self.current_tier = current_tier
self.required_tier = required_tier
class SubscriptionService:
"""Service for subscription and tier limit operations."""
# =========================================================================
# Tier Information
# =========================================================================
def get_tier_info(self, tier_code: str, db: Session | None = None) -> TierInfo:
"""
Get full tier information.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
# Try database first if session provided
if db is not None:
db_tier = self.get_tier_by_code(db, tier_code)
if db_tier:
return TierInfo(
code=db_tier.code,
name=db_tier.name,
price_monthly_cents=db_tier.price_monthly_cents,
price_annual_cents=db_tier.price_annual_cents,
limits=TierLimits(
orders_per_month=db_tier.orders_per_month,
products_limit=db_tier.products_limit,
team_members=db_tier.team_members,
order_history_months=db_tier.order_history_months,
),
features=db_tier.features or [],
)
# Fallback to hardcoded TIER_LIMITS
return self._get_tier_from_legacy(tier_code)
def _get_tier_from_legacy(self, tier_code: str) -> TierInfo:
"""Get tier info from hardcoded TIER_LIMITS (fallback)."""
try:
tier = TierCode(tier_code)
except ValueError:
tier = TierCode.ESSENTIAL
limits = TIER_LIMITS[tier]
return TierInfo(
code=tier.value,
name=limits["name"],
price_monthly_cents=limits["price_monthly_cents"],
price_annual_cents=limits.get("price_annual_cents"),
limits=TierLimits(
orders_per_month=limits.get("orders_per_month"),
products_limit=limits.get("products_limit"),
team_members=limits.get("team_members"),
order_history_months=limits.get("order_history_months"),
),
features=limits.get("features", []),
)
def get_all_tiers(self, db: Session | None = None) -> list[TierInfo]:
"""
Get information for all tiers.
Queries database if db session provided, otherwise falls back to TIER_LIMITS.
"""
if db is not None:
db_tiers = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.is_public == True, # noqa: E712
)
.order_by(SubscriptionTier.display_order)
.all()
)
if db_tiers:
return [
TierInfo(
code=t.code,
name=t.name,
price_monthly_cents=t.price_monthly_cents,
price_annual_cents=t.price_annual_cents,
limits=TierLimits(
orders_per_month=t.orders_per_month,
products_limit=t.products_limit,
team_members=t.team_members,
order_history_months=t.order_history_months,
),
features=t.features or [],
)
for t in db_tiers
]
# Fallback to hardcoded
return [
self._get_tier_from_legacy(tier.value)
for tier in TierCode
]
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier | None:
"""Get subscription tier by code."""
return (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == tier_code)
.first()
)
def get_tier_id(self, db: Session, tier_code: str) -> int | None:
"""Get tier ID from tier code. Returns None if tier not found."""
tier = self.get_tier_by_code(db, tier_code)
return tier.id if tier else None
# =========================================================================
# Subscription CRUD
# =========================================================================
def get_subscription(
self, db: Session, vendor_id: int
) -> VendorSubscription | None:
"""Get vendor subscription."""
return (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
def get_subscription_or_raise(
self, db: Session, vendor_id: int
) -> VendorSubscription:
"""Get vendor subscription or raise exception."""
subscription = self.get_subscription(db, vendor_id)
if not subscription:
raise SubscriptionNotFoundException(
f"No subscription found for vendor {vendor_id}"
)
return subscription
def get_current_tier(
self, db: Session, vendor_id: int
) -> TierCode | None:
"""Get vendor's current subscription tier code."""
subscription = self.get_subscription(db, vendor_id)
if subscription:
try:
return TierCode(subscription.tier)
except ValueError:
return None
return None
def get_or_create_subscription(
self,
db: Session,
vendor_id: int,
tier: str = TierCode.ESSENTIAL.value,
trial_days: int = 14,
) -> VendorSubscription:
"""
Get existing subscription or create a new trial subscription.
Used when a vendor first accesses the system.
"""
subscription = self.get_subscription(db, vendor_id)
if subscription:
return subscription
# Create new trial subscription
now = datetime.now(UTC)
trial_end = now + timedelta(days=trial_days)
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=tier,
tier_id=tier_id,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=trial_end,
trial_ends_at=trial_end,
is_annual=False,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(
f"Created trial subscription for vendor {vendor_id} "
f"(tier={tier}, trial_ends={trial_end})"
)
return subscription
def create_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionCreate,
) -> VendorSubscription:
"""Create a subscription for a vendor."""
# Check if subscription exists
existing = self.get_subscription(db, vendor_id)
if existing:
raise ValueError("Vendor already has a subscription")
now = datetime.now(UTC)
# Calculate period end based on billing cycle
if data.is_annual:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
# Handle trial
trial_ends_at = None
status = SubscriptionStatus.ACTIVE.value
if data.trial_days > 0:
trial_ends_at = now + timedelta(days=data.trial_days)
status = SubscriptionStatus.TRIAL.value
period_end = trial_ends_at
# Lookup tier_id from tier code
tier_id = self.get_tier_id(db, data.tier)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=data.tier,
tier_id=tier_id,
status=status,
period_start=now,
period_end=period_end,
trial_ends_at=trial_ends_at,
is_annual=data.is_annual,
)
db.add(subscription)
db.flush()
db.refresh(subscription)
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
return subscription
def update_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionUpdate,
) -> VendorSubscription:
"""Update a vendor subscription."""
subscription = self.get_subscription_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
# If tier is being updated, also update tier_id
if "tier" in update_data:
tier_id = self.get_tier_id(db, update_data["tier"])
update_data["tier_id"] = tier_id
for key, value in update_data.items():
setattr(subscription, key, value)
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Updated subscription for vendor {vendor_id}")
return subscription
def upgrade_tier(
self,
db: Session,
vendor_id: int,
new_tier: str,
) -> VendorSubscription:
"""Upgrade vendor to a new tier."""
subscription = self.get_subscription_or_raise(db, vendor_id)
old_tier = subscription.tier
subscription.tier = new_tier
subscription.tier_id = self.get_tier_id(db, new_tier)
subscription.updated_at = datetime.now(UTC)
# If upgrading from trial, mark as active
if subscription.status == SubscriptionStatus.TRIAL.value:
subscription.status = SubscriptionStatus.ACTIVE.value
db.flush()
db.refresh(subscription)
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
return subscription
def cancel_subscription(
self,
db: Session,
vendor_id: int,
reason: str | None = None,
) -> VendorSubscription:
"""Cancel a vendor subscription (access until period end)."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.status = SubscriptionStatus.CANCELLED.value
subscription.cancelled_at = datetime.now(UTC)
subscription.cancellation_reason = reason
subscription.updated_at = datetime.now(UTC)
db.flush()
db.refresh(subscription)
logger.info(f"Cancelled subscription for vendor {vendor_id}")
return subscription
# =========================================================================
# Usage Tracking
# =========================================================================
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
"""Get current subscription usage statistics."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Calculate usage stats
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
def calc_percent(current: int, limit: int | None) -> float | None:
if limit is None or limit == 0:
return None
return min(100.0, (current / limit) * 100)
return SubscriptionUsage(
orders_used=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
products_used=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
products_percent_used=calc_percent(products_count, products_limit),
team_members_used=team_count,
team_members_limit=team_limit,
team_members_remaining=calc_remaining(team_count, team_limit),
team_members_percent_used=calc_percent(team_count, team_limit),
)
def get_usage_summary(self, db: Session, vendor_id: int) -> UsageSummary:
"""Get usage summary for billing page display."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Get limits
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
return UsageSummary(
orders_this_period=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
products_count=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
team_count=team_count,
team_limit=team_limit,
team_remaining=calc_remaining(team_count, team_limit),
)
def increment_order_count(self, db: Session, vendor_id: int) -> None:
"""
Increment the order counter for the current period.
Call this after successfully creating/importing an order.
"""
subscription = self.get_or_create_subscription(db, vendor_id)
subscription.increment_order_count()
db.flush()
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
"""Reset counters for a new billing period."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.reset_period_counters()
db.flush()
logger.info(f"Reset period counters for vendor {vendor_id}")
# =========================================================================
# Limit Checks
# =========================================================================
def can_create_order(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can create/import another order.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.can_create_order()
def check_order_limit(self, db: Session, vendor_id: int) -> None:
"""
Check order limit and raise exception if exceeded.
Use this in order creation flows.
"""
can_create, message = self.can_create_order(db, vendor_id)
if not can_create:
subscription = self.get_subscription(db, vendor_id)
raise TierLimitExceededException(
message=message or "Order limit exceeded",
limit_type="orders",
current=subscription.orders_this_period if subscription else 0,
limit=subscription.orders_limit if subscription else 0,
)
def can_add_product(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another product.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
return subscription.can_add_product(products_count)
def check_product_limit(self, db: Session, vendor_id: int) -> None:
"""
Check product limit and raise exception if exceeded.
Use this in product creation flows.
"""
can_add, message = self.can_add_product(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Product limit exceeded",
limit_type="products",
current=products_count,
limit=subscription.products_limit if subscription else 0,
)
def can_add_team_member(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another team member.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
return subscription.can_add_team_member(team_count)
def check_team_limit(self, db: Session, vendor_id: int) -> None:
"""
Check team member limit and raise exception if exceeded.
Use this in team member invitation flows.
"""
can_add, message = self.can_add_team_member(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Team member limit exceeded",
limit_type="team_members",
current=team_count,
limit=subscription.team_members_limit if subscription else 0,
)
# =========================================================================
# Feature Gating
# =========================================================================
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
"""Check if vendor has access to a feature."""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.has_feature(feature)
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
"""
Check feature access and raise exception if not available.
Use this to gate premium features.
"""
if not self.has_feature(db, vendor_id, feature):
subscription = self.get_or_create_subscription(db, vendor_id)
# Find which tier has this feature
required_tier = None
for tier_code, limits in TIER_LIMITS.items():
if feature in limits.get("features", []):
required_tier = limits["name"]
break
raise FeatureNotAvailableException(
feature=feature,
current_tier=subscription.tier,
required_tier=required_tier or "higher",
)
def get_feature_tier(self, feature: str) -> str | None:
"""Get the minimum tier required for a feature."""
for tier_code in [
TierCode.ESSENTIAL,
TierCode.PROFESSIONAL,
TierCode.BUSINESS,
TierCode.ENTERPRISE,
]:
if feature in TIER_LIMITS[tier_code].get("features", []):
return tier_code.value
return None
# Singleton instance
subscription_service = SubscriptionService()