feat: add invoicing system and subscription tier enforcement
Phase 1 OMS implementation: Invoicing: - Add Invoice and VendorInvoiceSettings database models - Full EU VAT support (27 countries, OSS, B2B reverse charge) - Invoice PDF generation with WeasyPrint + Jinja2 templates - Vendor invoice API endpoints for settings, creation, PDF download Subscription Tiers: - Add VendorSubscription model with 4 tiers (Essential/Professional/Business/Enterprise) - Tier limit enforcement for orders, products, team members - Feature gating based on subscription tier - Automatic trial subscription creation for new vendors - Integrate limit checks into order creation (direct and Letzshop sync) Marketing: - Update pricing documentation with 4-tier structure - Revise back-office positioning strategy - Update homepage with Veeqo-inspired Letzshop-focused messaging 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
164
app/services/invoice_pdf_service.py
Normal file
164
app/services/invoice_pdf_service.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# app/services/invoice_pdf_service.py
|
||||
"""
|
||||
Invoice PDF generation service using WeasyPrint.
|
||||
|
||||
Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint.
|
||||
Stores generated PDFs in the configured storage location.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from models.database.invoice import Invoice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Template directory
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices"
|
||||
|
||||
# PDF storage directory (relative to project root)
|
||||
PDF_STORAGE_DIR = Path("storage") / "invoices"
|
||||
|
||||
|
||||
class InvoicePDFService:
|
||||
"""Service for generating invoice PDFs."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the PDF service with Jinja2 environment."""
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
def _ensure_storage_dir(self, vendor_id: int) -> Path:
|
||||
"""Ensure the storage directory exists for a vendor."""
|
||||
storage_path = PDF_STORAGE_DIR / str(vendor_id)
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
return storage_path
|
||||
|
||||
def _get_pdf_filename(self, invoice: Invoice) -> str:
|
||||
"""Generate PDF filename for an invoice."""
|
||||
# Sanitize invoice number for filename
|
||||
safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-")
|
||||
return f"{safe_number}.pdf"
|
||||
|
||||
def generate_pdf(
|
||||
self,
|
||||
db: Session,
|
||||
invoice: Invoice,
|
||||
force_regenerate: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate PDF for an invoice.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
invoice: Invoice to generate PDF for
|
||||
force_regenerate: If True, regenerate even if PDF already exists
|
||||
|
||||
Returns:
|
||||
Path to the generated PDF file
|
||||
"""
|
||||
# Check if PDF already exists
|
||||
if invoice.pdf_path and not force_regenerate:
|
||||
if Path(invoice.pdf_path).exists():
|
||||
logger.debug(f"PDF already exists for invoice {invoice.invoice_number}")
|
||||
return invoice.pdf_path
|
||||
|
||||
# Ensure storage directory exists
|
||||
storage_dir = self._ensure_storage_dir(invoice.vendor_id)
|
||||
pdf_filename = self._get_pdf_filename(invoice)
|
||||
pdf_path = storage_dir / pdf_filename
|
||||
|
||||
# Render HTML template
|
||||
html_content = self._render_html(invoice)
|
||||
|
||||
# Generate PDF using WeasyPrint
|
||||
try:
|
||||
from weasyprint import HTML
|
||||
|
||||
html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR))
|
||||
html_doc.write_pdf(str(pdf_path))
|
||||
|
||||
logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}")
|
||||
except ImportError:
|
||||
logger.error("WeasyPrint not installed. Install with: pip install weasyprint")
|
||||
raise RuntimeError("WeasyPrint not installed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}")
|
||||
raise
|
||||
|
||||
# Update invoice record with PDF path and timestamp
|
||||
invoice.pdf_path = str(pdf_path)
|
||||
invoice.pdf_generated_at = datetime.now(UTC)
|
||||
db.commit()
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
def _render_html(self, invoice: Invoice) -> str:
|
||||
"""Render the invoice HTML template."""
|
||||
template = self.env.get_template("invoice.html")
|
||||
|
||||
# Prepare template context
|
||||
context = {
|
||||
"invoice": invoice,
|
||||
"seller": invoice.seller_details,
|
||||
"buyer": invoice.buyer_details,
|
||||
"line_items": invoice.line_items,
|
||||
"bank_details": invoice.bank_details,
|
||||
"payment_terms": invoice.payment_terms,
|
||||
"footer_text": invoice.footer_text,
|
||||
"now": datetime.now(UTC),
|
||||
}
|
||||
|
||||
return template.render(**context)
|
||||
|
||||
def get_pdf_path(self, invoice: Invoice) -> str | None:
|
||||
"""Get the PDF path for an invoice if it exists."""
|
||||
if invoice.pdf_path and Path(invoice.pdf_path).exists():
|
||||
return invoice.pdf_path
|
||||
return None
|
||||
|
||||
def delete_pdf(self, invoice: Invoice, db: Session) -> bool:
|
||||
"""
|
||||
Delete the PDF file for an invoice.
|
||||
|
||||
Args:
|
||||
invoice: Invoice whose PDF to delete
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
if not invoice.pdf_path:
|
||||
return False
|
||||
|
||||
pdf_path = Path(invoice.pdf_path)
|
||||
if pdf_path.exists():
|
||||
try:
|
||||
pdf_path.unlink()
|
||||
logger.info(f"Deleted PDF for invoice {invoice.invoice_number}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete PDF {pdf_path}: {e}")
|
||||
return False
|
||||
|
||||
# Clear PDF fields
|
||||
invoice.pdf_path = None
|
||||
invoice.pdf_generated_at = None
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
def regenerate_pdf(self, db: Session, invoice: Invoice) -> str:
|
||||
"""Force regenerate PDF for an invoice."""
|
||||
return self.generate_pdf(db, invoice, force_regenerate=True)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
invoice_pdf_service = InvoicePDFService()
|
||||
666
app/services/invoice_service.py
Normal file
666
app/services/invoice_service.py
Normal file
@@ -0,0 +1,666 @@
|
||||
# app/services/invoice_service.py
|
||||
"""
|
||||
Invoice service for generating and managing invoices.
|
||||
|
||||
Handles:
|
||||
- Vendor invoice settings management
|
||||
- Invoice generation from orders
|
||||
- VAT calculation (Luxembourg, EU, B2B reverse charge)
|
||||
- Invoice number sequencing
|
||||
- PDF generation (via separate module)
|
||||
|
||||
VAT Logic:
|
||||
- Luxembourg domestic: 17% (standard), 8% (reduced), 3% (super-reduced), 14% (intermediate)
|
||||
- EU cross-border B2C with OSS: Use destination country VAT rate
|
||||
- EU cross-border B2C without OSS: Use Luxembourg VAT rate (origin principle)
|
||||
- EU B2B with valid VAT number: Reverse charge (0% VAT)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
OrderNotFoundException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.database.invoice import (
|
||||
Invoice,
|
||||
InvoiceStatus,
|
||||
VATRegime,
|
||||
VendorInvoiceSettings,
|
||||
)
|
||||
from models.database.order import Order
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.invoice import (
|
||||
InvoiceBuyerDetails,
|
||||
InvoiceCreate,
|
||||
InvoiceLineItem,
|
||||
InvoiceManualCreate,
|
||||
InvoiceSellerDetails,
|
||||
VendorInvoiceSettingsCreate,
|
||||
VendorInvoiceSettingsUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# EU VAT rates by country code (2024 standard rates)
|
||||
EU_VAT_RATES: dict[str, Decimal] = {
|
||||
"AT": Decimal("20.00"), # Austria
|
||||
"BE": Decimal("21.00"), # Belgium
|
||||
"BG": Decimal("20.00"), # Bulgaria
|
||||
"HR": Decimal("25.00"), # Croatia
|
||||
"CY": Decimal("19.00"), # Cyprus
|
||||
"CZ": Decimal("21.00"), # Czech Republic
|
||||
"DK": Decimal("25.00"), # Denmark
|
||||
"EE": Decimal("22.00"), # Estonia
|
||||
"FI": Decimal("24.00"), # Finland
|
||||
"FR": Decimal("20.00"), # France
|
||||
"DE": Decimal("19.00"), # Germany
|
||||
"GR": Decimal("24.00"), # Greece
|
||||
"HU": Decimal("27.00"), # Hungary
|
||||
"IE": Decimal("23.00"), # Ireland
|
||||
"IT": Decimal("22.00"), # Italy
|
||||
"LV": Decimal("21.00"), # Latvia
|
||||
"LT": Decimal("21.00"), # Lithuania
|
||||
"LU": Decimal("17.00"), # Luxembourg (standard)
|
||||
"MT": Decimal("18.00"), # Malta
|
||||
"NL": Decimal("21.00"), # Netherlands
|
||||
"PL": Decimal("23.00"), # Poland
|
||||
"PT": Decimal("23.00"), # Portugal
|
||||
"RO": Decimal("19.00"), # Romania
|
||||
"SK": Decimal("20.00"), # Slovakia
|
||||
"SI": Decimal("22.00"), # Slovenia
|
||||
"ES": Decimal("21.00"), # Spain
|
||||
"SE": Decimal("25.00"), # Sweden
|
||||
}
|
||||
|
||||
# Luxembourg specific VAT rates
|
||||
LU_VAT_RATES = {
|
||||
"standard": Decimal("17.00"),
|
||||
"intermediate": Decimal("14.00"),
|
||||
"reduced": Decimal("8.00"),
|
||||
"super_reduced": Decimal("3.00"),
|
||||
}
|
||||
|
||||
|
||||
class InvoiceNotFoundException(Exception):
|
||||
"""Raised when invoice not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvoiceSettingsNotFoundException(Exception):
|
||||
"""Raised when vendor invoice settings not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvoiceService:
|
||||
"""Service for invoice operations."""
|
||||
|
||||
# =========================================================================
|
||||
# VAT Calculation
|
||||
# =========================================================================
|
||||
|
||||
def get_vat_rate_for_country(self, country_iso: str) -> Decimal:
|
||||
"""Get standard VAT rate for EU country."""
|
||||
return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00"))
|
||||
|
||||
def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str:
|
||||
"""Get human-readable VAT rate label."""
|
||||
country_names = {
|
||||
"AT": "Austria",
|
||||
"BE": "Belgium",
|
||||
"BG": "Bulgaria",
|
||||
"HR": "Croatia",
|
||||
"CY": "Cyprus",
|
||||
"CZ": "Czech Republic",
|
||||
"DK": "Denmark",
|
||||
"EE": "Estonia",
|
||||
"FI": "Finland",
|
||||
"FR": "France",
|
||||
"DE": "Germany",
|
||||
"GR": "Greece",
|
||||
"HU": "Hungary",
|
||||
"IE": "Ireland",
|
||||
"IT": "Italy",
|
||||
"LV": "Latvia",
|
||||
"LT": "Lithuania",
|
||||
"LU": "Luxembourg",
|
||||
"MT": "Malta",
|
||||
"NL": "Netherlands",
|
||||
"PL": "Poland",
|
||||
"PT": "Portugal",
|
||||
"RO": "Romania",
|
||||
"SK": "Slovakia",
|
||||
"SI": "Slovenia",
|
||||
"ES": "Spain",
|
||||
"SE": "Sweden",
|
||||
}
|
||||
country_name = country_names.get(country_iso.upper(), country_iso)
|
||||
return f"{country_name} VAT {vat_rate}%"
|
||||
|
||||
def determine_vat_regime(
|
||||
self,
|
||||
seller_country: str,
|
||||
buyer_country: str,
|
||||
buyer_vat_number: str | None,
|
||||
seller_oss_registered: bool,
|
||||
) -> tuple[VATRegime, Decimal, str | None]:
|
||||
"""
|
||||
Determine VAT regime and rate for invoice.
|
||||
|
||||
Returns: (regime, vat_rate, destination_country)
|
||||
"""
|
||||
seller_country = seller_country.upper()
|
||||
buyer_country = buyer_country.upper()
|
||||
|
||||
# Same country = domestic VAT
|
||||
if seller_country == buyer_country:
|
||||
vat_rate = self.get_vat_rate_for_country(seller_country)
|
||||
return VATRegime.DOMESTIC, vat_rate, None
|
||||
|
||||
# Different EU countries
|
||||
if buyer_country in EU_VAT_RATES:
|
||||
# B2B with valid VAT number = reverse charge
|
||||
if buyer_vat_number:
|
||||
return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country
|
||||
|
||||
# B2C cross-border
|
||||
if seller_oss_registered:
|
||||
# OSS: use destination country VAT
|
||||
vat_rate = self.get_vat_rate_for_country(buyer_country)
|
||||
return VATRegime.OSS, vat_rate, buyer_country
|
||||
else:
|
||||
# No OSS: use origin country VAT
|
||||
vat_rate = self.get_vat_rate_for_country(seller_country)
|
||||
return VATRegime.ORIGIN, vat_rate, buyer_country
|
||||
|
||||
# Non-EU = VAT exempt (export)
|
||||
return VATRegime.EXEMPT, Decimal("0.00"), buyer_country
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Settings Management
|
||||
# =========================================================================
|
||||
|
||||
def get_settings(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorInvoiceSettings | None:
|
||||
"""Get vendor invoice settings."""
|
||||
return (
|
||||
db.query(VendorInvoiceSettings)
|
||||
.filter(VendorInvoiceSettings.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_settings_or_raise(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Get vendor invoice settings or raise exception."""
|
||||
settings = self.get_settings(db, vendor_id)
|
||||
if not settings:
|
||||
raise InvoiceSettingsNotFoundException(
|
||||
f"Invoice settings not configured for vendor {vendor_id}"
|
||||
)
|
||||
return settings
|
||||
|
||||
def create_settings(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: VendorInvoiceSettingsCreate,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Create vendor invoice settings."""
|
||||
# Check if settings already exist
|
||||
existing = self.get_settings(db, vendor_id)
|
||||
if existing:
|
||||
raise ValidationException(
|
||||
"Invoice settings already exist for this vendor"
|
||||
)
|
||||
|
||||
settings = VendorInvoiceSettings(
|
||||
vendor_id=vendor_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
db.add(settings)
|
||||
db.commit()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Created invoice settings for vendor {vendor_id}")
|
||||
return settings
|
||||
|
||||
def update_settings(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: VendorInvoiceSettingsUpdate,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Update vendor invoice settings."""
|
||||
settings = self.get_settings_or_raise(db, vendor_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(settings, key, value)
|
||||
|
||||
settings.updated_at = datetime.now(UTC)
|
||||
db.commit()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Updated invoice settings for vendor {vendor_id}")
|
||||
return settings
|
||||
|
||||
def create_settings_from_vendor(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""
|
||||
Create invoice settings from vendor/company info.
|
||||
|
||||
Used for initial setup based on existing vendor data.
|
||||
"""
|
||||
company = vendor.company
|
||||
|
||||
settings = VendorInvoiceSettings(
|
||||
vendor_id=vendor.id,
|
||||
company_name=company.legal_name if company else vendor.name,
|
||||
company_address=vendor.effective_business_address,
|
||||
company_city=None, # Would need to parse from address
|
||||
company_postal_code=None,
|
||||
company_country="LU",
|
||||
vat_number=vendor.effective_tax_number,
|
||||
is_vat_registered=bool(vendor.effective_tax_number),
|
||||
)
|
||||
db.add(settings)
|
||||
db.commit()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}")
|
||||
return settings
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Number Generation
|
||||
# =========================================================================
|
||||
|
||||
def _get_next_invoice_number(
|
||||
self, db: Session, settings: VendorInvoiceSettings
|
||||
) -> str:
|
||||
"""Generate next invoice number and increment counter."""
|
||||
number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding)
|
||||
invoice_number = f"{settings.invoice_prefix}{number}"
|
||||
|
||||
# Increment counter
|
||||
settings.invoice_next_number += 1
|
||||
db.flush()
|
||||
|
||||
return invoice_number
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Creation
|
||||
# =========================================================================
|
||||
|
||||
def create_invoice_from_order(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
notes: str | None = None,
|
||||
) -> Invoice:
|
||||
"""
|
||||
Create an invoice from an order.
|
||||
|
||||
Captures snapshots of seller/buyer details and calculates VAT.
|
||||
"""
|
||||
# Get invoice settings
|
||||
settings = self.get_settings_or_raise(db, vendor_id)
|
||||
|
||||
# Get order
|
||||
order = (
|
||||
db.query(Order)
|
||||
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
if not order:
|
||||
raise OrderNotFoundException(f"Order {order_id} not found")
|
||||
|
||||
# Check for existing invoice
|
||||
existing = (
|
||||
db.query(Invoice)
|
||||
.filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValidationException(f"Invoice already exists for order {order_id}")
|
||||
|
||||
# Determine VAT regime
|
||||
buyer_country = order.bill_country_iso
|
||||
vat_regime, vat_rate, destination_country = self.determine_vat_regime(
|
||||
seller_country=settings.company_country,
|
||||
buyer_country=buyer_country,
|
||||
buyer_vat_number=None, # TODO: Add B2B VAT number support
|
||||
seller_oss_registered=settings.is_oss_registered,
|
||||
)
|
||||
|
||||
# Build seller details snapshot
|
||||
seller_details = {
|
||||
"company_name": settings.company_name,
|
||||
"address": settings.company_address,
|
||||
"city": settings.company_city,
|
||||
"postal_code": settings.company_postal_code,
|
||||
"country": settings.company_country,
|
||||
"vat_number": settings.vat_number,
|
||||
}
|
||||
|
||||
# Build buyer details snapshot
|
||||
buyer_details = {
|
||||
"name": f"{order.bill_first_name} {order.bill_last_name}".strip(),
|
||||
"email": order.customer_email,
|
||||
"address": order.bill_address_line_1,
|
||||
"city": order.bill_city,
|
||||
"postal_code": order.bill_postal_code,
|
||||
"country": order.bill_country_iso,
|
||||
"vat_number": None, # TODO: B2B support
|
||||
}
|
||||
if order.bill_company:
|
||||
buyer_details["company"] = order.bill_company
|
||||
|
||||
# Build line items from order items
|
||||
line_items = []
|
||||
for item in order.items:
|
||||
line_items.append({
|
||||
"description": item.product_name,
|
||||
"quantity": item.quantity,
|
||||
"unit_price_cents": item.unit_price_cents,
|
||||
"total_cents": item.total_price_cents,
|
||||
"sku": item.product_sku,
|
||||
"ean": item.gtin,
|
||||
})
|
||||
|
||||
# Calculate amounts
|
||||
subtotal_cents = sum(item["total_cents"] for item in line_items)
|
||||
|
||||
# Calculate VAT
|
||||
if vat_rate > 0:
|
||||
vat_amount_cents = int(
|
||||
subtotal_cents * float(vat_rate) / 100
|
||||
)
|
||||
else:
|
||||
vat_amount_cents = 0
|
||||
|
||||
total_cents = subtotal_cents + vat_amount_cents
|
||||
|
||||
# Get VAT label
|
||||
vat_rate_label = None
|
||||
if vat_rate > 0:
|
||||
if destination_country:
|
||||
vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate)
|
||||
else:
|
||||
vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate)
|
||||
|
||||
# Generate invoice number
|
||||
invoice_number = self._get_next_invoice_number(db, settings)
|
||||
|
||||
# Create invoice
|
||||
invoice = Invoice(
|
||||
vendor_id=vendor_id,
|
||||
order_id=order_id,
|
||||
invoice_number=invoice_number,
|
||||
invoice_date=datetime.now(UTC),
|
||||
status=InvoiceStatus.DRAFT.value,
|
||||
seller_details=seller_details,
|
||||
buyer_details=buyer_details,
|
||||
line_items=line_items,
|
||||
vat_regime=vat_regime.value,
|
||||
destination_country=destination_country,
|
||||
vat_rate=vat_rate,
|
||||
vat_rate_label=vat_rate_label,
|
||||
currency=order.currency,
|
||||
subtotal_cents=subtotal_cents,
|
||||
vat_amount_cents=vat_amount_cents,
|
||||
total_cents=total_cents,
|
||||
payment_terms=settings.payment_terms,
|
||||
bank_details={
|
||||
"bank_name": settings.bank_name,
|
||||
"iban": settings.bank_iban,
|
||||
"bic": settings.bank_bic,
|
||||
} if settings.bank_iban else None,
|
||||
footer_text=settings.footer_text,
|
||||
notes=notes,
|
||||
)
|
||||
|
||||
db.add(invoice)
|
||||
db.commit()
|
||||
db.refresh(invoice)
|
||||
|
||||
logger.info(
|
||||
f"Created invoice {invoice_number} for order {order_id} "
|
||||
f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})"
|
||||
)
|
||||
|
||||
return invoice
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Retrieval
|
||||
# =========================================================================
|
||||
|
||||
def get_invoice(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by ID."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_invoice_or_raise(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Get invoice by ID or raise exception."""
|
||||
invoice = self.get_invoice(db, vendor_id, invoice_id)
|
||||
if not invoice:
|
||||
raise InvoiceNotFoundException(f"Invoice {invoice_id} not found")
|
||||
return invoice
|
||||
|
||||
def get_invoice_by_number(
|
||||
self, db: Session, vendor_id: int, invoice_number: str
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by invoice number."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.invoice_number == invoice_number,
|
||||
Invoice.vendor_id == vendor_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def list_invoices(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> tuple[list[Invoice], int]:
|
||||
"""
|
||||
List invoices for vendor with pagination.
|
||||
|
||||
Returns: (invoices, total_count)
|
||||
"""
|
||||
query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Invoice.status == status)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Apply pagination and order
|
||||
invoices = (
|
||||
query.order_by(Invoice.invoice_date.desc())
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return invoices, total
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Status Management
|
||||
# =========================================================================
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
new_status: str,
|
||||
) -> Invoice:
|
||||
"""Update invoice status."""
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
|
||||
# Validate status transition
|
||||
valid_statuses = [s.value for s in InvoiceStatus]
|
||||
if new_status not in valid_statuses:
|
||||
raise ValidationException(f"Invalid status: {new_status}")
|
||||
|
||||
# Cannot change cancelled invoices
|
||||
if invoice.status == InvoiceStatus.CANCELLED.value:
|
||||
raise ValidationException("Cannot change status of cancelled invoice")
|
||||
|
||||
invoice.status = new_status
|
||||
invoice.updated_at = datetime.now(UTC)
|
||||
db.commit()
|
||||
db.refresh(invoice)
|
||||
|
||||
logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}")
|
||||
return invoice
|
||||
|
||||
def mark_as_issued(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Mark invoice as issued."""
|
||||
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value)
|
||||
|
||||
def mark_as_paid(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Mark invoice as paid."""
|
||||
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value)
|
||||
|
||||
def cancel_invoice(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Cancel invoice."""
|
||||
return self.update_status(
|
||||
db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_invoice_stats(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> dict[str, Any]:
|
||||
"""Get invoice statistics for vendor."""
|
||||
total_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(Invoice.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
total_revenue = (
|
||||
db.query(func.sum(Invoice.total_cents))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status.in_([
|
||||
InvoiceStatus.ISSUED.value,
|
||||
InvoiceStatus.PAID.value,
|
||||
]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
draft_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status == InvoiceStatus.DRAFT.value,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
paid_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status == InvoiceStatus.PAID.value,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_invoices": total_count,
|
||||
"total_revenue_cents": total_revenue,
|
||||
"total_revenue": total_revenue / 100 if total_revenue else 0,
|
||||
"draft_count": draft_count,
|
||||
"paid_count": paid_count,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PDF Generation
|
||||
# =========================================================================
|
||||
|
||||
def generate_pdf(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
force_regenerate: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate PDF for an invoice.
|
||||
|
||||
Returns path to the generated PDF.
|
||||
"""
|
||||
from app.services.invoice_pdf_service import invoice_pdf_service
|
||||
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate)
|
||||
|
||||
def get_pdf_path(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
) -> str | None:
|
||||
"""Get PDF path for an invoice if it exists."""
|
||||
from app.services.invoice_pdf_service import invoice_pdf_service
|
||||
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
return invoice_pdf_service.get_pdf_path(invoice)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
invoice_service = InvoiceService()
|
||||
@@ -15,6 +15,7 @@ from sqlalchemy import String, and_, func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.order_service import order_service as unified_order_service
|
||||
from app.services.subscription_service import subscription_service
|
||||
from models.database.letzshop import (
|
||||
LetzshopFulfillmentQueue,
|
||||
LetzshopHistoricalImportJob,
|
||||
@@ -792,6 +793,7 @@ class LetzshopOrderService:
|
||||
"updated": 0,
|
||||
"skipped": 0,
|
||||
"errors": 0,
|
||||
"limit_exceeded": 0,
|
||||
"products_matched": 0,
|
||||
"products_not_found": 0,
|
||||
"eans_processed": set(),
|
||||
@@ -800,6 +802,10 @@ class LetzshopOrderService:
|
||||
"error_messages": [],
|
||||
}
|
||||
|
||||
# Get subscription usage upfront for batch efficiency
|
||||
usage = subscription_service.get_usage(self.db, vendor_id)
|
||||
orders_remaining = usage.orders_remaining # None = unlimited
|
||||
|
||||
for i, shipment in enumerate(shipments):
|
||||
shipment_id = shipment.get("id")
|
||||
if not shipment_id:
|
||||
@@ -844,11 +850,24 @@ class LetzshopOrderService:
|
||||
else:
|
||||
stats["skipped"] += 1
|
||||
else:
|
||||
# Check tier limit before creating order
|
||||
if orders_remaining is not None and orders_remaining <= 0:
|
||||
stats["limit_exceeded"] += 1
|
||||
stats["error_messages"].append(
|
||||
f"Shipment {shipment_id}: Order limit reached"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create new order using unified service
|
||||
try:
|
||||
self.create_order(vendor_id, shipment)
|
||||
self.db.commit() # noqa: SVC-006 - background task needs incremental commits
|
||||
stats["imported"] += 1
|
||||
|
||||
# Decrement remaining count for batch efficiency
|
||||
if orders_remaining is not None:
|
||||
orders_remaining -= 1
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback() # Rollback failed order
|
||||
stats["errors"] += 1
|
||||
|
||||
@@ -31,6 +31,10 @@ from app.exceptions import (
|
||||
ValidationException,
|
||||
)
|
||||
from app.services.order_item_exception_service import order_item_exception_service
|
||||
from app.services.subscription_service import (
|
||||
subscription_service,
|
||||
TierLimitExceededException,
|
||||
)
|
||||
from app.utils.money import Money, cents_to_euros, euros_to_cents
|
||||
from models.database.customer import Customer
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
@@ -271,7 +275,11 @@ class OrderService:
|
||||
Raises:
|
||||
ValidationException: If order data is invalid
|
||||
InsufficientInventoryException: If not enough inventory
|
||||
TierLimitExceededException: If vendor has reached order limit
|
||||
"""
|
||||
# Check tier limit before creating order
|
||||
subscription_service.check_order_limit(db, vendor_id)
|
||||
|
||||
try:
|
||||
# Get or create customer
|
||||
if order_data.customer_id:
|
||||
@@ -428,6 +436,9 @@ class OrderService:
|
||||
db.flush()
|
||||
db.refresh(order)
|
||||
|
||||
# Increment order count for subscription tracking
|
||||
subscription_service.increment_order_count(db, vendor_id)
|
||||
|
||||
logger.info(
|
||||
f"Order {order.order_number} created for vendor {vendor_id}, "
|
||||
f"total: EUR {cents_to_euros(total_amount_cents):.2f}"
|
||||
@@ -439,6 +450,7 @@ class OrderService:
|
||||
ValidationException,
|
||||
InsufficientInventoryException,
|
||||
CustomerNotFoundException,
|
||||
TierLimitExceededException,
|
||||
):
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -450,6 +462,7 @@ class OrderService:
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
shipment_data: dict[str, Any],
|
||||
skip_limit_check: bool = False,
|
||||
) -> Order:
|
||||
"""
|
||||
Create an order from Letzshop shipment data.
|
||||
@@ -462,13 +475,27 @@ class OrderService:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
shipment_data: Raw shipment data from Letzshop API
|
||||
skip_limit_check: If True, skip tier limit check (for batch imports
|
||||
that check limit upfront)
|
||||
|
||||
Returns:
|
||||
Created Order object
|
||||
|
||||
Raises:
|
||||
ValidationException: If product not found by GTIN
|
||||
TierLimitExceededException: If vendor has reached order limit
|
||||
"""
|
||||
# Check tier limit before creating order (unless skipped for batch ops)
|
||||
if not skip_limit_check:
|
||||
can_create, message = subscription_service.can_create_order(db, vendor_id)
|
||||
if not can_create:
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Order limit exceeded",
|
||||
limit_type="orders",
|
||||
current=0, # Will be filled by caller if needed
|
||||
limit=0,
|
||||
)
|
||||
|
||||
order_data = shipment_data.get("order", {})
|
||||
|
||||
# Generate order number using Letzshop order number
|
||||
@@ -777,6 +804,9 @@ class OrderService:
|
||||
f"order {order.order_number}"
|
||||
)
|
||||
|
||||
# Increment order count for subscription tracking
|
||||
subscription_service.increment_order_count(db, vendor_id)
|
||||
|
||||
logger.info(
|
||||
f"Letzshop order {order.order_number} created for vendor {vendor_id}, "
|
||||
f"status: {status}, items: {len(inventory_units)}"
|
||||
|
||||
512
app/services/subscription_service.py
Normal file
512
app/services/subscription_service.py
Normal file
@@ -0,0 +1,512 @@
|
||||
# app/services/subscription_service.py
|
||||
"""
|
||||
Subscription service for tier-based access control.
|
||||
|
||||
Handles:
|
||||
- Subscription creation and management
|
||||
- Tier limit enforcement
|
||||
- Usage tracking
|
||||
- Feature gating
|
||||
|
||||
Usage:
|
||||
from app.services.subscription_service import subscription_service
|
||||
|
||||
# Check if vendor can create an order
|
||||
can_create, message = subscription_service.can_create_order(db, vendor_id)
|
||||
|
||||
# Increment order counter after successful order
|
||||
subscription_service.increment_order_count(db, vendor_id)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.product import Product
|
||||
from models.database.subscription import (
|
||||
SubscriptionStatus,
|
||||
TIER_LIMITS,
|
||||
TierCode,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
from models.schema.subscription import (
|
||||
SubscriptionCreate,
|
||||
SubscriptionUpdate,
|
||||
SubscriptionUsage,
|
||||
TierInfo,
|
||||
TierLimits,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubscriptionNotFoundException(Exception):
|
||||
"""Raised when subscription not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TierLimitExceededException(Exception):
|
||||
"""Raised when a tier limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str, limit_type: str, current: int, limit: int):
|
||||
super().__init__(message)
|
||||
self.limit_type = limit_type
|
||||
self.current = current
|
||||
self.limit = limit
|
||||
|
||||
|
||||
class FeatureNotAvailableException(Exception):
|
||||
"""Raised when a feature is not available in current tier."""
|
||||
|
||||
def __init__(self, feature: str, current_tier: str, required_tier: str):
|
||||
message = f"Feature '{feature}' requires {required_tier} tier (current: {current_tier})"
|
||||
super().__init__(message)
|
||||
self.feature = feature
|
||||
self.current_tier = current_tier
|
||||
self.required_tier = required_tier
|
||||
|
||||
|
||||
class SubscriptionService:
|
||||
"""Service for subscription and tier limit operations."""
|
||||
|
||||
# =========================================================================
|
||||
# Tier Information
|
||||
# =========================================================================
|
||||
|
||||
def get_tier_info(self, tier_code: str) -> TierInfo:
|
||||
"""Get full tier information."""
|
||||
try:
|
||||
tier = TierCode(tier_code)
|
||||
except ValueError:
|
||||
tier = TierCode.ESSENTIAL
|
||||
|
||||
limits = TIER_LIMITS[tier]
|
||||
return TierInfo(
|
||||
code=tier.value,
|
||||
name=limits["name"],
|
||||
price_monthly_cents=limits["price_monthly_cents"],
|
||||
price_annual_cents=limits.get("price_annual_cents"),
|
||||
limits=TierLimits(
|
||||
orders_per_month=limits.get("orders_per_month"),
|
||||
products_limit=limits.get("products_limit"),
|
||||
team_members=limits.get("team_members"),
|
||||
order_history_months=limits.get("order_history_months"),
|
||||
),
|
||||
features=limits.get("features", []),
|
||||
)
|
||||
|
||||
def get_all_tiers(self) -> list[TierInfo]:
|
||||
"""Get information for all tiers."""
|
||||
return [
|
||||
self.get_tier_info(tier.value)
|
||||
for tier in TierCode
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# Subscription CRUD
|
||||
# =========================================================================
|
||||
|
||||
def get_subscription(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription | None:
|
||||
"""Get vendor subscription."""
|
||||
return (
|
||||
db.query(VendorSubscription)
|
||||
.filter(VendorSubscription.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_subscription_or_raise(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorSubscription:
|
||||
"""Get vendor subscription or raise exception."""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if not subscription:
|
||||
raise SubscriptionNotFoundException(
|
||||
f"No subscription found for vendor {vendor_id}"
|
||||
)
|
||||
return subscription
|
||||
|
||||
def get_or_create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
tier: str = TierCode.ESSENTIAL.value,
|
||||
trial_days: int = 14,
|
||||
) -> VendorSubscription:
|
||||
"""
|
||||
Get existing subscription or create a new trial subscription.
|
||||
|
||||
Used when a vendor first accesses the system.
|
||||
"""
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
if subscription:
|
||||
return subscription
|
||||
|
||||
# Create new trial subscription
|
||||
now = datetime.now(UTC)
|
||||
trial_end = now + timedelta(days=trial_days)
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=tier,
|
||||
status=SubscriptionStatus.TRIAL.value,
|
||||
period_start=now,
|
||||
period_end=trial_end,
|
||||
trial_ends_at=trial_end,
|
||||
is_annual=False,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(
|
||||
f"Created trial subscription for vendor {vendor_id} "
|
||||
f"(tier={tier}, trial_ends={trial_end})"
|
||||
)
|
||||
|
||||
return subscription
|
||||
|
||||
def create_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionCreate,
|
||||
) -> VendorSubscription:
|
||||
"""Create a subscription for a vendor."""
|
||||
# Check if subscription exists
|
||||
existing = self.get_subscription(db, vendor_id)
|
||||
if existing:
|
||||
raise ValueError("Vendor already has a subscription")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Calculate period end based on billing cycle
|
||||
if data.is_annual:
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
# Handle trial
|
||||
trial_ends_at = None
|
||||
status = SubscriptionStatus.ACTIVE.value
|
||||
if data.trial_days > 0:
|
||||
trial_ends_at = now + timedelta(days=data.trial_days)
|
||||
status = SubscriptionStatus.TRIAL.value
|
||||
period_end = trial_ends_at
|
||||
|
||||
subscription = VendorSubscription(
|
||||
vendor_id=vendor_id,
|
||||
tier=data.tier,
|
||||
status=status,
|
||||
period_start=now,
|
||||
period_end=period_end,
|
||||
trial_ends_at=trial_ends_at,
|
||||
is_annual=data.is_annual,
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Created subscription for vendor {vendor_id}: {data.tier}")
|
||||
return subscription
|
||||
|
||||
def update_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: SubscriptionUpdate,
|
||||
) -> VendorSubscription:
|
||||
"""Update a vendor subscription."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(subscription, key, value)
|
||||
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Updated subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
def upgrade_tier(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
new_tier: str,
|
||||
) -> VendorSubscription:
|
||||
"""Upgrade vendor to a new tier."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
old_tier = subscription.tier
|
||||
subscription.tier = new_tier
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
# If upgrading from trial, mark as active
|
||||
if subscription.status == SubscriptionStatus.TRIAL.value:
|
||||
subscription.status = SubscriptionStatus.ACTIVE.value
|
||||
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Upgraded vendor {vendor_id} from {old_tier} to {new_tier}")
|
||||
return subscription
|
||||
|
||||
def cancel_subscription(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
reason: str | None = None,
|
||||
) -> VendorSubscription:
|
||||
"""Cancel a vendor subscription (access until period end)."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
|
||||
subscription.status = SubscriptionStatus.CANCELLED.value
|
||||
subscription.cancelled_at = datetime.now(UTC)
|
||||
subscription.cancellation_reason = reason
|
||||
subscription.updated_at = datetime.now(UTC)
|
||||
|
||||
db.commit()
|
||||
db.refresh(subscription)
|
||||
|
||||
logger.info(f"Cancelled subscription for vendor {vendor_id}")
|
||||
return subscription
|
||||
|
||||
# =========================================================================
|
||||
# Usage Tracking
|
||||
# =========================================================================
|
||||
|
||||
def get_usage(self, db: Session, vendor_id: int) -> SubscriptionUsage:
|
||||
"""Get current subscription usage statistics."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Get actual counts
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Calculate usage stats
|
||||
orders_limit = subscription.orders_limit
|
||||
products_limit = subscription.products_limit
|
||||
team_limit = subscription.team_members_limit
|
||||
|
||||
def calc_remaining(current: int, limit: int | None) -> int | None:
|
||||
if limit is None:
|
||||
return None
|
||||
return max(0, limit - current)
|
||||
|
||||
def calc_percent(current: int, limit: int | None) -> float | None:
|
||||
if limit is None or limit == 0:
|
||||
return None
|
||||
return min(100.0, (current / limit) * 100)
|
||||
|
||||
return SubscriptionUsage(
|
||||
orders_used=subscription.orders_this_period,
|
||||
orders_limit=orders_limit,
|
||||
orders_remaining=calc_remaining(subscription.orders_this_period, orders_limit),
|
||||
orders_percent_used=calc_percent(subscription.orders_this_period, orders_limit),
|
||||
products_used=products_count,
|
||||
products_limit=products_limit,
|
||||
products_remaining=calc_remaining(products_count, products_limit),
|
||||
products_percent_used=calc_percent(products_count, products_limit),
|
||||
team_members_used=team_count,
|
||||
team_members_limit=team_limit,
|
||||
team_members_remaining=calc_remaining(team_count, team_limit),
|
||||
team_members_percent_used=calc_percent(team_count, team_limit),
|
||||
)
|
||||
|
||||
def increment_order_count(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Increment the order counter for the current period.
|
||||
|
||||
Call this after successfully creating/importing an order.
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
subscription.increment_order_count()
|
||||
db.commit()
|
||||
|
||||
def reset_period_counters(self, db: Session, vendor_id: int) -> None:
|
||||
"""Reset counters for a new billing period."""
|
||||
subscription = self.get_subscription_or_raise(db, vendor_id)
|
||||
subscription.reset_period_counters()
|
||||
db.commit()
|
||||
logger.info(f"Reset period counters for vendor {vendor_id}")
|
||||
|
||||
# =========================================================================
|
||||
# Limit Checks
|
||||
# =========================================================================
|
||||
|
||||
def can_create_order(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can create/import another order.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.can_create_order()
|
||||
|
||||
def check_order_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check order limit and raise exception if exceeded.
|
||||
|
||||
Use this in order creation flows.
|
||||
"""
|
||||
can_create, message = self.can_create_order(db, vendor_id)
|
||||
if not can_create:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Order limit exceeded",
|
||||
limit_type="orders",
|
||||
current=subscription.orders_this_period if subscription else 0,
|
||||
limit=subscription.orders_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_product(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another product.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_product(products_count)
|
||||
|
||||
def check_product_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check product limit and raise exception if exceeded.
|
||||
|
||||
Use this in product creation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_product(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
products_count = (
|
||||
db.query(func.count(Product.id))
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Product limit exceeded",
|
||||
limit_type="products",
|
||||
current=products_count,
|
||||
limit=subscription.products_limit if subscription else 0,
|
||||
)
|
||||
|
||||
def can_add_team_member(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Check if vendor can add another team member.
|
||||
|
||||
Returns: (allowed, error_message)
|
||||
"""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return subscription.can_add_team_member(team_count)
|
||||
|
||||
def check_team_limit(self, db: Session, vendor_id: int) -> None:
|
||||
"""
|
||||
Check team member limit and raise exception if exceeded.
|
||||
|
||||
Use this in team member invitation flows.
|
||||
"""
|
||||
can_add, message = self.can_add_team_member(db, vendor_id)
|
||||
if not can_add:
|
||||
subscription = self.get_subscription(db, vendor_id)
|
||||
team_count = (
|
||||
db.query(func.count(VendorUser.id))
|
||||
.filter(VendorUser.vendor_id == vendor_id, VendorUser.is_active == True)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
raise TierLimitExceededException(
|
||||
message=message or "Team member limit exceeded",
|
||||
limit_type="team_members",
|
||||
current=team_count,
|
||||
limit=subscription.team_members_limit if subscription else 0,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Feature Gating
|
||||
# =========================================================================
|
||||
|
||||
def has_feature(self, db: Session, vendor_id: int, feature: str) -> bool:
|
||||
"""Check if vendor has access to a feature."""
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
return subscription.has_feature(feature)
|
||||
|
||||
def check_feature(self, db: Session, vendor_id: int, feature: str) -> None:
|
||||
"""
|
||||
Check feature access and raise exception if not available.
|
||||
|
||||
Use this to gate premium features.
|
||||
"""
|
||||
if not self.has_feature(db, vendor_id, feature):
|
||||
subscription = self.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Find which tier has this feature
|
||||
required_tier = None
|
||||
for tier_code, limits in TIER_LIMITS.items():
|
||||
if feature in limits.get("features", []):
|
||||
required_tier = limits["name"]
|
||||
break
|
||||
|
||||
raise FeatureNotAvailableException(
|
||||
feature=feature,
|
||||
current_tier=subscription.tier,
|
||||
required_tier=required_tier or "higher",
|
||||
)
|
||||
|
||||
def get_feature_tier(self, feature: str) -> str | None:
|
||||
"""Get the minimum tier required for a feature."""
|
||||
for tier_code in [
|
||||
TierCode.ESSENTIAL,
|
||||
TierCode.PROFESSIONAL,
|
||||
TierCode.BUSINESS,
|
||||
TierCode.ENTERPRISE,
|
||||
]:
|
||||
if feature in TIER_LIMITS[tier_code].get("features", []):
|
||||
return tier_code.value
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
subscription_service = SubscriptionService()
|
||||
Reference in New Issue
Block a user