feat: add invoicing system and subscription tier enforcement

Phase 1 OMS implementation:

Invoicing:
- Add Invoice and VendorInvoiceSettings database models
- Full EU VAT support (27 countries, OSS, B2B reverse charge)
- Invoice PDF generation with WeasyPrint + Jinja2 templates
- Vendor invoice API endpoints for settings, creation, PDF download

Subscription Tiers:
- Add VendorSubscription model with 4 tiers (Essential/Professional/Business/Enterprise)
- Tier limit enforcement for orders, products, team members
- Feature gating based on subscription tier
- Automatic trial subscription creation for new vendors
- Integrate limit checks into order creation (direct and Letzshop sync)

Marketing:
- Update pricing documentation with 4-tier structure
- Revise back-office positioning strategy
- Update homepage with Veeqo-inspired Letzshop-focused messaging

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-24 18:15:27 +01:00
parent 4d9b816072
commit 6232bb47f6
23 changed files with 4342 additions and 241 deletions

View File

@@ -0,0 +1,164 @@
# app/services/invoice_pdf_service.py
"""
Invoice PDF generation service using WeasyPrint.
Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint.
Stores generated PDFs in the configured storage location.
"""
import logging
import os
from datetime import UTC, datetime
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
from sqlalchemy.orm import Session
from app.core.config import settings
from models.database.invoice import Invoice
logger = logging.getLogger(__name__)
# Template directory
TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices"
# PDF storage directory (relative to project root)
PDF_STORAGE_DIR = Path("storage") / "invoices"
class InvoicePDFService:
"""Service for generating invoice PDFs."""
def __init__(self):
"""Initialize the PDF service with Jinja2 environment."""
self.env = Environment(
loader=FileSystemLoader(str(TEMPLATE_DIR)),
autoescape=True,
)
def _ensure_storage_dir(self, vendor_id: int) -> Path:
"""Ensure the storage directory exists for a vendor."""
storage_path = PDF_STORAGE_DIR / str(vendor_id)
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def _get_pdf_filename(self, invoice: Invoice) -> str:
"""Generate PDF filename for an invoice."""
# Sanitize invoice number for filename
safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-")
return f"{safe_number}.pdf"
def generate_pdf(
self,
db: Session,
invoice: Invoice,
force_regenerate: bool = False,
) -> str:
"""
Generate PDF for an invoice.
Args:
db: Database session
invoice: Invoice to generate PDF for
force_regenerate: If True, regenerate even if PDF already exists
Returns:
Path to the generated PDF file
"""
# Check if PDF already exists
if invoice.pdf_path and not force_regenerate:
if Path(invoice.pdf_path).exists():
logger.debug(f"PDF already exists for invoice {invoice.invoice_number}")
return invoice.pdf_path
# Ensure storage directory exists
storage_dir = self._ensure_storage_dir(invoice.vendor_id)
pdf_filename = self._get_pdf_filename(invoice)
pdf_path = storage_dir / pdf_filename
# Render HTML template
html_content = self._render_html(invoice)
# Generate PDF using WeasyPrint
try:
from weasyprint import HTML
html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR))
html_doc.write_pdf(str(pdf_path))
logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}")
except ImportError:
logger.error("WeasyPrint not installed. Install with: pip install weasyprint")
raise RuntimeError("WeasyPrint not installed")
except Exception as e:
logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}")
raise
# Update invoice record with PDF path and timestamp
invoice.pdf_path = str(pdf_path)
invoice.pdf_generated_at = datetime.now(UTC)
db.commit()
return str(pdf_path)
def _render_html(self, invoice: Invoice) -> str:
"""Render the invoice HTML template."""
template = self.env.get_template("invoice.html")
# Prepare template context
context = {
"invoice": invoice,
"seller": invoice.seller_details,
"buyer": invoice.buyer_details,
"line_items": invoice.line_items,
"bank_details": invoice.bank_details,
"payment_terms": invoice.payment_terms,
"footer_text": invoice.footer_text,
"now": datetime.now(UTC),
}
return template.render(**context)
def get_pdf_path(self, invoice: Invoice) -> str | None:
"""Get the PDF path for an invoice if it exists."""
if invoice.pdf_path and Path(invoice.pdf_path).exists():
return invoice.pdf_path
return None
def delete_pdf(self, invoice: Invoice, db: Session) -> bool:
"""
Delete the PDF file for an invoice.
Args:
invoice: Invoice whose PDF to delete
db: Database session
Returns:
True if deleted, False if not found
"""
if not invoice.pdf_path:
return False
pdf_path = Path(invoice.pdf_path)
if pdf_path.exists():
try:
pdf_path.unlink()
logger.info(f"Deleted PDF for invoice {invoice.invoice_number}")
except Exception as e:
logger.error(f"Failed to delete PDF {pdf_path}: {e}")
return False
# Clear PDF fields
invoice.pdf_path = None
invoice.pdf_generated_at = None
db.commit()
return True
def regenerate_pdf(self, db: Session, invoice: Invoice) -> str:
"""Force regenerate PDF for an invoice."""
return self.generate_pdf(db, invoice, force_regenerate=True)
# Singleton instance
invoice_pdf_service = InvoicePDFService()

