refactor: migrate modules from re-exports to canonical implementations

Move actual code implementations into module directories:
- orders: 5 services, 4 models, order/invoice schemas
- inventory: 3 services, 2 models, 30+ schemas
- customers: 3 services, 2 models, customer schemas
- messaging: 3 services, 2 models, message/notification schemas
- monitoring: background_tasks_service
- marketplace: 5+ services including letzshop submodule
- dev_tools: code_quality_service, test_runner_service
- billing: billing_service
- contracts: definition.py

Legacy files in app/services/, models/database/, models/schema/
now re-export from canonical module locations for backwards
compatibility. Architecture validator passes with 0 errors.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-29 21:28:56 +01:00
parent b5a803cde8
commit de83875d0a
99 changed files with 19413 additions and 15357 deletions

View File

@@ -12,7 +12,7 @@ from fastapi.responses import HTMLResponse
from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_from_cookie_or_header, get_db
from app.services.platform_settings_service import platform_settings_service
from app.services.platform_settings_service import platform_settings_service # noqa: MOD-004 - shared platform service
from app.templates_config import templates
from models.database.user import User
from models.database.vendor import Vendor

View File

@@ -17,8 +17,7 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_admin_api, require_module_access
from app.core.database import get_db
from app.services.admin_subscription_service import admin_subscription_service
from app.services.subscription_service import subscription_service
from app.modules.billing.services import admin_subscription_service, subscription_service
from models.database.user import User
from models.schema.billing import (
BillingHistoryListResponse,

View File

@@ -19,8 +19,7 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_api, require_module_access
from app.core.config import settings
from app.core.database import get_db
from app.services.billing_service import billing_service
from app.services.subscription_service import subscription_service
from app.modules.billing.services import billing_service, subscription_service
from models.database.user import User
logger = logging.getLogger(__name__)

View File

@@ -17,6 +17,16 @@ from app.modules.billing.services.admin_subscription_service import (
AdminSubscriptionService,
admin_subscription_service,
)
from app.modules.billing.services.billing_service import (
BillingService,
billing_service,
BillingServiceError,
PaymentSystemNotConfiguredError,
TierNotFoundError,
StripePriceNotConfiguredError,
NoActiveSubscriptionError,
SubscriptionNotCancelledError,
)
__all__ = [
"SubscriptionService",
@@ -25,4 +35,12 @@ __all__ = [
"stripe_service",
"AdminSubscriptionService",
"admin_subscription_service",
"BillingService",
"billing_service",
"BillingServiceError",
"PaymentSystemNotConfiguredError",
"TierNotFoundError",
"StripePriceNotConfiguredError",
"NoActiveSubscriptionError",
"SubscriptionNotCancelledError",
]

View File

@@ -0,0 +1,588 @@
# app/modules/billing/services/billing_service.py
"""
Billing service for subscription and payment operations.
Provides:
- Subscription status and usage queries
- Tier management
- Invoice history
- Add-on management
"""
import logging
from datetime import datetime
from sqlalchemy.orm import Session
from app.modules.billing.services.stripe_service import stripe_service
from app.modules.billing.services.subscription_service import subscription_service
from app.modules.billing.models import (
AddOnProduct,
BillingHistory,
SubscriptionTier,
VendorAddOn,
VendorSubscription,
)
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class BillingServiceError(Exception):
"""Base exception for billing service errors."""
pass
class PaymentSystemNotConfiguredError(BillingServiceError):
"""Raised when Stripe is not configured."""
def __init__(self):
super().__init__("Payment system not configured")
class TierNotFoundError(BillingServiceError):
"""Raised when a tier is not found."""
def __init__(self, tier_code: str):
self.tier_code = tier_code
super().__init__(f"Tier '{tier_code}' not found")
class StripePriceNotConfiguredError(BillingServiceError):
"""Raised when Stripe price is not configured for a tier."""
def __init__(self, tier_code: str):
self.tier_code = tier_code
super().__init__(f"Stripe price not configured for tier '{tier_code}'")
class NoActiveSubscriptionError(BillingServiceError):
"""Raised when no active subscription exists."""
def __init__(self):
super().__init__("No active subscription found")
class SubscriptionNotCancelledError(BillingServiceError):
"""Raised when trying to reactivate a non-cancelled subscription."""
def __init__(self):
super().__init__("Subscription is not cancelled")
class BillingService:
"""Service for billing operations."""
def get_subscription_with_tier(
self, db: Session, vendor_id: int
) -> tuple[VendorSubscription, SubscriptionTier | None]:
"""
Get subscription and its tier info.
Returns:
Tuple of (subscription, tier) where tier may be None
"""
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == subscription.tier)
.first()
)
return subscription, tier
def get_available_tiers(
self, db: Session, current_tier: str
) -> tuple[list[dict], dict[str, int]]:
"""
Get all available tiers with upgrade/downgrade flags.
Returns:
Tuple of (tier_list, tier_order_map)
"""
tiers = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.is_public == True, # noqa: E712
)
.order_by(SubscriptionTier.display_order)
.all()
)
tier_order = {t.code: t.display_order for t in tiers}
current_order = tier_order.get(current_tier, 0)
tier_list = []
for tier in tiers:
tier_list.append({
"code": tier.code,
"name": tier.name,
"description": tier.description,
"price_monthly_cents": tier.price_monthly_cents,
"price_annual_cents": tier.price_annual_cents,
"orders_per_month": tier.orders_per_month,
"products_limit": tier.products_limit,
"team_members": tier.team_members,
"features": tier.features or [],
"is_current": tier.code == current_tier,
"can_upgrade": tier.display_order > current_order,
"can_downgrade": tier.display_order < current_order,
})
return tier_list, tier_order
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
"""
Get a tier by its code.
Raises:
TierNotFoundError: If tier doesn't exist
"""
tier = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.code == tier_code,
SubscriptionTier.is_active == True, # noqa: E712
)
.first()
)
if not tier:
raise TierNotFoundError(tier_code)
return tier
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
"""
Get vendor by ID.
Raises:
VendorNotFoundException from app.exceptions
"""
from app.exceptions import VendorNotFoundException
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
return vendor
def create_checkout_session(
self,
db: Session,
vendor_id: int,
tier_code: str,
is_annual: bool,
success_url: str,
cancel_url: str,
) -> dict:
"""
Create a Stripe checkout session.
Returns:
Dict with checkout_url and session_id
Raises:
PaymentSystemNotConfiguredError: If Stripe not configured
TierNotFoundError: If tier doesn't exist
StripePriceNotConfiguredError: If price not configured
"""
if not stripe_service.is_configured:
raise PaymentSystemNotConfiguredError()
vendor = self.get_vendor(db, vendor_id)
tier = self.get_tier_by_code(db, tier_code)
price_id = (
tier.stripe_price_annual_id
if is_annual and tier.stripe_price_annual_id
else tier.stripe_price_monthly_id
)
if not price_id:
raise StripePriceNotConfiguredError(tier_code)
# Check if this is a new subscription (for trial)
existing_sub = subscription_service.get_subscription(db, vendor_id)
trial_days = None
if not existing_sub or not existing_sub.stripe_subscription_id:
from app.core.config import settings
trial_days = settings.stripe_trial_days
session = stripe_service.create_checkout_session(
db=db,
vendor=vendor,
price_id=price_id,
success_url=success_url,
cancel_url=cancel_url,
trial_days=trial_days,
)
# Update subscription with tier info
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
subscription.tier = tier_code
subscription.is_annual = is_annual
return {
"checkout_url": session.url,
"session_id": session.id,
}
def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict:
"""
Create a Stripe customer portal session.
Returns:
Dict with portal_url
Raises:
PaymentSystemNotConfiguredError: If Stripe not configured
NoActiveSubscriptionError: If no subscription with customer ID
"""
if not stripe_service.is_configured:
raise PaymentSystemNotConfiguredError()
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_customer_id:
raise NoActiveSubscriptionError()
session = stripe_service.create_portal_session(
customer_id=subscription.stripe_customer_id,
return_url=return_url,
)
return {"portal_url": session.url}
def get_invoices(
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20
) -> tuple[list[BillingHistory], int]:
"""
Get invoice history for a vendor.
Returns:
Tuple of (invoices, total_count)
"""
query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id)
total = query.count()
invoices = (
query.order_by(BillingHistory.invoice_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return invoices, total
def get_available_addons(
self, db: Session, category: str | None = None
) -> list[AddOnProduct]:
"""Get available add-on products."""
query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712
if category:
query = query.filter(AddOnProduct.category == category)
return query.order_by(AddOnProduct.display_order).all()
def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]:
"""Get vendor's purchased add-ons."""
return (
db.query(VendorAddOn)
.filter(VendorAddOn.vendor_id == vendor_id)
.all()
)
def cancel_subscription(
self, db: Session, vendor_id: int, reason: str | None, immediately: bool
) -> dict:
"""
Cancel a subscription.
Returns:
Dict with message and effective_date
Raises:
NoActiveSubscriptionError: If no subscription to cancel
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_subscription_id:
raise NoActiveSubscriptionError()
if stripe_service.is_configured:
stripe_service.cancel_subscription(
subscription_id=subscription.stripe_subscription_id,
immediately=immediately,
cancellation_reason=reason,
)
subscription.cancelled_at = datetime.utcnow()
subscription.cancellation_reason = reason
effective_date = (
datetime.utcnow().isoformat()
if immediately
else subscription.period_end.isoformat()
if subscription.period_end
else datetime.utcnow().isoformat()
)
return {
"message": "Subscription cancelled successfully",
"effective_date": effective_date,
}
def reactivate_subscription(self, db: Session, vendor_id: int) -> dict:
"""
Reactivate a cancelled subscription.
Returns:
Dict with success message
Raises:
NoActiveSubscriptionError: If no subscription
SubscriptionNotCancelledError: If not cancelled
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_subscription_id:
raise NoActiveSubscriptionError()
if not subscription.cancelled_at:
raise SubscriptionNotCancelledError()
if stripe_service.is_configured:
stripe_service.reactivate_subscription(subscription.stripe_subscription_id)
subscription.cancelled_at = None
subscription.cancellation_reason = None
return {"message": "Subscription reactivated successfully"}
def get_upcoming_invoice(self, db: Session, vendor_id: int) -> dict:
"""
Get upcoming invoice preview.
Returns:
Dict with amount_due_cents, currency, next_payment_date, line_items
Raises:
NoActiveSubscriptionError: If no subscription with customer ID
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_customer_id:
raise NoActiveSubscriptionError()
if not stripe_service.is_configured:
# Return empty preview if Stripe not configured
return {
"amount_due_cents": 0,
"currency": "EUR",
"next_payment_date": None,
"line_items": [],
}
invoice = stripe_service.get_upcoming_invoice(subscription.stripe_customer_id)
if not invoice:
return {
"amount_due_cents": 0,
"currency": "EUR",
"next_payment_date": None,
"line_items": [],
}
line_items = []
if invoice.lines and invoice.lines.data:
for line in invoice.lines.data:
line_items.append({
"description": line.description or "",
"amount_cents": line.amount,
"quantity": line.quantity or 1,
})
return {
"amount_due_cents": invoice.amount_due,
"currency": invoice.currency.upper(),
"next_payment_date": datetime.fromtimestamp(invoice.next_payment_attempt).isoformat()
if invoice.next_payment_attempt
else None,
"line_items": line_items,
}
def change_tier(
self,
db: Session,
vendor_id: int,
new_tier_code: str,
is_annual: bool,
) -> dict:
"""
Change subscription tier (upgrade/downgrade).
Returns:
Dict with message, new_tier, effective_immediately
Raises:
TierNotFoundError: If tier doesn't exist
NoActiveSubscriptionError: If no subscription
StripePriceNotConfiguredError: If price not configured
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_subscription_id:
raise NoActiveSubscriptionError()
tier = self.get_tier_by_code(db, new_tier_code)
price_id = (
tier.stripe_price_annual_id
if is_annual and tier.stripe_price_annual_id
else tier.stripe_price_monthly_id
)
if not price_id:
raise StripePriceNotConfiguredError(new_tier_code)
# Update in Stripe
if stripe_service.is_configured:
stripe_service.update_subscription(
subscription_id=subscription.stripe_subscription_id,
new_price_id=price_id,
)
# Update local subscription
old_tier = subscription.tier
subscription.tier = new_tier_code
subscription.tier_id = tier.id
subscription.is_annual = is_annual
subscription.updated_at = datetime.utcnow()
is_upgrade = self._is_upgrade(db, old_tier, new_tier_code)
return {
"message": f"Subscription {'upgraded' if is_upgrade else 'changed'} to {tier.name}",
"new_tier": new_tier_code,
"effective_immediately": True,
}
def _is_upgrade(self, db: Session, old_tier: str, new_tier: str) -> bool:
"""Check if tier change is an upgrade."""
old = db.query(SubscriptionTier).filter(SubscriptionTier.code == old_tier).first()
new = db.query(SubscriptionTier).filter(SubscriptionTier.code == new_tier).first()
if not old or not new:
return False
return new.display_order > old.display_order
def purchase_addon(
self,
db: Session,
vendor_id: int,
addon_code: str,
domain_name: str | None,
quantity: int,
success_url: str,
cancel_url: str,
) -> dict:
"""
Create checkout session for add-on purchase.
Returns:
Dict with checkout_url and session_id
Raises:
PaymentSystemNotConfiguredError: If Stripe not configured
AddonNotFoundError: If addon doesn't exist
"""
if not stripe_service.is_configured:
raise PaymentSystemNotConfiguredError()
addon = (
db.query(AddOnProduct)
.filter(
AddOnProduct.code == addon_code,
AddOnProduct.is_active == True, # noqa: E712
)
.first()
)
if not addon:
raise BillingServiceError(f"Add-on '{addon_code}' not found")
if not addon.stripe_price_id:
raise BillingServiceError(f"Stripe price not configured for add-on '{addon_code}'")
vendor = self.get_vendor(db, vendor_id)
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
# Create checkout session for add-on
session = stripe_service.create_checkout_session(
db=db,
vendor=vendor,
price_id=addon.stripe_price_id,
success_url=success_url,
cancel_url=cancel_url,
quantity=quantity,
metadata={
"addon_code": addon_code,
"domain_name": domain_name or "",
},
)
return {
"checkout_url": session.url,
"session_id": session.id,
}
def cancel_addon(self, db: Session, vendor_id: int, addon_id: int) -> dict:
"""
Cancel a purchased add-on.
Returns:
Dict with message and addon_code
Raises:
BillingServiceError: If addon not found or not owned by vendor
"""
vendor_addon = (
db.query(VendorAddOn)
.filter(
VendorAddOn.id == addon_id,
VendorAddOn.vendor_id == vendor_id,
)
.first()
)
if not vendor_addon:
raise BillingServiceError("Add-on not found")
addon_code = vendor_addon.addon_product.code
# Cancel in Stripe if applicable
if stripe_service.is_configured and vendor_addon.stripe_subscription_item_id:
try:
stripe_service.cancel_subscription_item(vendor_addon.stripe_subscription_item_id)
except Exception as e:
logger.warning(f"Failed to cancel addon in Stripe: {e}")
# Mark as cancelled
vendor_addon.status = "cancelled"
vendor_addon.cancelled_at = datetime.utcnow()
return {
"message": "Add-on cancelled successfully",
"addon_code": addon_code,
}
# Create service instance
billing_service = BillingService()

View File

@@ -25,7 +25,7 @@ from app.modules.cms.schemas import (
CMSUsageResponse,
)
from app.modules.cms.services import content_page_service
from app.services.vendor_service import VendorService
from app.services.vendor_service import VendorService # noqa: MOD-004 - shared platform service
from models.database.user import User
vendor_service = VendorService()

View File

@@ -13,7 +13,7 @@ from sqlalchemy.orm import Session
from app.api.deps import get_current_vendor_from_cookie_or_header, get_db
from app.modules.cms.services import content_page_service
from app.services.platform_settings_service import platform_settings_service
from app.services.platform_settings_service import platform_settings_service # noqa: MOD-004 - shared platform service
from app.templates_config import templates
from models.database.user import User
from models.database.vendor import Vendor

View File

@@ -0,0 +1,24 @@
# app/modules/contracts/definition.py
"""
Contracts module definition.
Cross-module contracts and Protocol interfaces.
Infrastructure module - cannot be disabled.
"""
from app.modules.base import ModuleDefinition
contracts_module = ModuleDefinition(
code="contracts",
name="Module Contracts",
description="Cross-module contracts using Protocol pattern for type-safe inter-module communication.",
version="1.0.0",
is_core=True,
features=[
"service_protocols",
"cross_module_interfaces",
],
menu_items={}, # Infrastructure module - no UI
)
__all__ = ["contracts_module"]

View File

@@ -2,14 +2,23 @@
"""
Customers module database models.
Re-exports customer-related models from their source locations.
This is the canonical location for customer models. Module models are
automatically discovered and registered with SQLAlchemy's Base.metadata
at startup.
Usage:
from app.modules.customers.models import (
Customer,
CustomerAddress,
PasswordResetToken,
)
"""
from models.database.customer import (
from app.modules.customers.models.customer import (
Customer,
CustomerAddress,
)
from models.database.password_reset_token import PasswordResetToken
from app.modules.customers.models.password_reset_token import PasswordResetToken
__all__ = [
"Customer",

View File

@@ -0,0 +1,93 @@
# app/modules/customers/models/customer.py
"""
Customer database models.
Provides Customer and CustomerAddress models for vendor-scoped
customer management.
"""
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
Numeric,
String,
)
from sqlalchemy.orm import relationship
from app.core.database import Base
from models.database.base import TimestampMixin
class Customer(Base, TimestampMixin):
"""Customer model with vendor isolation."""
__tablename__ = "customers"
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False)
email = Column(
String(255), nullable=False, index=True
) # Unique within vendor scope
hashed_password = Column(String(255), nullable=False)
first_name = Column(String(100))
last_name = Column(String(100))
phone = Column(String(50))
customer_number = Column(
String(100), nullable=False, index=True
) # Vendor-specific ID
preferences = Column(JSON, default=dict)
marketing_consent = Column(Boolean, default=False)
last_order_date = Column(DateTime)
total_orders = Column(Integer, default=0)
total_spent = Column(Numeric(10, 2), default=0)
is_active = Column(Boolean, default=True, nullable=False)
# Language preference (NULL = use vendor storefront_language default)
# Supported: en, fr, de, lb
preferred_language = Column(String(5), nullable=True)
# Relationships
vendor = relationship("Vendor", back_populates="customers")
addresses = relationship("CustomerAddress", back_populates="customer")
orders = relationship("Order", back_populates="customer")
def __repr__(self):
return f"<Customer(id={self.id}, vendor_id={self.vendor_id}, email='{self.email}')>"
@property
def full_name(self):
if self.first_name and self.last_name:
return f"{self.first_name} {self.last_name}"
return self.email
class CustomerAddress(Base, TimestampMixin):
"""Customer address model for shipping and billing addresses."""
__tablename__ = "customer_addresses"
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False)
customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False)
address_type = Column(String(50), nullable=False) # 'billing', 'shipping'
first_name = Column(String(100), nullable=False)
last_name = Column(String(100), nullable=False)
company = Column(String(200))
address_line_1 = Column(String(255), nullable=False)
address_line_2 = Column(String(255))
city = Column(String(100), nullable=False)
postal_code = Column(String(20), nullable=False)
country_name = Column(String(100), nullable=False)
country_iso = Column(String(5), nullable=False)
is_default = Column(Boolean, default=False)
# Relationships
vendor = relationship("Vendor")
customer = relationship("Customer", back_populates="addresses")
def __repr__(self):
return f"<CustomerAddress(id={self.id}, customer_id={self.customer_id}, type='{self.address_type}')>"

View File

@@ -0,0 +1,91 @@
# app/modules/customers/models/password_reset_token.py
"""
Password reset token model for customer accounts.
Security features:
- Tokens are stored as SHA256 hashes, not plaintext
- Tokens expire after 1 hour
- Only one active token per customer (old tokens invalidated on new request)
"""
import hashlib
import secrets
from datetime import datetime, timedelta
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm import Session, relationship
from app.core.database import Base
class PasswordResetToken(Base):
"""Password reset token for customer accounts."""
__tablename__ = "password_reset_tokens"
# Token expiry in hours
TOKEN_EXPIRY_HOURS = 1
id = Column(Integer, primary_key=True, index=True)
customer_id = Column(
Integer, ForeignKey("customers.id", ondelete="CASCADE"), nullable=False
)
token_hash = Column(String(64), nullable=False, index=True)
expires_at = Column(DateTime, nullable=False)
used_at = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
# Relationships
customer = relationship("Customer")
def __repr__(self):
return f"<PasswordResetToken(id={self.id}, customer_id={self.customer_id}, expires_at={self.expires_at})>"
@staticmethod
def hash_token(token: str) -> str:
"""Hash a token using SHA256."""
return hashlib.sha256(token.encode()).hexdigest()
@classmethod
def create_for_customer(cls, db: Session, customer_id: int) -> str:
"""Create a new password reset token for a customer.
Invalidates any existing tokens for the customer.
Returns the plaintext token (to be sent via email).
"""
# Invalidate existing tokens for this customer
db.query(cls).filter(
cls.customer_id == customer_id,
cls.used_at.is_(None),
).delete()
# Generate new token
plaintext_token = secrets.token_urlsafe(32)
token_hash = cls.hash_token(plaintext_token)
# Create token record
token = cls(
customer_id=customer_id,
token_hash=token_hash,
expires_at=datetime.utcnow() + timedelta(hours=cls.TOKEN_EXPIRY_HOURS),
)
db.add(token)
db.flush()
return plaintext_token
@classmethod
def find_valid_token(cls, db: Session, plaintext_token: str) -> "PasswordResetToken | None":
"""Find a valid (not expired, not used) token."""
token_hash = cls.hash_token(plaintext_token)
return db.query(cls).filter(
cls.token_hash == token_hash,
cls.expires_at > datetime.utcnow(),
cls.used_at.is_(None),
).first()
def mark_used(self, db: Session) -> None:
"""Mark this token as used."""
self.used_at = datetime.utcnow()
db.flush()

View File

@@ -2,27 +2,68 @@
"""
Customers module Pydantic schemas.
Re-exports customer-related schemas from their source locations.
This is the canonical location for customer schemas.
Usage:
from app.modules.customers.schemas import (
CustomerRegister,
CustomerUpdate,
CustomerResponse,
)
"""
from models.schema.customer import (
from app.modules.customers.schemas.customer import (
# Registration & Authentication
CustomerRegister,
CustomerUpdate,
CustomerLogin,
CustomerPasswordChange,
# Customer Response
CustomerResponse,
CustomerListResponse,
# Address
CustomerAddressCreate,
CustomerAddressUpdate,
CustomerAddressResponse,
CustomerAddressListResponse,
# Preferences
CustomerPreferencesUpdate,
# Vendor Management
CustomerMessageResponse,
VendorCustomerListResponse,
CustomerDetailResponse,
CustomerOrderInfo,
CustomerOrdersResponse,
CustomerStatisticsResponse,
# Admin Management
AdminCustomerItem,
AdminCustomerListResponse,
AdminCustomerDetailResponse,
)
__all__ = [
# Registration & Authentication
"CustomerRegister",
"CustomerUpdate",
"CustomerLogin",
"CustomerPasswordChange",
# Customer Response
"CustomerResponse",
"CustomerListResponse",
# Address
"CustomerAddressCreate",
"CustomerAddressUpdate",
"CustomerAddressResponse",
"CustomerAddressListResponse",
# Preferences
"CustomerPreferencesUpdate",
# Vendor Management
"CustomerMessageResponse",
"VendorCustomerListResponse",
"CustomerDetailResponse",
"CustomerOrderInfo",
"CustomerOrdersResponse",
"CustomerStatisticsResponse",
# Admin Management
"AdminCustomerItem",
"AdminCustomerListResponse",
"AdminCustomerDetailResponse",
]

View File

@@ -0,0 +1,340 @@
# app/modules/customers/schemas/customer.py
"""
Pydantic schemas for customer-related operations.
Provides schemas for:
- Customer registration and authentication
- Customer profile management
- Customer addresses
- Admin customer management
"""
from datetime import datetime
from decimal import Decimal
from pydantic import BaseModel, EmailStr, Field, field_validator
# ============================================================================
# Customer Registration & Authentication
# ============================================================================
class CustomerRegister(BaseModel):
"""Schema for customer registration."""
email: EmailStr = Field(..., description="Customer email address")
password: str = Field(
..., min_length=8, description="Password (minimum 8 characters)"
)
first_name: str = Field(..., min_length=1, max_length=100)
last_name: str = Field(..., min_length=1, max_length=100)
phone: str | None = Field(None, max_length=50)
marketing_consent: bool = Field(default=False)
preferred_language: str | None = Field(
None, description="Preferred language (en, fr, de, lb)"
)
@field_validator("email")
@classmethod
def email_lowercase(cls, v: str) -> str:
"""Convert email to lowercase."""
return v.lower()
@field_validator("password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Validate password strength."""
if len(v) < 8:
raise ValueError("Password must be at least 8 characters")
if not any(char.isdigit() for char in v):
raise ValueError("Password must contain at least one digit")
if not any(char.isalpha() for char in v):
raise ValueError("Password must contain at least one letter")
return v
class CustomerUpdate(BaseModel):
"""Schema for updating customer profile."""
email: EmailStr | None = None
first_name: str | None = Field(None, min_length=1, max_length=100)
last_name: str | None = Field(None, min_length=1, max_length=100)
phone: str | None = Field(None, max_length=50)
marketing_consent: bool | None = None
preferred_language: str | None = Field(
None, description="Preferred language (en, fr, de, lb)"
)
@field_validator("email")
@classmethod
def email_lowercase(cls, v: str | None) -> str | None:
"""Convert email to lowercase."""
return v.lower() if v else None
class CustomerPasswordChange(BaseModel):
"""Schema for customer password change."""
current_password: str = Field(..., description="Current password")
new_password: str = Field(
..., min_length=8, description="New password (minimum 8 characters)"
)
confirm_password: str = Field(..., description="Confirm new password")
@field_validator("new_password")
@classmethod
def password_strength(cls, v: str) -> str:
"""Validate password strength."""
if len(v) < 8:
raise ValueError("Password must be at least 8 characters")
if not any(char.isdigit() for char in v):
raise ValueError("Password must contain at least one digit")
if not any(char.isalpha() for char in v):
raise ValueError("Password must contain at least one letter")
return v
# ============================================================================
# Customer Response
# ============================================================================
class CustomerResponse(BaseModel):
"""Schema for customer response (excludes password)."""
id: int
vendor_id: int
email: str
first_name: str | None
last_name: str | None
phone: str | None
customer_number: str
marketing_consent: bool
preferred_language: str | None
last_order_date: datetime | None
total_orders: int
total_spent: Decimal
is_active: bool
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
@property
def full_name(self) -> str:
"""Get customer full name."""
if self.first_name and self.last_name:
return f"{self.first_name} {self.last_name}"
return self.email
class CustomerListResponse(BaseModel):
"""Schema for paginated customer list."""
customers: list[CustomerResponse]
total: int
page: int
per_page: int
total_pages: int
# ============================================================================
# Customer Address
# ============================================================================
class CustomerAddressCreate(BaseModel):
"""Schema for creating customer address."""
address_type: str = Field(..., pattern="^(billing|shipping)$")
first_name: str = Field(..., min_length=1, max_length=100)
last_name: str = Field(..., min_length=1, max_length=100)
company: str | None = Field(None, max_length=200)
address_line_1: str = Field(..., min_length=1, max_length=255)
address_line_2: str | None = Field(None, max_length=255)
city: str = Field(..., min_length=1, max_length=100)
postal_code: str = Field(..., min_length=1, max_length=20)
country_name: str = Field(..., min_length=2, max_length=100)
country_iso: str = Field(..., min_length=2, max_length=5)
is_default: bool = Field(default=False)
class CustomerAddressUpdate(BaseModel):
"""Schema for updating customer address."""
address_type: str | None = Field(None, pattern="^(billing|shipping)$")
first_name: str | None = Field(None, min_length=1, max_length=100)
last_name: str | None = Field(None, min_length=1, max_length=100)
company: str | None = Field(None, max_length=200)
address_line_1: str | None = Field(None, min_length=1, max_length=255)
address_line_2: str | None = Field(None, max_length=255)
city: str | None = Field(None, min_length=1, max_length=100)
postal_code: str | None = Field(None, min_length=1, max_length=20)
country_name: str | None = Field(None, min_length=2, max_length=100)
country_iso: str | None = Field(None, min_length=2, max_length=5)
is_default: bool | None = None
class CustomerAddressResponse(BaseModel):
"""Schema for customer address response."""
id: int
vendor_id: int
customer_id: int
address_type: str
first_name: str
last_name: str
company: str | None
address_line_1: str
address_line_2: str | None
city: str
postal_code: str
country_name: str
country_iso: str
is_default: bool
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class CustomerAddressListResponse(BaseModel):
"""Schema for customer address list response."""
addresses: list[CustomerAddressResponse]
total: int
# ============================================================================
# Customer Preferences
# ============================================================================
class CustomerPreferencesUpdate(BaseModel):
"""Schema for updating customer preferences."""
marketing_consent: bool | None = None
preferred_language: str | None = Field(
None, description="Preferred language (en, fr, de, lb)"
)
currency: str | None = Field(None, max_length=3)
notification_preferences: dict[str, bool] | None = None
# ============================================================================
# Vendor Customer Management Response Schemas
# ============================================================================
class CustomerMessageResponse(BaseModel):
"""Simple message response for customer operations."""
message: str
class VendorCustomerListResponse(BaseModel):
"""Schema for vendor customer list with skip/limit pagination."""
customers: list[CustomerResponse] = []
total: int = 0
skip: int = 0
limit: int = 100
message: str | None = None
class CustomerDetailResponse(BaseModel):
"""Detailed customer response for vendor management."""
id: int | None = None
vendor_id: int | None = None
email: str | None = None
first_name: str | None = None
last_name: str | None = None
phone: str | None = None
customer_number: str | None = None
marketing_consent: bool | None = None
preferred_language: str | None = None
last_order_date: datetime | None = None
total_orders: int | None = None
total_spent: Decimal | None = None
is_active: bool | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
message: str | None = None
model_config = {"from_attributes": True}
class CustomerOrderInfo(BaseModel):
"""Basic order info for customer order history."""
id: int
order_number: str
status: str
total: Decimal
created_at: datetime
class CustomerOrdersResponse(BaseModel):
"""Response for customer order history."""
orders: list[CustomerOrderInfo] = []
total: int = 0
message: str | None = None
class CustomerStatisticsResponse(BaseModel):
"""Response for customer statistics."""
total: int = 0
active: int = 0
inactive: int = 0
with_orders: int = 0
total_spent: float = 0.0
total_orders: int = 0
avg_order_value: float = 0.0
# ============================================================================
# Admin Customer Management Response Schemas
# ============================================================================
class AdminCustomerItem(BaseModel):
"""Admin customer list item with vendor info."""
id: int
vendor_id: int
email: str
first_name: str | None = None
last_name: str | None = None
phone: str | None = None
customer_number: str
marketing_consent: bool = False
preferred_language: str | None = None
last_order_date: datetime | None = None
total_orders: int = 0
total_spent: float = 0.0
is_active: bool = True
created_at: datetime
updated_at: datetime
vendor_name: str | None = None
vendor_code: str | None = None
model_config = {"from_attributes": True}
class AdminCustomerListResponse(BaseModel):
"""Admin paginated customer list with skip/limit."""
customers: list[AdminCustomerItem] = []
total: int = 0
skip: int = 0
limit: int = 20
class AdminCustomerDetailResponse(AdminCustomerItem):
"""Detailed customer response for admin."""
pass

View File

@@ -2,18 +2,25 @@
"""
Customers module services.
Re-exports customer-related services from their source locations.
This is the canonical location for customer services.
Usage:
from app.modules.customers.services import (
customer_service,
admin_customer_service,
customer_address_service,
)
"""
from app.services.customer_service import (
from app.modules.customers.services.customer_service import (
customer_service,
CustomerService,
)
from app.services.admin_customer_service import (
from app.modules.customers.services.admin_customer_service import (
admin_customer_service,
AdminCustomerService,
)
from app.services.customer_address_service import (
from app.modules.customers.services.customer_address_service import (
customer_address_service,
CustomerAddressService,
)

View File

@@ -0,0 +1,242 @@
# app/modules/customers/services/admin_customer_service.py
"""
Admin customer management service.
Handles customer operations for admin users across all vendors.
"""
import logging
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions.customer import CustomerNotFoundException
from app.modules.customers.models import Customer
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class AdminCustomerService:
"""Service for admin-level customer management across vendors."""
def list_customers(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
is_active: bool | None = None,
skip: int = 0,
limit: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""
Get paginated list of customers across all vendors.
Args:
db: Database session
vendor_id: Optional vendor ID filter
search: Search by email, name, or customer number
is_active: Filter by active status
skip: Number of records to skip
limit: Maximum records to return
Returns:
Tuple of (customers list, total count)
"""
# Build query
query = db.query(Customer).join(Vendor, Customer.vendor_id == Vendor.id)
# Apply filters
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
if search:
search_term = f"%{search}%"
query = query.filter(
(Customer.email.ilike(search_term))
| (Customer.first_name.ilike(search_term))
| (Customer.last_name.ilike(search_term))
| (Customer.customer_number.ilike(search_term))
)
if is_active is not None:
query = query.filter(Customer.is_active == is_active)
# Get total count
total = query.count()
# Get paginated results with vendor info
customers = (
query.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
.order_by(Customer.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# Format response
result = []
for row in customers:
customer = row[0]
vendor_name = row[1]
vendor_code = row[2]
customer_dict = {
"id": customer.id,
"vendor_id": customer.vendor_id,
"email": customer.email,
"first_name": customer.first_name,
"last_name": customer.last_name,
"phone": customer.phone,
"customer_number": customer.customer_number,
"marketing_consent": customer.marketing_consent,
"preferred_language": customer.preferred_language,
"last_order_date": customer.last_order_date,
"total_orders": customer.total_orders,
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
"is_active": customer.is_active,
"created_at": customer.created_at,
"updated_at": customer.updated_at,
"vendor_name": vendor_name,
"vendor_code": vendor_code,
}
result.append(customer_dict)
return result, total
def get_customer_stats(
self,
db: Session,
vendor_id: int | None = None,
) -> dict[str, Any]:
"""
Get customer statistics.
Args:
db: Database session
vendor_id: Optional vendor ID filter
Returns:
Dict with customer statistics
"""
query = db.query(Customer)
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
total = query.count()
active = query.filter(Customer.is_active == True).count() # noqa: E712
inactive = query.filter(Customer.is_active == False).count() # noqa: E712
with_orders = query.filter(Customer.total_orders > 0).count()
# Total spent across all customers
total_spent_result = query.with_entities(func.sum(Customer.total_spent)).scalar()
total_spent = float(total_spent_result) if total_spent_result else 0
# Average order value
total_orders_result = query.with_entities(func.sum(Customer.total_orders)).scalar()
total_orders = int(total_orders_result) if total_orders_result else 0
avg_order_value = total_spent / total_orders if total_orders > 0 else 0
return {
"total": total,
"active": active,
"inactive": inactive,
"with_orders": with_orders,
"total_spent": total_spent,
"total_orders": total_orders,
"avg_order_value": round(avg_order_value, 2),
}
def get_customer(
self,
db: Session,
customer_id: int,
) -> dict[str, Any]:
"""
Get customer details by ID.
Args:
db: Database session
customer_id: Customer ID
Returns:
Customer dict with vendor info
Raises:
CustomerNotFoundException: If customer not found
"""
result = (
db.query(Customer)
.join(Vendor, Customer.vendor_id == Vendor.id)
.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
.filter(Customer.id == customer_id)
.first()
)
if not result:
raise CustomerNotFoundException(str(customer_id))
customer = result[0]
return {
"id": customer.id,
"vendor_id": customer.vendor_id,
"email": customer.email,
"first_name": customer.first_name,
"last_name": customer.last_name,
"phone": customer.phone,
"customer_number": customer.customer_number,
"marketing_consent": customer.marketing_consent,
"preferred_language": customer.preferred_language,
"last_order_date": customer.last_order_date,
"total_orders": customer.total_orders,
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
"is_active": customer.is_active,
"created_at": customer.created_at,
"updated_at": customer.updated_at,
"vendor_name": result[1],
"vendor_code": result[2],
}
def toggle_customer_status(
self,
db: Session,
customer_id: int,
admin_email: str,
) -> dict[str, Any]:
"""
Toggle customer active status.
Args:
db: Database session
customer_id: Customer ID
admin_email: Admin user email for logging
Returns:
Dict with customer ID, new status, and message
Raises:
CustomerNotFoundException: If customer not found
"""
customer = db.query(Customer).filter(Customer.id == customer_id).first()
if not customer:
raise CustomerNotFoundException(str(customer_id))
customer.is_active = not customer.is_active
db.flush()
db.refresh(customer)
status = "activated" if customer.is_active else "deactivated"
logger.info(f"Customer {customer.email} {status} by admin {admin_email}")
return {
"id": customer.id,
"is_active": customer.is_active,
"message": f"Customer {status} successfully",
}
# Singleton instance
admin_customer_service = AdminCustomerService()

View File

@@ -0,0 +1,314 @@
# app/modules/customers/services/customer_address_service.py
"""
Customer Address Service
Business logic for managing customer addresses with vendor isolation.
"""
import logging
from sqlalchemy.orm import Session
from app.exceptions import (
AddressLimitExceededException,
AddressNotFoundException,
)
from app.modules.customers.models import CustomerAddress
from app.modules.customers.schemas import CustomerAddressCreate, CustomerAddressUpdate
logger = logging.getLogger(__name__)
class CustomerAddressService:
"""Service for managing customer addresses with vendor isolation."""
MAX_ADDRESSES_PER_CUSTOMER = 10
def list_addresses(
self, db: Session, vendor_id: int, customer_id: int
) -> list[CustomerAddress]:
"""
Get all addresses for a customer.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
Returns:
List of customer addresses
"""
return (
db.query(CustomerAddress)
.filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
)
.order_by(CustomerAddress.is_default.desc(), CustomerAddress.created_at.desc())
.all()
)
def get_address(
self, db: Session, vendor_id: int, customer_id: int, address_id: int
) -> CustomerAddress:
"""
Get a specific address with ownership validation.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
Returns:
Customer address
Raises:
AddressNotFoundException: If address not found or doesn't belong to customer
"""
address = (
db.query(CustomerAddress)
.filter(
CustomerAddress.id == address_id,
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
)
.first()
)
if not address:
raise AddressNotFoundException(address_id)
return address
def get_default_address(
self, db: Session, vendor_id: int, customer_id: int, address_type: str
) -> CustomerAddress | None:
"""
Get the default address for a specific type.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_type: 'shipping' or 'billing'
Returns:
Default address or None if not set
"""
return (
db.query(CustomerAddress)
.filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
CustomerAddress.address_type == address_type,
CustomerAddress.is_default == True, # noqa: E712
)
.first()
)
def create_address(
self,
db: Session,
vendor_id: int,
customer_id: int,
address_data: CustomerAddressCreate,
) -> CustomerAddress:
"""
Create a new address for a customer.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_data: Address creation data
Returns:
Created customer address
Raises:
AddressLimitExceededException: If max addresses reached
"""
# Check address limit
current_count = (
db.query(CustomerAddress)
.filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
)
.count()
)
if current_count >= self.MAX_ADDRESSES_PER_CUSTOMER:
raise AddressLimitExceededException(self.MAX_ADDRESSES_PER_CUSTOMER)
# If setting as default, clear other defaults of same type
if address_data.is_default:
self._clear_other_defaults(
db, vendor_id, customer_id, address_data.address_type
)
# Create the address
address = CustomerAddress(
vendor_id=vendor_id,
customer_id=customer_id,
address_type=address_data.address_type,
first_name=address_data.first_name,
last_name=address_data.last_name,
company=address_data.company,
address_line_1=address_data.address_line_1,
address_line_2=address_data.address_line_2,
city=address_data.city,
postal_code=address_data.postal_code,
country_name=address_data.country_name,
country_iso=address_data.country_iso,
is_default=address_data.is_default,
)
db.add(address)
db.flush()
logger.info(
f"Created address {address.id} for customer {customer_id} "
f"(type={address_data.address_type}, default={address_data.is_default})"
)
return address
def update_address(
self,
db: Session,
vendor_id: int,
customer_id: int,
address_id: int,
address_data: CustomerAddressUpdate,
) -> CustomerAddress:
"""
Update an existing address.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
address_data: Address update data
Returns:
Updated customer address
Raises:
AddressNotFoundException: If address not found
"""
address = self.get_address(db, vendor_id, customer_id, address_id)
# Update only provided fields
update_data = address_data.model_dump(exclude_unset=True)
# Handle default flag - clear others if setting to default
if update_data.get("is_default") is True:
# Use updated type if provided, otherwise current type
address_type = update_data.get("address_type", address.address_type)
self._clear_other_defaults(
db, vendor_id, customer_id, address_type, exclude_id=address_id
)
for field, value in update_data.items():
setattr(address, field, value)
db.flush()
logger.info(f"Updated address {address_id} for customer {customer_id}")
return address
def delete_address(
self, db: Session, vendor_id: int, customer_id: int, address_id: int
) -> None:
"""
Delete an address.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
Raises:
AddressNotFoundException: If address not found
"""
address = self.get_address(db, vendor_id, customer_id, address_id)
db.delete(address)
db.flush()
logger.info(f"Deleted address {address_id} for customer {customer_id}")
def set_default(
self, db: Session, vendor_id: int, customer_id: int, address_id: int
) -> CustomerAddress:
"""
Set an address as the default for its type.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
Returns:
Updated customer address
Raises:
AddressNotFoundException: If address not found
"""
address = self.get_address(db, vendor_id, customer_id, address_id)
# Clear other defaults of same type
self._clear_other_defaults(
db, vendor_id, customer_id, address.address_type, exclude_id=address_id
)
# Set this one as default
address.is_default = True
db.flush()
logger.info(
f"Set address {address_id} as default {address.address_type} "
f"for customer {customer_id}"
)
return address
def _clear_other_defaults(
self,
db: Session,
vendor_id: int,
customer_id: int,
address_type: str,
exclude_id: int | None = None,
) -> None:
"""
Clear the default flag on other addresses of the same type.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_type: 'shipping' or 'billing'
exclude_id: Address ID to exclude from clearing
"""
query = db.query(CustomerAddress).filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
CustomerAddress.address_type == address_type,
CustomerAddress.is_default == True, # noqa: E712
)
if exclude_id:
query = query.filter(CustomerAddress.id != exclude_id)
query.update({"is_default": False}, synchronize_session=False)
# Singleton instance
customer_address_service = CustomerAddressService()

View File

@@ -0,0 +1,659 @@
# app/modules/customers/services/customer_service.py
"""
Customer management service.
Handles customer registration, authentication, and profile management
with complete vendor isolation.
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions.customer import (
CustomerNotActiveException,
CustomerNotFoundException,
CustomerValidationException,
DuplicateCustomerEmailException,
InvalidCustomerCredentialsException,
InvalidPasswordResetTokenException,
PasswordTooShortException,
)
from app.exceptions.vendor import VendorNotActiveException, VendorNotFoundException
from app.services.auth_service import AuthService
from app.modules.customers.models import Customer, PasswordResetToken
from app.modules.customers.schemas import CustomerRegister, CustomerUpdate
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class CustomerService:
"""Service for managing vendor-scoped customers."""
def __init__(self):
self.auth_service = AuthService()
def register_customer(
self, db: Session, vendor_id: int, customer_data: CustomerRegister
) -> Customer:
"""
Register a new customer for a specific vendor.
Args:
db: Database session
vendor_id: Vendor ID
customer_data: Customer registration data
Returns:
Customer: Created customer object
Raises:
VendorNotFoundException: If vendor doesn't exist
VendorNotActiveException: If vendor is not active
DuplicateCustomerEmailException: If email already exists for this vendor
CustomerValidationException: If customer data is invalid
"""
# Verify vendor exists and is active
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
if not vendor.is_active:
raise VendorNotActiveException(vendor.vendor_code)
# Check if email already exists for this vendor
existing_customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == customer_data.email.lower(),
)
)
.first()
)
if existing_customer:
raise DuplicateCustomerEmailException(
customer_data.email, vendor.vendor_code
)
# Generate unique customer number for this vendor
customer_number = self._generate_customer_number(
db, vendor_id, vendor.vendor_code
)
# Hash password
hashed_password = self.auth_service.hash_password(customer_data.password)
# Create customer
customer = Customer(
vendor_id=vendor_id,
email=customer_data.email.lower(),
hashed_password=hashed_password,
first_name=customer_data.first_name,
last_name=customer_data.last_name,
phone=customer_data.phone,
customer_number=customer_number,
marketing_consent=(
customer_data.marketing_consent
if hasattr(customer_data, "marketing_consent")
else False
),
is_active=True,
)
try:
db.add(customer)
db.flush()
db.refresh(customer)
logger.info(
f"Customer registered successfully: {customer.email} "
f"(ID: {customer.id}, Number: {customer.customer_number}) "
f"for vendor {vendor.vendor_code}"
)
return customer
except Exception as e:
logger.error(f"Error registering customer: {str(e)}")
raise CustomerValidationException(
message="Failed to register customer", details={"error": str(e)}
)
def login_customer(
self, db: Session, vendor_id: int, credentials
) -> dict[str, Any]:
"""
Authenticate customer and generate JWT token.
Args:
db: Database session
vendor_id: Vendor ID
credentials: Login credentials (UserLogin schema)
Returns:
Dict containing customer and token data
Raises:
VendorNotFoundException: If vendor doesn't exist
InvalidCustomerCredentialsException: If credentials are invalid
CustomerNotActiveException: If customer account is inactive
"""
# Verify vendor exists
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
# Find customer by email (vendor-scoped)
customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == credentials.email_or_username.lower(),
)
)
.first()
)
if not customer:
raise InvalidCustomerCredentialsException()
# Verify password using auth_manager directly
if not self.auth_service.auth_manager.verify_password(
credentials.password, customer.hashed_password
):
raise InvalidCustomerCredentialsException()
# Check if customer is active
if not customer.is_active:
raise CustomerNotActiveException(customer.email)
# Generate JWT token with customer context
from jose import jwt
auth_manager = self.auth_service.auth_manager
expires_delta = timedelta(minutes=auth_manager.token_expire_minutes)
expire = datetime.now(UTC) + expires_delta
payload = {
"sub": str(customer.id),
"email": customer.email,
"vendor_id": vendor_id,
"type": "customer",
"exp": expire,
"iat": datetime.now(UTC),
}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
token_data = {
"access_token": token,
"token_type": "bearer",
"expires_in": auth_manager.token_expire_minutes * 60,
}
logger.info(
f"Customer login successful: {customer.email} "
f"for vendor {vendor.vendor_code}"
)
return {"customer": customer, "token_data": token_data}
def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
"""
Get customer by ID with vendor isolation.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Customer: Customer object
Raises:
CustomerNotFoundException: If customer not found
"""
customer = (
db.query(Customer)
.filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
.first()
)
if not customer:
raise CustomerNotFoundException(str(customer_id))
return customer
def get_customer_by_email(
self, db: Session, vendor_id: int, email: str
) -> Customer | None:
"""
Get customer by email (vendor-scoped).
Args:
db: Database session
vendor_id: Vendor ID
email: Customer email
Returns:
Optional[Customer]: Customer object or None
"""
return (
db.query(Customer)
.filter(
and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
)
.first()
)
def get_vendor_customers(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
search: str | None = None,
is_active: bool | None = None,
) -> tuple[list[Customer], int]:
"""
Get all customers for a vendor with filtering and pagination.
Args:
db: Database session
vendor_id: Vendor ID
skip: Pagination offset
limit: Pagination limit
search: Search in name/email
is_active: Filter by active status
Returns:
Tuple of (customers, total_count)
"""
from sqlalchemy import or_
query = db.query(Customer).filter(Customer.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
or_(
Customer.email.ilike(search_pattern),
Customer.first_name.ilike(search_pattern),
Customer.last_name.ilike(search_pattern),
Customer.customer_number.ilike(search_pattern),
)
)
if is_active is not None:
query = query.filter(Customer.is_active == is_active)
# Order by most recent first
query = query.order_by(Customer.created_at.desc())
total = query.count()
customers = query.offset(skip).limit(limit).all()
return customers, total
def get_customer_orders(
self,
db: Session,
vendor_id: int,
customer_id: int,
skip: int = 0,
limit: int = 50,
) -> tuple[list, int]:
"""
Get orders for a specific customer.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
skip: Pagination offset
limit: Pagination limit
Returns:
Tuple of (orders, total_count)
Raises:
CustomerNotFoundException: If customer not found
"""
from models.database.order import Order
# Verify customer belongs to vendor
self.get_customer(db, vendor_id, customer_id)
# Get customer orders
query = (
db.query(Order)
.filter(
Order.customer_id == customer_id,
Order.vendor_id == vendor_id,
)
.order_by(Order.created_at.desc())
)
total = query.count()
orders = query.offset(skip).limit(limit).all()
return orders, total
def get_customer_statistics(
self, db: Session, vendor_id: int, customer_id: int
) -> dict:
"""
Get detailed statistics for a customer.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Dict with customer statistics
"""
from sqlalchemy import func
from models.database.order import Order
customer = self.get_customer(db, vendor_id, customer_id)
# Get order statistics
order_stats = (
db.query(
func.count(Order.id).label("total_orders"),
func.sum(Order.total_cents).label("total_spent_cents"),
func.avg(Order.total_cents).label("avg_order_cents"),
func.max(Order.created_at).label("last_order_date"),
)
.filter(Order.customer_id == customer_id)
.first()
)
total_orders = order_stats.total_orders or 0
total_spent_cents = order_stats.total_spent_cents or 0
avg_order_cents = order_stats.avg_order_cents or 0
return {
"customer_id": customer_id,
"total_orders": total_orders,
"total_spent": total_spent_cents / 100, # Convert to euros
"average_order_value": avg_order_cents / 100 if avg_order_cents else 0.0,
"last_order_date": order_stats.last_order_date,
"member_since": customer.created_at,
"is_active": customer.is_active,
}
def toggle_customer_status(
self, db: Session, vendor_id: int, customer_id: int
) -> Customer:
"""
Toggle customer active status.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Customer: Updated customer
"""
customer = self.get_customer(db, vendor_id, customer_id)
customer.is_active = not customer.is_active
db.flush()
db.refresh(customer)
action = "activated" if customer.is_active else "deactivated"
logger.info(f"Customer {action}: {customer.email} (ID: {customer.id})")
return customer
def update_customer(
self,
db: Session,
vendor_id: int,
customer_id: int,
customer_data: CustomerUpdate,
) -> Customer:
"""
Update customer profile.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
customer_data: Updated customer data
Returns:
Customer: Updated customer object
Raises:
CustomerNotFoundException: If customer not found
CustomerValidationException: If update data is invalid
"""
customer = self.get_customer(db, vendor_id, customer_id)
# Update fields
update_data = customer_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "email" and value:
# Check if new email already exists for this vendor
existing = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == value.lower(),
Customer.id != customer_id,
)
)
.first()
)
if existing:
raise DuplicateCustomerEmailException(value, "vendor")
setattr(customer, field, value.lower())
elif hasattr(customer, field):
setattr(customer, field, value)
try:
db.flush()
db.refresh(customer)
logger.info(f"Customer updated: {customer.email} (ID: {customer.id})")
return customer
except Exception as e:
logger.error(f"Error updating customer: {str(e)}")
raise CustomerValidationException(
message="Failed to update customer", details={"error": str(e)}
)
def deactivate_customer(
self, db: Session, vendor_id: int, customer_id: int
) -> Customer:
"""
Deactivate customer account.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Customer: Deactivated customer object
Raises:
CustomerNotFoundException: If customer not found
"""
customer = self.get_customer(db, vendor_id, customer_id)
customer.is_active = False
db.flush()
db.refresh(customer)
logger.info(f"Customer deactivated: {customer.email} (ID: {customer.id})")
return customer
def update_customer_stats(
self, db: Session, customer_id: int, order_total: float
) -> None:
"""
Update customer statistics after order.
Args:
db: Database session
customer_id: Customer ID
order_total: Order total amount
"""
customer = db.query(Customer).filter(Customer.id == customer_id).first()
if customer:
customer.total_orders += 1
customer.total_spent += order_total
customer.last_order_date = datetime.utcnow()
logger.debug(f"Updated stats for customer {customer.email}")
def _generate_customer_number(
self, db: Session, vendor_id: int, vendor_code: str
) -> str:
"""
Generate unique customer number for vendor.
Format: {VENDOR_CODE}-CUST-{SEQUENCE}
Example: VENDORA-CUST-00001
Args:
db: Database session
vendor_id: Vendor ID
vendor_code: Vendor code
Returns:
str: Unique customer number
"""
# Get count of customers for this vendor
count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
# Generate number with padding
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
# Ensure uniqueness (in case of deletions)
while (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.customer_number == customer_number,
)
)
.first()
):
count += 1
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
return customer_number
def get_customer_for_password_reset(
self, db: Session, vendor_id: int, email: str
) -> Customer | None:
"""
Get active customer by email for password reset.
Args:
db: Database session
vendor_id: Vendor ID
email: Customer email
Returns:
Customer if found and active, None otherwise
"""
return (
db.query(Customer)
.filter(
Customer.vendor_id == vendor_id,
Customer.email == email.lower(),
Customer.is_active == True, # noqa: E712
)
.first()
)
def validate_and_reset_password(
self,
db: Session,
vendor_id: int,
reset_token: str,
new_password: str,
) -> Customer:
"""
Validate reset token and update customer password.
Args:
db: Database session
vendor_id: Vendor ID
reset_token: Password reset token from email
new_password: New password
Returns:
Customer: Updated customer
Raises:
PasswordTooShortException: If password too short
InvalidPasswordResetTokenException: If token invalid/expired
CustomerNotActiveException: If customer not active
"""
# Validate password length
if len(new_password) < 8:
raise PasswordTooShortException(min_length=8)
# Find valid token
token_record = PasswordResetToken.find_valid_token(db, reset_token)
if not token_record:
raise InvalidPasswordResetTokenException()
# Get the customer and verify they belong to this vendor
customer = (
db.query(Customer)
.filter(Customer.id == token_record.customer_id)
.first()
)
if not customer or customer.vendor_id != vendor_id:
raise InvalidPasswordResetTokenException()
if not customer.is_active:
raise CustomerNotActiveException(customer.email)
# Hash the new password and update customer
hashed_password = self.auth_service.hash_password(new_password)
customer.hashed_password = hashed_password
# Mark token as used
token_record.mark_used(db)
logger.info(f"Password reset completed for customer {customer.id}")
return customer
# Singleton instance
customer_service = CustomerService()

View File

@@ -2,16 +2,14 @@
"""
Dev-Tools module services.
This module re-exports services from their current locations.
In future cleanup phases, the actual service implementations
may be moved here.
Provides code quality scanning and test running functionality.
Services:
- code_quality_service: Code quality scanning and violation management
- test_runner_service: Test execution and results management
"""
from app.services.code_quality_service import (
from app.modules.dev_tools.services.code_quality_service import (
code_quality_service,
CodeQualityService,
VALIDATOR_ARCHITECTURE,
@@ -21,7 +19,7 @@ from app.services.code_quality_service import (
VALIDATOR_SCRIPTS,
VALIDATOR_NAMES,
)
from app.services.test_runner_service import (
from app.modules.dev_tools.services.test_runner_service import (
test_runner_service,
TestRunnerService,
)

View File

@@ -0,0 +1,820 @@
"""
Code Quality Service
Business logic for managing code quality scans and violations
Supports multiple validator types: architecture, security, performance
"""
import json
import logging
import subprocess
from datetime import datetime, UTC
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from app.exceptions import (
ScanParseException,
ScanTimeoutException,
ViolationNotFoundException,
)
from app.modules.dev_tools.models import (
ArchitectureScan,
ArchitectureViolation,
ViolationAssignment,
ViolationComment,
)
logger = logging.getLogger(__name__)
# Validator type constants
VALIDATOR_ARCHITECTURE = "architecture"
VALIDATOR_SECURITY = "security"
VALIDATOR_PERFORMANCE = "performance"
VALID_VALIDATOR_TYPES = [VALIDATOR_ARCHITECTURE, VALIDATOR_SECURITY, VALIDATOR_PERFORMANCE]
# Map validator types to their scripts
VALIDATOR_SCRIPTS = {
VALIDATOR_ARCHITECTURE: "scripts/validate_architecture.py",
VALIDATOR_SECURITY: "scripts/validate_security.py",
VALIDATOR_PERFORMANCE: "scripts/validate_performance.py",
}
# Human-readable names
VALIDATOR_NAMES = {
VALIDATOR_ARCHITECTURE: "Architecture",
VALIDATOR_SECURITY: "Security",
VALIDATOR_PERFORMANCE: "Performance",
}
class CodeQualityService:
"""Service for managing code quality scans and violations"""
def run_scan(
self,
db: Session,
triggered_by: str = "manual",
validator_type: str = VALIDATOR_ARCHITECTURE,
) -> ArchitectureScan:
"""
Run a code quality validator and store results in database
Args:
db: Database session
triggered_by: Who/what triggered the scan ('manual', 'scheduled', 'ci/cd')
validator_type: Type of validator ('architecture', 'security', 'performance')
Returns:
ArchitectureScan object with results
Raises:
ValueError: If validator_type is invalid
ScanTimeoutException: If validator times out
ScanParseException: If validator output cannot be parsed
"""
if validator_type not in VALID_VALIDATOR_TYPES:
raise ValueError(
f"Invalid validator type: {validator_type}. "
f"Must be one of: {VALID_VALIDATOR_TYPES}"
)
script_path = VALIDATOR_SCRIPTS[validator_type]
validator_name = VALIDATOR_NAMES[validator_type]
logger.info(
f"Starting {validator_name} scan (triggered by: {triggered_by})"
)
# Get git commit hash
git_commit = self._get_git_commit_hash()
# Run validator with JSON output
start_time = datetime.now()
try:
result = subprocess.run(
["python", script_path, "--json"],
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
)
except subprocess.TimeoutExpired:
logger.error(f"{validator_name} scan timed out after 5 minutes")
raise ScanTimeoutException(timeout_seconds=300)
duration = (datetime.now() - start_time).total_seconds()
# Parse JSON output (get only the JSON part, skip progress messages)
try:
# Find the JSON part in stdout
lines = result.stdout.strip().split("\n")
json_start = -1
for i, line in enumerate(lines):
if line.strip().startswith("{"):
json_start = i
break
if json_start == -1:
raise ValueError("No JSON output found")
json_output = "\n".join(lines[json_start:])
data = json.loads(json_output)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse {validator_name} validator output: {e}")
logger.error(f"Stdout: {result.stdout}")
logger.error(f"Stderr: {result.stderr}")
raise ScanParseException(reason=str(e))
# Create scan record
scan = ArchitectureScan(
timestamp=datetime.now(),
validator_type=validator_type,
total_files=data.get("files_checked", 0),
total_violations=data.get("total_violations", 0),
errors=data.get("errors", 0),
warnings=data.get("warnings", 0),
duration_seconds=duration,
triggered_by=triggered_by,
git_commit_hash=git_commit,
)
db.add(scan)
db.flush() # Get scan.id
# Create violation records
violations_data = data.get("violations", [])
logger.info(f"Creating {len(violations_data)} {validator_name} violation records")
for v in violations_data:
violation = ArchitectureViolation(
scan_id=scan.id,
validator_type=validator_type,
rule_id=v["rule_id"],
rule_name=v["rule_name"],
severity=v["severity"],
file_path=v["file_path"],
line_number=v["line_number"],
message=v["message"],
context=v.get("context", ""),
suggestion=v.get("suggestion", ""),
status="open",
)
db.add(violation)
db.flush()
db.refresh(scan)
logger.info(
f"{validator_name} scan completed: {scan.total_violations} violations found"
)
return scan
def run_all_scans(
self, db: Session, triggered_by: str = "manual"
) -> list[ArchitectureScan]:
"""
Run all validators and return list of scans
Args:
db: Database session
triggered_by: Who/what triggered the scan
Returns:
List of ArchitectureScan objects (one per validator)
"""
results = []
for validator_type in VALID_VALIDATOR_TYPES:
try:
scan = self.run_scan(db, triggered_by, validator_type)
results.append(scan)
except Exception as e:
logger.error(f"Failed to run {validator_type} scan: {e}")
# Continue with other validators even if one fails
return results
def get_latest_scan(
self, db: Session, validator_type: str = None
) -> ArchitectureScan | None:
"""
Get the most recent scan
Args:
db: Database session
validator_type: Optional filter by validator type
Returns:
Most recent ArchitectureScan or None
"""
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
if validator_type:
query = query.filter(ArchitectureScan.validator_type == validator_type)
return query.first()
def get_latest_scans_by_type(self, db: Session) -> dict[str, ArchitectureScan]:
"""
Get the most recent scan for each validator type
Returns:
Dictionary mapping validator_type to its latest scan
"""
result = {}
for vtype in VALID_VALIDATOR_TYPES:
scan = self.get_latest_scan(db, validator_type=vtype)
if scan:
result[vtype] = scan
return result
def get_scan_by_id(self, db: Session, scan_id: int) -> ArchitectureScan | None:
"""Get scan by ID"""
return db.query(ArchitectureScan).filter(ArchitectureScan.id == scan_id).first()
def create_pending_scan(
self, db: Session, validator_type: str, triggered_by: str
) -> ArchitectureScan:
"""
Create a new scan record with pending status.
Args:
db: Database session
validator_type: Type of validator (architecture, security, performance)
triggered_by: Who triggered the scan (e.g., "manual:username")
Returns:
The created ArchitectureScan record with ID populated
"""
scan = ArchitectureScan(
timestamp=datetime.now(UTC),
validator_type=validator_type,
status="pending",
triggered_by=triggered_by,
)
db.add(scan)
db.flush() # Get scan.id
return scan
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
"""
Get all currently running scans (pending or running status).
Returns:
List of scans with status 'pending' or 'running', newest first
"""
return (
db.query(ArchitectureScan)
.filter(ArchitectureScan.status.in_(["pending", "running"]))
.order_by(ArchitectureScan.timestamp.desc())
.all()
)
def get_scan_history(
self, db: Session, limit: int = 30, validator_type: str = None
) -> list[ArchitectureScan]:
"""
Get scan history for trend graphs
Args:
db: Database session
limit: Maximum number of scans to return
validator_type: Optional filter by validator type
Returns:
List of ArchitectureScan objects, newest first
"""
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
if validator_type:
query = query.filter(ArchitectureScan.validator_type == validator_type)
return query.limit(limit).all()
def get_violations(
self,
db: Session,
scan_id: int = None,
validator_type: str = None,
severity: str = None,
status: str = None,
rule_id: str = None,
file_path: str = None,
limit: int = 100,
offset: int = 0,
) -> tuple[list[ArchitectureViolation], int]:
"""
Get violations with filtering and pagination
Args:
db: Database session
scan_id: Filter by scan ID (if None, use latest scan(s))
validator_type: Filter by validator type
severity: Filter by severity ('error', 'warning')
status: Filter by status ('open', 'assigned', 'resolved', etc.)
rule_id: Filter by rule ID
file_path: Filter by file path (partial match)
limit: Page size
offset: Page offset
Returns:
Tuple of (violations list, total count)
"""
# Build query
query = db.query(ArchitectureViolation)
# If scan_id specified, filter by it
if scan_id is not None:
query = query.filter(ArchitectureViolation.scan_id == scan_id)
else:
# If no scan_id, get violations from latest scan(s)
if validator_type:
# Get latest scan for specific validator type
latest_scan = self.get_latest_scan(db, validator_type)
if not latest_scan:
return [], 0
query = query.filter(ArchitectureViolation.scan_id == latest_scan.id)
else:
# Get violations from latest scans of all types
latest_scans = self.get_latest_scans_by_type(db)
if not latest_scans:
return [], 0
scan_ids = [s.id for s in latest_scans.values()]
query = query.filter(ArchitectureViolation.scan_id.in_(scan_ids))
# Apply validator_type filter if specified (for scan_id queries)
if validator_type and scan_id is not None:
query = query.filter(ArchitectureViolation.validator_type == validator_type)
# Apply other filters
if severity:
query = query.filter(ArchitectureViolation.severity == severity)
if status:
query = query.filter(ArchitectureViolation.status == status)
if rule_id:
query = query.filter(ArchitectureViolation.rule_id == rule_id)
if file_path:
query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
# Get total count
total = query.count()
# Get page of results
violations = (
query.order_by(
ArchitectureViolation.severity.desc(),
ArchitectureViolation.validator_type,
ArchitectureViolation.file_path,
)
.limit(limit)
.offset(offset)
.all()
)
return violations, total
def get_violation_by_id(
self, db: Session, violation_id: int
) -> ArchitectureViolation | None:
"""Get single violation with details"""
return (
db.query(ArchitectureViolation)
.filter(ArchitectureViolation.id == violation_id)
.first()
)
def assign_violation(
self,
db: Session,
violation_id: int,
user_id: int,
assigned_by: int,
due_date: datetime = None,
priority: str = "medium",
) -> ViolationAssignment:
"""
Assign violation to a developer
Args:
db: Database session
violation_id: Violation ID
user_id: User to assign to
assigned_by: User who is assigning
due_date: Due date (optional)
priority: Priority level ('low', 'medium', 'high', 'critical')
Returns:
ViolationAssignment object
"""
# Update violation status
violation = self.get_violation_by_id(db, violation_id)
if violation:
violation.status = "assigned"
violation.assigned_to = user_id
# Create assignment record
assignment = ViolationAssignment(
violation_id=violation_id,
user_id=user_id,
assigned_by=assigned_by,
due_date=due_date,
priority=priority,
)
db.add(assignment)
db.flush()
logger.info(f"Violation {violation_id} assigned to user {user_id}")
return assignment
def resolve_violation(
self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
) -> ArchitectureViolation:
"""
Mark violation as resolved
Args:
db: Database session
violation_id: Violation ID
resolved_by: User who resolved it
resolution_note: Note about resolution
Returns:
Updated ArchitectureViolation object
"""
violation = self.get_violation_by_id(db, violation_id)
if not violation:
raise ViolationNotFoundException(violation_id)
violation.status = "resolved"
violation.resolved_at = datetime.now()
violation.resolved_by = resolved_by
violation.resolution_note = resolution_note
db.flush()
logger.info(f"Violation {violation_id} resolved by user {resolved_by}")
return violation
def ignore_violation(
self, db: Session, violation_id: int, ignored_by: int, reason: str
) -> ArchitectureViolation:
"""
Mark violation as ignored/won't fix
Args:
db: Database session
violation_id: Violation ID
ignored_by: User who ignored it
reason: Reason for ignoring
Returns:
Updated ArchitectureViolation object
"""
violation = self.get_violation_by_id(db, violation_id)
if not violation:
raise ViolationNotFoundException(violation_id)
violation.status = "ignored"
violation.resolved_at = datetime.now()
violation.resolved_by = ignored_by
violation.resolution_note = f"Ignored: {reason}"
db.flush()
logger.info(f"Violation {violation_id} ignored by user {ignored_by}")
return violation
def add_comment(
self, db: Session, violation_id: int, user_id: int, comment: str
) -> ViolationComment:
"""
Add comment to violation
Args:
db: Database session
violation_id: Violation ID
user_id: User posting comment
comment: Comment text
Returns:
ViolationComment object
"""
comment_obj = ViolationComment(
violation_id=violation_id, user_id=user_id, comment=comment
)
db.add(comment_obj)
db.flush()
logger.info(f"Comment added to violation {violation_id} by user {user_id}")
return comment_obj
def get_dashboard_stats(
self, db: Session, validator_type: str = None
) -> dict:
"""
Get statistics for dashboard
Args:
db: Database session
validator_type: Optional filter by validator type. If None, returns combined stats.
Returns:
Dictionary with various statistics including per-validator breakdown
"""
# Get latest scans by type
latest_scans = self.get_latest_scans_by_type(db)
if not latest_scans:
return self._empty_dashboard_stats()
# If specific validator type requested
if validator_type and validator_type in latest_scans:
scan = latest_scans[validator_type]
return self._get_stats_for_scan(db, scan, validator_type)
# Combined stats across all validators
return self._get_combined_stats(db, latest_scans)
def _empty_dashboard_stats(self) -> dict:
"""Return empty dashboard stats structure"""
return {
"total_violations": 0,
"errors": 0,
"warnings": 0,
"info": 0,
"open": 0,
"assigned": 0,
"resolved": 0,
"ignored": 0,
"technical_debt_score": 100,
"trend": [],
"by_severity": {},
"by_rule": {},
"by_module": {},
"top_files": [],
"last_scan": None,
"by_validator": {},
}
def _get_stats_for_scan(
self, db: Session, scan: ArchitectureScan, validator_type: str
) -> dict:
"""Get stats for a single scan/validator type"""
# Get violation counts by status
status_counts = (
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.status)
.all()
)
status_dict = {status: count for status, count in status_counts}
# Get violations by severity
severity_counts = (
db.query(
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.severity)
.all()
)
by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule
rule_counts = (
db.query(
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.rule_id)
.all()
)
by_rule = {
rule: count
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
}
# Get top violating files
file_counts = (
db.query(
ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label("count"),
)
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.file_path)
.order_by(desc("count"))
.limit(10)
.all()
)
top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module
by_module = self._get_violations_by_module(db, scan.id)
# Get trend for this validator type
trend_scans = self.get_scan_history(db, limit=7, validator_type=validator_type)
trend = [
{
"timestamp": s.timestamp.isoformat(),
"violations": s.total_violations,
"errors": s.errors,
"warnings": s.warnings,
}
for s in reversed(trend_scans)
]
return {
"total_violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
"info": by_severity.get("info", 0),
"open": status_dict.get("open", 0),
"assigned": status_dict.get("assigned", 0),
"resolved": status_dict.get("resolved", 0),
"ignored": status_dict.get("ignored", 0),
"technical_debt_score": self._calculate_score(scan.errors, scan.warnings),
"trend": trend,
"by_severity": by_severity,
"by_rule": by_rule,
"by_module": by_module,
"top_files": top_files,
"last_scan": scan.timestamp.isoformat(),
"validator_type": validator_type,
"by_validator": {
validator_type: {
"total_violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
"last_scan": scan.timestamp.isoformat(),
}
},
}
def _get_combined_stats(
self, db: Session, latest_scans: dict[str, ArchitectureScan]
) -> dict:
"""Get combined stats across all validators"""
# Aggregate totals
total_violations = sum(s.total_violations for s in latest_scans.values())
total_errors = sum(s.errors for s in latest_scans.values())
total_warnings = sum(s.warnings for s in latest_scans.values())
# Get all scan IDs
scan_ids = [s.id for s in latest_scans.values()]
# Get violation counts by status
status_counts = (
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.status)
.all()
)
status_dict = {status: count for status, count in status_counts}
# Get violations by severity
severity_counts = (
db.query(
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.severity)
.all()
)
by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule (across all validators)
rule_counts = (
db.query(
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.rule_id)
.all()
)
by_rule = {
rule: count
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
}
# Get top violating files
file_counts = (
db.query(
ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label("count"),
)
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.file_path)
.order_by(desc("count"))
.limit(10)
.all()
)
top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module
by_module = {}
for scan_id in scan_ids:
module_counts = self._get_violations_by_module(db, scan_id)
for module, count in module_counts.items():
by_module[module] = by_module.get(module, 0) + count
by_module = dict(
sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
)
# Per-validator breakdown
by_validator = {}
for vtype, scan in latest_scans.items():
by_validator[vtype] = {
"total_violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
"last_scan": scan.timestamp.isoformat(),
}
# Get most recent scan timestamp
most_recent = max(latest_scans.values(), key=lambda s: s.timestamp)
return {
"total_violations": total_violations,
"errors": total_errors,
"warnings": total_warnings,
"info": by_severity.get("info", 0),
"open": status_dict.get("open", 0),
"assigned": status_dict.get("assigned", 0),
"resolved": status_dict.get("resolved", 0),
"ignored": status_dict.get("ignored", 0),
"technical_debt_score": self._calculate_score(total_errors, total_warnings),
"trend": [], # Combined trend would need special handling
"by_severity": by_severity,
"by_rule": by_rule,
"by_module": by_module,
"top_files": top_files,
"last_scan": most_recent.timestamp.isoformat(),
"by_validator": by_validator,
}
def _get_violations_by_module(self, db: Session, scan_id: int) -> dict[str, int]:
"""Extract module from file paths and count violations"""
by_module = {}
violations = (
db.query(ArchitectureViolation.file_path)
.filter(ArchitectureViolation.scan_id == scan_id)
.all()
)
for v in violations:
path_parts = v.file_path.split("/")
if len(path_parts) >= 2:
module = "/".join(path_parts[:2])
else:
module = path_parts[0]
by_module[module] = by_module.get(module, 0) + 1
return dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10])
def _calculate_score(self, errors: int, warnings: int) -> int:
"""Calculate technical debt score (0-100)"""
score = 100 - (errors * 0.5 + warnings * 0.05)
return max(0, min(100, int(score)))
def calculate_technical_debt_score(
self, db: Session, scan_id: int = None, validator_type: str = None
) -> int:
"""
Calculate technical debt score (0-100)
Formula: 100 - (errors * 0.5 + warnings * 0.05)
Capped at 0 minimum
Args:
db: Database session
scan_id: Scan ID (if None, use latest)
validator_type: Filter by validator type
Returns:
Score from 0-100
"""
if scan_id is None:
latest_scan = self.get_latest_scan(db, validator_type)
if not latest_scan:
return 100
scan_id = latest_scan.id
scan = self.get_scan_by_id(db, scan_id)
if not scan:
return 100
return self._calculate_score(scan.errors, scan.warnings)
def _get_git_commit_hash(self) -> str | None:
"""Get current git commit hash"""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
return result.stdout.strip()[:40]
except Exception:
pass
return None
# Singleton instance
code_quality_service = CodeQualityService()

View File

@@ -0,0 +1,507 @@
"""
Test Runner Service
Service for running pytest and storing results
"""
import json
import logging
import re
import subprocess
import tempfile
from datetime import UTC, datetime
from pathlib import Path
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from app.modules.dev_tools.models import TestCollection, TestResult, TestRun
logger = logging.getLogger(__name__)
class TestRunnerService:
"""Service for managing pytest test runs"""
def __init__(self):
self.project_root = Path(__file__).parent.parent.parent.parent.parent
def create_test_run(
self,
db: Session,
test_path: str = "tests",
triggered_by: str = "manual",
) -> TestRun:
"""Create a test run record without executing tests"""
test_run = TestRun(
timestamp=datetime.now(UTC),
triggered_by=triggered_by,
test_path=test_path,
status="running",
git_commit_hash=self._get_git_commit(),
git_branch=self._get_git_branch(),
)
db.add(test_run)
db.flush()
return test_run
def run_tests(
self,
db: Session,
test_path: str = "tests",
triggered_by: str = "manual",
extra_args: list[str] | None = None,
) -> TestRun:
"""
Run pytest synchronously and store results in database
Args:
db: Database session
test_path: Path to tests (relative to project root)
triggered_by: Who triggered the run
extra_args: Additional pytest arguments
Returns:
TestRun object with results
"""
test_run = self.create_test_run(db, test_path, triggered_by)
self._execute_tests(db, test_run, test_path, extra_args)
return test_run
def _execute_tests(
self,
db: Session,
test_run: TestRun,
test_path: str,
extra_args: list[str] | None,
) -> None:
"""Execute pytest and update the test run record"""
try:
# Build pytest command with JSON output
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json_report_path = f.name
pytest_args = [
"python",
"-m",
"pytest",
test_path,
"--json-report",
f"--json-report-file={json_report_path}",
"-v",
"--tb=short",
]
if extra_args:
pytest_args.extend(extra_args)
test_run.pytest_args = " ".join(pytest_args)
# Run pytest
start_time = datetime.now(UTC)
result = subprocess.run(
pytest_args,
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=600, # 10 minute timeout
)
end_time = datetime.now(UTC)
test_run.duration_seconds = (end_time - start_time).total_seconds()
# Parse JSON report
try:
with open(json_report_path) as f:
report = json.load(f)
self._process_json_report(db, test_run, report)
except (FileNotFoundError, json.JSONDecodeError) as e:
# Fallback to parsing stdout if JSON report failed
logger.warning(f"JSON report unavailable ({e}), parsing stdout")
self._parse_pytest_output(test_run, result.stdout, result.stderr)
finally:
# Clean up temp file
try:
Path(json_report_path).unlink()
except Exception:
pass
# Set final status
if test_run.failed > 0 or test_run.errors > 0:
test_run.status = "failed"
else:
test_run.status = "passed"
except subprocess.TimeoutExpired:
test_run.status = "error"
logger.error("Pytest run timed out")
except Exception as e:
test_run.status = "error"
logger.error(f"Error running tests: {e}")
def _process_json_report(self, db: Session, test_run: TestRun, report: dict):
"""Process pytest-json-report output"""
summary = report.get("summary", {})
test_run.total_tests = summary.get("total", 0)
test_run.passed = summary.get("passed", 0)
test_run.failed = summary.get("failed", 0)
test_run.errors = summary.get("error", 0)
test_run.skipped = summary.get("skipped", 0)
test_run.xfailed = summary.get("xfailed", 0)
test_run.xpassed = summary.get("xpassed", 0)
# Process individual test results
tests = report.get("tests", [])
for test in tests:
node_id = test.get("nodeid", "")
outcome = test.get("outcome", "unknown")
# Parse node_id to get file, class, function
test_file, test_class, test_name = self._parse_node_id(node_id)
# Get failure details
error_message = None
traceback = None
if outcome in ("failed", "error"):
call_info = test.get("call", {})
if "longrepr" in call_info:
traceback = call_info["longrepr"]
# Extract error message from traceback
if isinstance(traceback, str):
lines = traceback.strip().split("\n")
if lines:
error_message = lines[-1][:500] # Last line, limited length
test_result = TestResult(
run_id=test_run.id,
node_id=node_id,
test_name=test_name,
test_file=test_file,
test_class=test_class,
outcome=outcome,
duration_seconds=test.get("duration", 0.0),
error_message=error_message,
traceback=traceback,
markers=test.get("keywords", []),
)
db.add(test_result)
def _parse_node_id(self, node_id: str) -> tuple[str, str | None, str]:
"""Parse pytest node_id into file, class, function"""
# Format: tests/unit/test_foo.py::TestClass::test_method
# or: tests/unit/test_foo.py::test_function
parts = node_id.split("::")
test_file = parts[0] if parts else ""
test_class = None
test_name = parts[-1] if parts else ""
if len(parts) == 3:
test_class = parts[1]
elif len(parts) == 2:
# Could be Class::method or file::function
if parts[1].startswith("Test"):
test_class = parts[1]
test_name = parts[1]
# Handle parametrized tests
if "[" in test_name:
test_name = test_name.split("[")[0]
return test_file, test_class, test_name
def _parse_pytest_output(self, test_run: TestRun, stdout: str, stderr: str):
"""Fallback parser for pytest text output"""
# Parse summary line like: "10 passed, 2 failed, 1 skipped"
summary_pattern = r"(\d+)\s+(passed|failed|error|skipped|xfailed|xpassed)"
for match in re.finditer(summary_pattern, stdout):
count = int(match.group(1))
status = match.group(2)
if status == "passed":
test_run.passed = count
elif status == "failed":
test_run.failed = count
elif status == "error":
test_run.errors = count
elif status == "skipped":
test_run.skipped = count
elif status == "xfailed":
test_run.xfailed = count
elif status == "xpassed":
test_run.xpassed = count
test_run.total_tests = (
test_run.passed
+ test_run.failed
+ test_run.errors
+ test_run.skipped
+ test_run.xfailed
+ test_run.xpassed
)
def _get_git_commit(self) -> str | None:
"""Get current git commit hash"""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=5,
)
return result.stdout.strip()[:40] if result.returncode == 0 else None
except:
return None
def _get_git_branch(self) -> str | None:
"""Get current git branch"""
try:
result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=5,
)
return result.stdout.strip() if result.returncode == 0 else None
except:
return None
def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]:
"""Get recent test run history"""
return db.query(TestRun).order_by(desc(TestRun.timestamp)).limit(limit).all()
def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None:
"""Get a specific test run with results"""
return db.query(TestRun).filter(TestRun.id == run_id).first()
def get_failed_tests(self, db: Session, run_id: int) -> list[TestResult]:
"""Get failed tests from a run"""
return (
db.query(TestResult)
.filter(
TestResult.run_id == run_id, TestResult.outcome.in_(["failed", "error"])
)
.all()
)
def get_run_results(
self, db: Session, run_id: int, outcome: str | None = None
) -> list[TestResult]:
"""Get test results for a specific run, optionally filtered by outcome"""
query = db.query(TestResult).filter(TestResult.run_id == run_id)
if outcome:
query = query.filter(TestResult.outcome == outcome)
return query.all()
def get_dashboard_stats(self, db: Session) -> dict:
"""Get statistics for the testing dashboard"""
# Get latest run
latest_run = (
db.query(TestRun)
.filter(TestRun.status != "running")
.order_by(desc(TestRun.timestamp))
.first()
)
# Get test collection info (or calculate from latest run)
collection = (
db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first()
)
# Get trend data (last 10 runs)
trend_runs = (
db.query(TestRun)
.filter(TestRun.status != "running")
.order_by(desc(TestRun.timestamp))
.limit(10)
.all()
)
# Calculate stats by category from latest run
by_category = {}
if latest_run:
results = (
db.query(TestResult).filter(TestResult.run_id == latest_run.id).all()
)
for result in results:
# Categorize by test path
if "unit" in result.test_file:
category = "Unit Tests"
elif "integration" in result.test_file:
category = "Integration Tests"
elif "performance" in result.test_file:
category = "Performance Tests"
elif "system" in result.test_file:
category = "System Tests"
else:
category = "Other"
if category not in by_category:
by_category[category] = {"total": 0, "passed": 0, "failed": 0}
by_category[category]["total"] += 1
if result.outcome == "passed":
by_category[category]["passed"] += 1
elif result.outcome in ("failed", "error"):
by_category[category]["failed"] += 1
# Get top failing tests (across recent runs)
top_failing = (
db.query(
TestResult.test_name,
TestResult.test_file,
func.count(TestResult.id).label("failure_count"),
)
.filter(TestResult.outcome.in_(["failed", "error"]))
.group_by(TestResult.test_name, TestResult.test_file)
.order_by(desc("failure_count"))
.limit(10)
.all()
)
return {
# Current run stats
"total_tests": latest_run.total_tests if latest_run else 0,
"passed": latest_run.passed if latest_run else 0,
"failed": latest_run.failed if latest_run else 0,
"errors": latest_run.errors if latest_run else 0,
"skipped": latest_run.skipped if latest_run else 0,
"pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0,
"duration_seconds": round(latest_run.duration_seconds, 2)
if latest_run
else 0,
"coverage_percent": latest_run.coverage_percent if latest_run else None,
"last_run": latest_run.timestamp.isoformat() if latest_run else None,
"last_run_status": latest_run.status if latest_run else None,
# Collection stats
"total_test_files": collection.total_files if collection else 0,
"collected_tests": collection.total_tests if collection else 0,
"unit_tests": collection.unit_tests if collection else 0,
"integration_tests": collection.integration_tests if collection else 0,
"performance_tests": collection.performance_tests if collection else 0,
"system_tests": collection.system_tests if collection else 0,
"last_collected": collection.collected_at.isoformat()
if collection
else None,
# Trend data
"trend": [
{
"timestamp": run.timestamp.isoformat(),
"total": run.total_tests,
"passed": run.passed,
"failed": run.failed,
"pass_rate": round(run.pass_rate, 1),
"duration": round(run.duration_seconds, 1),
}
for run in reversed(trend_runs)
],
# By category
"by_category": by_category,
# Top failing tests
"top_failing": [
{
"test_name": t.test_name,
"test_file": t.test_file,
"failure_count": t.failure_count,
}
for t in top_failing
],
}
def collect_tests(self, db: Session) -> TestCollection:
"""Collect test information without running tests"""
collection = TestCollection(
collected_at=datetime.now(UTC),
)
try:
# Run pytest --collect-only with JSON report
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json_report_path = f.name
result = subprocess.run(
[
"python",
"-m",
"pytest",
"--collect-only",
"--json-report",
f"--json-report-file={json_report_path}",
"tests",
],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=120,
)
# Parse JSON report
json_path = Path(json_report_path)
if json_path.exists():
with open(json_path) as f:
report = json.load(f)
# Get total from summary
collection.total_tests = report.get("summary", {}).get("collected", 0)
# Parse collectors to get test files and counts
test_files = {}
for collector in report.get("collectors", []):
for item in collector.get("result", []):
if item.get("type") == "Function":
node_id = item.get("nodeid", "")
if "::" in node_id:
file_path = node_id.split("::")[0]
if file_path not in test_files:
test_files[file_path] = 0
test_files[file_path] += 1
# Count files and categorize
for file_path, count in test_files.items():
collection.total_files += 1
if "/unit/" in file_path or file_path.startswith("tests/unit"):
collection.unit_tests += count
elif "/integration/" in file_path or file_path.startswith(
"tests/integration"
):
collection.integration_tests += count
elif "/performance/" in file_path or file_path.startswith(
"tests/performance"
):
collection.performance_tests += count
elif "/system/" in file_path or file_path.startswith(
"tests/system"
):
collection.system_tests += count
collection.test_files = [
{"file": f, "count": c}
for f, c in sorted(test_files.items(), key=lambda x: -x[1])
]
# Cleanup
json_path.unlink(missing_ok=True)
logger.info(
f"Collected {collection.total_tests} tests from {collection.total_files} files"
)
except Exception as e:
logger.error(f"Error collecting tests: {e}", exc_info=True)
db.add(collection)
return collection
# Singleton instance
test_runner_service = TestRunnerService()

View File

@@ -2,11 +2,11 @@
"""
Inventory module database models.
Re-exports inventory-related models from their source locations.
This module contains the canonical implementations of inventory-related models.
"""
from models.database.inventory import Inventory
from models.database.inventory_transaction import (
from app.modules.inventory.models.inventory import Inventory
from app.modules.inventory.models.inventory_transaction import (
InventoryTransaction,
TransactionType,
)

View File

@@ -0,0 +1,61 @@
# app/modules/inventory/models/inventory.py
"""
Inventory model for tracking stock at warehouse/bin locations.
Each entry represents a quantity of a product at a specific bin location
within a warehouse. Products can be scattered across multiple bins.
Example:
Warehouse: "strassen"
Bin: "SA-10-02"
Product: GTIN 4007817144145
Quantity: 3
"""
from sqlalchemy import Column, ForeignKey, Index, Integer, String, UniqueConstraint
from sqlalchemy.orm import relationship
from app.core.database import Base
from models.database.base import TimestampMixin
class Inventory(Base, TimestampMixin):
__tablename__ = "inventory"
id = Column(Integer, primary_key=True, index=True)
product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
# Location: warehouse + bin
warehouse = Column(String, nullable=False, default="strassen", index=True)
bin_location = Column(String, nullable=False, index=True) # e.g., "SA-10-02"
# Legacy field - kept for backward compatibility, will be removed
location = Column(String, index=True)
quantity = Column(Integer, nullable=False, default=0)
reserved_quantity = Column(Integer, default=0)
# Keep GTIN for reference/reporting (matches Product.gtin)
gtin = Column(String, index=True)
# Relationships
product = relationship("Product", back_populates="inventory_entries")
vendor = relationship("Vendor")
# Constraints
__table_args__ = (
UniqueConstraint(
"product_id", "warehouse", "bin_location", name="uq_inventory_product_warehouse_bin"
),
Index("idx_inventory_vendor_product", "vendor_id", "product_id"),
Index("idx_inventory_warehouse_bin", "warehouse", "bin_location"),
)
def __repr__(self):
return f"<Inventory(product_id={self.product_id}, location='{self.location}', quantity={self.quantity})>"
@property
def available_quantity(self):
"""Calculate available quantity (total - reserved)."""
return max(0, self.quantity - self.reserved_quantity)

View File

@@ -0,0 +1,170 @@
# app/modules/inventory/models/inventory_transaction.py
"""
Inventory Transaction Model - Audit trail for all stock movements.
This model tracks every change to inventory quantities, providing:
- Complete audit trail for compliance and debugging
- Order-linked transactions for traceability
- Support for different transaction types (reserve, fulfill, adjust, etc.)
All stock movements should create a transaction record.
"""
from datetime import UTC, datetime
from enum import Enum
from sqlalchemy import (
Column,
DateTime,
Enum as SQLEnum,
ForeignKey,
Index,
Integer,
String,
Text,
)
from sqlalchemy.orm import relationship
from app.core.database import Base
class TransactionType(str, Enum):
"""Types of inventory transactions."""
# Order-related
RESERVE = "reserve" # Stock reserved for order
FULFILL = "fulfill" # Reserved stock consumed (shipped)
RELEASE = "release" # Reserved stock released (cancelled)
# Manual adjustments
ADJUST = "adjust" # Manual adjustment (+/-)
SET = "set" # Set to exact quantity
# Imports
IMPORT = "import" # Initial import/sync
# Returns
RETURN = "return" # Stock returned from customer
class InventoryTransaction(Base):
"""
Audit log for inventory movements.
Every change to inventory quantity creates a transaction record,
enabling complete traceability of stock levels over time.
"""
__tablename__ = "inventory_transactions"
id = Column(Integer, primary_key=True, index=True)
# Core references
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True)
inventory_id = Column(
Integer, ForeignKey("inventory.id"), nullable=True, index=True
)
# Transaction details
transaction_type = Column(
SQLEnum(TransactionType), nullable=False, index=True
)
quantity_change = Column(Integer, nullable=False) # Positive = add, negative = remove
# Quantities after transaction (snapshot)
quantity_after = Column(Integer, nullable=False)
reserved_after = Column(Integer, nullable=False, default=0)
# Location context
location = Column(String, nullable=True)
warehouse = Column(String, nullable=True)
# Order reference (for order-related transactions)
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True)
order_number = Column(String, nullable=True)
# Audit fields
reason = Column(Text, nullable=True) # Human-readable reason
created_by = Column(String, nullable=True) # User/system that created
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
index=True,
)
# Relationships
vendor = relationship("Vendor")
product = relationship("Product")
inventory = relationship("Inventory")
order = relationship("Order")
# Indexes for common queries
__table_args__ = (
Index("idx_inv_tx_vendor_product", "vendor_id", "product_id"),
Index("idx_inv_tx_vendor_created", "vendor_id", "created_at"),
Index("idx_inv_tx_order", "order_id"),
Index("idx_inv_tx_type_created", "transaction_type", "created_at"),
)
def __repr__(self) -> str:
return (
f"<InventoryTransaction {self.id}: "
f"{self.transaction_type.value} {self.quantity_change:+d} "
f"for product {self.product_id}>"
)
@classmethod
def create_transaction(
cls,
vendor_id: int,
product_id: int,
transaction_type: TransactionType,
quantity_change: int,
quantity_after: int,
reserved_after: int = 0,
inventory_id: int | None = None,
location: str | None = None,
warehouse: str | None = None,
order_id: int | None = None,
order_number: str | None = None,
reason: str | None = None,
created_by: str | None = None,
) -> "InventoryTransaction":
"""
Factory method to create a transaction record.
Args:
vendor_id: Vendor ID
product_id: Product ID
transaction_type: Type of transaction
quantity_change: Change in quantity (positive = add, negative = remove)
quantity_after: Total quantity after this transaction
reserved_after: Reserved quantity after this transaction
inventory_id: Optional inventory record ID
location: Optional location
warehouse: Optional warehouse
order_id: Optional order ID (for order-related transactions)
order_number: Optional order number for display
reason: Optional human-readable reason
created_by: Optional user/system identifier
Returns:
InventoryTransaction instance (not yet added to session)
"""
return cls(
vendor_id=vendor_id,
product_id=product_id,
inventory_id=inventory_id,
transaction_type=transaction_type,
quantity_change=quantity_change,
quantity_after=quantity_after,
reserved_after=reserved_after,
location=location,
warehouse=warehouse,
order_id=order_id,
order_number=order_number,
reason=reason,
created_by=created_by,
)

View File

@@ -2,27 +2,77 @@
"""
Inventory module Pydantic schemas.
Re-exports inventory-related schemas from their source locations.
This module contains the canonical implementations of inventory-related schemas.
"""
from models.schema.inventory import (
from app.modules.inventory.schemas.inventory import (
# Base schemas
InventoryBase,
InventoryCreate,
InventoryAdjust,
InventoryUpdate,
InventoryReserve,
# Response schemas
InventoryResponse,
InventoryLocationResponse,
ProductInventorySummary,
InventoryListResponse,
InventoryTransactionResponse,
InventoryMessageResponse,
InventorySummaryResponse,
# Admin schemas
AdminInventoryCreate,
AdminInventoryAdjust,
AdminInventoryItem,
AdminInventoryListResponse,
AdminInventoryStats,
AdminLowStockItem,
AdminVendorWithInventory,
AdminVendorsWithInventoryResponse,
AdminInventoryLocationsResponse,
# Transaction schemas
InventoryTransactionResponse,
InventoryTransactionWithProduct,
InventoryTransactionListResponse,
ProductTransactionHistoryResponse,
OrderTransactionHistoryResponse,
# Admin transaction schemas
AdminInventoryTransactionItem,
AdminInventoryTransactionListResponse,
AdminTransactionStatsResponse,
)
__all__ = [
# Base schemas
"InventoryBase",
"InventoryCreate",
"InventoryAdjust",
"InventoryUpdate",
"InventoryReserve",
# Response schemas
"InventoryResponse",
"InventoryLocationResponse",
"ProductInventorySummary",
"InventoryListResponse",
"InventoryTransactionResponse",
"InventoryMessageResponse",
"InventorySummaryResponse",
# Admin schemas
"AdminInventoryCreate",
"AdminInventoryAdjust",
"AdminInventoryItem",
"AdminInventoryListResponse",
"AdminInventoryStats",
"AdminLowStockItem",
"AdminVendorWithInventory",
"AdminVendorsWithInventoryResponse",
"AdminInventoryLocationsResponse",
# Transaction schemas
"InventoryTransactionResponse",
"InventoryTransactionWithProduct",
"InventoryTransactionListResponse",
"ProductTransactionHistoryResponse",
"OrderTransactionHistoryResponse",
# Admin transaction schemas
"AdminInventoryTransactionItem",
"AdminInventoryTransactionListResponse",
"AdminTransactionStatsResponse",
]

View File

@@ -0,0 +1,294 @@
# app/modules/inventory/schemas/inventory.py
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
class InventoryBase(BaseModel):
product_id: int = Field(..., description="Product ID in vendor catalog")
location: str = Field(..., description="Storage location")
class InventoryCreate(InventoryBase):
"""Set exact inventory quantity (replaces existing)."""
quantity: int = Field(..., description="Exact inventory quantity", ge=0)
class InventoryAdjust(InventoryBase):
"""Add or remove inventory quantity."""
quantity: int = Field(
..., description="Quantity to add (positive) or remove (negative)"
)
class InventoryUpdate(BaseModel):
"""Update inventory fields."""
quantity: int | None = Field(None, ge=0)
reserved_quantity: int | None = Field(None, ge=0)
location: str | None = None
class InventoryReserve(BaseModel):
"""Reserve inventory for orders."""
product_id: int
location: str
quantity: int = Field(..., gt=0)
class InventoryResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
product_id: int
vendor_id: int
location: str
quantity: int
reserved_quantity: int
gtin: str | None
created_at: datetime
updated_at: datetime
@property
def available_quantity(self):
return max(0, self.quantity - self.reserved_quantity)
class InventoryLocationResponse(BaseModel):
location: str
quantity: int
reserved_quantity: int
available_quantity: int
class ProductInventorySummary(BaseModel):
"""Inventory summary for a product."""
product_id: int
vendor_id: int
product_sku: str | None
product_title: str
total_quantity: int
total_reserved: int
total_available: int
locations: list[InventoryLocationResponse]
class InventoryListResponse(BaseModel):
inventories: list[InventoryResponse]
total: int
skip: int
limit: int
class InventoryMessageResponse(BaseModel):
"""Simple message response for inventory operations."""
message: str
class InventorySummaryResponse(BaseModel):
"""Inventory summary response for marketplace product service."""
gtin: str
total_quantity: int
locations: list[InventoryLocationResponse]
# ============================================================================
# Admin Inventory Schemas
# ============================================================================
class AdminInventoryCreate(BaseModel):
"""Admin version of inventory create - requires explicit vendor_id."""
vendor_id: int = Field(..., description="Target vendor ID")
product_id: int = Field(..., description="Product ID in vendor catalog")
location: str = Field(..., description="Storage location")
quantity: int = Field(..., description="Exact inventory quantity", ge=0)
class AdminInventoryAdjust(BaseModel):
"""Admin version of inventory adjust - requires explicit vendor_id."""
vendor_id: int = Field(..., description="Target vendor ID")
product_id: int = Field(..., description="Product ID in vendor catalog")
location: str = Field(..., description="Storage location")
quantity: int = Field(
..., description="Quantity to add (positive) or remove (negative)"
)
reason: str | None = Field(None, description="Reason for adjustment")
class AdminInventoryItem(BaseModel):
"""Inventory item with vendor info for admin list view."""
model_config = ConfigDict(from_attributes=True)
id: int
product_id: int
vendor_id: int
vendor_name: str | None = None
vendor_code: str | None = None
product_title: str | None = None
product_sku: str | None = None
location: str
quantity: int
reserved_quantity: int
available_quantity: int
gtin: str | None = None
created_at: datetime
updated_at: datetime
class AdminInventoryListResponse(BaseModel):
"""Cross-vendor inventory list for admin."""
inventories: list[AdminInventoryItem]
total: int
skip: int
limit: int
vendor_filter: int | None = None
location_filter: str | None = None
class AdminInventoryStats(BaseModel):
"""Inventory statistics for admin dashboard."""
total_entries: int
total_quantity: int
total_reserved: int
total_available: int
low_stock_count: int
vendors_with_inventory: int
unique_locations: int
class AdminLowStockItem(BaseModel):
"""Low stock item for admin alerts."""
id: int
product_id: int
vendor_id: int
vendor_name: str | None = None
product_title: str | None = None
location: str
quantity: int
reserved_quantity: int
available_quantity: int
class AdminVendorWithInventory(BaseModel):
"""Vendor with inventory entries."""
id: int
name: str
vendor_code: str
class AdminVendorsWithInventoryResponse(BaseModel):
"""Response for vendors with inventory list."""
vendors: list[AdminVendorWithInventory]
class AdminInventoryLocationsResponse(BaseModel):
"""Response for unique inventory locations."""
locations: list[str]
# ============================================================================
# Inventory Transaction Schemas
# ============================================================================
class InventoryTransactionResponse(BaseModel):
"""Single inventory transaction record."""
model_config = ConfigDict(from_attributes=True)
id: int
vendor_id: int
product_id: int
inventory_id: int | None = None
transaction_type: str
quantity_change: int
quantity_after: int
reserved_after: int
location: str | None = None
warehouse: str | None = None
order_id: int | None = None
order_number: str | None = None
reason: str | None = None
created_by: str | None = None
created_at: datetime
class InventoryTransactionWithProduct(InventoryTransactionResponse):
"""Transaction with product details for list views."""
product_title: str | None = None
product_sku: str | None = None
class InventoryTransactionListResponse(BaseModel):
"""Paginated list of inventory transactions."""
transactions: list[InventoryTransactionWithProduct]
total: int
skip: int
limit: int
class ProductTransactionHistoryResponse(BaseModel):
"""Transaction history for a specific product."""
product_id: int
product_title: str | None = None
product_sku: str | None = None
current_quantity: int
current_reserved: int
transactions: list[InventoryTransactionResponse]
total: int
class OrderTransactionHistoryResponse(BaseModel):
"""Transaction history for a specific order."""
order_id: int
order_number: str
transactions: list[InventoryTransactionWithProduct]
# ============================================================================
# Admin Inventory Transaction Schemas
# ============================================================================
class AdminInventoryTransactionItem(InventoryTransactionWithProduct):
"""Transaction with vendor details for admin views."""
vendor_name: str | None = None
vendor_code: str | None = None
class AdminInventoryTransactionListResponse(BaseModel):
"""Paginated list of transactions for admin."""
transactions: list[AdminInventoryTransactionItem]
total: int
skip: int
limit: int
class AdminTransactionStatsResponse(BaseModel):
"""Transaction statistics for admin dashboard."""
total_transactions: int
transactions_today: int
by_type: dict[str, int]

View File

@@ -2,20 +2,21 @@
"""
Inventory module services.
Re-exports inventory-related services from their source locations.
This module contains the canonical implementations of inventory-related services.
"""
from app.services.inventory_service import (
from app.modules.inventory.services.inventory_service import (
inventory_service,
InventoryService,
)
from app.services.inventory_transaction_service import (
from app.modules.inventory.services.inventory_transaction_service import (
inventory_transaction_service,
InventoryTransactionService,
)
from app.services.inventory_import_service import (
from app.modules.inventory.services.inventory_import_service import (
inventory_import_service,
InventoryImportService,
ImportResult,
)
__all__ = [
@@ -25,4 +26,5 @@ __all__ = [
"InventoryTransactionService",
"inventory_import_service",
"InventoryImportService",
"ImportResult",
]

View File

@@ -0,0 +1,250 @@
# app/modules/inventory/services/inventory_import_service.py
"""
Inventory import service for bulk importing stock from TSV/CSV files.
Supports two formats:
1. One row per unit (quantity = count of rows):
BIN EAN PRODUCT
SA-10-02 0810050910101 Product Name
SA-10-02 0810050910101 Product Name (2nd unit)
2. With explicit quantity column:
BIN EAN PRODUCT QUANTITY
SA-10-02 0810050910101 Product Name 12
Products are matched by GTIN/EAN to existing vendor products.
"""
import csv
import io
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from sqlalchemy.orm import Session
from app.modules.inventory.models.inventory import Inventory
from models.database.product import Product
logger = logging.getLogger(__name__)
@dataclass
class ImportResult:
"""Result of an inventory import operation."""
success: bool = True
total_rows: int = 0
entries_created: int = 0
entries_updated: int = 0
quantity_imported: int = 0
unmatched_gtins: list = field(default_factory=list)
errors: list = field(default_factory=list)
class InventoryImportService:
"""Service for importing inventory from TSV/CSV files."""
def import_from_text(
self,
db: Session,
content: str,
vendor_id: int,
warehouse: str = "strassen",
delimiter: str = "\t",
clear_existing: bool = False,
) -> ImportResult:
"""
Import inventory from TSV/CSV text content.
Args:
db: Database session
content: TSV/CSV content as string
vendor_id: Vendor ID for inventory
warehouse: Warehouse name (default: "strassen")
delimiter: Column delimiter (default: tab)
clear_existing: If True, clear existing inventory before import
Returns:
ImportResult with summary and errors
"""
result = ImportResult()
try:
# Parse CSV/TSV
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
# Normalize headers (case-insensitive, strip whitespace)
if reader.fieldnames:
reader.fieldnames = [h.strip().upper() for h in reader.fieldnames]
# Validate required columns
required = {"BIN", "EAN"}
if not reader.fieldnames or not required.issubset(set(reader.fieldnames)):
result.success = False
result.errors.append(
f"Missing required columns. Found: {reader.fieldnames}, Required: {required}"
)
return result
has_quantity = "QUANTITY" in reader.fieldnames
# Group entries by (EAN, BIN)
# Key: (ean, bin) -> quantity
inventory_data: dict[tuple[str, str], int] = defaultdict(int)
product_names: dict[str, str] = {} # EAN -> product name (for logging)
for row in reader:
result.total_rows += 1
ean = row.get("EAN", "").strip()
bin_loc = row.get("BIN", "").strip()
product_name = row.get("PRODUCT", "").strip()
if not ean or not bin_loc:
result.errors.append(f"Row {result.total_rows}: Missing EAN or BIN")
continue
# Get quantity
if has_quantity:
try:
qty = int(row.get("QUANTITY", "1").strip())
except ValueError:
result.errors.append(
f"Row {result.total_rows}: Invalid quantity '{row.get('QUANTITY')}'"
)
continue
else:
qty = 1 # Each row = 1 unit
inventory_data[(ean, bin_loc)] += qty
if product_name:
product_names[ean] = product_name
# Clear existing inventory if requested
if clear_existing:
db.query(Inventory).filter(
Inventory.vendor_id == vendor_id,
Inventory.warehouse == warehouse,
).delete()
db.flush()
# Build EAN to Product mapping for this vendor
products = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.gtin.isnot(None),
)
.all()
)
ean_to_product: dict[str, Product] = {p.gtin: p for p in products if p.gtin}
# Track unmatched GTINs
unmatched: dict[str, int] = {} # EAN -> total quantity
# Process inventory entries
for (ean, bin_loc), quantity in inventory_data.items():
product = ean_to_product.get(ean)
if not product:
# Track unmatched
if ean not in unmatched:
unmatched[ean] = 0
unmatched[ean] += quantity
continue
# Upsert inventory entry
existing = (
db.query(Inventory)
.filter(
Inventory.product_id == product.id,
Inventory.warehouse == warehouse,
Inventory.bin_location == bin_loc,
)
.first()
)
if existing:
existing.quantity = quantity
existing.gtin = ean
result.entries_updated += 1
else:
inv = Inventory(
product_id=product.id,
vendor_id=vendor_id,
warehouse=warehouse,
bin_location=bin_loc,
location=bin_loc, # Legacy field
quantity=quantity,
gtin=ean,
)
db.add(inv)
result.entries_created += 1
result.quantity_imported += quantity
db.flush()
# Format unmatched GTINs for result
for ean, qty in unmatched.items():
product_name = product_names.get(ean, "Unknown")
result.unmatched_gtins.append(
{"gtin": ean, "quantity": qty, "product_name": product_name}
)
if result.unmatched_gtins:
logger.warning(
f"Import had {len(result.unmatched_gtins)} unmatched GTINs"
)
except Exception as e:
logger.exception("Inventory import failed")
result.success = False
result.errors.append(str(e))
return result
def import_from_file(
self,
db: Session,
file_path: str,
vendor_id: int,
warehouse: str = "strassen",
clear_existing: bool = False,
) -> ImportResult:
"""
Import inventory from a TSV/CSV file.
Args:
db: Database session
file_path: Path to TSV/CSV file
vendor_id: Vendor ID for inventory
warehouse: Warehouse name
clear_existing: If True, clear existing inventory before import
Returns:
ImportResult with summary and errors
"""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
except Exception as e:
return ImportResult(success=False, errors=[f"Failed to read file: {e}"])
# Detect delimiter
first_line = content.split("\n")[0] if content else ""
delimiter = "\t" if "\t" in first_line else ","
return self.import_from_text(
db=db,
content=content,
vendor_id=vendor_id,
warehouse=warehouse,
delimiter=delimiter,
clear_existing=clear_existing,
)
# Singleton instance
inventory_import_service = InventoryImportService()

View File

@@ -0,0 +1,949 @@
# app/modules/inventory/services/inventory_service.py
import logging
from datetime import UTC, datetime
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import (
InsufficientInventoryException,
InvalidQuantityException,
InventoryNotFoundException,
InventoryValidationException,
ProductNotFoundException,
ValidationException,
VendorNotFoundException,
)
from app.modules.inventory.models.inventory import Inventory
from app.modules.inventory.schemas.inventory import (
AdminInventoryItem,
AdminInventoryListResponse,
AdminInventoryLocationsResponse,
AdminInventoryStats,
AdminLowStockItem,
AdminVendorsWithInventoryResponse,
AdminVendorWithInventory,
InventoryAdjust,
InventoryCreate,
InventoryLocationResponse,
InventoryReserve,
InventoryUpdate,
ProductInventorySummary,
)
from models.database.product import Product
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class InventoryService:
"""Service for inventory operations with vendor isolation."""
def set_inventory(
self, db: Session, vendor_id: int, inventory_data: InventoryCreate
) -> Inventory:
"""
Set exact inventory quantity for a product at a location (replaces existing).
Args:
db: Database session
vendor_id: Vendor ID (from middleware)
inventory_data: Inventory data
Returns:
Inventory object
"""
try:
# Validate product belongs to vendor
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
# Validate location
location = self._validate_location(inventory_data.location)
# Validate quantity
self._validate_quantity(inventory_data.quantity, allow_zero=True)
# Check if inventory entry exists
existing = self._get_inventory_entry(
db, inventory_data.product_id, location
)
if existing:
old_qty = existing.quantity
existing.quantity = inventory_data.quantity
existing.updated_at = datetime.now(UTC)
db.flush()
db.refresh(existing)
logger.info(
f"Set inventory for product {inventory_data.product_id} at {location}: "
f"{old_qty}{inventory_data.quantity}"
)
return existing
# Create new inventory entry
new_inventory = Inventory(
product_id=inventory_data.product_id,
vendor_id=vendor_id,
warehouse="strassen", # Default warehouse
bin_location=location, # Use location as bin location
location=location, # Keep for backward compatibility
quantity=inventory_data.quantity,
gtin=product.marketplace_product.gtin, # Optional reference
)
db.add(new_inventory)
db.flush()
db.refresh(new_inventory)
logger.info(
f"Created inventory for product {inventory_data.product_id} at {location}: "
f"{inventory_data.quantity}"
)
return new_inventory
except (
ProductNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error setting inventory: {str(e)}")
raise ValidationException("Failed to set inventory")
def adjust_inventory(
self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
) -> Inventory:
"""
Adjust inventory by adding or removing quantity.
Positive quantity = add, negative = remove.
Args:
db: Database session
vendor_id: Vendor ID
inventory_data: Adjustment data
Returns:
Updated Inventory object
"""
try:
# Validate product belongs to vendor
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
# Validate location
location = self._validate_location(inventory_data.location)
# Check if inventory exists
existing = self._get_inventory_entry(
db, inventory_data.product_id, location
)
if not existing:
# Create new if adding, error if removing
if inventory_data.quantity < 0:
raise InventoryNotFoundException(
f"No inventory found for product {inventory_data.product_id} at {location}"
)
# Create with positive quantity
new_inventory = Inventory(
product_id=inventory_data.product_id,
vendor_id=vendor_id,
warehouse="strassen", # Default warehouse
bin_location=location, # Use location as bin location
location=location, # Keep for backward compatibility
quantity=inventory_data.quantity,
gtin=product.marketplace_product.gtin,
)
db.add(new_inventory)
db.flush()
db.refresh(new_inventory)
logger.info(
f"Created inventory for product {inventory_data.product_id} at {location}: "
f"+{inventory_data.quantity}"
)
return new_inventory
# Adjust existing inventory
old_qty = existing.quantity
new_qty = old_qty + inventory_data.quantity
# Validate resulting quantity
if new_qty < 0:
raise InsufficientInventoryException(
f"Insufficient inventory. Available: {old_qty}, "
f"Requested removal: {abs(inventory_data.quantity)}"
)
existing.quantity = new_qty
existing.updated_at = datetime.now(UTC)
db.flush()
db.refresh(existing)
logger.info(
f"Adjusted inventory for product {inventory_data.product_id} at {location}: "
f"{old_qty} {'+' if inventory_data.quantity >= 0 else ''}{inventory_data.quantity} = {new_qty}"
)
return existing
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error adjusting inventory: {str(e)}")
raise ValidationException("Failed to adjust inventory")
def reserve_inventory(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Reserve inventory for an order (increases reserved_quantity).
Args:
db: Database session
vendor_id: Vendor ID
reserve_data: Reservation data
Returns:
Updated Inventory object
"""
try:
# Validate product
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
# Validate location and quantity
location = self._validate_location(reserve_data.location)
self._validate_quantity(reserve_data.quantity, allow_zero=False)
# Get inventory entry
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
if not inventory:
raise InventoryNotFoundException(
f"No inventory found for product {reserve_data.product_id} at {location}"
)
# Check available quantity
available = inventory.quantity - inventory.reserved_quantity
if available < reserve_data.quantity:
raise InsufficientInventoryException(
f"Insufficient available inventory. Available: {available}, "
f"Requested: {reserve_data.quantity}"
)
# Reserve inventory
inventory.reserved_quantity += reserve_data.quantity
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(
f"Reserved {reserve_data.quantity} units for product {reserve_data.product_id} "
f"at {location}"
)
return inventory
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error reserving inventory: {str(e)}")
raise ValidationException("Failed to reserve inventory")
def release_reservation(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Release reserved inventory (decreases reserved_quantity).
Args:
db: Database session
vendor_id: Vendor ID
reserve_data: Reservation data
Returns:
Updated Inventory object
"""
try:
# Validate product
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
location = self._validate_location(reserve_data.location)
self._validate_quantity(reserve_data.quantity, allow_zero=False)
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
if not inventory:
raise InventoryNotFoundException(
f"No inventory found for product {reserve_data.product_id} at {location}"
)
# Validate reserved quantity
if inventory.reserved_quantity < reserve_data.quantity:
logger.warning(
f"Attempting to release more than reserved. Reserved: {inventory.reserved_quantity}, "
f"Requested: {reserve_data.quantity}"
)
inventory.reserved_quantity = 0
else:
inventory.reserved_quantity -= reserve_data.quantity
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(
f"Released {reserve_data.quantity} units for product {reserve_data.product_id} "
f"at {location}"
)
return inventory
except (
ProductNotFoundException,
InventoryNotFoundException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error releasing reservation: {str(e)}")
raise ValidationException("Failed to release reservation")
def fulfill_reservation(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Fulfill a reservation (decreases both quantity and reserved_quantity).
Use when order is shipped/completed.
Args:
db: Database session
vendor_id: Vendor ID
reserve_data: Reservation data
Returns:
Updated Inventory object
"""
try:
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
location = self._validate_location(reserve_data.location)
self._validate_quantity(reserve_data.quantity, allow_zero=False)
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
if not inventory:
raise InventoryNotFoundException(
f"No inventory found for product {reserve_data.product_id} at {location}"
)
# Validate quantities
if inventory.quantity < reserve_data.quantity:
raise InsufficientInventoryException(
f"Insufficient inventory. Available: {inventory.quantity}, "
f"Requested: {reserve_data.quantity}"
)
if inventory.reserved_quantity < reserve_data.quantity:
logger.warning(
f"Fulfilling more than reserved. Reserved: {inventory.reserved_quantity}, "
f"Fulfilling: {reserve_data.quantity}"
)
# Fulfill (remove from both quantity and reserved)
inventory.quantity -= reserve_data.quantity
inventory.reserved_quantity = max(
0, inventory.reserved_quantity - reserve_data.quantity
)
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(
f"Fulfilled {reserve_data.quantity} units for product {reserve_data.product_id} "
f"at {location}"
)
return inventory
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error fulfilling reservation: {str(e)}")
raise ValidationException("Failed to fulfill reservation")
def get_product_inventory(
self, db: Session, vendor_id: int, product_id: int
) -> ProductInventorySummary:
"""
Get inventory summary for a product across all locations.
Args:
db: Database session
vendor_id: Vendor ID
product_id: Product ID
Returns:
ProductInventorySummary
"""
try:
product = self._get_vendor_product(db, vendor_id, product_id)
inventory_entries = (
db.query(Inventory).filter(Inventory.product_id == product_id).all()
)
if not inventory_entries:
return ProductInventorySummary(
product_id=product_id,
vendor_id=vendor_id,
product_sku=product.vendor_sku,
product_title=product.marketplace_product.get_title() or "",
total_quantity=0,
total_reserved=0,
total_available=0,
locations=[],
)
total_qty = sum(inv.quantity for inv in inventory_entries)
total_reserved = sum(inv.reserved_quantity for inv in inventory_entries)
total_available = sum(inv.available_quantity for inv in inventory_entries)
locations = [
InventoryLocationResponse(
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
)
for inv in inventory_entries
]
return ProductInventorySummary(
product_id=product_id,
vendor_id=vendor_id,
product_sku=product.vendor_sku,
product_title=product.marketplace_product.get_title() or "",
total_quantity=total_qty,
total_reserved=total_reserved,
total_available=total_available,
locations=locations,
)
except ProductNotFoundException:
raise
except Exception as e:
logger.error(f"Error getting product inventory: {str(e)}")
raise ValidationException("Failed to retrieve product inventory")
def get_vendor_inventory(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
location: str | None = None,
low_stock_threshold: int | None = None,
) -> list[Inventory]:
"""
Get all inventory for a vendor with filtering.
Args:
db: Database session
vendor_id: Vendor ID
skip: Pagination offset
limit: Pagination limit
location: Filter by location
low_stock_threshold: Filter items below threshold
Returns:
List of Inventory objects
"""
try:
query = db.query(Inventory).filter(Inventory.vendor_id == vendor_id)
if location:
query = query.filter(Inventory.location.ilike(f"%{location}%"))
if low_stock_threshold is not None:
query = query.filter(Inventory.quantity <= low_stock_threshold)
return query.offset(skip).limit(limit).all()
except Exception as e:
logger.error(f"Error getting vendor inventory: {str(e)}")
raise ValidationException("Failed to retrieve vendor inventory")
def update_inventory(
self,
db: Session,
vendor_id: int,
inventory_id: int,
inventory_update: InventoryUpdate,
) -> Inventory:
"""Update inventory entry."""
try:
inventory = self._get_inventory_by_id(db, inventory_id)
# Verify ownership
if inventory.vendor_id != vendor_id:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
# Update fields
if inventory_update.quantity is not None:
self._validate_quantity(inventory_update.quantity, allow_zero=True)
inventory.quantity = inventory_update.quantity
if inventory_update.reserved_quantity is not None:
self._validate_quantity(
inventory_update.reserved_quantity, allow_zero=True
)
inventory.reserved_quantity = inventory_update.reserved_quantity
if inventory_update.location:
inventory.location = self._validate_location(inventory_update.location)
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(f"Updated inventory {inventory_id}")
return inventory
except (
InventoryNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error updating inventory: {str(e)}")
raise ValidationException("Failed to update inventory")
def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
"""Delete inventory entry."""
try:
inventory = self._get_inventory_by_id(db, inventory_id)
# Verify ownership
if inventory.vendor_id != vendor_id:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
db.delete(inventory)
db.flush()
logger.info(f"Deleted inventory {inventory_id}")
return True
except InventoryNotFoundException:
raise
except Exception as e:
db.rollback()
logger.error(f"Error deleting inventory: {str(e)}")
raise ValidationException("Failed to delete inventory")
# =========================================================================
# Admin Methods (cross-vendor operations)
# =========================================================================
def get_all_inventory_admin(
self,
db: Session,
skip: int = 0,
limit: int = 50,
vendor_id: int | None = None,
location: str | None = None,
low_stock: int | None = None,
search: str | None = None,
) -> AdminInventoryListResponse:
"""
Get inventory across all vendors with filtering (admin only).
Args:
db: Database session
skip: Pagination offset
limit: Pagination limit
vendor_id: Filter by vendor
location: Filter by location
low_stock: Filter items below threshold
search: Search by product title or SKU
Returns:
AdminInventoryListResponse
"""
query = db.query(Inventory).join(Product).join(Vendor)
# Apply filters
if vendor_id is not None:
query = query.filter(Inventory.vendor_id == vendor_id)
if location:
query = query.filter(Inventory.location.ilike(f"%{location}%"))
if low_stock is not None:
query = query.filter(Inventory.quantity <= low_stock)
if search:
from models.database.marketplace_product import MarketplaceProduct
from models.database.marketplace_product_translation import (
MarketplaceProductTranslation,
)
query = (
query.join(MarketplaceProduct)
.outerjoin(MarketplaceProductTranslation)
.filter(
(MarketplaceProductTranslation.title.ilike(f"%{search}%"))
| (Product.vendor_sku.ilike(f"%{search}%"))
)
)
# Get total count before pagination
total = query.count()
# Apply pagination
inventories = query.offset(skip).limit(limit).all()
# Build response with vendor/product info
items = []
for inv in inventories:
product = inv.product
vendor = inv.vendor
title = None
if product and product.marketplace_product:
title = product.marketplace_product.get_title()
items.append(
AdminInventoryItem(
id=inv.id,
product_id=inv.product_id,
vendor_id=inv.vendor_id,
vendor_name=vendor.name if vendor else None,
vendor_code=vendor.vendor_code if vendor else None,
product_title=title,
product_sku=product.vendor_sku if product else None,
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
gtin=inv.gtin,
created_at=inv.created_at,
updated_at=inv.updated_at,
)
)
return AdminInventoryListResponse(
inventories=items,
total=total,
skip=skip,
limit=limit,
vendor_filter=vendor_id,
location_filter=location,
)
def get_inventory_stats_admin(self, db: Session) -> AdminInventoryStats:
"""Get platform-wide inventory statistics (admin only)."""
# Total entries
total_entries = db.query(func.count(Inventory.id)).scalar() or 0
# Aggregate quantities
totals = db.query(
func.sum(Inventory.quantity).label("total_qty"),
func.sum(Inventory.reserved_quantity).label("total_reserved"),
).first()
total_quantity = totals.total_qty or 0
total_reserved = totals.total_reserved or 0
total_available = total_quantity - total_reserved
# Low stock count (default threshold: 10)
low_stock_count = (
db.query(func.count(Inventory.id))
.filter(Inventory.quantity <= 10)
.scalar()
or 0
)
# Vendors with inventory
vendors_with_inventory = (
db.query(func.count(func.distinct(Inventory.vendor_id))).scalar() or 0
)
# Unique locations
unique_locations = (
db.query(func.count(func.distinct(Inventory.location))).scalar() or 0
)
return AdminInventoryStats(
total_entries=total_entries,
total_quantity=total_quantity,
total_reserved=total_reserved,
total_available=total_available,
low_stock_count=low_stock_count,
vendors_with_inventory=vendors_with_inventory,
unique_locations=unique_locations,
)
def get_low_stock_items_admin(
self,
db: Session,
threshold: int = 10,
vendor_id: int | None = None,
limit: int = 50,
) -> list[AdminLowStockItem]:
"""Get items with low stock levels (admin only)."""
query = (
db.query(Inventory)
.join(Product)
.join(Vendor)
.filter(Inventory.quantity <= threshold)
)
if vendor_id is not None:
query = query.filter(Inventory.vendor_id == vendor_id)
# Order by quantity ascending (most critical first)
query = query.order_by(Inventory.quantity.asc())
inventories = query.limit(limit).all()
items = []
for inv in inventories:
product = inv.product
vendor = inv.vendor
title = None
if product and product.marketplace_product:
title = product.marketplace_product.get_title()
items.append(
AdminLowStockItem(
id=inv.id,
product_id=inv.product_id,
vendor_id=inv.vendor_id,
vendor_name=vendor.name if vendor else None,
product_title=title,
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
)
)
return items
def get_vendors_with_inventory_admin(
self, db: Session
) -> AdminVendorsWithInventoryResponse:
"""Get list of vendors that have inventory entries (admin only)."""
# noqa: SVC-005 - Admin function, intentionally cross-vendor
# Use subquery to avoid DISTINCT on JSON columns (PostgreSQL can't compare JSON)
vendor_ids_subquery = (
db.query(Inventory.vendor_id)
.distinct()
.subquery()
)
vendors = (
db.query(Vendor)
.filter(Vendor.id.in_(db.query(vendor_ids_subquery.c.vendor_id)))
.order_by(Vendor.name)
.all()
)
return AdminVendorsWithInventoryResponse(
vendors=[
AdminVendorWithInventory(
id=v.id, name=v.name, vendor_code=v.vendor_code
)
for v in vendors
]
)
def get_inventory_locations_admin(
self, db: Session, vendor_id: int | None = None
) -> AdminInventoryLocationsResponse:
"""Get list of unique inventory locations (admin only)."""
query = db.query(func.distinct(Inventory.location))
if vendor_id is not None:
query = query.filter(Inventory.vendor_id == vendor_id)
locations = [loc[0] for loc in query.all()]
return AdminInventoryLocationsResponse(locations=sorted(locations))
def get_vendor_inventory_admin(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 50,
location: str | None = None,
low_stock: int | None = None,
) -> AdminInventoryListResponse:
"""Get inventory for a specific vendor (admin only)."""
# Verify vendor exists
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
# Use the existing method
inventories = self.get_vendor_inventory(
db=db,
vendor_id=vendor_id,
skip=skip,
limit=limit,
location=location,
low_stock_threshold=low_stock,
)
# Build response with product info
items = []
for inv in inventories:
product = inv.product
title = None
if product and product.marketplace_product:
title = product.marketplace_product.get_title()
items.append(
AdminInventoryItem(
id=inv.id,
product_id=inv.product_id,
vendor_id=inv.vendor_id,
vendor_name=vendor.name,
vendor_code=vendor.vendor_code,
product_title=title,
product_sku=product.vendor_sku if product else None,
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
gtin=inv.gtin,
created_at=inv.created_at,
updated_at=inv.updated_at,
)
)
# Get total count for pagination
total_query = db.query(func.count(Inventory.id)).filter(
Inventory.vendor_id == vendor_id
)
if location:
total_query = total_query.filter(Inventory.location.ilike(f"%{location}%"))
if low_stock is not None:
total_query = total_query.filter(Inventory.quantity <= low_stock)
total = total_query.scalar() or 0
return AdminInventoryListResponse(
inventories=items,
total=total,
skip=skip,
limit=limit,
vendor_filter=vendor_id,
location_filter=location,
)
def get_product_inventory_admin(
self, db: Session, product_id: int
) -> ProductInventorySummary:
"""Get inventory summary for a product (admin only - no vendor check)."""
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise ProductNotFoundException(f"Product {product_id} not found")
# Use existing method with the product's vendor_id
return self.get_product_inventory(db, product.vendor_id, product_id)
def verify_vendor_exists(self, db: Session, vendor_id: int) -> Vendor:
"""Verify vendor exists and return it."""
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
return vendor
def get_inventory_by_id_admin(self, db: Session, inventory_id: int) -> Inventory:
"""Get inventory by ID (admin only - returns inventory with vendor_id)."""
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
if not inventory:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
return inventory
# =========================================================================
# Private helper methods
# =========================================================================
def _get_vendor_product(
self, db: Session, vendor_id: int, product_id: int
) -> Product:
"""Get product and verify it belongs to vendor."""
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(
f"Product {product_id} not found in your catalog"
)
return product
def _get_inventory_entry(
self, db: Session, product_id: int, location: str
) -> Inventory | None:
"""Get inventory entry by product and location."""
return (
db.query(Inventory)
.filter(Inventory.product_id == product_id, Inventory.location == location)
.first()
)
def _get_inventory_by_id(self, db: Session, inventory_id: int) -> Inventory:
"""Get inventory by ID or raise exception."""
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
if not inventory:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
return inventory
def _validate_location(self, location: str) -> str:
"""Validate and normalize location."""
if not location or not location.strip():
raise InventoryValidationException("Location is required")
return location.strip().upper()
def _validate_quantity(self, quantity: int, allow_zero: bool = True) -> None:
"""Validate quantity value."""
if quantity is None:
raise InvalidQuantityException("Quantity is required")
if not isinstance(quantity, int):
raise InvalidQuantityException("Quantity must be an integer")
if quantity < 0:
raise InvalidQuantityException("Quantity cannot be negative")
if not allow_zero and quantity == 0:
raise InvalidQuantityException("Quantity must be positive")
# Create service instance
inventory_service = InventoryService()

View File

@@ -0,0 +1,431 @@
# app/modules/inventory/services/inventory_transaction_service.py
"""
Inventory Transaction Service.
Provides query operations for inventory transaction history.
All transaction WRITES are handled by OrderInventoryService.
This service handles transaction READS for reporting and auditing.
"""
import logging
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions import OrderNotFoundException, ProductNotFoundException
from app.modules.inventory.models.inventory import Inventory
from app.modules.inventory.models.inventory_transaction import InventoryTransaction
from models.database.order import Order
from models.database.product import Product
logger = logging.getLogger(__name__)
class InventoryTransactionService:
"""Service for querying inventory transaction history."""
def get_vendor_transactions(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 50,
product_id: int | None = None,
transaction_type: str | None = None,
) -> tuple[list[dict], int]:
"""
Get inventory transactions for a vendor with optional filters.
Args:
db: Database session
vendor_id: Vendor ID
skip: Pagination offset
limit: Pagination limit
product_id: Optional product filter
transaction_type: Optional transaction type filter
Returns:
Tuple of (transactions with product details, total count)
"""
# Build query
query = db.query(InventoryTransaction).filter(
InventoryTransaction.vendor_id == vendor_id
)
# Apply filters
if product_id:
query = query.filter(InventoryTransaction.product_id == product_id)
if transaction_type:
query = query.filter(
InventoryTransaction.transaction_type == transaction_type
)
# Get total count
total = query.count()
# Get transactions with pagination (newest first)
transactions = (
query.order_by(InventoryTransaction.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# Build result with product details
result = []
for tx in transactions:
product = db.query(Product).filter(Product.id == tx.product_id).first()
product_title = None
product_sku = None
if product:
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
result.append(
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
"product_title": product_title,
"product_sku": product_sku,
}
)
return result, total
def get_product_history(
self,
db: Session,
vendor_id: int,
product_id: int,
limit: int = 50,
) -> dict:
"""
Get transaction history for a specific product.
Args:
db: Database session
vendor_id: Vendor ID
product_id: Product ID
limit: Max transactions to return
Returns:
Dict with product info, current inventory, and transactions
Raises:
ProductNotFoundException: If product not found or doesn't belong to vendor
"""
# Get product details
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(
f"Product {product_id} not found in vendor catalog"
)
product_title = None
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
# Get current inventory
inventory = (
db.query(Inventory)
.filter(Inventory.product_id == product_id, Inventory.vendor_id == vendor_id)
.first()
)
current_quantity = inventory.quantity if inventory else 0
current_reserved = inventory.reserved_quantity if inventory else 0
# Get transactions
transactions = (
db.query(InventoryTransaction)
.filter(
InventoryTransaction.vendor_id == vendor_id,
InventoryTransaction.product_id == product_id,
)
.order_by(InventoryTransaction.created_at.desc())
.limit(limit)
.all()
)
total = (
db.query(func.count(InventoryTransaction.id))
.filter(
InventoryTransaction.vendor_id == vendor_id,
InventoryTransaction.product_id == product_id,
)
.scalar()
or 0
)
return {
"product_id": product_id,
"product_title": product_title,
"product_sku": product_sku,
"current_quantity": current_quantity,
"current_reserved": current_reserved,
"transactions": [
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
}
for tx in transactions
],
"total": total,
}
def get_order_history(
self,
db: Session,
vendor_id: int,
order_id: int,
) -> dict:
"""
Get all inventory transactions for a specific order.
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
Returns:
Dict with order info and transactions
Raises:
OrderNotFoundException: If order not found or doesn't belong to vendor
"""
# Verify order belongs to vendor
order = (
db.query(Order)
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
.first()
)
if not order:
raise OrderNotFoundException(f"Order {order_id} not found")
# Get transactions for this order
transactions = (
db.query(InventoryTransaction)
.filter(InventoryTransaction.order_id == order_id)
.order_by(InventoryTransaction.created_at.asc())
.all()
)
# Build result with product details
result = []
for tx in transactions:
product = db.query(Product).filter(Product.id == tx.product_id).first()
product_title = None
product_sku = None
if product:
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
result.append(
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
"product_title": product_title,
"product_sku": product_sku,
}
)
return {
"order_id": order_id,
"order_number": order.order_number,
"transactions": result,
}
# =========================================================================
# Admin Methods (cross-vendor operations)
# =========================================================================
def get_all_transactions_admin(
self,
db: Session,
skip: int = 0,
limit: int = 50,
vendor_id: int | None = None,
product_id: int | None = None,
transaction_type: str | None = None,
order_id: int | None = None,
) -> tuple[list[dict], int]:
"""
Get inventory transactions across all vendors (admin only).
Args:
db: Database session
skip: Pagination offset
limit: Pagination limit
vendor_id: Optional vendor filter
product_id: Optional product filter
transaction_type: Optional transaction type filter
order_id: Optional order filter
Returns:
Tuple of (transactions with details, total count)
"""
from models.database.vendor import Vendor
# Build query
query = db.query(InventoryTransaction)
# Apply filters
if vendor_id:
query = query.filter(InventoryTransaction.vendor_id == vendor_id)
if product_id:
query = query.filter(InventoryTransaction.product_id == product_id)
if transaction_type:
query = query.filter(
InventoryTransaction.transaction_type == transaction_type
)
if order_id:
query = query.filter(InventoryTransaction.order_id == order_id)
# Get total count
total = query.count()
# Get transactions with pagination (newest first)
transactions = (
query.order_by(InventoryTransaction.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# Build result with vendor and product details
result = []
for tx in transactions:
vendor = db.query(Vendor).filter(Vendor.id == tx.vendor_id).first()
product = db.query(Product).filter(Product.id == tx.product_id).first()
product_title = None
product_sku = None
if product:
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
result.append(
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"vendor_name": vendor.name if vendor else None,
"vendor_code": vendor.vendor_code if vendor else None,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
"product_title": product_title,
"product_sku": product_sku,
}
)
return result, total
def get_transaction_stats_admin(self, db: Session) -> dict:
"""
Get transaction statistics across the platform (admin only).
Returns:
Dict with transaction counts by type
"""
from sqlalchemy import func as sql_func
# Count by transaction type
type_counts = (
db.query(
InventoryTransaction.transaction_type,
sql_func.count(InventoryTransaction.id).label("count"),
)
.group_by(InventoryTransaction.transaction_type)
.all()
)
# Total transactions
total = db.query(sql_func.count(InventoryTransaction.id)).scalar() or 0
# Transactions today
from datetime import UTC, datetime, timedelta
today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
today_count = (
db.query(sql_func.count(InventoryTransaction.id))
.filter(InventoryTransaction.created_at >= today_start)
.scalar()
or 0
)
return {
"total_transactions": total,
"transactions_today": today_count,
"by_type": {tc.transaction_type.value: tc.count for tc in type_counts},
}
# Create service instance
inventory_transaction_service = InventoryTransactionService()

View File

@@ -2,45 +2,35 @@
"""
Marketplace module services.
Re-exports Letzshop and marketplace services from their current locations.
Services remain in app/services/ for now to avoid breaking existing imports.
Usage:
from app.modules.marketplace.services import (
letzshop_export_service,
marketplace_import_job_service,
marketplace_product_service,
)
from app.modules.marketplace.services.letzshop import (
LetzshopClient,
LetzshopCredentialsService,
LetzshopOrderService,
LetzshopVendorSyncService,
)
This module contains the canonical implementations of marketplace-related services.
"""
# Re-export from existing locations for convenience
from app.services.letzshop_export_service import (
# Main marketplace services
from app.modules.marketplace.services.letzshop_export_service import (
LetzshopExportService,
letzshop_export_service,
)
from app.services.marketplace_import_job_service import (
from app.modules.marketplace.services.marketplace_import_job_service import (
MarketplaceImportJobService,
marketplace_import_job_service,
)
from app.services.marketplace_product_service import (
from app.modules.marketplace.services.marketplace_product_service import (
MarketplaceProductService,
marketplace_product_service,
)
# Letzshop submodule re-exports
from app.services.letzshop import (
# Letzshop submodule services
from app.modules.marketplace.services.letzshop import (
LetzshopClient,
LetzshopClientError,
)
from app.services.letzshop.credentials_service import LetzshopCredentialsService
from app.services.letzshop.order_service import LetzshopOrderService
from app.services.letzshop.vendor_sync_service import LetzshopVendorSyncService
from app.modules.marketplace.services.letzshop.credentials_service import (
LetzshopCredentialsService,
)
from app.modules.marketplace.services.letzshop.order_service import LetzshopOrderService
from app.modules.marketplace.services.letzshop.vendor_sync_service import (
LetzshopVendorSyncService,
)
__all__ = [
# Export service

View File

@@ -0,0 +1,53 @@
# app/modules/marketplace/services/letzshop/__init__.py
"""
Letzshop marketplace integration services.
Provides:
- GraphQL client for API communication
- Credential management service
- Order import service
- Fulfillment sync service
- Vendor directory sync service
"""
from .client_service import (
LetzshopAPIError,
LetzshopAuthError,
LetzshopClient,
LetzshopClientError,
LetzshopConnectionError,
)
from .credentials_service import (
CredentialsError,
CredentialsNotFoundError,
LetzshopCredentialsService,
)
from .order_service import (
LetzshopOrderService,
OrderNotFoundError,
VendorNotFoundError,
)
from .vendor_sync_service import (
LetzshopVendorSyncService,
get_vendor_sync_service,
)
__all__ = [
# Client
"LetzshopClient",
"LetzshopClientError",
"LetzshopAuthError",
"LetzshopAPIError",
"LetzshopConnectionError",
# Credentials
"LetzshopCredentialsService",
"CredentialsError",
"CredentialsNotFoundError",
# Order Service
"LetzshopOrderService",
"OrderNotFoundError",
"VendorNotFoundError",
# Vendor Sync Service
"LetzshopVendorSyncService",
"get_vendor_sync_service",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,400 @@
# app/services/letzshop/credentials_service.py
"""
Letzshop credentials management service.
Handles secure storage and retrieval of per-vendor Letzshop API credentials.
"""
import logging
from datetime import UTC, datetime
from sqlalchemy.orm import Session
from app.utils.encryption import decrypt_value, encrypt_value, mask_api_key
from models.database.letzshop import VendorLetzshopCredentials
from .client_service import LetzshopClient
logger = logging.getLogger(__name__)
# Default Letzshop GraphQL endpoint
DEFAULT_ENDPOINT = "https://letzshop.lu/graphql"
class CredentialsError(Exception):
"""Base exception for credentials errors."""
class CredentialsNotFoundError(CredentialsError):
"""Raised when credentials are not found for a vendor."""
class LetzshopCredentialsService:
"""
Service for managing Letzshop API credentials.
Provides secure storage and retrieval of encrypted API keys,
connection testing, and sync status updates.
"""
def __init__(self, db: Session):
"""
Initialize the credentials service.
Args:
db: SQLAlchemy database session.
"""
self.db = db
# ========================================================================
# CRUD Operations
# ========================================================================
def get_credentials(self, vendor_id: int) -> VendorLetzshopCredentials | None:
"""
Get Letzshop credentials for a vendor.
Args:
vendor_id: The vendor ID.
Returns:
VendorLetzshopCredentials or None if not found.
"""
return (
self.db.query(VendorLetzshopCredentials)
.filter(VendorLetzshopCredentials.vendor_id == vendor_id)
.first()
)
def get_credentials_or_raise(self, vendor_id: int) -> VendorLetzshopCredentials:
"""
Get Letzshop credentials for a vendor or raise an exception.
Args:
vendor_id: The vendor ID.
Returns:
VendorLetzshopCredentials.
Raises:
CredentialsNotFoundError: If credentials are not found.
"""
credentials = self.get_credentials(vendor_id)
if credentials is None:
raise CredentialsNotFoundError(
f"Letzshop credentials not found for vendor {vendor_id}"
)
return credentials
def create_credentials(
self,
vendor_id: int,
api_key: str,
api_endpoint: str | None = None,
auto_sync_enabled: bool = False,
sync_interval_minutes: int = 15,
) -> VendorLetzshopCredentials:
"""
Create Letzshop credentials for a vendor.
Args:
vendor_id: The vendor ID.
api_key: The Letzshop API key (will be encrypted).
api_endpoint: Custom API endpoint (optional).
auto_sync_enabled: Whether to enable automatic sync.
sync_interval_minutes: Sync interval in minutes.
Returns:
Created VendorLetzshopCredentials.
"""
# Encrypt the API key
encrypted_key = encrypt_value(api_key)
credentials = VendorLetzshopCredentials(
vendor_id=vendor_id,
api_key_encrypted=encrypted_key,
api_endpoint=api_endpoint or DEFAULT_ENDPOINT,
auto_sync_enabled=auto_sync_enabled,
sync_interval_minutes=sync_interval_minutes,
)
self.db.add(credentials)
self.db.flush()
logger.info(f"Created Letzshop credentials for vendor {vendor_id}")
return credentials
def update_credentials(
self,
vendor_id: int,
api_key: str | None = None,
api_endpoint: str | None = None,
auto_sync_enabled: bool | None = None,
sync_interval_minutes: int | None = None,
) -> VendorLetzshopCredentials:
"""
Update Letzshop credentials for a vendor.
Args:
vendor_id: The vendor ID.
api_key: New API key (optional, will be encrypted if provided).
api_endpoint: New API endpoint (optional).
auto_sync_enabled: New auto-sync setting (optional).
sync_interval_minutes: New sync interval (optional).
Returns:
Updated VendorLetzshopCredentials.
Raises:
CredentialsNotFoundError: If credentials are not found.
"""
credentials = self.get_credentials_or_raise(vendor_id)
if api_key is not None:
credentials.api_key_encrypted = encrypt_value(api_key)
if api_endpoint is not None:
credentials.api_endpoint = api_endpoint
if auto_sync_enabled is not None:
credentials.auto_sync_enabled = auto_sync_enabled
if sync_interval_minutes is not None:
credentials.sync_interval_minutes = sync_interval_minutes
self.db.flush()
logger.info(f"Updated Letzshop credentials for vendor {vendor_id}")
return credentials
def delete_credentials(self, vendor_id: int) -> bool:
"""
Delete Letzshop credentials for a vendor.
Args:
vendor_id: The vendor ID.
Returns:
True if deleted, False if not found.
"""
credentials = self.get_credentials(vendor_id)
if credentials is None:
return False
self.db.delete(credentials)
self.db.flush()
logger.info(f"Deleted Letzshop credentials for vendor {vendor_id}")
return True
def upsert_credentials(
self,
vendor_id: int,
api_key: str,
api_endpoint: str | None = None,
auto_sync_enabled: bool = False,
sync_interval_minutes: int = 15,
) -> VendorLetzshopCredentials:
"""
Create or update Letzshop credentials for a vendor.
Args:
vendor_id: The vendor ID.
api_key: The Letzshop API key (will be encrypted).
api_endpoint: Custom API endpoint (optional).
auto_sync_enabled: Whether to enable automatic sync.
sync_interval_minutes: Sync interval in minutes.
Returns:
Created or updated VendorLetzshopCredentials.
"""
existing = self.get_credentials(vendor_id)
if existing:
return self.update_credentials(
vendor_id=vendor_id,
api_key=api_key,
api_endpoint=api_endpoint,
auto_sync_enabled=auto_sync_enabled,
sync_interval_minutes=sync_interval_minutes,
)
return self.create_credentials(
vendor_id=vendor_id,
api_key=api_key,
api_endpoint=api_endpoint,
auto_sync_enabled=auto_sync_enabled,
sync_interval_minutes=sync_interval_minutes,
)
# ========================================================================
# Key Decryption and Client Creation
# ========================================================================
def get_decrypted_api_key(self, vendor_id: int) -> str:
"""
Get the decrypted API key for a vendor.
Args:
vendor_id: The vendor ID.
Returns:
Decrypted API key.
Raises:
CredentialsNotFoundError: If credentials are not found.
"""
credentials = self.get_credentials_or_raise(vendor_id)
return decrypt_value(credentials.api_key_encrypted)
def get_masked_api_key(self, vendor_id: int) -> str:
"""
Get a masked version of the API key for display.
Args:
vendor_id: The vendor ID.
Returns:
Masked API key (e.g., "sk-a***************").
Raises:
CredentialsNotFoundError: If credentials are not found.
"""
api_key = self.get_decrypted_api_key(vendor_id)
return mask_api_key(api_key)
def create_client(self, vendor_id: int) -> LetzshopClient:
"""
Create a Letzshop client for a vendor.
Args:
vendor_id: The vendor ID.
Returns:
Configured LetzshopClient.
Raises:
CredentialsNotFoundError: If credentials are not found.
"""
credentials = self.get_credentials_or_raise(vendor_id)
api_key = decrypt_value(credentials.api_key_encrypted)
return LetzshopClient(
api_key=api_key,
endpoint=credentials.api_endpoint,
)
# ========================================================================
# Connection Testing
# ========================================================================
def test_connection(self, vendor_id: int) -> tuple[bool, float | None, str | None]:
"""
Test the connection for a vendor's credentials.
Args:
vendor_id: The vendor ID.
Returns:
Tuple of (success, response_time_ms, error_message).
"""
try:
with self.create_client(vendor_id) as client:
return client.test_connection()
except CredentialsNotFoundError:
return False, None, "Letzshop credentials not configured"
except Exception as e:
logger.error(f"Connection test failed for vendor {vendor_id}: {e}")
return False, None, str(e)
def test_api_key(
self,
api_key: str,
api_endpoint: str | None = None,
) -> tuple[bool, float | None, str | None]:
"""
Test an API key without saving it.
Args:
api_key: The API key to test.
api_endpoint: Optional custom endpoint.
Returns:
Tuple of (success, response_time_ms, error_message).
"""
try:
with LetzshopClient(
api_key=api_key,
endpoint=api_endpoint or DEFAULT_ENDPOINT,
) as client:
return client.test_connection()
except Exception as e:
logger.error(f"API key test failed: {e}")
return False, None, str(e)
# ========================================================================
# Sync Status Updates
# ========================================================================
def update_sync_status(
self,
vendor_id: int,
status: str,
error: str | None = None,
) -> VendorLetzshopCredentials | None:
"""
Update the last sync status for a vendor.
Args:
vendor_id: The vendor ID.
status: Sync status (success, failed, partial).
error: Error message if sync failed.
Returns:
Updated credentials or None if not found.
"""
credentials = self.get_credentials(vendor_id)
if credentials is None:
return None
credentials.last_sync_at = datetime.now(UTC)
credentials.last_sync_status = status
credentials.last_sync_error = error
self.db.flush()
return credentials
# ========================================================================
# Status Helpers
# ========================================================================
def is_configured(self, vendor_id: int) -> bool:
"""Check if Letzshop is configured for a vendor."""
return self.get_credentials(vendor_id) is not None
def get_status(self, vendor_id: int) -> dict:
"""
Get the Letzshop integration status for a vendor.
Args:
vendor_id: The vendor ID.
Returns:
Status dictionary with configuration and sync info.
"""
credentials = self.get_credentials(vendor_id)
if credentials is None:
return {
"is_configured": False,
"is_connected": False,
"last_sync_at": None,
"last_sync_status": None,
"auto_sync_enabled": False,
}
return {
"is_configured": True,
"is_connected": credentials.last_sync_status == "success",
"last_sync_at": credentials.last_sync_at,
"last_sync_status": credentials.last_sync_status,
"auto_sync_enabled": credentials.auto_sync_enabled,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,521 @@
# app/services/letzshop/vendor_sync_service.py
"""
Service for syncing Letzshop vendor directory to local cache.
Fetches vendor data from Letzshop's public GraphQL API and stores it
in the letzshop_vendor_cache table for fast lookups during signup.
"""
import logging
from datetime import UTC, datetime
from typing import Any, Callable
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from app.services.letzshop.client_service import LetzshopClient
from models.database.letzshop import LetzshopVendorCache
logger = logging.getLogger(__name__)
class LetzshopVendorSyncService:
"""
Service for syncing Letzshop vendor directory.
Usage:
service = LetzshopVendorSyncService(db)
stats = service.sync_all_vendors()
"""
def __init__(self, db: Session):
"""Initialize the sync service."""
self.db = db
def sync_all_vendors(
self,
progress_callback: Callable[[int, int, int], None] | None = None,
max_pages: int | None = None,
) -> dict[str, Any]:
"""
Sync all vendors from Letzshop to local cache.
Args:
progress_callback: Optional callback(page, fetched, total) for progress.
Returns:
Dictionary with sync statistics.
"""
stats = {
"started_at": datetime.now(UTC),
"total_fetched": 0,
"created": 0,
"updated": 0,
"errors": 0,
"error_details": [],
}
logger.info("Starting Letzshop vendor directory sync...")
# Create client (no API key needed for public vendor data)
client = LetzshopClient(api_key="")
try:
# Fetch all vendors
vendors = client.get_all_vendors_paginated(
page_size=50,
max_pages=max_pages,
progress_callback=progress_callback,
)
stats["total_fetched"] = len(vendors)
logger.info(f"Fetched {len(vendors)} vendors from Letzshop")
# Process each vendor
for vendor_data in vendors:
try:
result = self._upsert_vendor(vendor_data)
if result == "created":
stats["created"] += 1
elif result == "updated":
stats["updated"] += 1
except Exception as e:
stats["errors"] += 1
error_info = {
"vendor_id": vendor_data.get("id"),
"slug": vendor_data.get("slug"),
"error": str(e),
}
stats["error_details"].append(error_info)
logger.error(f"Error processing vendor {vendor_data.get('slug')}: {e}")
# Commit all changes
self.db.commit()
logger.info(
f"Sync complete: {stats['created']} created, "
f"{stats['updated']} updated, {stats['errors']} errors"
)
except Exception as e:
self.db.rollback()
logger.error(f"Vendor sync failed: {e}")
stats["error"] = str(e)
raise
finally:
client.close()
stats["completed_at"] = datetime.now(UTC)
stats["duration_seconds"] = (
stats["completed_at"] - stats["started_at"]
).total_seconds()
return stats
def _upsert_vendor(self, vendor_data: dict[str, Any]) -> str:
"""
Insert or update a vendor in the cache.
Args:
vendor_data: Raw vendor data from Letzshop API.
Returns:
"created" or "updated" indicating the operation performed.
"""
letzshop_id = vendor_data.get("id")
slug = vendor_data.get("slug")
if not letzshop_id or not slug:
raise ValueError("Vendor missing required id or slug")
# Parse the vendor data
parsed = self._parse_vendor_data(vendor_data)
# Check if exists
existing = (
self.db.query(LetzshopVendorCache)
.filter(LetzshopVendorCache.letzshop_id == letzshop_id)
.first()
)
if existing:
# Update existing record (preserve claimed status)
for key, value in parsed.items():
if key not in ("claimed_by_vendor_id", "claimed_at"):
setattr(existing, key, value)
existing.last_synced_at = datetime.now(UTC)
return "updated"
else:
# Create new record
cache_entry = LetzshopVendorCache(
**parsed,
last_synced_at=datetime.now(UTC),
)
self.db.add(cache_entry)
return "created"
def _parse_vendor_data(self, data: dict[str, Any]) -> dict[str, Any]:
"""
Parse raw Letzshop vendor data into cache model fields.
Args:
data: Raw vendor data from Letzshop API.
Returns:
Dictionary of parsed fields for LetzshopVendorCache.
"""
# Extract location
location = data.get("location") or {}
country = location.get("country") or {}
# Extract descriptions
description = data.get("description") or {}
# Extract opening hours
opening_hours = data.get("openingHours") or {}
# Extract categories (list of translated name objects)
categories = []
for cat in data.get("vendorCategories") or []:
cat_name = cat.get("name") or {}
# Prefer English, fallback to French or German
name = cat_name.get("en") or cat_name.get("fr") or cat_name.get("de")
if name:
categories.append(name)
# Extract social media URLs
social_links = []
for link in data.get("socialMediaLinks") or []:
url = link.get("url")
if url:
social_links.append(url)
# Extract background image
bg_image = data.get("backgroundImage") or {}
return {
"letzshop_id": data.get("id"),
"slug": data.get("slug"),
"name": data.get("name"),
"company_name": data.get("companyName") or data.get("legalName"),
"is_active": data.get("active", True),
# Descriptions
"description_en": description.get("en"),
"description_fr": description.get("fr"),
"description_de": description.get("de"),
# Contact
"email": data.get("email"),
"phone": data.get("phone"),
"fax": data.get("fax"),
"website": data.get("homepage"),
# Location
"street": location.get("street"),
"street_number": location.get("number"),
"city": location.get("city"),
"zipcode": location.get("zipcode"),
"country_iso": country.get("iso", "LU"),
"latitude": str(data.get("lat")) if data.get("lat") else None,
"longitude": str(data.get("lng")) if data.get("lng") else None,
# Categories and media
"categories": categories,
"background_image_url": bg_image.get("url"),
"social_media_links": social_links,
# Opening hours
"opening_hours_en": opening_hours.get("en"),
"opening_hours_fr": opening_hours.get("fr"),
"opening_hours_de": opening_hours.get("de"),
# Representative
"representative_name": data.get("representative"),
"representative_title": data.get("representativeTitle"),
# Raw data for reference
"raw_data": data,
}
def sync_single_vendor(self, slug: str) -> LetzshopVendorCache | None:
"""
Sync a single vendor by slug.
Useful for on-demand refresh when a user looks up a vendor.
Args:
slug: The vendor's URL slug.
Returns:
The updated/created cache entry, or None if not found.
"""
client = LetzshopClient(api_key="")
try:
vendor_data = client.get_vendor_by_slug(slug)
if not vendor_data:
logger.warning(f"Vendor not found on Letzshop: {slug}")
return None
result = self._upsert_vendor(vendor_data)
self.db.commit()
logger.info(f"Single vendor sync: {slug} ({result})")
return (
self.db.query(LetzshopVendorCache)
.filter(LetzshopVendorCache.slug == slug)
.first()
)
finally:
client.close()
def get_cached_vendor(self, slug: str) -> LetzshopVendorCache | None:
"""
Get a vendor from cache by slug.
Args:
slug: The vendor's URL slug.
Returns:
Cache entry or None if not found.
"""
return (
self.db.query(LetzshopVendorCache)
.filter(LetzshopVendorCache.slug == slug.lower())
.first()
)
def search_cached_vendors(
self,
search: str | None = None,
city: str | None = None,
category: str | None = None,
only_unclaimed: bool = False,
page: int = 1,
limit: int = 20,
) -> tuple[list[LetzshopVendorCache], int]:
"""
Search cached vendors with filters.
Args:
search: Search term for name.
city: Filter by city.
category: Filter by category.
only_unclaimed: Only return vendors not yet claimed.
page: Page number (1-indexed).
limit: Items per page.
Returns:
Tuple of (vendors list, total count).
"""
query = self.db.query(LetzshopVendorCache).filter(
LetzshopVendorCache.is_active == True # noqa: E712
)
if search:
search_term = f"%{search.lower()}%"
query = query.filter(
func.lower(LetzshopVendorCache.name).like(search_term)
)
if city:
query = query.filter(
func.lower(LetzshopVendorCache.city) == city.lower()
)
if category:
# Search in JSON array
query = query.filter(
LetzshopVendorCache.categories.contains([category])
)
if only_unclaimed:
query = query.filter(
LetzshopVendorCache.claimed_by_vendor_id.is_(None)
)
# Get total count
total = query.count()
# Apply pagination
offset = (page - 1) * limit
vendors = (
query.order_by(LetzshopVendorCache.name)
.offset(offset)
.limit(limit)
.all()
)
return vendors, total
def get_sync_stats(self) -> dict[str, Any]:
"""
Get statistics about the vendor cache.
Returns:
Dictionary with cache statistics.
"""
total = self.db.query(LetzshopVendorCache).count()
active = (
self.db.query(LetzshopVendorCache)
.filter(LetzshopVendorCache.is_active == True) # noqa: E712
.count()
)
claimed = (
self.db.query(LetzshopVendorCache)
.filter(LetzshopVendorCache.claimed_by_vendor_id.isnot(None))
.count()
)
# Get last sync time
last_synced = (
self.db.query(func.max(LetzshopVendorCache.last_synced_at)).scalar()
)
# Get unique cities
cities = (
self.db.query(LetzshopVendorCache.city)
.filter(LetzshopVendorCache.city.isnot(None))
.distinct()
.count()
)
return {
"total_vendors": total,
"active_vendors": active,
"claimed_vendors": claimed,
"unclaimed_vendors": active - claimed,
"unique_cities": cities,
"last_synced_at": last_synced.isoformat() if last_synced else None,
}
def mark_vendor_claimed(
self,
letzshop_slug: str,
vendor_id: int,
) -> bool:
"""
Mark a Letzshop vendor as claimed by a platform vendor.
Args:
letzshop_slug: The Letzshop vendor slug.
vendor_id: The platform vendor ID that claimed it.
Returns:
True if successful, False if vendor not found.
"""
cache_entry = self.get_cached_vendor(letzshop_slug)
if not cache_entry:
return False
cache_entry.claimed_by_vendor_id = vendor_id
cache_entry.claimed_at = datetime.now(UTC)
self.db.commit()
logger.info(f"Vendor {letzshop_slug} claimed by vendor_id={vendor_id}")
return True
def create_vendor_from_cache(
self,
letzshop_slug: str,
company_id: int,
) -> dict[str, Any]:
"""
Create a platform vendor from a cached Letzshop vendor.
Args:
letzshop_slug: The Letzshop vendor slug.
company_id: The company ID to create the vendor under.
Returns:
Dictionary with created vendor info.
Raises:
ValueError: If vendor not found, already claimed, or company not found.
"""
import random
from sqlalchemy import func
from app.services.admin_service import admin_service
from models.database.company import Company
from models.database.vendor import Vendor
from models.schema.vendor import VendorCreate
# Get cache entry
cache_entry = self.get_cached_vendor(letzshop_slug)
if not cache_entry:
raise ValueError(f"Letzshop vendor '{letzshop_slug}' not found in cache")
if cache_entry.is_claimed:
raise ValueError(
f"Letzshop vendor '{cache_entry.name}' is already claimed "
f"by vendor ID {cache_entry.claimed_by_vendor_id}"
)
# Verify company exists
company = self.db.query(Company).filter(Company.id == company_id).first()
if not company:
raise ValueError(f"Company with ID {company_id} not found")
# Generate vendor code from slug
vendor_code = letzshop_slug.upper().replace("-", "_")[:20]
# Check if vendor code already exists
existing = (
self.db.query(Vendor)
.filter(func.upper(Vendor.vendor_code) == vendor_code)
.first()
)
if existing:
vendor_code = f"{vendor_code[:16]}_{random.randint(100, 999)}"
# Generate subdomain from slug
subdomain = letzshop_slug.lower().replace("_", "-")[:30]
existing_subdomain = (
self.db.query(Vendor)
.filter(func.lower(Vendor.subdomain) == subdomain)
.first()
)
if existing_subdomain:
subdomain = f"{subdomain[:26]}-{random.randint(100, 999)}"
# Create vendor data from cache
address = f"{cache_entry.street or ''} {cache_entry.street_number or ''}".strip()
vendor_data = VendorCreate(
name=cache_entry.name,
vendor_code=vendor_code,
subdomain=subdomain,
company_id=company_id,
email=cache_entry.email or company.email,
phone=cache_entry.phone,
description=cache_entry.description_en or cache_entry.description_fr or "",
city=cache_entry.city,
country=cache_entry.country_iso or "LU",
website=cache_entry.website,
address_line_1=address or None,
postal_code=cache_entry.zipcode,
)
# Create vendor
vendor = admin_service.create_vendor(self.db, vendor_data)
# Mark the Letzshop vendor as claimed (commits internally) # noqa: SVC-006
self.mark_vendor_claimed(letzshop_slug, vendor.id)
logger.info(
f"Created vendor {vendor.vendor_code} from Letzshop vendor {letzshop_slug}"
)
return {
"id": vendor.id,
"vendor_code": vendor.vendor_code,
"name": vendor.name,
"subdomain": vendor.subdomain,
"company_id": vendor.company_id,
}
# Singleton-style function for easy access
def get_vendor_sync_service(db: Session) -> LetzshopVendorSyncService:
"""Get a vendor sync service instance."""
return LetzshopVendorSyncService(db)

View File

@@ -0,0 +1,338 @@
# app/services/letzshop_export_service.py
"""
Service for exporting products to Letzshop CSV format.
Generates Google Shopping compatible CSV files for Letzshop marketplace.
"""
import csv
import io
import logging
from datetime import UTC, datetime
from sqlalchemy.orm import Session, joinedload
from models.database.letzshop import LetzshopSyncLog
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
logger = logging.getLogger(__name__)
# Letzshop CSV columns in order
LETZSHOP_CSV_COLUMNS = [
"id",
"title",
"description",
"link",
"image_link",
"additional_image_link",
"availability",
"price",
"sale_price",
"brand",
"gtin",
"mpn",
"google_product_category",
"product_type",
"condition",
"adult",
"multipack",
"is_bundle",
"age_group",
"color",
"gender",
"material",
"pattern",
"size",
"size_type",
"size_system",
"item_group_id",
"custom_label_0",
"custom_label_1",
"custom_label_2",
"custom_label_3",
"custom_label_4",
"identifier_exists",
"unit_pricing_measure",
"unit_pricing_base_measure",
"shipping",
"atalanda:tax_rate",
"atalanda:quantity",
"atalanda:boost_sort",
"atalanda:delivery_method",
]
class LetzshopExportService:
"""Service for exporting products to Letzshop CSV format."""
def __init__(self, default_tax_rate: float = 17.0):
"""
Initialize the export service.
Args:
default_tax_rate: Default VAT rate for Luxembourg (17%)
"""
self.default_tax_rate = default_tax_rate
def export_vendor_products(
self,
db: Session,
vendor_id: int,
language: str = "en",
include_inactive: bool = False,
) -> str:
"""
Export all products for a vendor in Letzshop CSV format.
Args:
db: Database session
vendor_id: Vendor ID to export products for
language: Language for title/description (en, fr, de)
include_inactive: Whether to include inactive products
Returns:
CSV string content
"""
# Query products for this vendor with their marketplace product data
query = (
db.query(Product)
.filter(Product.vendor_id == vendor_id)
.options(
joinedload(Product.marketplace_product).joinedload(
MarketplaceProduct.translations
)
)
)
if not include_inactive:
query = query.filter(Product.is_active == True)
products = query.all()
logger.info(
f"Exporting {len(products)} products for vendor {vendor_id} in {language}"
)
return self._generate_csv(products, language)
def export_marketplace_products(
self,
db: Session,
marketplace: str = "Letzshop",
language: str = "en",
limit: int | None = None,
) -> str:
"""
Export marketplace products directly (admin use).
Args:
db: Database session
marketplace: Filter by marketplace source
language: Language for title/description
limit: Optional limit on number of products
Returns:
CSV string content
"""
query = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.is_active == True)
.options(joinedload(MarketplaceProduct.translations))
)
if marketplace:
query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if limit:
query = query.limit(limit)
products = query.all()
logger.info(
f"Exporting {len(products)} marketplace products for {marketplace} in {language}"
)
return self._generate_csv_from_marketplace_products(products, language)
def _generate_csv(self, products: list[Product], language: str) -> str:
"""Generate CSV from vendor Product objects."""
output = io.StringIO()
writer = csv.DictWriter(
output,
fieldnames=LETZSHOP_CSV_COLUMNS,
delimiter="\t",
quoting=csv.QUOTE_MINIMAL,
)
writer.writeheader()
for product in products:
if product.marketplace_product:
row = self._product_to_row(product, language)
writer.writerow(row)
return output.getvalue()
def _generate_csv_from_marketplace_products(
self, products: list[MarketplaceProduct], language: str
) -> str:
"""Generate CSV from MarketplaceProduct objects directly."""
output = io.StringIO()
writer = csv.DictWriter(
output,
fieldnames=LETZSHOP_CSV_COLUMNS,
delimiter="\t",
quoting=csv.QUOTE_MINIMAL,
)
writer.writeheader()
for mp in products:
row = self._marketplace_product_to_row(mp, language)
writer.writerow(row)
return output.getvalue()
def _product_to_row(self, product: Product, language: str) -> dict:
"""Convert a Product (with MarketplaceProduct) to a CSV row."""
mp = product.marketplace_product
return self._marketplace_product_to_row(
mp, language, vendor_sku=product.vendor_sku
)
def _marketplace_product_to_row(
self,
mp: MarketplaceProduct,
language: str,
vendor_sku: str | None = None,
) -> dict:
"""Convert a MarketplaceProduct to a CSV row dict."""
# Get localized title and description
title = mp.get_title(language) or ""
description = mp.get_description(language) or ""
# Format price with currency
price = ""
if mp.price_numeric:
price = f"{mp.price_numeric:.2f} {mp.currency or 'EUR'}"
elif mp.price:
price = mp.price
# Format sale price
sale_price = ""
if mp.sale_price_numeric:
sale_price = f"{mp.sale_price_numeric:.2f} {mp.currency or 'EUR'}"
elif mp.sale_price:
sale_price = mp.sale_price
# Additional images - join with comma if multiple
additional_images = ""
if mp.additional_images:
additional_images = ",".join(mp.additional_images)
elif mp.additional_image_link:
additional_images = mp.additional_image_link
# Determine identifier_exists
identifier_exists = mp.identifier_exists
if not identifier_exists:
identifier_exists = "yes" if (mp.gtin or mp.mpn) else "no"
return {
"id": vendor_sku or mp.marketplace_product_id,
"title": title,
"description": description,
"link": mp.link or mp.source_url or "",
"image_link": mp.image_link or "",
"additional_image_link": additional_images,
"availability": mp.availability or "in stock",
"price": price,
"sale_price": sale_price,
"brand": mp.brand or "",
"gtin": mp.gtin or "",
"mpn": mp.mpn or "",
"google_product_category": mp.google_product_category or "",
"product_type": mp.product_type_raw or "",
"condition": mp.condition or "new",
"adult": mp.adult or "no",
"multipack": str(mp.multipack) if mp.multipack else "",
"is_bundle": mp.is_bundle or "no",
"age_group": mp.age_group or "",
"color": mp.color or "",
"gender": mp.gender or "",
"material": mp.material or "",
"pattern": mp.pattern or "",
"size": mp.size or "",
"size_type": mp.size_type or "",
"size_system": mp.size_system or "",
"item_group_id": mp.item_group_id or "",
"custom_label_0": mp.custom_label_0 or "",
"custom_label_1": mp.custom_label_1 or "",
"custom_label_2": mp.custom_label_2 or "",
"custom_label_3": mp.custom_label_3 or "",
"custom_label_4": mp.custom_label_4 or "",
"identifier_exists": identifier_exists,
"unit_pricing_measure": mp.unit_pricing_measure or "",
"unit_pricing_base_measure": mp.unit_pricing_base_measure or "",
"shipping": mp.shipping or "",
"atalanda:tax_rate": str(self.default_tax_rate),
"atalanda:quantity": "", # Would need inventory data
"atalanda:boost_sort": "",
"atalanda:delivery_method": "",
}
def log_export(
self,
db: Session,
vendor_id: int,
started_at: datetime,
completed_at: datetime,
files_processed: int,
files_succeeded: int,
files_failed: int,
products_exported: int,
triggered_by: str,
error_details: dict | None = None,
) -> LetzshopSyncLog:
"""
Log an export operation to the sync log.
Args:
db: Database session
vendor_id: Vendor ID
started_at: When the export started
completed_at: When the export completed
files_processed: Number of language files to export (e.g., 3)
files_succeeded: Number of files successfully exported
files_failed: Number of files that failed
products_exported: Total products in the export
triggered_by: Who triggered the export (e.g., "admin:123")
error_details: Optional error details if any failures
Returns:
Created LetzshopSyncLog entry
"""
sync_log = LetzshopSyncLog(
vendor_id=vendor_id,
operation_type="product_export",
direction="outbound",
status="completed" if files_failed == 0 else "partial",
records_processed=files_processed,
records_succeeded=files_succeeded,
records_failed=files_failed,
started_at=started_at,
completed_at=completed_at,
duration_seconds=int((completed_at - started_at).total_seconds()),
triggered_by=triggered_by,
error_details={
"products_exported": products_exported,
**(error_details or {}),
} if products_exported or error_details else None,
)
db.add(sync_log)
db.flush()
return sync_log
# Singleton instance
letzshop_export_service = LetzshopExportService()

View File

@@ -0,0 +1,334 @@
# app/services/marketplace_import_job_service.py
import logging
from sqlalchemy.orm import Session
from app.exceptions import (
ImportJobNotFoundException,
ImportJobNotOwnedException,
ValidationException,
)
from models.database.marketplace_import_job import (
MarketplaceImportError,
MarketplaceImportJob,
)
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (
AdminMarketplaceImportJobResponse,
MarketplaceImportJobRequest,
MarketplaceImportJobResponse,
)
logger = logging.getLogger(__name__)
class MarketplaceImportJobService:
"""Service class for Marketplace operations."""
def create_import_job(
self,
db: Session,
request: MarketplaceImportJobRequest,
vendor: Vendor, # CHANGED: Vendor object from middleware
user: User,
) -> MarketplaceImportJob:
"""
Create a new marketplace import job.
Args:
db: Database session
request: Import request data
vendor: Vendor object (from middleware)
user: User creating the job
Returns:
Created MarketplaceImportJob object
"""
try:
# Create marketplace import job record
import_job = MarketplaceImportJob(
status="pending",
source_url=request.source_url,
marketplace=request.marketplace,
language=request.language,
vendor_id=vendor.id,
user_id=user.id,
)
db.add(import_job)
db.flush()
db.refresh(import_job)
logger.info(
f"Created marketplace import job {import_job.id}: "
f"{request.marketplace} -> {vendor.name} (code: {vendor.vendor_code}) "
f"by user {user.username}"
)
return import_job
except Exception as e:
logger.error(f"Error creating import job: {str(e)}")
raise ValidationException("Failed to create import job")
def get_import_job_by_id(
self, db: Session, job_id: int, user: User
) -> MarketplaceImportJob:
"""Get a marketplace import job by ID with access control."""
try:
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ImportJobNotFoundException(job_id)
# Users can only see their own jobs, admins can see all
if user.role != "admin" and job.user_id != user.id:
raise ImportJobNotOwnedException(job_id, user.id)
return job
except (ImportJobNotFoundException, ImportJobNotOwnedException):
raise
except Exception as e:
logger.error(f"Error getting import job {job_id}: {str(e)}")
raise ValidationException("Failed to retrieve import job")
def get_import_job_for_vendor(
self, db: Session, job_id: int, vendor_id: int
) -> MarketplaceImportJob:
"""
Get a marketplace import job by ID with vendor access control.
Validates that the job belongs to the specified vendor.
Args:
db: Database session
job_id: Import job ID
vendor_id: Vendor ID from token (to verify ownership)
Raises:
ImportJobNotFoundException: If job not found
UnauthorizedVendorAccessException: If job doesn't belong to vendor
"""
from app.exceptions import UnauthorizedVendorAccessException
try:
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ImportJobNotFoundException(job_id)
# Verify job belongs to vendor (service layer validation)
if job.vendor_id != vendor_id:
raise UnauthorizedVendorAccessException(
vendor_code=str(vendor_id),
user_id=0, # Not user-specific, but vendor mismatch
)
return job
except (ImportJobNotFoundException, UnauthorizedVendorAccessException):
raise
except Exception as e:
logger.error(
f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}"
)
raise ValidationException("Failed to retrieve import job")
def get_import_jobs(
self,
db: Session,
vendor: Vendor, # ADDED: Vendor filter
user: User,
marketplace: str | None = None,
skip: int = 0,
limit: int = 50,
) -> list[MarketplaceImportJob]:
"""Get marketplace import jobs for a specific vendor."""
try:
query = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.vendor_id == vendor.id
)
# Users can only see their own jobs, admins can see all vendor jobs
if user.role != "admin":
query = query.filter(MarketplaceImportJob.user_id == user.id)
# Apply marketplace filter
if marketplace:
query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
# Order by creation date (newest first) and apply pagination
jobs = (
query.order_by(
MarketplaceImportJob.created_at.desc(),
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
)
.offset(skip)
.limit(limit)
.all()
)
return jobs
except Exception as e:
logger.error(f"Error getting import jobs: {str(e)}")
raise ValidationException("Failed to retrieve import jobs")
def convert_to_response_model(
self, job: MarketplaceImportJob
) -> MarketplaceImportJobResponse:
"""Convert database model to API response model."""
return MarketplaceImportJobResponse(
job_id=job.id,
status=job.status,
marketplace=job.marketplace,
language=job.language,
vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None,
vendor_name=job.vendor.name if job.vendor else None,
source_url=job.source_url,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
total_processed=job.total_processed or 0,
error_count=job.error_count or 0,
error_message=job.error_message,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
)
def convert_to_admin_response_model(
self, job: MarketplaceImportJob
) -> AdminMarketplaceImportJobResponse:
"""Convert database model to admin API response model with extra fields."""
return AdminMarketplaceImportJobResponse(
id=job.id,
job_id=job.id,
status=job.status,
marketplace=job.marketplace,
language=job.language,
vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None,
vendor_name=job.vendor.name if job.vendor else None,
source_url=job.source_url,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
total_processed=job.total_processed or 0,
error_count=job.error_count or 0,
error_message=job.error_message,
error_details=[],
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
created_by_name=job.user.username if job.user else None,
)
def get_all_import_jobs_paginated(
self,
db: Session,
marketplace: str | None = None,
status: str | None = None,
page: int = 1,
limit: int = 100,
) -> tuple[list[MarketplaceImportJob], int]:
"""Get all marketplace import jobs with pagination (for admin)."""
try:
query = db.query(MarketplaceImportJob)
if marketplace:
query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if status:
query = query.filter(MarketplaceImportJob.status == status)
total = query.count()
skip = (page - 1) * limit
jobs = (
query.order_by(
MarketplaceImportJob.created_at.desc(),
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
)
.offset(skip)
.limit(limit)
.all()
)
return jobs, total
except Exception as e:
logger.error(f"Error getting all import jobs: {str(e)}")
raise ValidationException("Failed to retrieve import jobs")
def get_import_job_by_id_admin(
self, db: Session, job_id: int
) -> MarketplaceImportJob:
"""Get a marketplace import job by ID (admin - no access control)."""
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ImportJobNotFoundException(job_id)
return job
def get_import_job_errors(
self,
db: Session,
job_id: int,
error_type: str | None = None,
page: int = 1,
limit: int = 50,
) -> tuple[list[MarketplaceImportError], int]:
"""
Get import errors for a specific job with pagination.
Args:
db: Database session
job_id: Import job ID
error_type: Optional filter by error type
page: Page number (1-indexed)
limit: Number of items per page
Returns:
Tuple of (list of errors, total count)
"""
try:
query = db.query(MarketplaceImportError).filter(
MarketplaceImportError.import_job_id == job_id
)
if error_type:
query = query.filter(MarketplaceImportError.error_type == error_type)
total = query.count()
offset = (page - 1) * limit
errors = (
query.order_by(MarketplaceImportError.row_number)
.offset(offset)
.limit(limit)
.all()
)
return errors, total
except Exception as e:
logger.error(f"Error getting import job errors for job {job_id}: {str(e)}")
raise ValidationException("Failed to retrieve import errors")
marketplace_import_job_service = MarketplaceImportJobService()

File diff suppressed because it is too large Load Diff

View File

@@ -2,10 +2,10 @@
"""
Messaging module database models.
Re-exports messaging-related models from their source locations.
This module contains the canonical implementations of messaging-related models.
"""
from models.database.message import (
from app.modules.messaging.models.message import (
Conversation,
ConversationParticipant,
ConversationType,
@@ -13,7 +13,7 @@ from models.database.message import (
MessageAttachment,
ParticipantType,
)
from models.database.admin import AdminNotification
from app.modules.messaging.models.admin_notification import AdminNotification
__all__ = [
"Conversation",

View File

@@ -0,0 +1,54 @@
# app/modules/messaging/models/admin_notification.py
"""
Admin notification database model.
This model handles admin-specific notifications for system alerts and warnings.
"""
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
JSON,
String,
Text,
)
from sqlalchemy.orm import relationship
from app.core.database import Base
from models.database.base import TimestampMixin
class AdminNotification(Base, TimestampMixin):
"""
Admin-specific notifications for system alerts and warnings.
Different from vendor/customer notifications - these are for platform
administrators to track system health and issues requiring attention.
"""
__tablename__ = "admin_notifications"
id = Column(Integer, primary_key=True, index=True)
type = Column(
String(50), nullable=False, index=True
) # system_alert, vendor_issue, import_failure
priority = Column(
String(20), default="normal", index=True
) # low, normal, high, critical
title = Column(String(200), nullable=False)
message = Column(Text, nullable=False)
is_read = Column(Boolean, default=False, index=True)
read_at = Column(DateTime, nullable=True)
read_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
action_required = Column(Boolean, default=False, index=True)
action_url = Column(String(500)) # Link to relevant admin page
notification_metadata = Column(JSON) # Additional contextual data
# Relationships
read_by = relationship("User", foreign_keys=[read_by_user_id])
def __repr__(self):
return f"<AdminNotification(id={self.id}, type='{self.type}', priority='{self.priority}')>"

View File

@@ -0,0 +1,272 @@
# app/modules/messaging/models/message.py
"""
Messaging system database models.
Supports three communication channels:
- Admin <-> Vendor
- Vendor <-> Customer
- Admin <-> Customer
Multi-tenant isolation is enforced via vendor_id for conversations
involving customers.
"""
import enum
from datetime import datetime
from sqlalchemy import (
Boolean,
Column,
DateTime,
Enum,
ForeignKey,
Index,
Integer,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import relationship
from app.core.database import Base
from models.database.base import TimestampMixin
class ConversationType(str, enum.Enum):
"""Defines the three supported conversation channels."""
ADMIN_VENDOR = "admin_vendor"
VENDOR_CUSTOMER = "vendor_customer"
ADMIN_CUSTOMER = "admin_customer"
class ParticipantType(str, enum.Enum):
"""Type of participant in a conversation."""
ADMIN = "admin" # User with role="admin"
VENDOR = "vendor" # User with role="vendor" (via VendorUser)
CUSTOMER = "customer" # Customer model
def _enum_values(enum_class):
"""Extract enum values for SQLAlchemy Enum column."""
return [e.value for e in enum_class]
class Conversation(Base, TimestampMixin):
"""
Represents a threaded conversation between participants.
Multi-tenancy: vendor_id is required for vendor_customer and admin_customer
conversations to ensure customer data isolation.
"""
__tablename__ = "conversations"
id = Column(Integer, primary_key=True, index=True)
# Conversation type determines participant structure
conversation_type = Column(
Enum(ConversationType, values_callable=_enum_values),
nullable=False,
index=True,
)
# Subject line for the conversation thread
subject = Column(String(500), nullable=False)
# For vendor_customer and admin_customer conversations
# Required for multi-tenant data isolation
vendor_id = Column(
Integer,
ForeignKey("vendors.id"),
nullable=True,
index=True,
)
# Status flags
is_closed = Column(Boolean, default=False, nullable=False)
closed_at = Column(DateTime, nullable=True)
closed_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True)
closed_by_id = Column(Integer, nullable=True)
# Last activity tracking for sorting
last_message_at = Column(DateTime, nullable=True, index=True)
message_count = Column(Integer, default=0, nullable=False)
# Relationships
vendor = relationship("Vendor", foreign_keys=[vendor_id])
participants = relationship(
"ConversationParticipant",
back_populates="conversation",
cascade="all, delete-orphan",
)
messages = relationship(
"Message",
back_populates="conversation",
cascade="all, delete-orphan",
order_by="Message.created_at",
)
# Indexes for common queries
__table_args__ = (
Index("ix_conversations_type_vendor", "conversation_type", "vendor_id"),
)
def __repr__(self) -> str:
return (
f"<Conversation(id={self.id}, type='{self.conversation_type.value}', "
f"subject='{self.subject[:30]}...')>"
)
class ConversationParticipant(Base, TimestampMixin):
"""
Links participants (users or customers) to conversations.
Polymorphic relationship:
- participant_type="admin" or "vendor": references users.id
- participant_type="customer": references customers.id
"""
__tablename__ = "conversation_participants"
id = Column(Integer, primary_key=True, index=True)
conversation_id = Column(
Integer,
ForeignKey("conversations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Polymorphic participant reference
participant_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False)
participant_id = Column(Integer, nullable=False, index=True)
# For vendor participants, track which vendor they represent
vendor_id = Column(
Integer,
ForeignKey("vendors.id"),
nullable=True,
)
# Unread tracking per participant
unread_count = Column(Integer, default=0, nullable=False)
last_read_at = Column(DateTime, nullable=True)
# Notification preferences for this conversation
email_notifications = Column(Boolean, default=True, nullable=False)
muted = Column(Boolean, default=False, nullable=False)
# Relationships
conversation = relationship("Conversation", back_populates="participants")
vendor = relationship("Vendor", foreign_keys=[vendor_id])
__table_args__ = (
UniqueConstraint(
"conversation_id",
"participant_type",
"participant_id",
name="uq_conversation_participant",
),
Index(
"ix_participant_lookup",
"participant_type",
"participant_id",
),
)
def __repr__(self) -> str:
return (
f"<ConversationParticipant(conversation_id={self.conversation_id}, "
f"type='{self.participant_type.value}', id={self.participant_id})>"
)
class Message(Base, TimestampMixin):
"""
Individual message within a conversation thread.
Sender polymorphism follows same pattern as participant.
"""
__tablename__ = "messages"
id = Column(Integer, primary_key=True, index=True)
conversation_id = Column(
Integer,
ForeignKey("conversations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Polymorphic sender reference
sender_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False)
sender_id = Column(Integer, nullable=False, index=True)
# Message content
content = Column(Text, nullable=False)
# System messages (e.g., "conversation closed")
is_system_message = Column(Boolean, default=False, nullable=False)
# Soft delete for moderation
is_deleted = Column(Boolean, default=False, nullable=False)
deleted_at = Column(DateTime, nullable=True)
deleted_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True)
deleted_by_id = Column(Integer, nullable=True)
# Relationships
conversation = relationship("Conversation", back_populates="messages")
attachments = relationship(
"MessageAttachment",
back_populates="message",
cascade="all, delete-orphan",
)
__table_args__ = (
Index("ix_messages_conversation_created", "conversation_id", "created_at"),
)
def __repr__(self) -> str:
return (
f"<Message(id={self.id}, conversation_id={self.conversation_id}, "
f"sender={self.sender_type.value}:{self.sender_id})>"
)
class MessageAttachment(Base, TimestampMixin):
"""
File attachments for messages.
Files are stored in platform storage (local/S3) with references here.
"""
__tablename__ = "message_attachments"
id = Column(Integer, primary_key=True, index=True)
message_id = Column(
Integer,
ForeignKey("messages.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# File metadata
filename = Column(String(255), nullable=False)
original_filename = Column(String(255), nullable=False)
file_path = Column(String(1000), nullable=False) # Storage path
file_size = Column(Integer, nullable=False) # Size in bytes
mime_type = Column(String(100), nullable=False)
# For image attachments
is_image = Column(Boolean, default=False, nullable=False)
image_width = Column(Integer, nullable=True)
image_height = Column(Integer, nullable=True)
thumbnail_path = Column(String(1000), nullable=True)
# Relationships
message = relationship("Message", back_populates="attachments")
def __repr__(self) -> str:
return f"<MessageAttachment(id={self.id}, filename='{self.original_filename}')>"

View File

@@ -2,27 +2,106 @@
"""
Messaging module Pydantic schemas.
Re-exports messaging-related schemas from their source locations.
This module contains the canonical implementations of messaging-related schemas.
"""
from models.schema.message import (
ConversationResponse,
ConversationListResponse,
MessageResponse,
MessageCreate,
from app.modules.messaging.schemas.message import (
# Attachment schemas
AttachmentResponse,
# Message schemas
MessageCreate,
MessageResponse,
# Participant schemas
ParticipantInfo,
ParticipantResponse,
# Conversation schemas
ConversationCreate,
ConversationSummary,
ConversationDetailResponse,
ConversationListResponse,
ConversationResponse,
# Unread count
UnreadCountResponse,
# Notification preferences
NotificationPreferencesUpdate,
# Conversation actions
CloseConversationResponse,
ReopenConversationResponse,
MarkReadResponse,
# Recipient selection
RecipientOption,
RecipientListResponse,
# Admin schemas
AdminConversationSummary,
AdminConversationListResponse,
AdminMessageStats,
)
from models.schema.notification import (
from app.modules.messaging.schemas.notification import (
# Response schemas
MessageResponse as NotificationMessageResponse,
UnreadCountResponse as NotificationUnreadCountResponse,
# Notification schemas
NotificationResponse,
NotificationListResponse,
# Settings schemas
NotificationSettingsResponse,
NotificationSettingsUpdate,
# Template schemas
NotificationTemplateResponse,
NotificationTemplateListResponse,
NotificationTemplateUpdate,
# Test notification
TestNotificationRequest,
# Alert statistics
AlertStatisticsResponse,
)
__all__ = [
"ConversationResponse",
"ConversationListResponse",
"MessageResponse",
"MessageCreate",
# Attachment schemas
"AttachmentResponse",
# Message schemas
"MessageCreate",
"MessageResponse",
# Participant schemas
"ParticipantInfo",
"ParticipantResponse",
# Conversation schemas
"ConversationCreate",
"ConversationSummary",
"ConversationDetailResponse",
"ConversationListResponse",
"ConversationResponse",
# Unread count
"UnreadCountResponse",
# Notification preferences
"NotificationPreferencesUpdate",
# Conversation actions
"CloseConversationResponse",
"ReopenConversationResponse",
"MarkReadResponse",
# Recipient selection
"RecipientOption",
"RecipientListResponse",
# Admin schemas
"AdminConversationSummary",
"AdminConversationListResponse",
"AdminMessageStats",
# Notification response schemas
"NotificationMessageResponse",
"NotificationUnreadCountResponse",
# Notification schemas
"NotificationResponse",
"NotificationListResponse",
# Settings schemas
"NotificationSettingsResponse",
"NotificationSettingsUpdate",
# Template schemas
"NotificationTemplateResponse",
"NotificationTemplateListResponse",
"NotificationTemplateUpdate",
# Test notification
"TestNotificationRequest",
# Alert statistics
"AlertStatisticsResponse",
]

View File

@@ -0,0 +1,312 @@
# app/modules/messaging/schemas/message.py
"""
Pydantic schemas for the messaging system.
Supports three communication channels:
- Admin <-> Vendor
- Vendor <-> Customer
- Admin <-> Customer
"""
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
from app.modules.messaging.models.message import ConversationType, ParticipantType
# ============================================================================
# Attachment Schemas
# ============================================================================
class AttachmentResponse(BaseModel):
"""Schema for message attachment in responses."""
model_config = ConfigDict(from_attributes=True)
id: int
filename: str
original_filename: str
file_size: int
mime_type: str
is_image: bool
image_width: int | None = None
image_height: int | None = None
download_url: str | None = None
thumbnail_url: str | None = None
@property
def file_size_display(self) -> str:
"""Human-readable file size."""
if self.file_size < 1024:
return f"{self.file_size} B"
elif self.file_size < 1024 * 1024:
return f"{self.file_size / 1024:.1f} KB"
else:
return f"{self.file_size / 1024 / 1024:.1f} MB"
# ============================================================================
# Message Schemas
# ============================================================================
class MessageCreate(BaseModel):
"""Schema for sending a new message."""
content: str = Field(..., min_length=1, max_length=10000)
class MessageResponse(BaseModel):
"""Schema for a single message in responses."""
model_config = ConfigDict(from_attributes=True)
id: int
conversation_id: int
sender_type: ParticipantType
sender_id: int
content: str
is_system_message: bool
is_deleted: bool
created_at: datetime
# Enriched sender info (populated by API)
sender_name: str | None = None
sender_email: str | None = None
# Attachments
attachments: list[AttachmentResponse] = []
# ============================================================================
# Participant Schemas
# ============================================================================
class ParticipantInfo(BaseModel):
"""Schema for participant information."""
id: int
type: ParticipantType
name: str
email: str | None = None
avatar_url: str | None = None
class ParticipantResponse(BaseModel):
"""Schema for conversation participant in responses."""
model_config = ConfigDict(from_attributes=True)
id: int
participant_type: ParticipantType
participant_id: int
unread_count: int
last_read_at: datetime | None
email_notifications: bool
muted: bool
# Enriched info (populated by API)
participant_info: ParticipantInfo | None = None
# ============================================================================
# Conversation Schemas
# ============================================================================
class ConversationCreate(BaseModel):
"""Schema for creating a new conversation."""
conversation_type: ConversationType
subject: str = Field(..., min_length=1, max_length=500)
recipient_type: ParticipantType
recipient_id: int
vendor_id: int | None = None
initial_message: str | None = Field(None, min_length=1, max_length=10000)
class ConversationSummary(BaseModel):
"""Schema for conversation in list views."""
model_config = ConfigDict(from_attributes=True)
id: int
conversation_type: ConversationType
subject: str
vendor_id: int | None = None
is_closed: bool
closed_at: datetime | None
last_message_at: datetime | None
message_count: int
created_at: datetime
# Unread count for current user (from participant)
unread_count: int = 0
# Other participant info (enriched by API)
other_participant: ParticipantInfo | None = None
# Last message preview
last_message_preview: str | None = None
class ConversationDetailResponse(BaseModel):
"""Schema for full conversation detail with messages."""
model_config = ConfigDict(from_attributes=True)
id: int
conversation_type: ConversationType
subject: str
vendor_id: int | None = None
is_closed: bool
closed_at: datetime | None
closed_by_type: ParticipantType | None = None
closed_by_id: int | None = None
last_message_at: datetime | None
message_count: int
created_at: datetime
updated_at: datetime
# Participants with enriched info
participants: list[ParticipantResponse] = []
# Messages ordered by created_at
messages: list[MessageResponse] = []
# Current user's unread count
unread_count: int = 0
# Vendor info if applicable
vendor_name: str | None = None
class ConversationListResponse(BaseModel):
"""Schema for paginated conversation list."""
conversations: list[ConversationSummary]
total: int
total_unread: int
skip: int
limit: int
# Backward compatibility alias
ConversationResponse = ConversationDetailResponse
# ============================================================================
# Unread Count Schemas
# ============================================================================
class UnreadCountResponse(BaseModel):
"""Schema for unread message count (for header badge)."""
total_unread: int
# ============================================================================
# Notification Preferences Schemas
# ============================================================================
class NotificationPreferencesUpdate(BaseModel):
"""Schema for updating notification preferences."""
email_notifications: bool | None = None
muted: bool | None = None
# ============================================================================
# Conversation Action Schemas
# ============================================================================
class CloseConversationResponse(BaseModel):
"""Response after closing a conversation."""
success: bool
message: str
conversation_id: int
class ReopenConversationResponse(BaseModel):
"""Response after reopening a conversation."""
success: bool
message: str
conversation_id: int
class MarkReadResponse(BaseModel):
"""Response after marking conversation as read."""
success: bool
conversation_id: int
unread_count: int
# ============================================================================
# Recipient Selection Schemas (for compose modal)
# ============================================================================
class RecipientOption(BaseModel):
"""Schema for a selectable recipient in compose modal."""
id: int
type: ParticipantType
name: str
email: str | None = None
vendor_id: int | None = None # For vendor users
vendor_name: str | None = None
class RecipientListResponse(BaseModel):
"""Schema for list of available recipients."""
recipients: list[RecipientOption]
total: int
# ============================================================================
# Admin-specific Schemas
# ============================================================================
class AdminConversationSummary(ConversationSummary):
"""Extended conversation summary with vendor info for admin views."""
vendor_name: str | None = None
vendor_code: str | None = None
class AdminConversationListResponse(BaseModel):
"""Schema for admin conversation list with vendor info."""
conversations: list[AdminConversationSummary]
total: int
total_unread: int
skip: int
limit: int
class AdminMessageStats(BaseModel):
"""Messaging statistics for admin dashboard."""
total_conversations: int = 0
open_conversations: int = 0
closed_conversations: int = 0
total_messages: int = 0
# By type
admin_vendor_conversations: int = 0
vendor_customer_conversations: int = 0
admin_customer_conversations: int = 0
# Unread
unread_admin: int = 0

View File

@@ -0,0 +1,152 @@
# app/modules/messaging/schemas/notification.py
"""
Notification Pydantic schemas for API validation and responses.
This module provides schemas for:
- Vendor notifications (list, read, delete)
- Notification settings management
- Notification email templates
- Unread counts and statistics
"""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
# ============================================================================
# SHARED RESPONSE SCHEMAS
# ============================================================================
class MessageResponse(BaseModel):
"""Generic message response for simple operations."""
message: str
class UnreadCountResponse(BaseModel):
"""Response for unread notification count."""
unread_count: int
message: str | None = None
# ============================================================================
# NOTIFICATION SCHEMAS
# ============================================================================
class NotificationResponse(BaseModel):
"""Single notification response."""
id: int
type: str
title: str
message: str
is_read: bool
read_at: datetime | None = None
priority: str = "normal"
action_url: str | None = None
metadata: dict[str, Any] | None = None
created_at: datetime
model_config = {"from_attributes": True}
class NotificationListResponse(BaseModel):
"""Paginated list of notifications."""
notifications: list[NotificationResponse] = []
total: int = 0
unread_count: int = 0
message: str | None = None
# ============================================================================
# NOTIFICATION SETTINGS SCHEMAS
# ============================================================================
class NotificationSettingsResponse(BaseModel):
"""Notification preferences response."""
email_notifications: bool = True
in_app_notifications: bool = True
notification_types: dict[str, bool] = Field(default_factory=dict)
message: str | None = None
class NotificationSettingsUpdate(BaseModel):
"""Request model for updating notification settings."""
email_notifications: bool | None = None
in_app_notifications: bool | None = None
notification_types: dict[str, bool] | None = None
# ============================================================================
# NOTIFICATION TEMPLATE SCHEMAS
# ============================================================================
class NotificationTemplateResponse(BaseModel):
"""Single notification template response."""
id: int
name: str
type: str
subject: str
body_html: str | None = None
body_text: str | None = None
variables: list[str] = Field(default_factory=list)
is_active: bool = True
created_at: datetime
updated_at: datetime | None = None
model_config = {"from_attributes": True}
class NotificationTemplateListResponse(BaseModel):
"""List of notification templates."""
templates: list[NotificationTemplateResponse] = []
message: str | None = None
class NotificationTemplateUpdate(BaseModel):
"""Request model for updating notification template."""
subject: str | None = Field(None, max_length=200)
body_html: str | None = None
body_text: str | None = None
is_active: bool | None = None
# ============================================================================
# TEST NOTIFICATION SCHEMA
# ============================================================================
class TestNotificationRequest(BaseModel):
"""Request model for sending test notification."""
template_id: int | None = Field(None, description="Template to use")
email: str | None = Field(None, description="Override recipient email")
notification_type: str = Field(
default="test", description="Type of notification to send"
)
# ============================================================================
# ADMIN ALERT STATISTICS SCHEMA
# ============================================================================
class AlertStatisticsResponse(BaseModel):
"""Response for alert statistics."""
total_alerts: int = 0
active_alerts: int = 0
critical_alerts: int = 0
resolved_today: int = 0

View File

@@ -2,24 +2,29 @@
"""
Messaging module services.
Re-exports messaging-related services from their source locations.
This module contains the canonical implementations of messaging-related services.
"""
from app.services.messaging_service import (
from app.modules.messaging.services.messaging_service import (
messaging_service,
MessagingService,
)
from app.services.message_attachment_service import (
from app.modules.messaging.services.message_attachment_service import (
message_attachment_service,
MessageAttachmentService,
)
from app.services.admin_notification_service import (
from app.modules.messaging.services.admin_notification_service import (
admin_notification_service,
AdminNotificationService,
platform_alert_service,
PlatformAlertService,
# Constants
NotificationType,
Priority,
AlertType,
Severity,
)
# Note: notification_service is a placeholder - not yet implemented
__all__ = [
"messaging_service",
"MessagingService",
@@ -27,4 +32,11 @@ __all__ = [
"MessageAttachmentService",
"admin_notification_service",
"AdminNotificationService",
"platform_alert_service",
"PlatformAlertService",
# Constants
"NotificationType",
"Priority",
"AlertType",
"Severity",
]

View File

@@ -0,0 +1,702 @@
# app/modules/messaging/services/admin_notification_service.py
"""
Admin notification service.
Provides functionality for:
- Creating and managing admin notifications
- Managing platform alerts
- Notification statistics and queries
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from sqlalchemy import and_, case, func
from sqlalchemy.orm import Session
from app.modules.messaging.models.admin_notification import AdminNotification
from models.database.admin import PlatformAlert
from models.schema.admin import AdminNotificationCreate, PlatformAlertCreate
logger = logging.getLogger(__name__)
# ============================================================================
# NOTIFICATION TYPES
# ============================================================================
class NotificationType:
"""Notification type constants."""
SYSTEM_ALERT = "system_alert"
IMPORT_FAILURE = "import_failure"
EXPORT_FAILURE = "export_failure"
ORDER_SYNC_FAILURE = "order_sync_failure"
VENDOR_ISSUE = "vendor_issue"
CUSTOMER_MESSAGE = "customer_message"
VENDOR_MESSAGE = "vendor_message"
SECURITY_ALERT = "security_alert"
PERFORMANCE_ALERT = "performance_alert"
ORDER_EXCEPTION = "order_exception"
CRITICAL_ERROR = "critical_error"
class Priority:
"""Priority level constants."""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
class AlertType:
"""Platform alert type constants."""
SECURITY = "security"
PERFORMANCE = "performance"
CAPACITY = "capacity"
INTEGRATION = "integration"
DATABASE = "database"
SYSTEM = "system"
class Severity:
"""Alert severity constants."""
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
# ============================================================================
# ADMIN NOTIFICATION SERVICE
# ============================================================================
class AdminNotificationService:
"""Service for managing admin notifications."""
def create_notification(
self,
db: Session,
notification_type: str,
title: str,
message: str,
priority: str = Priority.NORMAL,
action_required: bool = False,
action_url: str | None = None,
metadata: dict[str, Any] | None = None,
) -> AdminNotification:
"""
Create a new admin notification.
Args:
db: Database session
notification_type: Type of notification
title: Notification title
message: Notification message
priority: Priority level (low, normal, high, critical)
action_required: Whether action is required
action_url: URL to relevant admin page
metadata: Additional contextual data
Returns:
Created AdminNotification
"""
notification = AdminNotification(
type=notification_type,
title=title,
message=message,
priority=priority,
action_required=action_required,
action_url=action_url,
notification_metadata=metadata,
)
db.add(notification)
db.flush()
logger.info(
f"Created notification: {notification_type} - {title} (priority: {priority})"
)
return notification
def create_from_schema(
self,
db: Session,
data: AdminNotificationCreate,
) -> AdminNotification:
"""Create notification from Pydantic schema."""
return self.create_notification(
db=db,
notification_type=data.type,
title=data.title,
message=data.message,
priority=data.priority,
action_required=data.action_required,
action_url=data.action_url,
metadata=data.metadata,
)
def get_notifications(
self,
db: Session,
priority: str | None = None,
is_read: bool | None = None,
notification_type: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[AdminNotification], int, int]:
"""
Get paginated admin notifications.
Returns:
Tuple of (notifications, total_count, unread_count)
"""
query = db.query(AdminNotification)
# Apply filters
if priority:
query = query.filter(AdminNotification.priority == priority)
if is_read is not None:
query = query.filter(AdminNotification.is_read == is_read)
if notification_type:
query = query.filter(AdminNotification.type == notification_type)
# Get counts
total = query.count()
unread_count = (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.count()
)
# Get paginated results ordered by priority and date
priority_order = case(
(AdminNotification.priority == "critical", 1),
(AdminNotification.priority == "high", 2),
(AdminNotification.priority == "normal", 3),
(AdminNotification.priority == "low", 4),
else_=5,
)
notifications = (
query.order_by(
AdminNotification.is_read, # Unread first
priority_order,
AdminNotification.created_at.desc(),
)
.offset(skip)
.limit(limit)
.all()
)
return notifications, total, unread_count
def get_unread_count(self, db: Session) -> int:
"""Get count of unread notifications."""
return (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.count()
)
def get_recent_notifications(
self,
db: Session,
limit: int = 5,
) -> list[AdminNotification]:
"""Get recent unread notifications for header dropdown."""
priority_order = case(
(AdminNotification.priority == "critical", 1),
(AdminNotification.priority == "high", 2),
(AdminNotification.priority == "normal", 3),
(AdminNotification.priority == "low", 4),
else_=5,
)
return (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.order_by(priority_order, AdminNotification.created_at.desc())
.limit(limit)
.all()
)
def mark_as_read(
self,
db: Session,
notification_id: int,
user_id: int,
) -> AdminNotification | None:
"""Mark a notification as read."""
notification = (
db.query(AdminNotification)
.filter(AdminNotification.id == notification_id)
.first()
)
if notification and not notification.is_read:
notification.is_read = True
notification.read_at = datetime.utcnow()
notification.read_by_user_id = user_id
db.flush()
return notification
def mark_all_as_read(
self,
db: Session,
user_id: int,
) -> int:
"""Mark all unread notifications as read. Returns count of updated."""
now = datetime.utcnow()
count = (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.update(
{
AdminNotification.is_read: True,
AdminNotification.read_at: now,
AdminNotification.read_by_user_id: user_id,
}
)
)
db.flush()
return count
def delete_old_notifications(
self,
db: Session,
days: int = 30,
) -> int:
"""Delete notifications older than specified days."""
cutoff = datetime.utcnow() - timedelta(days=days)
count = (
db.query(AdminNotification)
.filter(
and_(
AdminNotification.is_read == True, # noqa: E712
AdminNotification.created_at < cutoff,
)
)
.delete()
)
db.flush()
return count
def delete_notification(
self,
db: Session,
notification_id: int,
) -> bool:
"""
Delete a notification by ID.
Returns:
True if notification was deleted, False if not found.
"""
notification = (
db.query(AdminNotification)
.filter(AdminNotification.id == notification_id)
.first()
)
if notification:
db.delete(notification)
db.flush()
logger.info(f"Deleted notification {notification_id}")
return True
return False
# =========================================================================
# CONVENIENCE METHODS FOR CREATING SPECIFIC NOTIFICATION TYPES
# =========================================================================
def notify_import_failure(
self,
db: Session,
vendor_name: str,
job_id: int,
error_message: str,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for import job failure."""
return self.create_notification(
db=db,
notification_type=NotificationType.IMPORT_FAILURE,
title=f"Import Failed: {vendor_name}",
message=error_message,
priority=Priority.HIGH,
action_required=True,
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
if vendor_id
else "/admin/marketplace",
metadata={"vendor_name": vendor_name, "job_id": job_id, "vendor_id": vendor_id},
)
def notify_order_sync_failure(
self,
db: Session,
vendor_name: str,
error_message: str,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for order sync failure."""
return self.create_notification(
db=db,
notification_type=NotificationType.ORDER_SYNC_FAILURE,
title=f"Order Sync Failed: {vendor_name}",
message=error_message,
priority=Priority.HIGH,
action_required=True,
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
if vendor_id
else "/admin/marketplace/letzshop",
metadata={"vendor_name": vendor_name, "vendor_id": vendor_id},
)
def notify_order_exception(
self,
db: Session,
vendor_name: str,
order_number: str,
exception_count: int,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for order item exceptions."""
return self.create_notification(
db=db,
notification_type=NotificationType.ORDER_EXCEPTION,
title=f"Order Exception: {order_number}",
message=f"{exception_count} item(s) need attention for order {order_number} ({vendor_name})",
priority=Priority.NORMAL,
action_required=True,
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=exceptions"
if vendor_id
else "/admin/marketplace/letzshop",
metadata={
"vendor_name": vendor_name,
"order_number": order_number,
"exception_count": exception_count,
"vendor_id": vendor_id,
},
)
def notify_critical_error(
self,
db: Session,
error_type: str,
error_message: str,
details: dict[str, Any] | None = None,
) -> AdminNotification:
"""Create notification for critical application errors."""
return self.create_notification(
db=db,
notification_type=NotificationType.CRITICAL_ERROR,
title=f"Critical Error: {error_type}",
message=error_message,
priority=Priority.CRITICAL,
action_required=True,
action_url="/admin/logs",
metadata=details,
)
def notify_vendor_issue(
self,
db: Session,
vendor_name: str,
issue_type: str,
message: str,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for vendor-related issues."""
return self.create_notification(
db=db,
notification_type=NotificationType.VENDOR_ISSUE,
title=f"Vendor Issue: {vendor_name}",
message=message,
priority=Priority.HIGH,
action_required=True,
action_url=f"/admin/vendors/{vendor_id}" if vendor_id else "/admin/vendors",
metadata={
"vendor_name": vendor_name,
"issue_type": issue_type,
"vendor_id": vendor_id,
},
)
def notify_security_alert(
self,
db: Session,
title: str,
message: str,
details: dict[str, Any] | None = None,
) -> AdminNotification:
"""Create notification for security-related alerts."""
return self.create_notification(
db=db,
notification_type=NotificationType.SECURITY_ALERT,
title=title,
message=message,
priority=Priority.CRITICAL,
action_required=True,
action_url="/admin/audit",
metadata=details,
)
# ============================================================================
# PLATFORM ALERT SERVICE
# ============================================================================
class PlatformAlertService:
"""Service for managing platform-wide alerts."""
def create_alert(
self,
db: Session,
alert_type: str,
severity: str,
title: str,
description: str | None = None,
affected_vendors: list[int] | None = None,
affected_systems: list[str] | None = None,
auto_generated: bool = True,
) -> PlatformAlert:
"""Create a new platform alert."""
now = datetime.utcnow()
alert = PlatformAlert(
alert_type=alert_type,
severity=severity,
title=title,
description=description,
affected_vendors=affected_vendors,
affected_systems=affected_systems,
auto_generated=auto_generated,
first_occurred_at=now,
last_occurred_at=now,
)
db.add(alert)
db.flush()
logger.info(f"Created platform alert: {alert_type} - {title} ({severity})")
return alert
def create_from_schema(
self,
db: Session,
data: PlatformAlertCreate,
) -> PlatformAlert:
"""Create alert from Pydantic schema."""
return self.create_alert(
db=db,
alert_type=data.alert_type,
severity=data.severity,
title=data.title,
description=data.description,
affected_vendors=data.affected_vendors,
affected_systems=data.affected_systems,
auto_generated=data.auto_generated,
)
def get_alerts(
self,
db: Session,
severity: str | None = None,
alert_type: str | None = None,
is_resolved: bool | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[PlatformAlert], int, int, int]:
"""
Get paginated platform alerts.
Returns:
Tuple of (alerts, total_count, active_count, critical_count)
"""
query = db.query(PlatformAlert)
# Apply filters
if severity:
query = query.filter(PlatformAlert.severity == severity)
if alert_type:
query = query.filter(PlatformAlert.alert_type == alert_type)
if is_resolved is not None:
query = query.filter(PlatformAlert.is_resolved == is_resolved)
# Get counts
total = query.count()
active_count = (
db.query(PlatformAlert)
.filter(PlatformAlert.is_resolved == False) # noqa: E712
.count()
)
critical_count = (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.is_resolved == False, # noqa: E712
PlatformAlert.severity == Severity.CRITICAL,
)
)
.count()
)
# Get paginated results
severity_order = case(
(PlatformAlert.severity == "critical", 1),
(PlatformAlert.severity == "error", 2),
(PlatformAlert.severity == "warning", 3),
(PlatformAlert.severity == "info", 4),
else_=5,
)
alerts = (
query.order_by(
PlatformAlert.is_resolved, # Unresolved first
severity_order,
PlatformAlert.last_occurred_at.desc(),
)
.offset(skip)
.limit(limit)
.all()
)
return alerts, total, active_count, critical_count
def resolve_alert(
self,
db: Session,
alert_id: int,
user_id: int,
resolution_notes: str | None = None,
) -> PlatformAlert | None:
"""Resolve a platform alert."""
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
if alert and not alert.is_resolved:
alert.is_resolved = True
alert.resolved_at = datetime.utcnow()
alert.resolved_by_user_id = user_id
alert.resolution_notes = resolution_notes
db.flush()
logger.info(f"Resolved platform alert {alert_id}")
return alert
def get_statistics(self, db: Session) -> dict[str, int]:
"""Get alert statistics."""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
total = db.query(PlatformAlert).count()
active = (
db.query(PlatformAlert)
.filter(PlatformAlert.is_resolved == False) # noqa: E712
.count()
)
critical = (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.is_resolved == False, # noqa: E712
PlatformAlert.severity == Severity.CRITICAL,
)
)
.count()
)
resolved_today = (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.is_resolved == True, # noqa: E712
PlatformAlert.resolved_at >= today_start,
)
)
.count()
)
return {
"total_alerts": total,
"active_alerts": active,
"critical_alerts": critical,
"resolved_today": resolved_today,
}
def increment_occurrence(
self,
db: Session,
alert_id: int,
) -> PlatformAlert | None:
"""Increment occurrence count for repeated alert."""
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
if alert:
alert.occurrence_count += 1
alert.last_occurred_at = datetime.utcnow()
db.flush()
return alert
def find_similar_active_alert(
self,
db: Session,
alert_type: str,
title: str,
) -> PlatformAlert | None:
"""Find an active alert with same type and title."""
return (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.alert_type == alert_type,
PlatformAlert.title == title,
PlatformAlert.is_resolved == False, # noqa: E712
)
)
.first()
)
def create_or_increment_alert(
self,
db: Session,
alert_type: str,
severity: str,
title: str,
description: str | None = None,
affected_vendors: list[int] | None = None,
affected_systems: list[str] | None = None,
) -> PlatformAlert:
"""Create alert or increment occurrence if similar exists."""
existing = self.find_similar_active_alert(db, alert_type, title)
if existing:
self.increment_occurrence(db, existing.id)
return existing
return self.create_alert(
db=db,
alert_type=alert_type,
severity=severity,
title=title,
description=description,
affected_vendors=affected_vendors,
affected_systems=affected_systems,
)
# Singleton instances
admin_notification_service = AdminNotificationService()
platform_alert_service = PlatformAlertService()

View File

@@ -0,0 +1,225 @@
# app/modules/messaging/services/message_attachment_service.py
"""
Attachment handling service for messaging system.
Handles file upload, validation, storage, and retrieval.
"""
import logging
import os
import uuid
from datetime import datetime
from pathlib import Path
from fastapi import UploadFile
from sqlalchemy.orm import Session
from app.services.admin_settings_service import admin_settings_service
logger = logging.getLogger(__name__)
# Allowed MIME types for attachments
ALLOWED_MIME_TYPES = {
# Images
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
# Documents
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
# Archives
"application/zip",
# Text
"text/plain",
"text/csv",
}
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
# Default max file size in MB
DEFAULT_MAX_FILE_SIZE_MB = 10
class MessageAttachmentService:
"""Service for handling message attachments."""
def __init__(self, storage_base: str = "uploads/messages"):
self.storage_base = storage_base
def get_max_file_size_bytes(self, db: Session) -> int:
"""Get maximum file size from platform settings."""
max_mb = admin_settings_service.get_setting_value(
db,
"message_attachment_max_size_mb",
default=DEFAULT_MAX_FILE_SIZE_MB,
)
try:
max_mb = int(max_mb)
except (TypeError, ValueError):
max_mb = DEFAULT_MAX_FILE_SIZE_MB
return max_mb * 1024 * 1024 # Convert to bytes
def validate_file_type(self, mime_type: str) -> bool:
"""Check if file type is allowed."""
return mime_type in ALLOWED_MIME_TYPES
def is_image(self, mime_type: str) -> bool:
"""Check if file is an image."""
return mime_type in IMAGE_MIME_TYPES
async def validate_and_store(
self,
db: Session,
file: UploadFile,
conversation_id: int,
) -> dict:
"""
Validate and store an uploaded file.
Returns dict with file metadata for MessageAttachment creation.
Raises:
ValueError: If file type or size is invalid
"""
# Validate MIME type
content_type = file.content_type or "application/octet-stream"
if not self.validate_file_type(content_type):
raise ValueError(
f"File type '{content_type}' not allowed. "
"Allowed types: images (JPEG, PNG, GIF, WebP), "
"PDF, Office documents, ZIP, text files."
)
# Read file content
content = await file.read()
file_size = len(content)
# Validate file size
max_size = self.get_max_file_size_bytes(db)
if file_size > max_size:
raise ValueError(
f"File size {file_size / 1024 / 1024:.1f}MB exceeds "
f"maximum allowed size of {max_size / 1024 / 1024:.1f}MB"
)
# Generate unique filename
original_filename = file.filename or "attachment"
ext = Path(original_filename).suffix.lower()
unique_filename = f"{uuid.uuid4().hex}{ext}"
# Create storage path: uploads/messages/YYYY/MM/conversation_id/filename
now = datetime.utcnow()
relative_path = os.path.join(
self.storage_base,
str(now.year),
f"{now.month:02d}",
str(conversation_id),
)
# Ensure directory exists
os.makedirs(relative_path, exist_ok=True)
# Full file path
file_path = os.path.join(relative_path, unique_filename)
# Write file
with open(file_path, "wb") as f:
f.write(content)
# Prepare metadata
is_image = self.is_image(content_type)
metadata = {
"filename": unique_filename,
"original_filename": original_filename,
"file_path": file_path,
"file_size": file_size,
"mime_type": content_type,
"is_image": is_image,
}
# Generate thumbnail for images
if is_image:
thumbnail_data = self._create_thumbnail(content, file_path)
metadata.update(thumbnail_data)
logger.info(
f"Stored attachment {unique_filename} for conversation {conversation_id} "
f"({file_size} bytes, type: {content_type})"
)
return metadata
def _create_thumbnail(self, content: bytes, original_path: str) -> dict:
"""Create thumbnail for image attachments."""
try:
from PIL import Image
import io
img = Image.open(io.BytesIO(content))
width, height = img.size
# Create thumbnail
img.thumbnail((200, 200))
thumb_path = original_path.replace(".", "_thumb.")
img.save(thumb_path)
return {
"image_width": width,
"image_height": height,
"thumbnail_path": thumb_path,
}
except ImportError:
logger.warning("PIL not installed, skipping thumbnail generation")
return {}
except Exception as e:
logger.error(f"Failed to create thumbnail: {e}")
return {}
def delete_attachment(
self, file_path: str, thumbnail_path: str | None = None
) -> bool:
"""Delete attachment files from storage."""
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Deleted attachment file: {file_path}")
if thumbnail_path and os.path.exists(thumbnail_path):
os.remove(thumbnail_path)
logger.info(f"Deleted thumbnail: {thumbnail_path}")
return True
except Exception as e:
logger.error(f"Failed to delete attachment {file_path}: {e}")
return False
def get_download_url(self, file_path: str) -> str:
"""
Get download URL for an attachment.
For local storage, returns a relative path that can be served
by the static file handler or a dedicated download endpoint.
"""
# Convert local path to URL path
# Assumes files are served from /static/uploads or similar
return f"/static/{file_path}"
def get_file_content(self, file_path: str) -> bytes | None:
"""Read file content from storage."""
try:
if os.path.exists(file_path):
with open(file_path, "rb") as f:
return f.read()
return None
except Exception as e:
logger.error(f"Failed to read file {file_path}: {e}")
return None
# Singleton instance
message_attachment_service = MessageAttachmentService()

View File

@@ -0,0 +1,684 @@
# app/modules/messaging/services/messaging_service.py
"""
Messaging service for conversation and message management.
Provides functionality for:
- Creating conversations between different participant types
- Sending messages with attachments
- Managing read status and unread counts
- Conversation listing with filters
- Multi-tenant data isolation
"""
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload
from app.modules.messaging.models.message import (
Conversation,
ConversationParticipant,
ConversationType,
Message,
MessageAttachment,
ParticipantType,
)
from models.database.customer import Customer
from models.database.user import User
logger = logging.getLogger(__name__)
class MessagingService:
"""Service for managing conversations and messages."""
# =========================================================================
# CONVERSATION MANAGEMENT
# =========================================================================
def create_conversation(
self,
db: Session,
conversation_type: ConversationType,
subject: str,
initiator_type: ParticipantType,
initiator_id: int,
recipient_type: ParticipantType,
recipient_id: int,
vendor_id: int | None = None,
initial_message: str | None = None,
) -> Conversation:
"""
Create a new conversation between two participants.
Args:
db: Database session
conversation_type: Type of conversation channel
subject: Conversation subject line
initiator_type: Type of initiating participant
initiator_id: ID of initiating participant
recipient_type: Type of receiving participant
recipient_id: ID of receiving participant
vendor_id: Required for vendor_customer/admin_customer types
initial_message: Optional first message content
Returns:
Created Conversation object
"""
# Validate vendor_id requirement
if conversation_type in [
ConversationType.VENDOR_CUSTOMER,
ConversationType.ADMIN_CUSTOMER,
]:
if not vendor_id:
raise ValueError(
f"vendor_id required for {conversation_type.value} conversations"
)
# Create conversation
conversation = Conversation(
conversation_type=conversation_type,
subject=subject,
vendor_id=vendor_id,
)
db.add(conversation)
db.flush()
# Add participants
initiator_vendor_id = (
vendor_id if initiator_type == ParticipantType.VENDOR else None
)
recipient_vendor_id = (
vendor_id if recipient_type == ParticipantType.VENDOR else None
)
initiator = ConversationParticipant(
conversation_id=conversation.id,
participant_type=initiator_type,
participant_id=initiator_id,
vendor_id=initiator_vendor_id,
unread_count=0, # Initiator has read their own message
)
recipient = ConversationParticipant(
conversation_id=conversation.id,
participant_type=recipient_type,
participant_id=recipient_id,
vendor_id=recipient_vendor_id,
unread_count=1 if initial_message else 0,
)
db.add(initiator)
db.add(recipient)
db.flush()
# Add initial message if provided
if initial_message:
self.send_message(
db=db,
conversation_id=conversation.id,
sender_type=initiator_type,
sender_id=initiator_id,
content=initial_message,
_skip_unread_update=True, # Already set above
)
logger.info(
f"Created {conversation_type.value} conversation {conversation.id}: "
f"{initiator_type.value}:{initiator_id} -> {recipient_type.value}:{recipient_id}"
)
return conversation
def get_conversation(
self,
db: Session,
conversation_id: int,
participant_type: ParticipantType,
participant_id: int,
) -> Conversation | None:
"""
Get conversation if participant has access.
Validates that the requester is a participant.
"""
conversation = (
db.query(Conversation)
.options(
joinedload(Conversation.participants),
joinedload(Conversation.messages).joinedload(Message.attachments),
)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
# Verify participant access
has_access = any(
p.participant_type == participant_type
and p.participant_id == participant_id
for p in conversation.participants
)
if not has_access:
logger.warning(
f"Access denied to conversation {conversation_id} for "
f"{participant_type.value}:{participant_id}"
)
return None
return conversation
def list_conversations(
self,
db: Session,
participant_type: ParticipantType,
participant_id: int,
vendor_id: int | None = None,
conversation_type: ConversationType | None = None,
is_closed: bool | None = None,
skip: int = 0,
limit: int = 20,
) -> tuple[list[Conversation], int, int]:
"""
List conversations for a participant with filters.
Returns:
Tuple of (conversations, total_count, total_unread)
"""
# Base query: conversations where user is a participant
query = (
db.query(Conversation)
.join(ConversationParticipant)
.filter(
and_(
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
)
# Multi-tenant filter for vendor users
if participant_type == ParticipantType.VENDOR and vendor_id:
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
# Customer vendor isolation
if participant_type == ParticipantType.CUSTOMER and vendor_id:
query = query.filter(Conversation.vendor_id == vendor_id)
# Type filter
if conversation_type:
query = query.filter(Conversation.conversation_type == conversation_type)
# Status filter
if is_closed is not None:
query = query.filter(Conversation.is_closed == is_closed)
# Get total count
total = query.count()
# Get total unread across all conversations
unread_query = db.query(
func.sum(ConversationParticipant.unread_count)
).filter(
and_(
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
if participant_type == ParticipantType.VENDOR and vendor_id:
unread_query = unread_query.filter(
ConversationParticipant.vendor_id == vendor_id
)
total_unread = unread_query.scalar() or 0
# Get paginated results, ordered by last activity
conversations = (
query.options(joinedload(Conversation.participants))
.order_by(Conversation.last_message_at.desc().nullslast())
.offset(skip)
.limit(limit)
.all()
)
return conversations, total, total_unread
def close_conversation(
self,
db: Session,
conversation_id: int,
closer_type: ParticipantType,
closer_id: int,
) -> Conversation | None:
"""Close a conversation thread."""
conversation = self.get_conversation(
db, conversation_id, closer_type, closer_id
)
if not conversation:
return None
conversation.is_closed = True
conversation.closed_at = datetime.now(UTC)
conversation.closed_by_type = closer_type
conversation.closed_by_id = closer_id
# Add system message
self.send_message(
db=db,
conversation_id=conversation_id,
sender_type=closer_type,
sender_id=closer_id,
content="Conversation closed",
is_system_message=True,
)
db.flush()
return conversation
def reopen_conversation(
self,
db: Session,
conversation_id: int,
opener_type: ParticipantType,
opener_id: int,
) -> Conversation | None:
"""Reopen a closed conversation."""
conversation = self.get_conversation(
db, conversation_id, opener_type, opener_id
)
if not conversation:
return None
conversation.is_closed = False
conversation.closed_at = None
conversation.closed_by_type = None
conversation.closed_by_id = None
# Add system message
self.send_message(
db=db,
conversation_id=conversation_id,
sender_type=opener_type,
sender_id=opener_id,
content="Conversation reopened",
is_system_message=True,
)
db.flush()
return conversation
# =========================================================================
# MESSAGE MANAGEMENT
# =========================================================================
def send_message(
self,
db: Session,
conversation_id: int,
sender_type: ParticipantType,
sender_id: int,
content: str,
attachments: list[dict[str, Any]] | None = None,
is_system_message: bool = False,
_skip_unread_update: bool = False,
) -> Message:
"""
Send a message in a conversation.
Args:
db: Database session
conversation_id: Target conversation ID
sender_type: Type of sender
sender_id: ID of sender
content: Message text content
attachments: List of attachment dicts with file metadata
is_system_message: Whether this is a system-generated message
_skip_unread_update: Internal flag to skip unread increment
Returns:
Created Message object
"""
# Create message
message = Message(
conversation_id=conversation_id,
sender_type=sender_type,
sender_id=sender_id,
content=content,
is_system_message=is_system_message,
)
db.add(message)
db.flush()
# Add attachments if any
if attachments:
for att_data in attachments:
attachment = MessageAttachment(
message_id=message.id,
filename=att_data["filename"],
original_filename=att_data["original_filename"],
file_path=att_data["file_path"],
file_size=att_data["file_size"],
mime_type=att_data["mime_type"],
is_image=att_data.get("is_image", False),
image_width=att_data.get("image_width"),
image_height=att_data.get("image_height"),
thumbnail_path=att_data.get("thumbnail_path"),
)
db.add(attachment)
# Update conversation metadata
conversation = (
db.query(Conversation).filter(Conversation.id == conversation_id).first()
)
if conversation:
conversation.last_message_at = datetime.now(UTC)
conversation.message_count += 1
# Update unread counts for other participants
if not _skip_unread_update:
db.query(ConversationParticipant).filter(
and_(
ConversationParticipant.conversation_id == conversation_id,
or_(
ConversationParticipant.participant_type != sender_type,
ConversationParticipant.participant_id != sender_id,
),
)
).update(
{
ConversationParticipant.unread_count: ConversationParticipant.unread_count
+ 1
}
)
db.flush()
logger.info(
f"Message {message.id} sent in conversation {conversation_id} "
f"by {sender_type.value}:{sender_id}"
)
return message
def delete_message(
self,
db: Session,
message_id: int,
deleter_type: ParticipantType,
deleter_id: int,
) -> Message | None:
"""Soft delete a message (for moderation)."""
message = db.query(Message).filter(Message.id == message_id).first()
if not message:
return None
# Verify deleter has access to conversation
conversation = self.get_conversation(
db, message.conversation_id, deleter_type, deleter_id
)
if not conversation:
return None
message.is_deleted = True
message.deleted_at = datetime.now(UTC)
message.deleted_by_type = deleter_type
message.deleted_by_id = deleter_id
db.flush()
return message
def mark_conversation_read(
self,
db: Session,
conversation_id: int,
reader_type: ParticipantType,
reader_id: int,
) -> bool:
"""Mark all messages in conversation as read for participant."""
result = (
db.query(ConversationParticipant)
.filter(
and_(
ConversationParticipant.conversation_id == conversation_id,
ConversationParticipant.participant_type == reader_type,
ConversationParticipant.participant_id == reader_id,
)
)
.update(
{
ConversationParticipant.unread_count: 0,
ConversationParticipant.last_read_at: datetime.now(UTC),
}
)
)
db.flush()
return result > 0
def get_unread_count(
self,
db: Session,
participant_type: ParticipantType,
participant_id: int,
vendor_id: int | None = None,
) -> int:
"""Get total unread message count for a participant."""
query = db.query(func.sum(ConversationParticipant.unread_count)).filter(
and_(
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
if vendor_id:
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
return query.scalar() or 0
# =========================================================================
# PARTICIPANT HELPERS
# =========================================================================
def get_participant_info(
self,
db: Session,
participant_type: ParticipantType,
participant_id: int,
) -> dict[str, Any] | None:
"""Get display info for a participant (name, email, avatar)."""
if participant_type in [ParticipantType.ADMIN, ParticipantType.VENDOR]:
user = db.query(User).filter(User.id == participant_id).first()
if user:
return {
"id": user.id,
"type": participant_type.value,
"name": f"{user.first_name or ''} {user.last_name or ''}".strip()
or user.username,
"email": user.email,
"avatar_url": None, # Could add avatar support later
}
elif participant_type == ParticipantType.CUSTOMER:
customer = db.query(Customer).filter(Customer.id == participant_id).first()
if customer:
return {
"id": customer.id,
"type": participant_type.value,
"name": f"{customer.first_name or ''} {customer.last_name or ''}".strip()
or customer.email,
"email": customer.email,
"avatar_url": None,
}
return None
def get_other_participant(
self,
conversation: Conversation,
my_type: ParticipantType,
my_id: int,
) -> ConversationParticipant | None:
"""Get the other participant in a conversation."""
for p in conversation.participants:
if p.participant_type != my_type or p.participant_id != my_id:
return p
return None
# =========================================================================
# NOTIFICATION PREFERENCES
# =========================================================================
def update_notification_preferences(
self,
db: Session,
conversation_id: int,
participant_type: ParticipantType,
participant_id: int,
email_notifications: bool | None = None,
muted: bool | None = None,
) -> bool:
"""Update notification preferences for a participant in a conversation."""
updates = {}
if email_notifications is not None:
updates[ConversationParticipant.email_notifications] = email_notifications
if muted is not None:
updates[ConversationParticipant.muted] = muted
if not updates:
return False
result = (
db.query(ConversationParticipant)
.filter(
and_(
ConversationParticipant.conversation_id == conversation_id,
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
.update(updates)
)
db.flush()
return result > 0
# =========================================================================
# RECIPIENT QUERIES
# =========================================================================
def get_vendor_recipients(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[dict], int]:
"""
Get list of vendor users as potential recipients.
Args:
db: Database session
vendor_id: Optional vendor ID filter
search: Search term for name/email
skip: Pagination offset
limit: Max results
Returns:
Tuple of (recipients list, total count)
"""
from models.database.vendor import VendorUser
query = (
db.query(User, VendorUser)
.join(VendorUser, User.id == VendorUser.user_id)
.filter(User.is_active == True) # noqa: E712
)
if vendor_id:
query = query.filter(VendorUser.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(User.username.ilike(search_pattern))
| (User.email.ilike(search_pattern))
| (User.first_name.ilike(search_pattern))
| (User.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
recipients = []
for user, vendor_user in results:
name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username
recipients.append({
"id": user.id,
"type": ParticipantType.VENDOR,
"name": name,
"email": user.email,
"vendor_id": vendor_user.vendor_id,
"vendor_name": vendor_user.vendor.name if vendor_user.vendor else None,
})
return recipients, total
def get_customer_recipients(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[dict], int]:
"""
Get list of customers as potential recipients.
Args:
db: Database session
vendor_id: Optional vendor ID filter (required for vendor users)
search: Search term for name/email
skip: Pagination offset
limit: Max results
Returns:
Tuple of (recipients list, total count)
"""
query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Customer.email.ilike(search_pattern))
| (Customer.first_name.ilike(search_pattern))
| (Customer.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
recipients = []
for customer in results:
name = f"{customer.first_name or ''} {customer.last_name or ''}".strip()
recipients.append({
"id": customer.id,
"type": ParticipantType.CUSTOMER,
"name": name or customer.email,
"email": customer.email,
"vendor_id": customer.vendor_id,
})
return recipients, total
# Singleton instance
messaging_service = MessagingService()

View File

@@ -2,10 +2,10 @@
"""
Monitoring module services.
Re-exports monitoring-related services from their source locations.
This module contains the canonical implementations of monitoring-related services.
"""
from app.services.background_tasks_service import (
from app.modules.monitoring.services.background_tasks_service import (
background_tasks_service,
BackgroundTasksService,
)

View File

@@ -0,0 +1,194 @@
# app/modules/monitoring/services/background_tasks_service.py
"""
Background Tasks Service
Service for monitoring background tasks across the system
"""
from datetime import UTC, datetime
from sqlalchemy import case, desc, func
from sqlalchemy.orm import Session
from models.database.architecture_scan import ArchitectureScan
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.test_run import TestRun
class BackgroundTasksService:
"""Service for monitoring background tasks"""
def get_import_jobs(
self, db: Session, status: str | None = None, limit: int = 50
) -> list[MarketplaceImportJob]:
"""Get import jobs with optional status filter"""
query = db.query(MarketplaceImportJob)
if status:
query = query.filter(MarketplaceImportJob.status == status)
return query.order_by(desc(MarketplaceImportJob.created_at)).limit(limit).all()
def get_test_runs(
self, db: Session, status: str | None = None, limit: int = 50
) -> list[TestRun]:
"""Get test runs with optional status filter"""
query = db.query(TestRun)
if status:
query = query.filter(TestRun.status == status)
return query.order_by(desc(TestRun.timestamp)).limit(limit).all()
def get_running_imports(self, db: Session) -> list[MarketplaceImportJob]:
"""Get currently running import jobs"""
return (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.status == "processing")
.all()
)
def get_running_test_runs(self, db: Session) -> list[TestRun]:
"""Get currently running test runs"""
# noqa: SVC-005 - Platform-level, TestRuns not vendor-scoped
return db.query(TestRun).filter(TestRun.status == "running").all()
def get_import_stats(self, db: Session) -> dict:
"""Get import job statistics"""
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(MarketplaceImportJob.id).label("total"),
func.sum(
case((MarketplaceImportJob.status == "processing", 1), else_=0)
).label("running"),
func.sum(
case(
(
MarketplaceImportJob.status.in_(
["completed", "completed_with_errors"]
),
1,
),
else_=0,
)
).label("completed"),
func.sum(
case((MarketplaceImportJob.status == "failed", 1), else_=0)
).label("failed"),
).first()
today_count = (
db.query(func.count(MarketplaceImportJob.id))
.filter(MarketplaceImportJob.created_at >= today_start)
.scalar()
or 0
)
return {
"total": stats.total or 0,
"running": stats.running or 0,
"completed": stats.completed or 0,
"failed": stats.failed or 0,
"today": today_count,
}
def get_test_run_stats(self, db: Session) -> dict:
"""Get test run statistics"""
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(TestRun.id).label("total"),
func.sum(case((TestRun.status == "running", 1), else_=0)).label(
"running"
),
func.sum(case((TestRun.status == "passed", 1), else_=0)).label(
"completed"
),
func.sum(
case((TestRun.status.in_(["failed", "error"]), 1), else_=0)
).label("failed"),
func.avg(TestRun.duration_seconds).label("avg_duration"),
).first()
today_count = (
db.query(func.count(TestRun.id))
.filter(TestRun.timestamp >= today_start)
.scalar()
or 0
)
return {
"total": stats.total or 0,
"running": stats.running or 0,
"completed": stats.completed or 0,
"failed": stats.failed or 0,
"today": today_count,
"avg_duration": round(stats.avg_duration or 0, 1),
}
def get_code_quality_scans(
self, db: Session, status: str | None = None, limit: int = 50
) -> list[ArchitectureScan]:
"""Get code quality scans with optional status filter"""
query = db.query(ArchitectureScan)
if status:
query = query.filter(ArchitectureScan.status == status)
return query.order_by(desc(ArchitectureScan.timestamp)).limit(limit).all()
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
"""Get currently running code quality scans"""
return (
db.query(ArchitectureScan)
.filter(ArchitectureScan.status.in_(["pending", "running"]))
.all()
)
def get_scan_stats(self, db: Session) -> dict:
"""Get code quality scan statistics"""
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(ArchitectureScan.id).label("total"),
func.sum(
case(
(ArchitectureScan.status.in_(["pending", "running"]), 1), else_=0
)
).label("running"),
func.sum(
case(
(
ArchitectureScan.status.in_(
["completed", "completed_with_warnings"]
),
1,
),
else_=0,
)
).label("completed"),
func.sum(
case((ArchitectureScan.status == "failed", 1), else_=0)
).label("failed"),
func.avg(ArchitectureScan.duration_seconds).label("avg_duration"),
).first()
today_count = (
db.query(func.count(ArchitectureScan.id))
.filter(ArchitectureScan.timestamp >= today_start)
.scalar()
or 0
)
return {
"total": stats.total or 0,
"running": stats.running or 0,
"completed": stats.completed or 0,
"failed": stats.failed or 0,
"today": today_count,
"avg_duration": round(stats.avg_duration or 0, 1),
}
# Singleton instance
background_tasks_service = BackgroundTasksService()

View File

@@ -2,15 +2,12 @@
"""
Orders module database models.
Re-exports order-related models from their source locations.
This module contains the canonical implementations of order-related models.
"""
from models.database.order import (
Order,
OrderItem,
)
from models.database.order_item_exception import OrderItemException
from models.database.invoice import (
from app.modules.orders.models.order import Order, OrderItem
from app.modules.orders.models.order_item_exception import OrderItemException
from app.modules.orders.models.invoice import (
Invoice,
InvoiceStatus,
VATRegime,

View File

@@ -0,0 +1,215 @@
# app/modules/orders/models/invoice.py
"""
Invoice database models for the OMS.
Provides models for:
- VendorInvoiceSettings: Per-vendor invoice configuration (company details, VAT, numbering)
- Invoice: Invoice records with snapshots of seller/buyer details
"""
import enum
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
)
from sqlalchemy.dialects.sqlite import JSON
from sqlalchemy.orm import relationship
from app.core.database import Base
from models.database.base import TimestampMixin
class VendorInvoiceSettings(Base, TimestampMixin):
"""
Per-vendor invoice configuration.
Stores company details, VAT number, invoice numbering preferences,
and payment information for invoice generation.
One-to-one relationship with Vendor.
"""
__tablename__ = "vendor_invoice_settings"
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(
Integer, ForeignKey("vendors.id"), unique=True, nullable=False, index=True
)
# Legal company details for invoice header
company_name = Column(String(255), nullable=False) # Legal name for invoices
company_address = Column(String(255), nullable=True) # Street address
company_city = Column(String(100), nullable=True)
company_postal_code = Column(String(20), nullable=True)
company_country = Column(String(2), nullable=False, default="LU") # ISO country code
# VAT information
vat_number = Column(String(50), nullable=True) # e.g., "LU12345678"
is_vat_registered = Column(Boolean, default=True, nullable=False)
# OSS (One-Stop-Shop) for EU VAT
is_oss_registered = Column(Boolean, default=False, nullable=False)
oss_registration_country = Column(String(2), nullable=True) # ISO country code
# Invoice numbering
invoice_prefix = Column(String(20), default="INV", nullable=False)
invoice_next_number = Column(Integer, default=1, nullable=False)
invoice_number_padding = Column(Integer, default=5, nullable=False) # e.g., INV00001
# Payment information
payment_terms = Column(Text, nullable=True) # e.g., "Payment due within 30 days"
bank_name = Column(String(255), nullable=True)
bank_iban = Column(String(50), nullable=True)
bank_bic = Column(String(20), nullable=True)
# Invoice footer
footer_text = Column(Text, nullable=True) # Custom footer text
# Default VAT rate for Luxembourg invoices (17% standard)
default_vat_rate = Column(Numeric(5, 2), default=17.00, nullable=False)
# Relationships
vendor = relationship("Vendor", back_populates="invoice_settings")
def __repr__(self):
return f"<VendorInvoiceSettings(vendor_id={self.vendor_id}, company='{self.company_name}')>"
def get_next_invoice_number(self) -> str:
"""Generate the next invoice number and increment counter."""
number = str(self.invoice_next_number).zfill(self.invoice_number_padding)
return f"{self.invoice_prefix}{number}"
class InvoiceStatus(str, enum.Enum):
"""Invoice status enumeration."""
DRAFT = "draft"
ISSUED = "issued"
PAID = "paid"
CANCELLED = "cancelled"
class VATRegime(str, enum.Enum):
"""VAT regime for invoice calculation."""
DOMESTIC = "domestic" # Same country as seller
OSS = "oss" # EU cross-border with OSS registration
REVERSE_CHARGE = "reverse_charge" # B2B with valid VAT number
ORIGIN = "origin" # Cross-border without OSS (use origin VAT)
EXEMPT = "exempt" # VAT exempt
class Invoice(Base, TimestampMixin):
"""
Invoice record with snapshots of seller/buyer details.
Stores complete invoice data including snapshots of seller and buyer
details at time of creation for audit purposes.
"""
__tablename__ = "invoices"
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True)
# Invoice identification
invoice_number = Column(String(50), nullable=False)
invoice_date = Column(DateTime(timezone=True), nullable=False)
# Status
status = Column(String(20), default=InvoiceStatus.DRAFT.value, nullable=False)
# Seller details snapshot (captured at invoice creation)
seller_details = Column(JSON, nullable=False)
# Structure: {
# "company_name": str,
# "address": str,
# "city": str,
# "postal_code": str,
# "country": str,
# "vat_number": str | None
# }
# Buyer details snapshot (captured at invoice creation)
buyer_details = Column(JSON, nullable=False)
# Structure: {
# "name": str,
# "email": str,
# "address": str,
# "city": str,
# "postal_code": str,
# "country": str,
# "vat_number": str | None (for B2B)
# }
# Line items snapshot
line_items = Column(JSON, nullable=False)
# Structure: [{
# "description": str,
# "quantity": int,
# "unit_price_cents": int,
# "total_cents": int,
# "sku": str | None,
# "ean": str | None
# }]
# VAT information
vat_regime = Column(String(20), default=VATRegime.DOMESTIC.value, nullable=False)
destination_country = Column(String(2), nullable=True) # For OSS invoices
vat_rate = Column(Numeric(5, 2), nullable=False) # e.g., 17.00 for 17%
vat_rate_label = Column(String(50), nullable=True) # e.g., "Luxembourg Standard VAT"
# Amounts (stored in cents for precision)
currency = Column(String(3), default="EUR", nullable=False)
subtotal_cents = Column(Integer, nullable=False) # Before VAT
vat_amount_cents = Column(Integer, nullable=False) # VAT amount
total_cents = Column(Integer, nullable=False) # After VAT
# Payment information
payment_terms = Column(Text, nullable=True)
bank_details = Column(JSON, nullable=True) # IBAN, BIC snapshot
footer_text = Column(Text, nullable=True)
# PDF storage
pdf_generated_at = Column(DateTime(timezone=True), nullable=True)
pdf_path = Column(String(500), nullable=True) # Path to stored PDF
# Notes
notes = Column(Text, nullable=True) # Internal notes
# Relationships
vendor = relationship("Vendor", back_populates="invoices")
order = relationship("Order", back_populates="invoices")
__table_args__ = (
Index("idx_invoice_vendor_number", "vendor_id", "invoice_number", unique=True),
Index("idx_invoice_vendor_date", "vendor_id", "invoice_date"),
Index("idx_invoice_status", "vendor_id", "status"),
)
def __repr__(self):
return f"<Invoice(id={self.id}, number='{self.invoice_number}', status='{self.status}')>"
@property
def subtotal(self) -> float:
"""Get subtotal in EUR."""
return self.subtotal_cents / 100
@property
def vat_amount(self) -> float:
"""Get VAT amount in EUR."""
return self.vat_amount_cents / 100
@property
def total(self) -> float:
"""Get total in EUR."""
return self.total_cents / 100

View File

@@ -0,0 +1,406 @@
# app/modules/orders/models/order.py
"""
Unified Order model for all sales channels.
Supports:
- Direct orders (from vendor's own storefront)
- Marketplace orders (Letzshop, etc.)
Design principles:
- Customer/address data is snapshotted at order time (preserves history)
- customer_id FK links to Customer record (may be inactive for marketplace imports)
- channel field distinguishes order source
- external_* fields store marketplace-specific references
Money values are stored as integer cents (e.g., €105.91 = 10591).
See docs/architecture/money-handling.md for details.
"""
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.modules.orders.models.order_item_exception import OrderItemException
from sqlalchemy.dialects.sqlite import JSON
from sqlalchemy.orm import relationship
from app.core.database import Base
from app.utils.money import cents_to_euros, euros_to_cents
from models.database.base import TimestampMixin
class Order(Base, TimestampMixin):
"""
Unified order model for all sales channels.
Stores orders from direct sales and marketplaces (Letzshop, etc.)
with snapshotted customer and address data.
All monetary amounts are stored as integer cents for precision.
"""
__tablename__ = "orders"
id = Column(Integer, primary_key=True, index=True)
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
customer_id = Column(
Integer, ForeignKey("customers.id"), nullable=False, index=True
)
order_number = Column(String(100), nullable=False, unique=True, index=True)
# === Channel/Source ===
channel = Column(
String(50), default="direct", nullable=False, index=True
) # direct, letzshop
# External references (for marketplace orders)
external_order_id = Column(
String(100), nullable=True, index=True
) # Marketplace order ID
external_shipment_id = Column(
String(100), nullable=True, index=True
) # Marketplace shipment ID
external_order_number = Column(String(100), nullable=True) # Marketplace order #
external_data = Column(JSON, nullable=True) # Raw marketplace data for debugging
# === Status ===
# pending: awaiting confirmation
# processing: confirmed, being prepared
# shipped: shipped with tracking
# delivered: delivered to customer
# cancelled: order cancelled/declined
# refunded: order refunded
status = Column(String(50), nullable=False, default="pending", index=True)
# === Financials (stored as integer cents) ===
subtotal_cents = Column(Integer, nullable=True) # May not be available from marketplace
tax_amount_cents = Column(Integer, nullable=True)
shipping_amount_cents = Column(Integer, nullable=True)
discount_amount_cents = Column(Integer, nullable=True)
total_amount_cents = Column(Integer, nullable=False)
currency = Column(String(10), default="EUR")
# === VAT Information ===
# VAT regime: domestic, oss, reverse_charge, origin, exempt
vat_regime = Column(String(20), nullable=True)
# VAT rate as percentage (e.g., 17.00 for 17%)
vat_rate = Column(Numeric(5, 2), nullable=True)
# Human-readable VAT label (e.g., "Luxembourg VAT 17%")
vat_rate_label = Column(String(100), nullable=True)
# Destination country for cross-border sales (ISO code)
vat_destination_country = Column(String(2), nullable=True)
# === Customer Snapshot (preserved at order time) ===
customer_first_name = Column(String(100), nullable=False)
customer_last_name = Column(String(100), nullable=False)
customer_email = Column(String(255), nullable=False)
customer_phone = Column(String(50), nullable=True)
customer_locale = Column(String(10), nullable=True) # en, fr, de, lb
# === Shipping Address Snapshot ===
ship_first_name = Column(String(100), nullable=False)
ship_last_name = Column(String(100), nullable=False)
ship_company = Column(String(200), nullable=True)
ship_address_line_1 = Column(String(255), nullable=False)
ship_address_line_2 = Column(String(255), nullable=True)
ship_city = Column(String(100), nullable=False)
ship_postal_code = Column(String(20), nullable=False)
ship_country_iso = Column(String(5), nullable=False)
# === Billing Address Snapshot ===
bill_first_name = Column(String(100), nullable=False)
bill_last_name = Column(String(100), nullable=False)
bill_company = Column(String(200), nullable=True)
bill_address_line_1 = Column(String(255), nullable=False)
bill_address_line_2 = Column(String(255), nullable=True)
bill_city = Column(String(100), nullable=False)
bill_postal_code = Column(String(20), nullable=False)
bill_country_iso = Column(String(5), nullable=False)
# === Tracking ===
shipping_method = Column(String(100), nullable=True)
tracking_number = Column(String(100), nullable=True)
tracking_provider = Column(String(100), nullable=True)
tracking_url = Column(String(500), nullable=True) # Full tracking URL
shipment_number = Column(String(100), nullable=True) # Carrier shipment number (e.g., H74683403433)
shipping_carrier = Column(String(50), nullable=True) # Carrier code (greco, colissimo, etc.)
# === Notes ===
customer_notes = Column(Text, nullable=True)
internal_notes = Column(Text, nullable=True)
# === Timestamps ===
order_date = Column(
DateTime(timezone=True), nullable=False
) # When customer placed order
confirmed_at = Column(DateTime(timezone=True), nullable=True)
shipped_at = Column(DateTime(timezone=True), nullable=True)
delivered_at = Column(DateTime(timezone=True), nullable=True)
cancelled_at = Column(DateTime(timezone=True), nullable=True)
# === Relationships ===
vendor = relationship("Vendor")
customer = relationship("Customer", back_populates="orders")
items = relationship(
"OrderItem", back_populates="order", cascade="all, delete-orphan"
)
invoices = relationship(
"Invoice", back_populates="order", cascade="all, delete-orphan"
)
# Composite indexes for common queries
__table_args__ = (
Index("idx_order_vendor_status", "vendor_id", "status"),
Index("idx_order_vendor_channel", "vendor_id", "channel"),
Index("idx_order_vendor_date", "vendor_id", "order_date"),
)
def __repr__(self):
return f"<Order(id={self.id}, order_number='{self.order_number}', channel='{self.channel}', status='{self.status}')>"
# === PRICE PROPERTIES (Euro convenience accessors) ===
@property
def subtotal(self) -> float | None:
"""Get subtotal in euros."""
if self.subtotal_cents is not None:
return cents_to_euros(self.subtotal_cents)
return None
@subtotal.setter
def subtotal(self, value: float | None):
"""Set subtotal from euros."""
self.subtotal_cents = euros_to_cents(value) if value is not None else None
@property
def tax_amount(self) -> float | None:
"""Get tax amount in euros."""
if self.tax_amount_cents is not None:
return cents_to_euros(self.tax_amount_cents)
return None
@tax_amount.setter
def tax_amount(self, value: float | None):
"""Set tax amount from euros."""
self.tax_amount_cents = euros_to_cents(value) if value is not None else None
@property
def shipping_amount(self) -> float | None:
"""Get shipping amount in euros."""
if self.shipping_amount_cents is not None:
return cents_to_euros(self.shipping_amount_cents)
return None
@shipping_amount.setter
def shipping_amount(self, value: float | None):
"""Set shipping amount from euros."""
self.shipping_amount_cents = euros_to_cents(value) if value is not None else None
@property
def discount_amount(self) -> float | None:
"""Get discount amount in euros."""
if self.discount_amount_cents is not None:
return cents_to_euros(self.discount_amount_cents)
return None
@discount_amount.setter
def discount_amount(self, value: float | None):
"""Set discount amount from euros."""
self.discount_amount_cents = euros_to_cents(value) if value is not None else None
@property
def total_amount(self) -> float:
"""Get total amount in euros."""
return cents_to_euros(self.total_amount_cents)
@total_amount.setter
def total_amount(self, value: float):
"""Set total amount from euros."""
self.total_amount_cents = euros_to_cents(value)
# === NAME PROPERTIES ===
@property
def customer_full_name(self) -> str:
"""Customer full name from snapshot."""
return f"{self.customer_first_name} {self.customer_last_name}".strip()
@property
def ship_full_name(self) -> str:
"""Shipping address full name."""
return f"{self.ship_first_name} {self.ship_last_name}".strip()
@property
def bill_full_name(self) -> str:
"""Billing address full name."""
return f"{self.bill_first_name} {self.bill_last_name}".strip()
@property
def is_marketplace_order(self) -> bool:
"""Check if this is a marketplace order."""
return self.channel != "direct"
@property
def is_fully_shipped(self) -> bool:
"""Check if all items are fully shipped."""
if not self.items:
return False
return all(item.is_fully_shipped for item in self.items)
@property
def is_partially_shipped(self) -> bool:
"""Check if some items are shipped but not all."""
if not self.items:
return False
has_shipped = any(item.shipped_quantity > 0 for item in self.items)
all_shipped = all(item.is_fully_shipped for item in self.items)
return has_shipped and not all_shipped
@property
def shipped_item_count(self) -> int:
"""Count of fully shipped items."""
return sum(1 for item in self.items if item.is_fully_shipped)
@property
def total_shipped_units(self) -> int:
"""Total quantity shipped across all items."""
return sum(item.shipped_quantity for item in self.items)
@property
def total_ordered_units(self) -> int:
"""Total quantity ordered across all items."""
return sum(item.quantity for item in self.items)
class OrderItem(Base, TimestampMixin):
"""
Individual items in an order.
Stores product snapshot at time of order plus external references
for marketplace items.
All monetary amounts are stored as integer cents for precision.
"""
__tablename__ = "order_items"
id = Column(Integer, primary_key=True, index=True)
order_id = Column(Integer, ForeignKey("orders.id"), nullable=False, index=True)
product_id = Column(Integer, ForeignKey("products.id"), nullable=False)
# === Product Snapshot (preserved at order time) ===
product_name = Column(String(255), nullable=False)
product_sku = Column(String(100), nullable=True)
gtin = Column(String(50), nullable=True) # EAN/UPC/ISBN etc.
gtin_type = Column(String(20), nullable=True) # ean13, upc, isbn, etc.
# === Pricing (stored as integer cents) ===
quantity = Column(Integer, nullable=False)
unit_price_cents = Column(Integer, nullable=False)
total_price_cents = Column(Integer, nullable=False)
# === External References (for marketplace items) ===
external_item_id = Column(String(100), nullable=True) # e.g., Letzshop inventory unit ID
external_variant_id = Column(String(100), nullable=True) # e.g., Letzshop variant ID
# === Item State (for marketplace confirmation flow) ===
# confirmed_available: item confirmed and available
# confirmed_unavailable: item confirmed but not available (declined)
item_state = Column(String(50), nullable=True)
# === Inventory Tracking ===
inventory_reserved = Column(Boolean, default=False)
inventory_fulfilled = Column(Boolean, default=False)
# === Shipment Tracking ===
shipped_quantity = Column(Integer, default=0, nullable=False) # Units shipped so far
# === Exception Tracking ===
# True if product was not found by GTIN during import (linked to placeholder)
needs_product_match = Column(Boolean, default=False, index=True)
# === Relationships ===
order = relationship("Order", back_populates="items")
product = relationship("Product")
exception = relationship(
"OrderItemException",
back_populates="order_item",
uselist=False,
cascade="all, delete-orphan",
)
def __repr__(self):
return f"<OrderItem(id={self.id}, order_id={self.order_id}, product_id={self.product_id}, gtin='{self.gtin}')>"
# === PRICE PROPERTIES (Euro convenience accessors) ===
@property
def unit_price(self) -> float:
"""Get unit price in euros."""
return cents_to_euros(self.unit_price_cents)
@unit_price.setter
def unit_price(self, value: float):
"""Set unit price from euros."""
self.unit_price_cents = euros_to_cents(value)
@property
def total_price(self) -> float:
"""Get total price in euros."""
return cents_to_euros(self.total_price_cents)
@total_price.setter
def total_price(self, value: float):
"""Set total price from euros."""
self.total_price_cents = euros_to_cents(value)
# === STATUS PROPERTIES ===
@property
def is_confirmed(self) -> bool:
"""Check if item has been confirmed (available or unavailable)."""
return self.item_state in ("confirmed_available", "confirmed_unavailable")
@property
def is_available(self) -> bool:
"""Check if item is confirmed as available."""
return self.item_state == "confirmed_available"
@property
def is_declined(self) -> bool:
"""Check if item was declined (unavailable)."""
return self.item_state == "confirmed_unavailable"
@property
def has_unresolved_exception(self) -> bool:
"""Check if item has an unresolved exception blocking confirmation."""
if not self.exception:
return False
return self.exception.blocks_confirmation
# === SHIPMENT PROPERTIES ===
@property
def remaining_quantity(self) -> int:
"""Quantity not yet shipped."""
return max(0, self.quantity - self.shipped_quantity)
@property
def is_fully_shipped(self) -> bool:
"""Check if all units have been shipped."""
return self.shipped_quantity >= self.quantity
@property
def is_partially_shipped(self) -> bool:
"""Check if some but not all units have been shipped."""
return 0 < self.shipped_quantity < self.quantity

View File

@@ -0,0 +1,117 @@
# app/modules/orders/models/order_item_exception.py
"""
Order Item Exception model for tracking unmatched products during marketplace imports.
When a marketplace order contains a GTIN that doesn't match any product in the
vendor's catalog, the order is still imported but the item is linked to a
placeholder product and an exception is recorded here for resolution.
"""
from sqlalchemy import (
Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
)
from sqlalchemy.orm import relationship
from app.core.database import Base
from models.database.base import TimestampMixin
class OrderItemException(Base, TimestampMixin):
"""
Tracks unmatched order items requiring admin/vendor resolution.
When a marketplace order is imported and a product cannot be found by GTIN,
the order item is linked to a placeholder product and this exception record
is created. The order cannot be confirmed until all exceptions are resolved.
"""
__tablename__ = "order_item_exceptions"
id = Column(Integer, primary_key=True, index=True)
# Link to the order item (one-to-one)
order_item_id = Column(
Integer,
ForeignKey("order_items.id", ondelete="CASCADE"),
nullable=False,
unique=True,
)
# Vendor ID for efficient querying (denormalized from order)
vendor_id = Column(
Integer, ForeignKey("vendors.id"), nullable=False, index=True
)
# Original data from marketplace (preserved for matching)
original_gtin = Column(String(50), nullable=True, index=True)
original_product_name = Column(String(500), nullable=True)
original_sku = Column(String(100), nullable=True)
# Exception classification
# product_not_found: GTIN not in vendor catalog
# gtin_mismatch: GTIN format issue
# duplicate_gtin: Multiple products with same GTIN
exception_type = Column(
String(50), nullable=False, default="product_not_found"
)
# Resolution status
# pending: Awaiting resolution
# resolved: Product has been assigned
# ignored: Marked as ignored (still blocks confirmation)
status = Column(String(50), nullable=False, default="pending", index=True)
# Resolution details (populated when resolved)
resolved_product_id = Column(
Integer, ForeignKey("products.id"), nullable=True
)
resolved_at = Column(DateTime(timezone=True), nullable=True)
resolved_by = Column(Integer, ForeignKey("users.id"), nullable=True)
resolution_notes = Column(Text, nullable=True)
# Relationships
order_item = relationship("OrderItem", back_populates="exception")
vendor = relationship("Vendor")
resolved_product = relationship("Product")
resolver = relationship("User")
# Composite indexes for common queries
__table_args__ = (
Index("idx_exception_vendor_status", "vendor_id", "status"),
Index("idx_exception_gtin", "vendor_id", "original_gtin"),
)
def __repr__(self):
return (
f"<OrderItemException(id={self.id}, "
f"order_item_id={self.order_item_id}, "
f"gtin='{self.original_gtin}', "
f"status='{self.status}')>"
)
@property
def is_pending(self) -> bool:
"""Check if exception is pending resolution."""
return self.status == "pending"
@property
def is_resolved(self) -> bool:
"""Check if exception has been resolved."""
return self.status == "resolved"
@property
def is_ignored(self) -> bool:
"""Check if exception has been ignored."""
return self.status == "ignored"
@property
def blocks_confirmation(self) -> bool:
"""Check if this exception blocks order confirmation."""
# Both pending and ignored exceptions block confirmation
return self.status in ("pending", "ignored")

View File

@@ -2,34 +2,133 @@
"""
Orders module Pydantic schemas.
Re-exports order-related schemas from their source locations.
This module contains the canonical implementations of order-related schemas.
"""
from models.schema.order import (
OrderCreate,
OrderItemCreate,
OrderResponse,
OrderItemResponse,
OrderListResponse,
from app.modules.orders.schemas.order import (
# Address schemas
AddressSnapshot,
AddressSnapshotResponse,
# Order item schemas
OrderItemCreate,
OrderItemExceptionBrief,
OrderItemResponse,
# Customer schemas
CustomerSnapshot,
CustomerSnapshotResponse,
# Order CRUD schemas
OrderCreate,
OrderUpdate,
OrderTrackingUpdate,
OrderItemStateUpdate,
# Order response schemas
OrderResponse,
OrderDetailResponse,
OrderListResponse,
OrderListItem,
# Admin schemas
AdminOrderItem,
AdminOrderListResponse,
AdminOrderStats,
AdminOrderStatusUpdate,
AdminVendorWithOrders,
AdminVendorsWithOrdersResponse,
# Letzshop schemas
LetzshopOrderImport,
LetzshopShippingInfo,
LetzshopOrderConfirmItem,
LetzshopOrderConfirmRequest,
# Shipping schemas
MarkAsShippedRequest,
ShippingLabelInfo,
)
from models.schema.invoice import (
from app.modules.orders.schemas.invoice import (
# Invoice settings schemas
VendorInvoiceSettingsCreate,
VendorInvoiceSettingsUpdate,
VendorInvoiceSettingsResponse,
# Line item schemas
InvoiceLineItem,
InvoiceLineItemResponse,
# Address schemas
InvoiceSellerDetails,
InvoiceBuyerDetails,
# Invoice CRUD schemas
InvoiceCreate,
InvoiceManualCreate,
InvoiceResponse,
InvoiceListResponse,
InvoiceStatusUpdate,
# Pagination
InvoiceListPaginatedResponse,
# PDF
InvoicePDFGeneratedResponse,
InvoiceStatsResponse,
# Backward compatibility
InvoiceSettingsCreate,
InvoiceSettingsUpdate,
InvoiceSettingsResponse,
)
__all__ = [
"OrderCreate",
"OrderItemCreate",
"OrderResponse",
"OrderItemResponse",
"OrderListResponse",
# Address schemas
"AddressSnapshot",
"AddressSnapshotResponse",
# Order item schemas
"OrderItemCreate",
"OrderItemExceptionBrief",
"OrderItemResponse",
# Customer schemas
"CustomerSnapshot",
"CustomerSnapshotResponse",
# Order CRUD schemas
"OrderCreate",
"OrderUpdate",
"OrderTrackingUpdate",
"OrderItemStateUpdate",
# Order response schemas
"OrderResponse",
"OrderDetailResponse",
"OrderListResponse",
"OrderListItem",
# Admin schemas
"AdminOrderItem",
"AdminOrderListResponse",
"AdminOrderStats",
"AdminOrderStatusUpdate",
"AdminVendorWithOrders",
"AdminVendorsWithOrdersResponse",
# Letzshop schemas
"LetzshopOrderImport",
"LetzshopShippingInfo",
"LetzshopOrderConfirmItem",
"LetzshopOrderConfirmRequest",
# Shipping schemas
"MarkAsShippedRequest",
"ShippingLabelInfo",
# Invoice settings schemas
"VendorInvoiceSettingsCreate",
"VendorInvoiceSettingsUpdate",
"VendorInvoiceSettingsResponse",
# Line item schemas
"InvoiceLineItem",
"InvoiceLineItemResponse",
# Invoice address schemas
"InvoiceSellerDetails",
"InvoiceBuyerDetails",
# Invoice CRUD schemas
"InvoiceCreate",
"InvoiceManualCreate",
"InvoiceResponse",
"InvoiceListResponse",
"InvoiceStatusUpdate",
# Pagination
"InvoiceListPaginatedResponse",
# PDF
"InvoicePDFGeneratedResponse",
"InvoiceStatsResponse",
# Backward compatibility
"InvoiceSettingsCreate",
"InvoiceSettingsUpdate",
"InvoiceSettingsResponse",

View File

@@ -0,0 +1,316 @@
# app/modules/orders/schemas/invoice.py
"""
Pydantic schemas for invoice operations.
Supports invoice settings management and invoice generation.
"""
from datetime import datetime
from decimal import Decimal
from pydantic import BaseModel, ConfigDict, Field
# ============================================================================
# Invoice Settings Schemas
# ============================================================================
class VendorInvoiceSettingsCreate(BaseModel):
"""Schema for creating vendor invoice settings."""
company_name: str = Field(..., min_length=1, max_length=255)
company_address: str | None = Field(None, max_length=255)
company_city: str | None = Field(None, max_length=100)
company_postal_code: str | None = Field(None, max_length=20)
company_country: str = Field(default="LU", min_length=2, max_length=2)
vat_number: str | None = Field(None, max_length=50)
is_vat_registered: bool = True
is_oss_registered: bool = False
oss_registration_country: str | None = Field(None, min_length=2, max_length=2)
invoice_prefix: str = Field(default="INV", max_length=20)
invoice_number_padding: int = Field(default=5, ge=1, le=10)
payment_terms: str | None = None
bank_name: str | None = Field(None, max_length=255)
bank_iban: str | None = Field(None, max_length=50)
bank_bic: str | None = Field(None, max_length=20)
footer_text: str | None = None
default_vat_rate: Decimal = Field(default=Decimal("17.00"), ge=0, le=100)
class VendorInvoiceSettingsUpdate(BaseModel):
"""Schema for updating vendor invoice settings."""
company_name: str | None = Field(None, min_length=1, max_length=255)
company_address: str | None = Field(None, max_length=255)
company_city: str | None = Field(None, max_length=100)
company_postal_code: str | None = Field(None, max_length=20)
company_country: str | None = Field(None, min_length=2, max_length=2)
vat_number: str | None = None
is_vat_registered: bool | None = None
is_oss_registered: bool | None = None
oss_registration_country: str | None = None
invoice_prefix: str | None = Field(None, max_length=20)
invoice_number_padding: int | None = Field(None, ge=1, le=10)
payment_terms: str | None = None
bank_name: str | None = Field(None, max_length=255)
bank_iban: str | None = Field(None, max_length=50)
bank_bic: str | None = Field(None, max_length=20)
footer_text: str | None = None
default_vat_rate: Decimal | None = Field(None, ge=0, le=100)
class VendorInvoiceSettingsResponse(BaseModel):
"""Schema for vendor invoice settings response."""
model_config = ConfigDict(from_attributes=True)
id: int
vendor_id: int
company_name: str
company_address: str | None
company_city: str | None
company_postal_code: str | None
company_country: str
vat_number: str | None
is_vat_registered: bool
is_oss_registered: bool
oss_registration_country: str | None
invoice_prefix: str
invoice_next_number: int
invoice_number_padding: int
payment_terms: str | None
bank_name: str | None
bank_iban: str | None
bank_bic: str | None
footer_text: str | None
default_vat_rate: Decimal
created_at: datetime
updated_at: datetime
# ============================================================================
# Invoice Line Item Schemas
# ============================================================================
class InvoiceLineItem(BaseModel):
"""Schema for invoice line item."""
description: str
quantity: int = Field(..., ge=1)
unit_price_cents: int
total_cents: int
sku: str | None = None
ean: str | None = None
class InvoiceLineItemResponse(BaseModel):
"""Schema for invoice line item in response."""
description: str
quantity: int
unit_price_cents: int
total_cents: int
sku: str | None = None
ean: str | None = None
@property
def unit_price(self) -> float:
return self.unit_price_cents / 100
@property
def total(self) -> float:
return self.total_cents / 100
# ============================================================================
# Invoice Address Schemas
# ============================================================================
class InvoiceSellerDetails(BaseModel):
"""Seller details for invoice."""
company_name: str
address: str | None = None
city: str | None = None
postal_code: str | None = None
country: str
vat_number: str | None = None
class InvoiceBuyerDetails(BaseModel):
"""Buyer details for invoice."""
name: str
email: str | None = None
address: str | None = None
city: str | None = None
postal_code: str | None = None
country: str
vat_number: str | None = None # For B2B
# ============================================================================
# Invoice Schemas
# ============================================================================
class InvoiceCreate(BaseModel):
"""Schema for creating an invoice from an order."""
order_id: int
notes: str | None = None
class InvoiceManualCreate(BaseModel):
"""Schema for creating a manual invoice (without order)."""
buyer_details: InvoiceBuyerDetails
line_items: list[InvoiceLineItem]
notes: str | None = None
payment_terms: str | None = None
class InvoiceResponse(BaseModel):
"""Schema for invoice response."""
model_config = ConfigDict(from_attributes=True)
id: int
vendor_id: int
order_id: int | None
invoice_number: str
invoice_date: datetime
status: str
seller_details: dict
buyer_details: dict
line_items: list[dict]
vat_regime: str
destination_country: str | None
vat_rate: Decimal
vat_rate_label: str | None
currency: str
subtotal_cents: int
vat_amount_cents: int
total_cents: int
payment_terms: str | None
bank_details: dict | None
footer_text: str | None
pdf_generated_at: datetime | None
pdf_path: str | None
notes: str | None
created_at: datetime
updated_at: datetime
@property
def subtotal(self) -> float:
return self.subtotal_cents / 100
@property
def vat_amount(self) -> float:
return self.vat_amount_cents / 100
@property
def total(self) -> float:
return self.total_cents / 100
class InvoiceListResponse(BaseModel):
"""Schema for invoice list response (summary)."""
model_config = ConfigDict(from_attributes=True)
id: int
invoice_number: str
invoice_date: datetime
status: str
currency: str
total_cents: int
order_id: int | None
# Buyer name for display
buyer_name: str | None = None
@property
def total(self) -> float:
return self.total_cents / 100
class InvoiceStatusUpdate(BaseModel):
"""Schema for updating invoice status."""
status: str = Field(..., pattern="^(draft|issued|paid|cancelled)$")
# ============================================================================
# Paginated Response
# ============================================================================
class InvoiceListPaginatedResponse(BaseModel):
"""Paginated invoice list response."""
items: list[InvoiceListResponse]
total: int
page: int
per_page: int
pages: int
# ============================================================================
# PDF Response
# ============================================================================
class InvoicePDFGeneratedResponse(BaseModel):
"""Response for PDF generation."""
pdf_path: str
message: str = "PDF generated successfully"
class InvoiceStatsResponse(BaseModel):
"""Invoice statistics response."""
total_invoices: int
total_revenue_cents: int
draft_count: int
issued_count: int
paid_count: int
cancelled_count: int
@property
def total_revenue(self) -> float:
return self.total_revenue_cents / 100
# Backward compatibility re-exports
InvoiceSettingsCreate = VendorInvoiceSettingsCreate
InvoiceSettingsUpdate = VendorInvoiceSettingsUpdate
InvoiceSettingsResponse = VendorInvoiceSettingsResponse

View File

@@ -0,0 +1,584 @@
# app/modules/orders/schemas/order.py
"""
Pydantic schemas for unified order operations.
Supports both direct orders and marketplace orders (Letzshop, etc.)
with snapshotted customer and address data.
"""
from datetime import datetime
from pydantic import BaseModel, ConfigDict, Field
# ============================================================================
# Address Snapshot Schemas
# ============================================================================
class AddressSnapshot(BaseModel):
"""Address snapshot for order creation."""
first_name: str = Field(..., min_length=1, max_length=100)
last_name: str = Field(..., min_length=1, max_length=100)
company: str | None = Field(None, max_length=200)
address_line_1: str = Field(..., min_length=1, max_length=255)
address_line_2: str | None = Field(None, max_length=255)
city: str = Field(..., min_length=1, max_length=100)
postal_code: str = Field(..., min_length=1, max_length=20)
country_iso: str = Field(..., min_length=2, max_length=5)
class AddressSnapshotResponse(BaseModel):
"""Address snapshot in order response."""
first_name: str
last_name: str
company: str | None
address_line_1: str
address_line_2: str | None
city: str
postal_code: str
country_iso: str
@property
def full_name(self) -> str:
return f"{self.first_name} {self.last_name}".strip()
# ============================================================================
# Order Item Schemas
# ============================================================================
class OrderItemCreate(BaseModel):
"""Schema for creating an order item."""
product_id: int
quantity: int = Field(..., ge=1)
class OrderItemExceptionBrief(BaseModel):
"""Brief exception info for embedding in order item responses."""
model_config = ConfigDict(from_attributes=True)
id: int
original_gtin: str | None
original_product_name: str | None
exception_type: str
status: str
resolved_product_id: int | None
class OrderItemResponse(BaseModel):
"""Schema for order item response."""
model_config = ConfigDict(from_attributes=True)
id: int
order_id: int
product_id: int
product_name: str
product_sku: str | None
gtin: str | None
gtin_type: str | None
quantity: int
unit_price: float
total_price: float
# External references (for marketplace items)
external_item_id: str | None = None
external_variant_id: str | None = None
# Item state (for marketplace confirmation flow)
item_state: str | None = None
# Inventory tracking
inventory_reserved: bool
inventory_fulfilled: bool
# Exception tracking
needs_product_match: bool = False
exception: OrderItemExceptionBrief | None = None
created_at: datetime
updated_at: datetime
@property
def is_confirmed(self) -> bool:
"""Check if item has been confirmed (available or unavailable)."""
return self.item_state in ("confirmed_available", "confirmed_unavailable")
@property
def is_available(self) -> bool:
"""Check if item is confirmed as available."""
return self.item_state == "confirmed_available"
@property
def is_declined(self) -> bool:
"""Check if item was declined (unavailable)."""
return self.item_state == "confirmed_unavailable"
@property
def has_unresolved_exception(self) -> bool:
"""Check if item has an unresolved exception blocking confirmation."""
if not self.exception:
return False
return self.exception.status in ("pending", "ignored")
# ============================================================================
# Customer Snapshot Schemas
# ============================================================================
class CustomerSnapshot(BaseModel):
"""Customer snapshot for order creation."""
first_name: str = Field(..., min_length=1, max_length=100)
last_name: str = Field(..., min_length=1, max_length=100)
email: str = Field(..., max_length=255)
phone: str | None = Field(None, max_length=50)
locale: str | None = Field(None, max_length=10)
class CustomerSnapshotResponse(BaseModel):
"""Customer snapshot in order response."""
first_name: str
last_name: str
email: str
phone: str | None
locale: str | None
@property
def full_name(self) -> str:
return f"{self.first_name} {self.last_name}".strip()
# ============================================================================
# Order Create/Update Schemas
# ============================================================================
class OrderCreate(BaseModel):
"""Schema for creating an order (direct channel)."""
customer_id: int | None = None # Optional for guest checkout
items: list[OrderItemCreate] = Field(..., min_length=1)
# Customer info snapshot
customer: CustomerSnapshot
# Addresses (snapshots)
shipping_address: AddressSnapshot
billing_address: AddressSnapshot | None = None # Use shipping if not provided
# Optional fields
shipping_method: str | None = None
customer_notes: str | None = Field(None, max_length=1000)
# Cart/session info
session_id: str | None = None
class OrderUpdate(BaseModel):
"""Schema for updating order status."""
status: str | None = Field(
None, pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$"
)
tracking_number: str | None = None
tracking_provider: str | None = None
internal_notes: str | None = None
class OrderTrackingUpdate(BaseModel):
"""Schema for setting tracking information."""
tracking_number: str = Field(..., min_length=1, max_length=100)
tracking_provider: str = Field(..., min_length=1, max_length=100)
class OrderItemStateUpdate(BaseModel):
"""Schema for updating item state (marketplace confirmation)."""
item_id: int
state: str = Field(..., pattern="^(confirmed_available|confirmed_unavailable)$")
# ============================================================================
# Order Response Schemas
# ============================================================================
class OrderResponse(BaseModel):
"""Schema for order response."""
model_config = ConfigDict(from_attributes=True)
id: int
vendor_id: int
customer_id: int
order_number: str
# Channel/Source
channel: str
external_order_id: str | None = None
external_shipment_id: str | None = None
external_order_number: str | None = None
# Status
status: str
# Financial
subtotal: float | None
tax_amount: float | None
shipping_amount: float | None
discount_amount: float | None
total_amount: float
currency: str
# VAT information
vat_regime: str | None = None
vat_rate: float | None = None
vat_rate_label: str | None = None
vat_destination_country: str | None = None
# Customer snapshot
customer_first_name: str
customer_last_name: str
customer_email: str
customer_phone: str | None
customer_locale: str | None
# Shipping address snapshot
ship_first_name: str
ship_last_name: str
ship_company: str | None
ship_address_line_1: str
ship_address_line_2: str | None
ship_city: str
ship_postal_code: str
ship_country_iso: str
# Billing address snapshot
bill_first_name: str
bill_last_name: str
bill_company: str | None
bill_address_line_1: str
bill_address_line_2: str | None
bill_city: str
bill_postal_code: str
bill_country_iso: str
# Tracking
shipping_method: str | None
tracking_number: str | None
tracking_provider: str | None
tracking_url: str | None = None
shipment_number: str | None = None
shipping_carrier: str | None = None
# Notes
customer_notes: str | None
internal_notes: str | None
# Timestamps
order_date: datetime
confirmed_at: datetime | None
shipped_at: datetime | None
delivered_at: datetime | None
cancelled_at: datetime | None
created_at: datetime
updated_at: datetime
@property
def customer_full_name(self) -> str:
return f"{self.customer_first_name} {self.customer_last_name}".strip()
@property
def ship_full_name(self) -> str:
return f"{self.ship_first_name} {self.ship_last_name}".strip()
@property
def is_marketplace_order(self) -> bool:
return self.channel != "direct"
class OrderDetailResponse(OrderResponse):
"""Schema for detailed order response with items."""
items: list[OrderItemResponse] = []
# Vendor info (enriched by API)
vendor_name: str | None = None
vendor_code: str | None = None
class OrderListResponse(BaseModel):
"""Schema for paginated order list."""
orders: list[OrderResponse]
total: int
skip: int
limit: int
# ============================================================================
# Order List Item (Simplified for list views)
# ============================================================================
class OrderListItem(BaseModel):
"""Simplified order item for list views."""
model_config = ConfigDict(from_attributes=True)
id: int
vendor_id: int
order_number: str
channel: str
status: str
# External references
external_order_number: str | None = None
# Customer
customer_full_name: str
customer_email: str
# Financial
total_amount: float
currency: str
# Shipping
ship_country_iso: str
# Tracking
tracking_number: str | None
tracking_provider: str | None
tracking_url: str | None = None
shipment_number: str | None = None
shipping_carrier: str | None = None
# Item count
item_count: int = 0
# Timestamps
order_date: datetime
confirmed_at: datetime | None
shipped_at: datetime | None
# ============================================================================
# Admin Order Schemas
# ============================================================================
class AdminOrderItem(BaseModel):
"""Order item with vendor info for admin list view."""
model_config = ConfigDict(from_attributes=True)
id: int
vendor_id: int
vendor_name: str | None = None
vendor_code: str | None = None
customer_id: int
order_number: str
channel: str
status: str
# External references
external_order_number: str | None = None
external_shipment_id: str | None = None
# Customer snapshot
customer_full_name: str
customer_email: str
# Financial
subtotal: float | None
tax_amount: float | None
shipping_amount: float | None
discount_amount: float | None
total_amount: float
currency: str
# VAT information
vat_regime: str | None = None
vat_rate: float | None = None
vat_rate_label: str | None = None
vat_destination_country: str | None = None
# Shipping
ship_country_iso: str
tracking_number: str | None
tracking_provider: str | None
tracking_url: str | None = None
shipment_number: str | None = None
shipping_carrier: str | None = None
# Item count
item_count: int = 0
# Timestamps
order_date: datetime
confirmed_at: datetime | None
shipped_at: datetime | None
delivered_at: datetime | None
cancelled_at: datetime | None
created_at: datetime
updated_at: datetime
class AdminOrderListResponse(BaseModel):
"""Cross-vendor order list for admin."""
orders: list[AdminOrderItem]
total: int
skip: int
limit: int
class AdminOrderStats(BaseModel):
"""Order statistics for admin dashboard."""
total_orders: int = 0
pending_orders: int = 0
processing_orders: int = 0
shipped_orders: int = 0
delivered_orders: int = 0
cancelled_orders: int = 0
refunded_orders: int = 0
total_revenue: float = 0.0
# By channel
direct_orders: int = 0
letzshop_orders: int = 0
# Vendors
vendors_with_orders: int = 0
class AdminOrderStatusUpdate(BaseModel):
"""Admin version of status update with reason."""
status: str = Field(
..., pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$"
)
tracking_number: str | None = None
tracking_provider: str | None = None
reason: str | None = Field(None, description="Reason for status change")
class AdminVendorWithOrders(BaseModel):
"""Vendor with order count."""
id: int
name: str
vendor_code: str
order_count: int = 0
class AdminVendorsWithOrdersResponse(BaseModel):
"""Response for vendors with orders list."""
vendors: list[AdminVendorWithOrders]
# ============================================================================
# Letzshop-specific Schemas
# ============================================================================
class LetzshopOrderImport(BaseModel):
"""Schema for importing a Letzshop order from shipment data."""
shipment_id: str
order_id: str
order_number: str
order_date: datetime
# Customer
customer_email: str
customer_locale: str | None = None
# Shipping address
ship_first_name: str
ship_last_name: str
ship_company: str | None = None
ship_address_line_1: str
ship_address_line_2: str | None = None
ship_city: str
ship_postal_code: str
ship_country_iso: str
# Billing address
bill_first_name: str
bill_last_name: str
bill_company: str | None = None
bill_address_line_1: str
bill_address_line_2: str | None = None
bill_city: str
bill_postal_code: str
bill_country_iso: str
# Totals
total_amount: float
currency: str = "EUR"
# State
letzshop_state: str # unconfirmed, confirmed, declined
# Items
inventory_units: list[dict]
# Raw data
raw_data: dict | None = None
class LetzshopShippingInfo(BaseModel):
"""Shipping info retrieved from Letzshop."""
tracking_number: str
tracking_provider: str
shipment_id: str
class LetzshopOrderConfirmItem(BaseModel):
"""Schema for confirming/declining a single item."""
item_id: int
external_item_id: str
action: str = Field(..., pattern="^(confirm|decline)$")
class LetzshopOrderConfirmRequest(BaseModel):
"""Schema for confirming/declining order items."""
items: list[LetzshopOrderConfirmItem]
# ============================================================================
# Mark as Shipped Schemas
# ============================================================================
class MarkAsShippedRequest(BaseModel):
"""Schema for marking an order as shipped with tracking info."""
tracking_number: str | None = Field(None, max_length=100)
tracking_url: str | None = Field(None, max_length=500)
shipping_carrier: str | None = Field(None, max_length=50)
class ShippingLabelInfo(BaseModel):
"""Shipping label information for an order."""
shipment_number: str | None = None
shipping_carrier: str | None = None
label_url: str | None = None
tracking_number: str | None = None
tracking_url: str | None = None

View File

@@ -2,26 +2,26 @@
"""
Orders module services.
Re-exports order-related services from their source locations.
This module contains the canonical implementations of order-related services.
"""
from app.services.order_service import (
from app.modules.orders.services.order_service import (
order_service,
OrderService,
)
from app.services.order_inventory_service import (
from app.modules.orders.services.order_inventory_service import (
order_inventory_service,
OrderInventoryService,
)
from app.services.order_item_exception_service import (
from app.modules.orders.services.order_item_exception_service import (
order_item_exception_service,
OrderItemExceptionService,
)
from app.services.invoice_service import (
from app.modules.orders.services.invoice_service import (
invoice_service,
InvoiceService,
)
from app.services.invoice_pdf_service import (
from app.modules.orders.services.invoice_pdf_service import (
invoice_pdf_service,
InvoicePDFService,
)

View File

@@ -0,0 +1,150 @@
# app/modules/orders/services/invoice_pdf_service.py
"""
Invoice PDF generation service using WeasyPrint.
Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint.
Stores generated PDFs in the configured storage location.
"""
import logging
from datetime import UTC, datetime
from pathlib import Path
from jinja2 import Environment, FileSystemLoader
from sqlalchemy.orm import Session
from app.modules.orders.models.invoice import Invoice
logger = logging.getLogger(__name__)
# Template directory
TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices"
# PDF storage directory (relative to project root)
PDF_STORAGE_DIR = Path("storage") / "invoices"
class InvoicePDFService:
"""Service for generating invoice PDFs."""
def __init__(self):
"""Initialize the PDF service with Jinja2 environment."""
self.env = Environment(
loader=FileSystemLoader(str(TEMPLATE_DIR)),
autoescape=True,
)
def _ensure_storage_dir(self, vendor_id: int) -> Path:
"""Ensure the storage directory exists for a vendor."""
storage_path = PDF_STORAGE_DIR / str(vendor_id)
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def _get_pdf_filename(self, invoice: Invoice) -> str:
"""Generate PDF filename for an invoice."""
safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-")
return f"{safe_number}.pdf"
def generate_pdf(
self,
db: Session,
invoice: Invoice,
force_regenerate: bool = False,
) -> str:
"""
Generate PDF for an invoice.
Args:
db: Database session
invoice: Invoice to generate PDF for
force_regenerate: If True, regenerate even if PDF already exists
Returns:
Path to the generated PDF file
"""
# Check if PDF already exists
if invoice.pdf_path and not force_regenerate:
if Path(invoice.pdf_path).exists():
logger.debug(f"PDF already exists for invoice {invoice.invoice_number}")
return invoice.pdf_path
# Ensure storage directory exists
storage_dir = self._ensure_storage_dir(invoice.vendor_id)
pdf_filename = self._get_pdf_filename(invoice)
pdf_path = storage_dir / pdf_filename
# Render HTML template
html_content = self._render_html(invoice)
# Generate PDF using WeasyPrint
try:
from weasyprint import HTML
html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR))
html_doc.write_pdf(str(pdf_path))
logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}")
except ImportError:
logger.error("WeasyPrint not installed. Install with: pip install weasyprint")
raise RuntimeError("WeasyPrint not installed")
except Exception as e:
logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}")
raise
# Update invoice record with PDF path and timestamp
invoice.pdf_path = str(pdf_path)
invoice.pdf_generated_at = datetime.now(UTC)
db.flush()
return str(pdf_path)
def _render_html(self, invoice: Invoice) -> str:
"""Render the invoice HTML template."""
template = self.env.get_template("invoice.html")
context = {
"invoice": invoice,
"seller": invoice.seller_details,
"buyer": invoice.buyer_details,
"line_items": invoice.line_items,
"bank_details": invoice.bank_details,
"payment_terms": invoice.payment_terms,
"footer_text": invoice.footer_text,
"now": datetime.now(UTC),
}
return template.render(**context)
def get_pdf_path(self, invoice: Invoice) -> str | None:
"""Get the PDF path for an invoice if it exists."""
if invoice.pdf_path and Path(invoice.pdf_path).exists():
return invoice.pdf_path
return None
def delete_pdf(self, invoice: Invoice, db: Session) -> bool:
"""Delete the PDF file for an invoice."""
if not invoice.pdf_path:
return False
pdf_path = Path(invoice.pdf_path)
if pdf_path.exists():
try:
pdf_path.unlink()
logger.info(f"Deleted PDF for invoice {invoice.invoice_number}")
except Exception as e:
logger.error(f"Failed to delete PDF {pdf_path}: {e}")
return False
invoice.pdf_path = None
invoice.pdf_generated_at = None
db.flush()
return True
def regenerate_pdf(self, db: Session, invoice: Invoice) -> str:
"""Force regenerate PDF for an invoice."""
return self.generate_pdf(db, invoice, force_regenerate=True)
# Singleton instance
invoice_pdf_service = InvoicePDFService()

View File

@@ -0,0 +1,587 @@
# app/modules/orders/services/invoice_service.py
"""
Invoice service for generating and managing invoices.
Handles:
- Vendor invoice settings management
- Invoice generation from orders
- VAT calculation (Luxembourg, EU, B2B reverse charge)
- Invoice number sequencing
- PDF generation (via separate module)
"""
import logging
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from app.exceptions import ValidationException
from app.exceptions.invoice import (
InvoiceNotFoundException,
InvoiceSettingsNotFoundException,
OrderNotFoundException,
)
from app.modules.orders.models.invoice import (
Invoice,
InvoiceStatus,
VATRegime,
VendorInvoiceSettings,
)
from app.modules.orders.models.order import Order
from app.modules.orders.schemas.invoice import (
VendorInvoiceSettingsCreate,
VendorInvoiceSettingsUpdate,
)
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
# EU VAT rates by country code (2024 standard rates)
EU_VAT_RATES: dict[str, Decimal] = {
"AT": Decimal("20.00"),
"BE": Decimal("21.00"),
"BG": Decimal("20.00"),
"HR": Decimal("25.00"),
"CY": Decimal("19.00"),
"CZ": Decimal("21.00"),
"DK": Decimal("25.00"),
"EE": Decimal("22.00"),
"FI": Decimal("24.00"),
"FR": Decimal("20.00"),
"DE": Decimal("19.00"),
"GR": Decimal("24.00"),
"HU": Decimal("27.00"),
"IE": Decimal("23.00"),
"IT": Decimal("22.00"),
"LV": Decimal("21.00"),
"LT": Decimal("21.00"),
"LU": Decimal("17.00"),
"MT": Decimal("18.00"),
"NL": Decimal("21.00"),
"PL": Decimal("23.00"),
"PT": Decimal("23.00"),
"RO": Decimal("19.00"),
"SK": Decimal("20.00"),
"SI": Decimal("22.00"),
"ES": Decimal("21.00"),
"SE": Decimal("25.00"),
}
LU_VAT_RATES = {
"standard": Decimal("17.00"),
"intermediate": Decimal("14.00"),
"reduced": Decimal("8.00"),
"super_reduced": Decimal("3.00"),
}
class InvoiceService:
"""Service for invoice operations."""
# =========================================================================
# VAT Calculation
# =========================================================================
def get_vat_rate_for_country(self, country_iso: str) -> Decimal:
"""Get standard VAT rate for EU country."""
return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00"))
def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str:
"""Get human-readable VAT rate label."""
country_names = {
"AT": "Austria", "BE": "Belgium", "BG": "Bulgaria", "HR": "Croatia",
"CY": "Cyprus", "CZ": "Czech Republic", "DK": "Denmark", "EE": "Estonia",
"FI": "Finland", "FR": "France", "DE": "Germany", "GR": "Greece",
"HU": "Hungary", "IE": "Ireland", "IT": "Italy", "LV": "Latvia",
"LT": "Lithuania", "LU": "Luxembourg", "MT": "Malta", "NL": "Netherlands",
"PL": "Poland", "PT": "Portugal", "RO": "Romania", "SK": "Slovakia",
"SI": "Slovenia", "ES": "Spain", "SE": "Sweden",
}
country_name = country_names.get(country_iso.upper(), country_iso)
return f"{country_name} VAT {vat_rate}%"
def determine_vat_regime(
self,
seller_country: str,
buyer_country: str,
buyer_vat_number: str | None,
seller_oss_registered: bool,
) -> tuple[VATRegime, Decimal, str | None]:
"""Determine VAT regime and rate for invoice."""
seller_country = seller_country.upper()
buyer_country = buyer_country.upper()
if seller_country == buyer_country:
vat_rate = self.get_vat_rate_for_country(seller_country)
return VATRegime.DOMESTIC, vat_rate, None
if buyer_country in EU_VAT_RATES:
if buyer_vat_number:
return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country
if seller_oss_registered:
vat_rate = self.get_vat_rate_for_country(buyer_country)
return VATRegime.OSS, vat_rate, buyer_country
else:
vat_rate = self.get_vat_rate_for_country(seller_country)
return VATRegime.ORIGIN, vat_rate, buyer_country
return VATRegime.EXEMPT, Decimal("0.00"), buyer_country
# =========================================================================
# Invoice Settings Management
# =========================================================================
def get_settings(
self, db: Session, vendor_id: int
) -> VendorInvoiceSettings | None:
"""Get vendor invoice settings."""
return (
db.query(VendorInvoiceSettings)
.filter(VendorInvoiceSettings.vendor_id == vendor_id)
.first()
)
def get_settings_or_raise(
self, db: Session, vendor_id: int
) -> VendorInvoiceSettings:
"""Get vendor invoice settings or raise exception."""
settings = self.get_settings(db, vendor_id)
if not settings:
raise InvoiceSettingsNotFoundException(vendor_id)
return settings
def create_settings(
self,
db: Session,
vendor_id: int,
data: VendorInvoiceSettingsCreate,
) -> VendorInvoiceSettings:
"""Create vendor invoice settings."""
existing = self.get_settings(db, vendor_id)
if existing:
raise ValidationException(
"Invoice settings already exist for this vendor"
)
settings = VendorInvoiceSettings(
vendor_id=vendor_id,
**data.model_dump(),
)
db.add(settings)
db.flush()
db.refresh(settings)
logger.info(f"Created invoice settings for vendor {vendor_id}")
return settings
def update_settings(
self,
db: Session,
vendor_id: int,
data: VendorInvoiceSettingsUpdate,
) -> VendorInvoiceSettings:
"""Update vendor invoice settings."""
settings = self.get_settings_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(settings, key, value)
settings.updated_at = datetime.now(UTC)
db.flush()
db.refresh(settings)
logger.info(f"Updated invoice settings for vendor {vendor_id}")
return settings
def create_settings_from_vendor(
self,
db: Session,
vendor: Vendor,
) -> VendorInvoiceSettings:
"""Create invoice settings from vendor/company info."""
company = vendor.company
settings = VendorInvoiceSettings(
vendor_id=vendor.id,
company_name=company.legal_name if company else vendor.name,
company_address=vendor.effective_business_address,
company_city=None,
company_postal_code=None,
company_country="LU",
vat_number=vendor.effective_tax_number,
is_vat_registered=bool(vendor.effective_tax_number),
)
db.add(settings)
db.flush()
db.refresh(settings)
logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}")
return settings
# =========================================================================
# Invoice Number Generation
# =========================================================================
def _get_next_invoice_number(
self, db: Session, settings: VendorInvoiceSettings
) -> str:
"""Generate next invoice number and increment counter."""
number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding)
invoice_number = f"{settings.invoice_prefix}{number}"
settings.invoice_next_number += 1
db.flush()
return invoice_number
# =========================================================================
# Invoice Creation
# =========================================================================
def create_invoice_from_order(
self,
db: Session,
vendor_id: int,
order_id: int,
notes: str | None = None,
) -> Invoice:
"""Create an invoice from an order."""
settings = self.get_settings_or_raise(db, vendor_id)
order = (
db.query(Order)
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
.first()
)
if not order:
raise OrderNotFoundException(f"Order {order_id} not found")
existing = (
db.query(Invoice)
.filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id))
.first()
)
if existing:
raise ValidationException(f"Invoice already exists for order {order_id}")
buyer_country = order.bill_country_iso
vat_regime, vat_rate, destination_country = self.determine_vat_regime(
seller_country=settings.company_country,
buyer_country=buyer_country,
buyer_vat_number=None,
seller_oss_registered=settings.is_oss_registered,
)
seller_details = {
"company_name": settings.company_name,
"address": settings.company_address,
"city": settings.company_city,
"postal_code": settings.company_postal_code,
"country": settings.company_country,
"vat_number": settings.vat_number,
}
buyer_details = {
"name": f"{order.bill_first_name} {order.bill_last_name}".strip(),
"email": order.customer_email,
"address": order.bill_address_line_1,
"city": order.bill_city,
"postal_code": order.bill_postal_code,
"country": order.bill_country_iso,
"vat_number": None,
}
if order.bill_company:
buyer_details["company"] = order.bill_company
line_items = []
for item in order.items:
line_items.append({
"description": item.product_name,
"quantity": item.quantity,
"unit_price_cents": item.unit_price_cents,
"total_cents": item.total_price_cents,
"sku": item.product_sku,
"ean": item.gtin,
})
subtotal_cents = sum(item["total_cents"] for item in line_items)
if vat_rate > 0:
vat_amount_cents = int(subtotal_cents * float(vat_rate) / 100)
else:
vat_amount_cents = 0
total_cents = subtotal_cents + vat_amount_cents
vat_rate_label = None
if vat_rate > 0:
if destination_country:
vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate)
else:
vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate)
invoice_number = self._get_next_invoice_number(db, settings)
invoice = Invoice(
vendor_id=vendor_id,
order_id=order_id,
invoice_number=invoice_number,
invoice_date=datetime.now(UTC),
status=InvoiceStatus.DRAFT.value,
seller_details=seller_details,
buyer_details=buyer_details,
line_items=line_items,
vat_regime=vat_regime.value,
destination_country=destination_country,
vat_rate=vat_rate,
vat_rate_label=vat_rate_label,
currency=order.currency,
subtotal_cents=subtotal_cents,
vat_amount_cents=vat_amount_cents,
total_cents=total_cents,
payment_terms=settings.payment_terms,
bank_details={
"bank_name": settings.bank_name,
"iban": settings.bank_iban,
"bic": settings.bank_bic,
} if settings.bank_iban else None,
footer_text=settings.footer_text,
notes=notes,
)
db.add(invoice)
db.flush()
db.refresh(invoice)
logger.info(
f"Created invoice {invoice_number} for order {order_id} "
f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})"
)
return invoice
# =========================================================================
# Invoice Retrieval
# =========================================================================
def get_invoice(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice | None:
"""Get invoice by ID."""
return (
db.query(Invoice)
.filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id))
.first()
)
def get_invoice_or_raise(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Get invoice by ID or raise exception."""
invoice = self.get_invoice(db, vendor_id, invoice_id)
if not invoice:
raise InvoiceNotFoundException(invoice_id)
return invoice
def get_invoice_by_number(
self, db: Session, vendor_id: int, invoice_number: str
) -> Invoice | None:
"""Get invoice by invoice number."""
return (
db.query(Invoice)
.filter(
and_(
Invoice.invoice_number == invoice_number,
Invoice.vendor_id == vendor_id,
)
)
.first()
)
def get_invoice_by_order_id(
self, db: Session, vendor_id: int, order_id: int
) -> Invoice | None:
"""Get invoice by order ID."""
return (
db.query(Invoice)
.filter(
and_(
Invoice.order_id == order_id,
Invoice.vendor_id == vendor_id,
)
)
.first()
)
def list_invoices(
self,
db: Session,
vendor_id: int,
status: str | None = None,
page: int = 1,
per_page: int = 20,
) -> tuple[list[Invoice], int]:
"""List invoices for vendor with pagination."""
query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id)
if status:
query = query.filter(Invoice.status == status)
total = query.count()
invoices = (
query.order_by(Invoice.invoice_date.desc())
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
return invoices, total
# =========================================================================
# Invoice Status Management
# =========================================================================
def update_status(
self,
db: Session,
vendor_id: int,
invoice_id: int,
new_status: str,
) -> Invoice:
"""Update invoice status."""
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
valid_statuses = [s.value for s in InvoiceStatus]
if new_status not in valid_statuses:
raise ValidationException(f"Invalid status: {new_status}")
if invoice.status == InvoiceStatus.CANCELLED.value:
raise ValidationException("Cannot change status of cancelled invoice")
invoice.status = new_status
invoice.updated_at = datetime.now(UTC)
db.flush()
db.refresh(invoice)
logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}")
return invoice
def mark_as_issued(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Mark invoice as issued."""
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value)
def mark_as_paid(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Mark invoice as paid."""
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value)
def cancel_invoice(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Cancel invoice."""
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value)
# =========================================================================
# Statistics
# =========================================================================
def get_invoice_stats(
self, db: Session, vendor_id: int
) -> dict[str, Any]:
"""Get invoice statistics for vendor."""
total_count = (
db.query(func.count(Invoice.id))
.filter(Invoice.vendor_id == vendor_id)
.scalar()
or 0
)
total_revenue = (
db.query(func.sum(Invoice.total_cents))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status.in_([
InvoiceStatus.ISSUED.value,
InvoiceStatus.PAID.value,
]),
)
)
.scalar()
or 0
)
draft_count = (
db.query(func.count(Invoice.id))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status == InvoiceStatus.DRAFT.value,
)
)
.scalar()
or 0
)
paid_count = (
db.query(func.count(Invoice.id))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status == InvoiceStatus.PAID.value,
)
)
.scalar()
or 0
)
return {
"total_invoices": total_count,
"total_revenue_cents": total_revenue,
"total_revenue": total_revenue / 100 if total_revenue else 0,
"draft_count": draft_count,
"paid_count": paid_count,
}
# =========================================================================
# PDF Generation
# =========================================================================
def generate_pdf(
self,
db: Session,
vendor_id: int,
invoice_id: int,
force_regenerate: bool = False,
) -> str:
"""Generate PDF for an invoice."""
from app.modules.orders.services.invoice_pdf_service import invoice_pdf_service
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate)
def get_pdf_path(
self,
db: Session,
vendor_id: int,
invoice_id: int,
) -> str | None:
"""Get PDF path for an invoice if it exists."""
from app.modules.orders.services.invoice_pdf_service import invoice_pdf_service
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
return invoice_pdf_service.get_pdf_path(invoice)
# Singleton instance
invoice_service = InvoiceService()

View File

@@ -0,0 +1,591 @@
# app/modules/orders/services/order_inventory_service.py
"""
Order-Inventory Integration Service.
This service orchestrates inventory operations for orders:
- Reserve inventory when orders are confirmed
- Fulfill (deduct) inventory when orders are shipped
- Release reservations when orders are cancelled
All operations are logged to the inventory_transactions table for audit trail.
"""
import logging
from sqlalchemy.orm import Session
from app.exceptions import (
InsufficientInventoryException,
InventoryNotFoundException,
OrderNotFoundException,
ValidationException,
)
from app.modules.inventory.models.inventory import Inventory
from app.modules.inventory.models.inventory_transaction import (
InventoryTransaction,
TransactionType,
)
from app.modules.inventory.schemas.inventory import InventoryReserve
from app.modules.inventory.services.inventory_service import inventory_service
from app.modules.orders.models.order import Order, OrderItem
logger = logging.getLogger(__name__)
# Default location for inventory operations
DEFAULT_LOCATION = "DEFAULT"
class OrderInventoryService:
"""
Orchestrate order and inventory operations together.
"""
def get_order_with_items(
self, db: Session, vendor_id: int, order_id: int
) -> Order:
"""Get order with items or raise OrderNotFoundException."""
order = (
db.query(Order)
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
.first()
)
if not order:
raise OrderNotFoundException(f"Order {order_id} not found")
return order
def _find_inventory_location(
self, db: Session, product_id: int, vendor_id: int
) -> str | None:
"""
Find the location with available inventory for a product.
"""
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == product_id,
Inventory.vendor_id == vendor_id,
Inventory.quantity > Inventory.reserved_quantity,
)
.first()
)
return inventory.location if inventory else None
def _is_placeholder_product(self, order_item: OrderItem) -> bool:
"""Check if the order item uses a placeholder product."""
if not order_item.product:
return True
return order_item.product.gtin == "0000000000000"
def _log_transaction(
self,
db: Session,
vendor_id: int,
product_id: int,
inventory: Inventory,
transaction_type: TransactionType,
quantity_change: int,
order: Order,
reason: str | None = None,
) -> InventoryTransaction:
"""Create an inventory transaction record for audit trail."""
transaction = InventoryTransaction.create_transaction(
vendor_id=vendor_id,
product_id=product_id,
inventory_id=inventory.id if inventory else None,
transaction_type=transaction_type,
quantity_change=quantity_change,
quantity_after=inventory.quantity if inventory else 0,
reserved_after=inventory.reserved_quantity if inventory else 0,
location=inventory.location if inventory else None,
warehouse=inventory.warehouse if inventory else None,
order_id=order.id,
order_number=order.order_number,
reason=reason,
created_by="system",
)
db.add(transaction)
return transaction
def reserve_for_order(
self,
db: Session,
vendor_id: int,
order_id: int,
skip_missing: bool = True,
) -> dict:
"""Reserve inventory for all items in an order."""
order = self.get_order_with_items(db, vendor_id, order_id)
reserved_count = 0
skipped_items = []
for item in order.items:
if self._is_placeholder_product(item):
skipped_items.append({
"item_id": item.id,
"reason": "placeholder_product",
})
continue
location = self._find_inventory_location(db, item.product_id, vendor_id)
if not location:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "no_inventory",
})
continue
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=location,
quantity=item.quantity,
)
updated_inventory = inventory_service.reserve_inventory(
db, vendor_id, reserve_data
)
reserved_count += 1
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.RESERVE,
quantity_change=0,
order=order,
reason=f"Reserved for order {order.order_number}",
)
logger.info(
f"Reserved {item.quantity} units of product {item.product_id} "
f"for order {order.order_number}"
)
except InsufficientInventoryException:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "insufficient_inventory",
})
else:
raise
logger.info(
f"Order {order.order_number}: reserved {reserved_count} items, "
f"skipped {len(skipped_items)}"
)
return {
"order_id": order_id,
"order_number": order.order_number,
"reserved_count": reserved_count,
"skipped_items": skipped_items,
}
def fulfill_order(
self,
db: Session,
vendor_id: int,
order_id: int,
skip_missing: bool = True,
) -> dict:
"""Fulfill (deduct) inventory when an order is shipped."""
order = self.get_order_with_items(db, vendor_id, order_id)
fulfilled_count = 0
skipped_items = []
for item in order.items:
if item.is_fully_shipped:
continue
if self._is_placeholder_product(item):
skipped_items.append({
"item_id": item.id,
"reason": "placeholder_product",
})
continue
quantity_to_fulfill = item.remaining_quantity
location = self._find_inventory_location(db, item.product_id, vendor_id)
if not location:
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == item.product_id,
Inventory.vendor_id == vendor_id,
)
.first()
)
if inventory:
location = inventory.location
if not location:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "no_inventory",
})
continue
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=location,
quantity=quantity_to_fulfill,
)
updated_inventory = inventory_service.fulfill_reservation(
db, vendor_id, reserve_data
)
fulfilled_count += 1
item.shipped_quantity = item.quantity
item.inventory_fulfilled = True
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.FULFILL,
quantity_change=-quantity_to_fulfill,
order=order,
reason=f"Fulfilled for order {order.order_number}",
)
logger.info(
f"Fulfilled {quantity_to_fulfill} units of product {item.product_id} "
f"for order {order.order_number}"
)
except (InsufficientInventoryException, InventoryNotFoundException) as e:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": str(e),
})
else:
raise
logger.info(
f"Order {order.order_number}: fulfilled {fulfilled_count} items, "
f"skipped {len(skipped_items)}"
)
return {
"order_id": order_id,
"order_number": order.order_number,
"fulfilled_count": fulfilled_count,
"skipped_items": skipped_items,
}
def fulfill_item(
self,
db: Session,
vendor_id: int,
order_id: int,
item_id: int,
quantity: int | None = None,
skip_missing: bool = True,
) -> dict:
"""Fulfill (deduct) inventory for a specific order item."""
order = self.get_order_with_items(db, vendor_id, order_id)
item = None
for order_item in order.items:
if order_item.id == item_id:
item = order_item
break
if not item:
raise ValidationException(f"Item {item_id} not found in order {order_id}")
if item.is_fully_shipped:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "Item already fully shipped",
}
quantity_to_fulfill = quantity or item.remaining_quantity
if quantity_to_fulfill > item.remaining_quantity:
raise ValidationException(
f"Cannot ship {quantity_to_fulfill} units - only {item.remaining_quantity} remaining"
)
if quantity_to_fulfill <= 0:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "Nothing to fulfill",
}
if self._is_placeholder_product(item):
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "Placeholder product - skipped",
}
location = self._find_inventory_location(db, item.product_id, vendor_id)
if not location:
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == item.product_id,
Inventory.vendor_id == vendor_id,
)
.first()
)
if inventory:
location = inventory.location
if not location:
if skip_missing:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "No inventory found",
}
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=location,
quantity=quantity_to_fulfill,
)
updated_inventory = inventory_service.fulfill_reservation(
db, vendor_id, reserve_data
)
item.shipped_quantity += quantity_to_fulfill
if item.is_fully_shipped:
item.inventory_fulfilled = True
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.FULFILL,
quantity_change=-quantity_to_fulfill,
order=order,
reason=f"Partial shipment for order {order.order_number}",
)
logger.info(
f"Fulfilled {quantity_to_fulfill} of {item.quantity} units "
f"for item {item_id} in order {order.order_number}"
)
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": quantity_to_fulfill,
"shipped_quantity": item.shipped_quantity,
"remaining_quantity": item.remaining_quantity,
"is_fully_shipped": item.is_fully_shipped,
}
except (InsufficientInventoryException, InventoryNotFoundException) as e:
if skip_missing:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": str(e),
}
else:
raise
def release_order_reservation(
self,
db: Session,
vendor_id: int,
order_id: int,
skip_missing: bool = True,
) -> dict:
"""Release reserved inventory when an order is cancelled."""
order = self.get_order_with_items(db, vendor_id, order_id)
released_count = 0
skipped_items = []
for item in order.items:
if self._is_placeholder_product(item):
skipped_items.append({
"item_id": item.id,
"reason": "placeholder_product",
})
continue
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == item.product_id,
Inventory.vendor_id == vendor_id,
)
.first()
)
if not inventory:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "no_inventory",
})
continue
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=inventory.location,
quantity=item.quantity,
)
updated_inventory = inventory_service.release_reservation(
db, vendor_id, reserve_data
)
released_count += 1
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.RELEASE,
quantity_change=0,
order=order,
reason=f"Released for cancelled order {order.order_number}",
)
logger.info(
f"Released {item.quantity} units of product {item.product_id} "
f"for cancelled order {order.order_number}"
)
except Exception as e:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": str(e),
})
else:
raise
logger.info(
f"Order {order.order_number}: released {released_count} items, "
f"skipped {len(skipped_items)}"
)
return {
"order_id": order_id,
"order_number": order.order_number,
"released_count": released_count,
"skipped_items": skipped_items,
}
def handle_status_change(
self,
db: Session,
vendor_id: int,
order_id: int,
old_status: str | None,
new_status: str,
) -> dict | None:
"""Handle inventory operations based on order status changes."""
if old_status == new_status:
return None
result = None
if new_status == "processing":
result = self.reserve_for_order(db, vendor_id, order_id, skip_missing=True)
logger.info(f"Order {order_id} confirmed: inventory reserved")
elif new_status == "shipped":
result = self.fulfill_order(db, vendor_id, order_id, skip_missing=True)
logger.info(f"Order {order_id} shipped: inventory fulfilled")
elif new_status == "partially_shipped":
logger.info(
f"Order {order_id} partially shipped: use fulfill_item for item-level fulfillment"
)
result = {"order_id": order_id, "status": "partially_shipped"}
elif new_status == "cancelled":
if old_status and old_status not in ("cancelled", "refunded"):
result = self.release_order_reservation(
db, vendor_id, order_id, skip_missing=True
)
logger.info(f"Order {order_id} cancelled: reservations released")
return result
def get_shipment_status(
self,
db: Session,
vendor_id: int,
order_id: int,
) -> dict:
"""Get detailed shipment status for an order."""
order = self.get_order_with_items(db, vendor_id, order_id)
items = []
for item in order.items:
items.append({
"item_id": item.id,
"product_id": item.product_id,
"product_name": item.product_name,
"quantity": item.quantity,
"shipped_quantity": item.shipped_quantity,
"remaining_quantity": item.remaining_quantity,
"is_fully_shipped": item.is_fully_shipped,
"is_partially_shipped": item.is_partially_shipped,
})
return {
"order_id": order_id,
"order_number": order.order_number,
"order_status": order.status,
"is_fully_shipped": order.is_fully_shipped,
"is_partially_shipped": order.is_partially_shipped,
"shipped_item_count": order.shipped_item_count,
"total_item_count": len(order.items),
"total_shipped_units": order.total_shipped_units,
"total_ordered_units": order.total_ordered_units,
"items": items,
}
# Create service instance
order_inventory_service = OrderInventoryService()

View File

@@ -0,0 +1,466 @@
# app/modules/orders/services/order_item_exception_service.py
"""
Service for managing order item exceptions (unmatched products).
This service handles:
- Creating exceptions when products are not found during order import
- Resolving exceptions by assigning products
- Auto-matching when new products are imported
- Querying and statistics for exceptions
"""
import logging
from datetime import UTC, datetime
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload
from app.exceptions import (
ExceptionAlreadyResolvedException,
InvalidProductForExceptionException,
OrderItemExceptionNotFoundException,
ProductNotFoundException,
)
from app.modules.orders.models.order import Order, OrderItem
from app.modules.orders.models.order_item_exception import OrderItemException
from models.database.product import Product
logger = logging.getLogger(__name__)
class OrderItemExceptionService:
"""Service for order item exception CRUD and resolution workflow."""
# =========================================================================
# Exception Creation
# =========================================================================
def create_exception(
self,
db: Session,
order_item: OrderItem,
vendor_id: int,
original_gtin: str | None,
original_product_name: str | None,
original_sku: str | None,
exception_type: str = "product_not_found",
) -> OrderItemException:
"""Create an exception record for an unmatched order item."""
exception = OrderItemException(
order_item_id=order_item.id,
vendor_id=vendor_id,
original_gtin=original_gtin,
original_product_name=original_product_name,
original_sku=original_sku,
exception_type=exception_type,
status="pending",
)
db.add(exception)
db.flush()
logger.info(
f"Created order item exception {exception.id} for order item "
f"{order_item.id}, GTIN: {original_gtin}"
)
return exception
# =========================================================================
# Exception Retrieval
# =========================================================================
def get_exception_by_id(
self,
db: Session,
exception_id: int,
vendor_id: int | None = None,
) -> OrderItemException:
"""Get an exception by ID, optionally filtered by vendor."""
query = db.query(OrderItemException).filter(
OrderItemException.id == exception_id
)
if vendor_id is not None:
query = query.filter(OrderItemException.vendor_id == vendor_id)
exception = query.first()
if not exception:
raise OrderItemExceptionNotFoundException(exception_id)
return exception
def get_pending_exceptions(
self,
db: Session,
vendor_id: int | None = None,
status: str | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[OrderItemException], int]:
"""Get exceptions with pagination and filtering."""
query = (
db.query(OrderItemException)
.join(OrderItem)
.join(Order)
.options(
joinedload(OrderItemException.order_item).joinedload(OrderItem.order)
)
)
if vendor_id is not None:
query = query.filter(OrderItemException.vendor_id == vendor_id)
if status:
query = query.filter(OrderItemException.status == status)
if search:
search_pattern = f"%{search}%"
query = query.filter(
or_(
OrderItemException.original_gtin.ilike(search_pattern),
OrderItemException.original_product_name.ilike(search_pattern),
OrderItemException.original_sku.ilike(search_pattern),
Order.order_number.ilike(search_pattern),
)
)
total = query.count()
exceptions = (
query.order_by(OrderItemException.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return exceptions, total
def get_exceptions_for_order(
self,
db: Session,
order_id: int,
) -> list[OrderItemException]:
"""Get all exceptions for items in an order."""
return (
db.query(OrderItemException)
.join(OrderItem)
.filter(OrderItem.order_id == order_id)
.all()
)
# =========================================================================
# Exception Statistics
# =========================================================================
def get_exception_stats(
self,
db: Session,
vendor_id: int | None = None,
) -> dict[str, int]:
"""Get exception counts by status."""
query = db.query(
OrderItemException.status,
func.count(OrderItemException.id).label("count"),
)
if vendor_id is not None:
query = query.filter(OrderItemException.vendor_id == vendor_id)
results = query.group_by(OrderItemException.status).all()
stats = {
"pending": 0,
"resolved": 0,
"ignored": 0,
"total": 0,
}
for status, count in results:
if status in stats:
stats[status] = count
stats["total"] += count
orders_query = (
db.query(func.count(func.distinct(OrderItem.order_id)))
.join(OrderItemException)
.filter(OrderItemException.status == "pending")
)
if vendor_id is not None:
orders_query = orders_query.filter(
OrderItemException.vendor_id == vendor_id
)
stats["orders_with_exceptions"] = orders_query.scalar() or 0
return stats
# =========================================================================
# Exception Resolution
# =========================================================================
def resolve_exception(
self,
db: Session,
exception_id: int,
product_id: int,
resolved_by: int,
notes: str | None = None,
vendor_id: int | None = None,
) -> OrderItemException:
"""Resolve an exception by assigning a product."""
exception = self.get_exception_by_id(db, exception_id, vendor_id)
if exception.status == "resolved":
raise ExceptionAlreadyResolvedException(exception_id)
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise ProductNotFoundException(product_id)
if product.vendor_id != exception.vendor_id:
raise InvalidProductForExceptionException(
product_id, "Product belongs to a different vendor"
)
if not product.is_active:
raise InvalidProductForExceptionException(
product_id, "Product is not active"
)
exception.status = "resolved"
exception.resolved_product_id = product_id
exception.resolved_at = datetime.now(UTC)
exception.resolved_by = resolved_by
exception.resolution_notes = notes
order_item = exception.order_item
order_item.product_id = product_id
order_item.needs_product_match = False
if product.marketplace_product:
order_item.product_name = product.marketplace_product.get_title("en")
order_item.product_sku = product.vendor_sku or order_item.product_sku
db.flush()
logger.info(
f"Resolved exception {exception_id} with product {product_id} "
f"by user {resolved_by}"
)
return exception
def ignore_exception(
self,
db: Session,
exception_id: int,
resolved_by: int,
notes: str,
vendor_id: int | None = None,
) -> OrderItemException:
"""Mark an exception as ignored."""
exception = self.get_exception_by_id(db, exception_id, vendor_id)
if exception.status == "resolved":
raise ExceptionAlreadyResolvedException(exception_id)
exception.status = "ignored"
exception.resolved_at = datetime.now(UTC)
exception.resolved_by = resolved_by
exception.resolution_notes = notes
db.flush()
logger.info(
f"Ignored exception {exception_id} by user {resolved_by}: {notes}"
)
return exception
# =========================================================================
# Auto-Matching
# =========================================================================
def auto_match_by_gtin(
self,
db: Session,
vendor_id: int,
gtin: str,
product_id: int,
) -> list[OrderItemException]:
"""Auto-resolve pending exceptions matching a GTIN."""
if not gtin:
return []
pending = (
db.query(OrderItemException)
.filter(
and_(
OrderItemException.vendor_id == vendor_id,
OrderItemException.original_gtin == gtin,
OrderItemException.status == "pending",
)
)
.all()
)
if not pending:
return []
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
logger.warning(f"Product {product_id} not found for auto-match")
return []
resolved = []
now = datetime.now(UTC)
for exception in pending:
exception.status = "resolved"
exception.resolved_product_id = product_id
exception.resolved_at = now
exception.resolution_notes = "Auto-matched during product import"
order_item = exception.order_item
order_item.product_id = product_id
order_item.needs_product_match = False
if product.marketplace_product:
order_item.product_name = product.marketplace_product.get_title("en")
resolved.append(exception)
if resolved:
db.flush()
logger.info(
f"Auto-matched {len(resolved)} exceptions for GTIN {gtin} "
f"with product {product_id}"
)
return resolved
def auto_match_batch(
self,
db: Session,
vendor_id: int,
gtin_to_product: dict[str, int],
) -> int:
"""Batch auto-match multiple GTINs after bulk import."""
if not gtin_to_product:
return 0
total_resolved = 0
for gtin, product_id in gtin_to_product.items():
resolved = self.auto_match_by_gtin(db, vendor_id, gtin, product_id)
total_resolved += len(resolved)
return total_resolved
# =========================================================================
# Confirmation Checks
# =========================================================================
def order_has_unresolved_exceptions(
self,
db: Session,
order_id: int,
) -> bool:
"""Check if order has any unresolved exceptions."""
count = (
db.query(func.count(OrderItemException.id))
.join(OrderItem)
.filter(
and_(
OrderItem.order_id == order_id,
OrderItemException.status.in_(["pending", "ignored"]),
)
)
.scalar()
)
return count > 0
def get_unresolved_exception_count(
self,
db: Session,
order_id: int,
) -> int:
"""Get count of unresolved exceptions for an order."""
return (
db.query(func.count(OrderItemException.id))
.join(OrderItem)
.filter(
and_(
OrderItem.order_id == order_id,
OrderItemException.status.in_(["pending", "ignored"]),
)
)
.scalar()
) or 0
# =========================================================================
# Bulk Operations
# =========================================================================
def bulk_resolve_by_gtin(
self,
db: Session,
vendor_id: int,
gtin: str,
product_id: int,
resolved_by: int,
notes: str | None = None,
) -> int:
"""Bulk resolve all pending exceptions for a GTIN."""
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise ProductNotFoundException(product_id)
if product.vendor_id != vendor_id:
raise InvalidProductForExceptionException(
product_id, "Product belongs to a different vendor"
)
pending = (
db.query(OrderItemException)
.filter(
and_(
OrderItemException.vendor_id == vendor_id,
OrderItemException.original_gtin == gtin,
OrderItemException.status == "pending",
)
)
.all()
)
now = datetime.now(UTC)
resolution_notes = notes or f"Bulk resolved for GTIN {gtin}"
for exception in pending:
exception.status = "resolved"
exception.resolved_product_id = product_id
exception.resolved_at = now
exception.resolved_by = resolved_by
exception.resolution_notes = resolution_notes
order_item = exception.order_item
order_item.product_id = product_id
order_item.needs_product_match = False
if product.marketplace_product:
order_item.product_name = product.marketplace_product.get_title("en")
db.flush()
logger.info(
f"Bulk resolved {len(pending)} exceptions for GTIN {gtin} "
f"with product {product_id} by user {resolved_by}"
)
return len(pending)
# Global service instance
order_item_exception_service = OrderItemExceptionService()

File diff suppressed because it is too large Load Diff