View File

@@ -0,0 +1,666 @@
# app/services/invoice_service.py
"""
Invoice service for generating and managing invoices.
Handles:
- Vendor invoice settings management
- Invoice generation from orders
- VAT calculation (Luxembourg, EU, B2B reverse charge)
- Invoice number sequencing
- PDF generation (via separate module)
VAT Logic:
- Luxembourg domestic: 17% (standard), 8% (reduced), 3% (super-reduced), 14% (intermediate)
- EU cross-border B2C with OSS: Use destination country VAT rate
- EU cross-border B2C without OSS: Use Luxembourg VAT rate (origin principle)
- EU B2B with valid VAT number: Reverse charge (0% VAT)
"""
import logging
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from app.exceptions import (
OrderNotFoundException,
ValidationException,
)
from models.database.invoice import (
Invoice,
InvoiceStatus,
VATRegime,
VendorInvoiceSettings,
)
from models.database.order import Order
from models.database.vendor import Vendor
from models.schema.invoice import (
InvoiceBuyerDetails,
InvoiceCreate,
InvoiceLineItem,
InvoiceManualCreate,
InvoiceSellerDetails,
VendorInvoiceSettingsCreate,
VendorInvoiceSettingsUpdate,
)
logger = logging.getLogger(__name__)
# EU VAT rates by country code (2024 standard rates)
EU_VAT_RATES: dict[str, Decimal] = {
"AT": Decimal("20.00"), # Austria
"BE": Decimal("21.00"), # Belgium
"BG": Decimal("20.00"), # Bulgaria
"HR": Decimal("25.00"), # Croatia
"CY": Decimal("19.00"), # Cyprus
"CZ": Decimal("21.00"), # Czech Republic
"DK": Decimal("25.00"), # Denmark
"EE": Decimal("22.00"), # Estonia
"FI": Decimal("24.00"), # Finland
"FR": Decimal("20.00"), # France
"DE": Decimal("19.00"), # Germany
"GR": Decimal("24.00"), # Greece
"HU": Decimal("27.00"), # Hungary
"IE": Decimal("23.00"), # Ireland
"IT": Decimal("22.00"), # Italy
"LV": Decimal("21.00"), # Latvia
"LT": Decimal("21.00"), # Lithuania
"LU": Decimal("17.00"), # Luxembourg (standard)
"MT": Decimal("18.00"), # Malta
"NL": Decimal("21.00"), # Netherlands
"PL": Decimal("23.00"), # Poland
"PT": Decimal("23.00"), # Portugal
"RO": Decimal("19.00"), # Romania
"SK": Decimal("20.00"), # Slovakia
"SI": Decimal("22.00"), # Slovenia
"ES": Decimal("21.00"), # Spain
"SE": Decimal("25.00"), # Sweden
}
# Luxembourg specific VAT rates
LU_VAT_RATES = {
"standard": Decimal("17.00"),
"intermediate": Decimal("14.00"),
"reduced": Decimal("8.00"),
"super_reduced": Decimal("3.00"),
}
class InvoiceNotFoundException(Exception):
"""Raised when invoice not found."""
pass
class InvoiceSettingsNotFoundException(Exception):
"""Raised when vendor invoice settings not found."""
pass
class InvoiceService:
"""Service for invoice operations."""
# =========================================================================
# VAT Calculation
# =========================================================================
def get_vat_rate_for_country(self, country_iso: str) -> Decimal:
"""Get standard VAT rate for EU country."""
return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00"))
def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str:
"""Get human-readable VAT rate label."""
country_names = {
"AT": "Austria",
"BE": "Belgium",
"BG": "Bulgaria",
"HR": "Croatia",
"CY": "Cyprus",
"CZ": "Czech Republic",
"DK": "Denmark",
"EE": "Estonia",
"FI": "Finland",
"FR": "France",
"DE": "Germany",
"GR": "Greece",
"HU": "Hungary",
"IE": "Ireland",
"IT": "Italy",
"LV": "Latvia",
"LT": "Lithuania",
"LU": "Luxembourg",
"MT": "Malta",
"NL": "Netherlands",
"PL": "Poland",
"PT": "Portugal",
"RO": "Romania",
"SK": "Slovakia",
"SI": "Slovenia",
"ES": "Spain",
"SE": "Sweden",
}
country_name = country_names.get(country_iso.upper(), country_iso)
return f"{country_name} VAT {vat_rate}%"
def determine_vat_regime(
self,
seller_country: str,
buyer_country: str,
buyer_vat_number: str | None,
seller_oss_registered: bool,
) -> tuple[VATRegime, Decimal, str | None]:
"""
Determine VAT regime and rate for invoice.
Returns: (regime, vat_rate, destination_country)
"""
seller_country = seller_country.upper()
buyer_country = buyer_country.upper()
# Same country = domestic VAT
if seller_country == buyer_country:
vat_rate = self.get_vat_rate_for_country(seller_country)
return VATRegime.DOMESTIC, vat_rate, None
# Different EU countries
if buyer_country in EU_VAT_RATES:
# B2B with valid VAT number = reverse charge
if buyer_vat_number:
return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country
# B2C cross-border
if seller_oss_registered:
# OSS: use destination country VAT
vat_rate = self.get_vat_rate_for_country(buyer_country)
return VATRegime.OSS, vat_rate, buyer_country
else:
# No OSS: use origin country VAT
vat_rate = self.get_vat_rate_for_country(seller_country)
return VATRegime.ORIGIN, vat_rate, buyer_country
# Non-EU = VAT exempt (export)
return VATRegime.EXEMPT, Decimal("0.00"), buyer_country
# =========================================================================
# Invoice Settings Management
# =========================================================================
def get_settings(
self, db: Session, vendor_id: int
) -> VendorInvoiceSettings | None:
"""Get vendor invoice settings."""
return (
db.query(VendorInvoiceSettings)
.filter(VendorInvoiceSettings.vendor_id == vendor_id)
.first()
)
def get_settings_or_raise(
self, db: Session, vendor_id: int
) -> VendorInvoiceSettings:
"""Get vendor invoice settings or raise exception."""
settings = self.get_settings(db, vendor_id)
if not settings:
raise InvoiceSettingsNotFoundException(
f"Invoice settings not configured for vendor {vendor_id}"
)
return settings
def create_settings(
self,
db: Session,
vendor_id: int,
data: VendorInvoiceSettingsCreate,
) -> VendorInvoiceSettings:
"""Create vendor invoice settings."""
# Check if settings already exist
existing = self.get_settings(db, vendor_id)
if existing:
raise ValidationException(
"Invoice settings already exist for this vendor"
)
settings = VendorInvoiceSettings(
vendor_id=vendor_id,
**data.model_dump(),
)
db.add(settings)
db.commit()
db.refresh(settings)
logger.info(f"Created invoice settings for vendor {vendor_id}")
return settings
def update_settings(
self,
db: Session,
vendor_id: int,
data: VendorInvoiceSettingsUpdate,
) -> VendorInvoiceSettings:
"""Update vendor invoice settings."""
settings = self.get_settings_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(settings, key, value)
settings.updated_at = datetime.now(UTC)
db.commit()
db.refresh(settings)
logger.info(f"Updated invoice settings for vendor {vendor_id}")
return settings
def create_settings_from_vendor(
self,
db: Session,
vendor: Vendor,
) -> VendorInvoiceSettings:
"""
Create invoice settings from vendor/company info.
Used for initial setup based on existing vendor data.
"""
company = vendor.company
settings = VendorInvoiceSettings(
vendor_id=vendor.id,
company_name=company.legal_name if company else vendor.name,
company_address=vendor.effective_business_address,
company_city=None, # Would need to parse from address
company_postal_code=None,
company_country="LU",
vat_number=vendor.effective_tax_number,
is_vat_registered=bool(vendor.effective_tax_number),
)
db.add(settings)
db.commit()
db.refresh(settings)
logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}")
return settings
# =========================================================================
# Invoice Number Generation
# =========================================================================
def _get_next_invoice_number(
self, db: Session, settings: VendorInvoiceSettings
) -> str:
"""Generate next invoice number and increment counter."""
number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding)
invoice_number = f"{settings.invoice_prefix}{number}"
# Increment counter
settings.invoice_next_number += 1
db.flush()
return invoice_number
# =========================================================================
# Invoice Creation
# =========================================================================
def create_invoice_from_order(
self,
db: Session,
vendor_id: int,
order_id: int,
notes: str | None = None,
) -> Invoice:
"""
Create an invoice from an order.
Captures snapshots of seller/buyer details and calculates VAT.
"""
# Get invoice settings
settings = self.get_settings_or_raise(db, vendor_id)
# Get order
order = (
db.query(Order)
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
.first()
)
if not order:
raise OrderNotFoundException(f"Order {order_id} not found")
# Check for existing invoice
existing = (
db.query(Invoice)
.filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id))
.first()
)
if existing:
raise ValidationException(f"Invoice already exists for order {order_id}")
# Determine VAT regime
buyer_country = order.bill_country_iso
vat_regime, vat_rate, destination_country = self.determine_vat_regime(
seller_country=settings.company_country,
buyer_country=buyer_country,
buyer_vat_number=None, # TODO: Add B2B VAT number support
seller_oss_registered=settings.is_oss_registered,
)
# Build seller details snapshot
seller_details = {
"company_name": settings.company_name,
"address": settings.company_address,
"city": settings.company_city,
"postal_code": settings.company_postal_code,
"country": settings.company_country,
"vat_number": settings.vat_number,
}
# Build buyer details snapshot
buyer_details = {
"name": f"{order.bill_first_name} {order.bill_last_name}".strip(),
"email": order.customer_email,
"address": order.bill_address_line_1,
"city": order.bill_city,
"postal_code": order.bill_postal_code,
"country": order.bill_country_iso,
"vat_number": None, # TODO: B2B support
}
if order.bill_company:
buyer_details["company"] = order.bill_company
# Build line items from order items
line_items = []
for item in order.items:
line_items.append({
"description": item.product_name,
"quantity": item.quantity,
"unit_price_cents": item.unit_price_cents,
"total_cents": item.total_price_cents,
"sku": item.product_sku,
"ean": item.gtin,
})
# Calculate amounts
subtotal_cents = sum(item["total_cents"] for item in line_items)
# Calculate VAT
if vat_rate > 0:
vat_amount_cents = int(
subtotal_cents * float(vat_rate) / 100
)
else:
vat_amount_cents = 0
total_cents = subtotal_cents + vat_amount_cents
# Get VAT label
vat_rate_label = None
if vat_rate > 0:
if destination_country:
vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate)
else:
vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate)
# Generate invoice number
invoice_number = self._get_next_invoice_number(db, settings)
# Create invoice
invoice = Invoice(
vendor_id=vendor_id,
order_id=order_id,
invoice_number=invoice_number,
invoice_date=datetime.now(UTC),
status=InvoiceStatus.DRAFT.value,
seller_details=seller_details,
buyer_details=buyer_details,
line_items=line_items,
vat_regime=vat_regime.value,
destination_country=destination_country,
vat_rate=vat_rate,
vat_rate_label=vat_rate_label,
currency=order.currency,
subtotal_cents=subtotal_cents,
vat_amount_cents=vat_amount_cents,
total_cents=total_cents,
payment_terms=settings.payment_terms,
bank_details={
"bank_name": settings.bank_name,
"iban": settings.bank_iban,
"bic": settings.bank_bic,
} if settings.bank_iban else None,
footer_text=settings.footer_text,
notes=notes,
)
db.add(invoice)
db.commit()
db.refresh(invoice)
logger.info(
f"Created invoice {invoice_number} for order {order_id} "
f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})"
)
return invoice
# =========================================================================
# Invoice Retrieval
# =========================================================================
def get_invoice(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice | None:
"""Get invoice by ID."""
return (
db.query(Invoice)
.filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id))
.first()
)
def get_invoice_or_raise(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Get invoice by ID or raise exception."""
invoice = self.get_invoice(db, vendor_id, invoice_id)
if not invoice:
raise InvoiceNotFoundException(f"Invoice {invoice_id} not found")
return invoice
def get_invoice_by_number(
self, db: Session, vendor_id: int, invoice_number: str
) -> Invoice | None:
"""Get invoice by invoice number."""
return (
db.query(Invoice)
.filter(
and_(
Invoice.invoice_number == invoice_number,
Invoice.vendor_id == vendor_id,
)
)
.first()
)
def list_invoices(
self,
db: Session,
vendor_id: int,
status: str | None = None,
page: int = 1,
per_page: int = 20,
) -> tuple[list[Invoice], int]:
"""
List invoices for vendor with pagination.
Returns: (invoices, total_count)
"""
query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id)
if status:
query = query.filter(Invoice.status == status)
# Get total count
total = query.count()
# Apply pagination and order
invoices = (
query.order_by(Invoice.invoice_date.desc())
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
return invoices, total
# =========================================================================
# Invoice Status Management
# =========================================================================
def update_status(
self,
db: Session,
vendor_id: int,
invoice_id: int,
new_status: str,
) -> Invoice:
"""Update invoice status."""
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
# Validate status transition
valid_statuses = [s.value for s in InvoiceStatus]
if new_status not in valid_statuses:
raise ValidationException(f"Invalid status: {new_status}")
# Cannot change cancelled invoices
if invoice.status == InvoiceStatus.CANCELLED.value:
raise ValidationException("Cannot change status of cancelled invoice")
invoice.status = new_status
invoice.updated_at = datetime.now(UTC)
db.commit()
db.refresh(invoice)
logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}")
return invoice
def mark_as_issued(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Mark invoice as issued."""
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value)
def mark_as_paid(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Mark invoice as paid."""
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value)
def cancel_invoice(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Cancel invoice."""
return self.update_status(
db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value
)
# =========================================================================
# Statistics
# =========================================================================
def get_invoice_stats(
self, db: Session, vendor_id: int
) -> dict[str, Any]:
"""Get invoice statistics for vendor."""
total_count = (
db.query(func.count(Invoice.id))
.filter(Invoice.vendor_id == vendor_id)
.scalar()
or 0
)
total_revenue = (
db.query(func.sum(Invoice.total_cents))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status.in_([
InvoiceStatus.ISSUED.value,
InvoiceStatus.PAID.value,
]),
)
)
.scalar()
or 0
)
draft_count = (
db.query(func.count(Invoice.id))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status == InvoiceStatus.DRAFT.value,
)
)
.scalar()
or 0
)
paid_count = (
db.query(func.count(Invoice.id))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status == InvoiceStatus.PAID.value,
)
)
.scalar()
or 0
)
return {
"total_invoices": total_count,
"total_revenue_cents": total_revenue,
"total_revenue": total_revenue / 100 if total_revenue else 0,
"draft_count": draft_count,
"paid_count": paid_count,
}
# =========================================================================
# PDF Generation
# =========================================================================
def generate_pdf(
self,
db: Session,
vendor_id: int,
invoice_id: int,
force_regenerate: bool = False,
) -> str:
"""
Generate PDF for an invoice.
Returns path to the generated PDF.
"""
from app.services.invoice_pdf_service import invoice_pdf_service
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate)
def get_pdf_path(
self,
db: Session,
vendor_id: int,
invoice_id: int,
) -> str | None:
"""Get PDF path for an invoice if it exists."""
from app.services.invoice_pdf_service import invoice_pdf_service
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
return invoice_pdf_service.get_pdf_path(invoice)
# Singleton instance
invoice_service = InvoiceService()

View File

@@ -15,6 +15,7 @@ from sqlalchemy import String, and_, func, or_
from sqlalchemy.orm import Session
from app.services.order_service import order_service as unified_order_service
from app.services.subscription_service import subscription_service
from models.database.letzshop import (
LetzshopFulfillmentQueue,
LetzshopHistoricalImportJob,
@@ -792,6 +793,7 @@ class LetzshopOrderService:
"updated": 0,
"skipped": 0,
"errors": 0,
"limit_exceeded": 0,
"products_matched": 0,
"products_not_found": 0,
"eans_processed": set(),
@@ -800,6 +802,10 @@ class LetzshopOrderService:
"error_messages": [],
}
# Get subscription usage upfront for batch efficiency
usage = subscription_service.get_usage(self.db, vendor_id)
orders_remaining = usage.orders_remaining # None = unlimited
for i, shipment in enumerate(shipments):
shipment_id = shipment.get("id")
if not shipment_id:
@@ -844,11 +850,24 @@ class LetzshopOrderService:
else:
stats["skipped"] += 1
else:
# Check tier limit before creating order
if orders_remaining is not None and orders_remaining <= 0:
stats["limit_exceeded"] += 1
stats["error_messages"].append(
f"Shipment {shipment_id}: Order limit reached"
)
continue
# Create new order using unified service
try:
self.create_order(vendor_id, shipment)
self.db.commit() # noqa: SVC-006 - background task needs incremental commits
stats["imported"] += 1
# Decrement remaining count for batch efficiency
if orders_remaining is not None:
orders_remaining -= 1
except Exception as e:
self.db.rollback() # Rollback failed order
stats["errors"] += 1

View File

@@ -31,6 +31,10 @@ from app.exceptions import (
ValidationException,
)
from app.services.order_item_exception_service import order_item_exception_service
from app.services.subscription_service import (
subscription_service,
TierLimitExceededException,
)
from app.utils.money import Money, cents_to_euros, euros_to_cents
from models.database.customer import Customer
from models.database.marketplace_product import MarketplaceProduct
@@ -271,7 +275,11 @@ class OrderService:
Raises:
ValidationException: If order data is invalid
InsufficientInventoryException: If not enough inventory
TierLimitExceededException: If vendor has reached order limit
"""
# Check tier limit before creating order
subscription_service.check_order_limit(db, vendor_id)
try:
# Get or create customer
if order_data.customer_id:
@@ -428,6 +436,9 @@ class OrderService:
db.flush()
db.refresh(order)
# Increment order count for subscription tracking
subscription_service.increment_order_count(db, vendor_id)
logger.info(
f"Order {order.order_number} created for vendor {vendor_id}, "
f"total: EUR {cents_to_euros(total_amount_cents):.2f}"
@@ -439,6 +450,7 @@ class OrderService:
ValidationException,
InsufficientInventoryException,
CustomerNotFoundException,
TierLimitExceededException,
):
raise
except Exception as e:
@@ -450,6 +462,7 @@ class OrderService:
db: Session,
vendor_id: int,
shipment_data: dict[str, Any],
skip_limit_check: bool = False,
) -> Order:
"""
Create an order from Letzshop shipment data.
@@ -462,13 +475,27 @@ class OrderService:
db: Database session
vendor_id: Vendor ID
shipment_data: Raw shipment data from Letzshop API
skip_limit_check: If True, skip tier limit check (for batch imports
that check limit upfront)
Returns:
Created Order object
Raises:
ValidationException: If product not found by GTIN
TierLimitExceededException: If vendor has reached order limit
"""
# Check tier limit before creating order (unless skipped for batch ops)
if not skip_limit_check:
can_create, message = subscription_service.can_create_order(db, vendor_id)
if not can_create:
raise TierLimitExceededException(
message=message or "Order limit exceeded",
limit_type="orders",
current=0, # Will be filled by caller if needed
limit=0,
)
order_data = shipment_data.get("order", {})
# Generate order number using Letzshop order number
@@ -777,6 +804,9 @@ class OrderService:
f"order {order.order_number}"
)
# Increment order count for subscription tracking
subscription_service.increment_order_count(db, vendor_id)
logger.info(
f"Letzshop order {order.order_number} created for vendor {vendor_id}, "
f"status: {status}, items: {len(inventory_units)}"

View File

@@ -0,0 +1,512 @@
# app/services/subscription_service.py
"""
Subscription service for tier-based access control.
Handles:
- Subscription creation and management
- Tier limit enforcement
- Usage tracking
- Feature gating
Usage:
from app.services.subscription_service import subscription_service
# Check if vendor can create an order
can_create, message = subscription_service.can_create_order(db, vendor_id)
# Increment order counter after successful order
subscription_service.increment_order_count(db, vendor_id)
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from models.database.product import Product
from models.database.subscription import (
SubscriptionStatus,
TIER_LIMITS,
TierCode,
VendorSubscription,
)
from models.database.vendor import Vendor, VendorUser
from models.schema.subscription import (
SubscriptionCreate,
SubscriptionUpdate,
SubscriptionUsage,
TierInfo,
TierLimits,
)
logger = logging.getLogger(__name__)
class SubscriptionNotFoundException(Exception):
"""Raised when subscription not found."""
pass
class TierLimitExceededException(Exception):
"""Raised when a tier limit is exceeded."""
def __init__(self, message: str, limit_type: str, current: int, limit: int):
super().__init__(message)
self.limit_type = limit_type
self.current = current
self.limit = limit
class FeatureNotAvailableException(Exception):
"""Raised when a feature is not available in current tier."""
def __init__(self, feature: str, current_tier: str, required_tier: str):
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
super().__init__(message)
self.feature = feature
self.current_tier = current_tier
self.required_tier = required_tier
class SubscriptionService:
"""Service for subscription and tier limit operations."""
# =========================================================================
# Tier Information
# =========================================================================
def get_tier_info(self, tier_code: str) -> TierInfo:
"""Get full tier information."""
try:
tier = TierCode(tier_code)
except ValueError:
tier = TierCode.ESSENTIAL
limits = TIER_LIMITS[tier]
return TierInfo(
code=tier.value,
name=limits["name"],
price_monthly_cents=limits["price_monthly_cents"],
price_annual_cents=limits.get("price_annual_cents"),
limits=TierLimits(
orders_per_month=limits.get("orders_per_month"),
products_limit=limits.get("products_limit"),
team_members=limits.get("team_members"),
order_history_months=limits.get("order_history_months"),
),
features=limits.get("features", []),
)
def get_all_tiers(self) -> list[TierInfo]:
"""Get information for all tiers."""
return [
self.get_tier_info(tier.value)
for tier in TierCode
]
# =========================================================================
# Subscription CRUD
# =========================================================================
def get_subscription(
self, db: Session, vendor_id: int
) -> VendorSubscription | None:
"""Get vendor subscription."""
return (
db.query(VendorSubscription)
.filter(VendorSubscription.vendor_id == vendor_id)
.first()
)
def get_subscription_or_raise(
self, db: Session, vendor_id: int
) -> VendorSubscription:
"""Get vendor subscription or raise exception."""
subscription = self.get_subscription(db, vendor_id)
if not subscription:
raise SubscriptionNotFoundException(
f"No subscription found for vendor {vendor_id}"
)
return subscription
def get_or_create_subscription(
self,
db: Session,
vendor_id: int,
tier: str = TierCode.ESSENTIAL.value,
trial_days: int = 14,
) -> VendorSubscription:
"""
Get existing subscription or create a new trial subscription.
Used when a vendor first accesses the system.
"""
subscription = self.get_subscription(db, vendor_id)
if subscription:
return subscription
# Create new trial subscription
now = datetime.now(UTC)
trial_end = now + timedelta(days=trial_days)
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=tier,
status=SubscriptionStatus.TRIAL.value,
period_start=now,
period_end=trial_end,
trial_ends_at=trial_end,
is_annual=False,
)
db.add(subscription)
db.commit()
db.refresh(subscription)
logger.info(
f"Created trial subscription for vendor {vendor_id} "
f"(tier={tier}, trial_ends={trial_end})"
)
return subscription
def create_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionCreate,
) -> VendorSubscription:
"""Create a subscription for a vendor."""
# Check if subscription exists
existing = self.get_subscription(db, vendor_id)
if existing:
raise ValueError("Vendor already has a subscription")
now = datetime.now(UTC)
# Calculate period end based on billing cycle
if data.is_annual:
period_end = now + timedelta(days=365)
else:
period_end = now + timedelta(days=30)
# Handle trial
trial_ends_at = None
status = SubscriptionStatus.ACTIVE.value
if data.trial_days > 0:
trial_ends_at = now + timedelta(days=data.trial_days)
status = SubscriptionStatus.TRIAL.value
period_end = trial_ends_at
subscription = VendorSubscription(
vendor_id=vendor_id,
tier=data.tier,
status=status,
period_start=now,
period_end=period_end,
trial_ends_at=trial_ends_at,
is_annual=data.is_annual,
)
db.add(subscription)
db.commit()
db.refresh(subscription)
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
return subscription
def update_subscription(
self,
db: Session,
vendor_id: int,
data: SubscriptionUpdate,
) -> VendorSubscription:
"""Update a vendor subscription."""
subscription = self.get_subscription_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(subscription, key, value)
subscription.updated_at = datetime.now(UTC)
db.commit()
db.refresh(subscription)
logger.info(f"Updated subscription for vendor {vendor_id}")
return subscription
def upgrade_tier(
self,
db: Session,
vendor_id: int,
new_tier: str,
) -> VendorSubscription:
"""Upgrade vendor to a new tier."""
subscription = self.get_subscription_or_raise(db, vendor_id)
old_tier = subscription.tier
subscription.tier = new_tier
subscription.updated_at = datetime.now(UTC)
# If upgrading from trial, mark as active
if subscription.status == SubscriptionStatus.TRIAL.value:
subscription.status = SubscriptionStatus.ACTIVE.value
db.commit()
db.refresh(subscription)
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
return subscription
def cancel_subscription(
self,
db: Session,
vendor_id: int,
reason: str | None = None,
) -> VendorSubscription:
"""Cancel a vendor subscription (access until period end)."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.status = SubscriptionStatus.CANCELLED.value
subscription.cancelled_at = datetime.now(UTC)
subscription.cancellation_reason = reason
subscription.updated_at = datetime.now(UTC)
db.commit()
db.refresh(subscription)
logger.info(f"Cancelled subscription for vendor {vendor_id}")
return subscription
# =========================================================================
# Usage Tracking
# =========================================================================
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
"""Get current subscription usage statistics."""
subscription = self.get_or_create_subscription(db, vendor_id)
# Get actual counts
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
# Calculate usage stats
orders_limit = subscription.orders_limit
products_limit = subscription.products_limit
team_limit = subscription.team_members_limit
def calc_remaining(current: int, limit: int | None) -> int | None:
if limit is None:
return None
return max(0, limit - current)
def calc_percent(current: int, limit: int | None) -> float | None:
if limit is None or limit == 0:
return None
return min(100.0, (current / limit) * 100)
return SubscriptionUsage(
orders_used=subscription.orders_this_period,
orders_limit=orders_limit,
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
products_used=products_count,
products_limit=products_limit,
products_remaining=calc_remaining(products_count, products_limit),
products_percent_used=calc_percent(products_count, products_limit),
team_members_used=team_count,
team_members_limit=team_limit,
team_members_remaining=calc_remaining(team_count, team_limit),
team_members_percent_used=calc_percent(team_count, team_limit),
)
def increment_order_count(self, db: Session, vendor_id: int) -> None:
"""
Increment the order counter for the current period.
Call this after successfully creating/importing an order.
"""
subscription = self.get_or_create_subscription(db, vendor_id)
subscription.increment_order_count()
db.commit()
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
"""Reset counters for a new billing period."""
subscription = self.get_subscription_or_raise(db, vendor_id)
subscription.reset_period_counters()
db.commit()
logger.info(f"Reset period counters for vendor {vendor_id}")
# =========================================================================
# Limit Checks
# =========================================================================
def can_create_order(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can create/import another order.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.can_create_order()
def check_order_limit(self, db: Session, vendor_id: int) -> None:
"""
Check order limit and raise exception if exceeded.
Use this in order creation flows.
"""
can_create, message = self.can_create_order(db, vendor_id)
if not can_create:
subscription = self.get_subscription(db, vendor_id)
raise TierLimitExceededException(
message=message or "Order limit exceeded",
limit_type="orders",
current=subscription.orders_this_period if subscription else 0,
limit=subscription.orders_limit if subscription else 0,
)
def can_add_product(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another product.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
return subscription.can_add_product(products_count)
def check_product_limit(self, db: Session, vendor_id: int) -> None:
"""
Check product limit and raise exception if exceeded.
Use this in product creation flows.
"""
can_add, message = self.can_add_product(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
products_count = (
db.query(func.count(Product.id))
.filter(Product.vendor_id == vendor_id)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Product limit exceeded",
limit_type="products",
current=products_count,
limit=subscription.products_limit if subscription else 0,
)
def can_add_team_member(
self, db: Session, vendor_id: int
) -> tuple[bool, str | None]:
"""
Check if vendor can add another team member.
Returns: (allowed, error_message)
"""
subscription = self.get_or_create_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
return subscription.can_add_team_member(team_count)
def check_team_limit(self, db: Session, vendor_id: int) -> None:
"""
Check team member limit and raise exception if exceeded.
Use this in team member invitation flows.
"""
can_add, message = self.can_add_team_member(db, vendor_id)
if not can_add:
subscription = self.get_subscription(db, vendor_id)
team_count = (
db.query(func.count(VendorUser.id))
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
.scalar()
or 0
)
raise TierLimitExceededException(
message=message or "Team member limit exceeded",
limit_type="team_members",
current=team_count,
limit=subscription.team_members_limit if subscription else 0,
)
# =========================================================================
# Feature Gating
# =========================================================================
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
"""Check if vendor has access to a feature."""
subscription = self.get_or_create_subscription(db, vendor_id)
return subscription.has_feature(feature)
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
"""
Check feature access and raise exception if not available.
Use this to gate premium features.
"""
if not self.has_feature(db, vendor_id, feature):
subscription = self.get_or_create_subscription(db, vendor_id)
# Find which tier has this feature
required_tier = None
for tier_code, limits in TIER_LIMITS.items():
if feature in limits.get("features", []):
required_tier = limits["name"]
break
raise FeatureNotAvailableException(
feature=feature,
current_tier=subscription.tier,
required_tier=required_tier or "higher",
)
def get_feature_tier(self, feature: str) -> str | None:
"""Get the minimum tier required for a feature."""
for tier_code in [
TierCode.ESSENTIAL,
TierCode.PROFESSIONAL,
TierCode.BUSINESS,
TierCode.ENTERPRISE,
]:
if feature in TIER_LIMITS[tier_code].get("features", []):
return tier_code.value
return None
# Singleton instance
subscription_service = SubscriptionService()