refactor: migrate modules from re-exports to canonical implementations
Move actual code implementations into module directories: - orders: 5 services, 4 models, order/invoice schemas - inventory: 3 services, 2 models, 30+ schemas - customers: 3 services, 2 models, customer schemas - messaging: 3 services, 2 models, message/notification schemas - monitoring: background_tasks_service - marketplace: 5+ services including letzshop submodule - dev_tools: code_quality_service, test_runner_service - billing: billing_service - contracts: definition.py Legacy files in app/services/, models/database/, models/schema/ now re-export from canonical module locations for backwards compatibility. Architecture validator passes with 0 errors. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from fastapi.responses import HTMLResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_current_vendor_from_cookie_or_header, get_db
|
||||
from app.services.platform_settings_service import platform_settings_service
|
||||
from app.services.platform_settings_service import platform_settings_service # noqa: MOD-004 - shared platform service
|
||||
from app.templates_config import templates
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
@@ -17,8 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_current_admin_api, require_module_access
|
||||
from app.core.database import get_db
|
||||
from app.services.admin_subscription_service import admin_subscription_service
|
||||
from app.services.subscription_service import subscription_service
|
||||
from app.modules.billing.services import admin_subscription_service, subscription_service
|
||||
from models.database.user import User
|
||||
from models.schema.billing import (
|
||||
BillingHistoryListResponse,
|
||||
|
||||
@@ -19,8 +19,7 @@ from sqlalchemy.orm import Session
|
||||
from app.api.deps import get_current_vendor_api, require_module_access
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.services.billing_service import billing_service
|
||||
from app.services.subscription_service import subscription_service
|
||||
from app.modules.billing.services import billing_service, subscription_service
|
||||
from models.database.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,6 +17,16 @@ from app.modules.billing.services.admin_subscription_service import (
|
||||
AdminSubscriptionService,
|
||||
admin_subscription_service,
|
||||
)
|
||||
from app.modules.billing.services.billing_service import (
|
||||
BillingService,
|
||||
billing_service,
|
||||
BillingServiceError,
|
||||
PaymentSystemNotConfiguredError,
|
||||
TierNotFoundError,
|
||||
StripePriceNotConfiguredError,
|
||||
NoActiveSubscriptionError,
|
||||
SubscriptionNotCancelledError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SubscriptionService",
|
||||
@@ -25,4 +35,12 @@ __all__ = [
|
||||
"stripe_service",
|
||||
"AdminSubscriptionService",
|
||||
"admin_subscription_service",
|
||||
"BillingService",
|
||||
"billing_service",
|
||||
"BillingServiceError",
|
||||
"PaymentSystemNotConfiguredError",
|
||||
"TierNotFoundError",
|
||||
"StripePriceNotConfiguredError",
|
||||
"NoActiveSubscriptionError",
|
||||
"SubscriptionNotCancelledError",
|
||||
]
|
||||
|
||||
588
app/modules/billing/services/billing_service.py
Normal file
588
app/modules/billing/services/billing_service.py
Normal file
@@ -0,0 +1,588 @@
|
||||
# app/modules/billing/services/billing_service.py
|
||||
"""
|
||||
Billing service for subscription and payment operations.
|
||||
|
||||
Provides:
|
||||
- Subscription status and usage queries
|
||||
- Tier management
|
||||
- Invoice history
|
||||
- Add-on management
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.billing.services.stripe_service import stripe_service
|
||||
from app.modules.billing.services.subscription_service import subscription_service
|
||||
from app.modules.billing.models import (
|
||||
AddOnProduct,
|
||||
BillingHistory,
|
||||
SubscriptionTier,
|
||||
VendorAddOn,
|
||||
VendorSubscription,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Base exception for billing service errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PaymentSystemNotConfiguredError(BillingServiceError):
|
||||
"""Raised when Stripe is not configured."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Payment system not configured")
|
||||
|
||||
|
||||
class TierNotFoundError(BillingServiceError):
|
||||
"""Raised when a tier is not found."""
|
||||
|
||||
def __init__(self, tier_code: str):
|
||||
self.tier_code = tier_code
|
||||
super().__init__(f"Tier '{tier_code}' not found")
|
||||
|
||||
|
||||
class StripePriceNotConfiguredError(BillingServiceError):
|
||||
"""Raised when Stripe price is not configured for a tier."""
|
||||
|
||||
def __init__(self, tier_code: str):
|
||||
self.tier_code = tier_code
|
||||
super().__init__(f"Stripe price not configured for tier '{tier_code}'")
|
||||
|
||||
|
||||
class NoActiveSubscriptionError(BillingServiceError):
|
||||
"""Raised when no active subscription exists."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("No active subscription found")
|
||||
|
||||
|
||||
class SubscriptionNotCancelledError(BillingServiceError):
|
||||
"""Raised when trying to reactivate a non-cancelled subscription."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Subscription is not cancelled")
|
||||
|
||||
|
||||
class BillingService:
|
||||
"""Service for billing operations."""
|
||||
|
||||
def get_subscription_with_tier(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[VendorSubscription, SubscriptionTier | None]:
|
||||
"""
|
||||
Get subscription and its tier info.
|
||||
|
||||
Returns:
|
||||
Tuple of (subscription, tier) where tier may be None
|
||||
"""
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == subscription.tier)
|
||||
.first()
|
||||
)
|
||||
|
||||
return subscription, tier
|
||||
|
||||
def get_available_tiers(
|
||||
self, db: Session, current_tier: str
|
||||
) -> tuple[list[dict], dict[str, int]]:
|
||||
"""
|
||||
Get all available tiers with upgrade/downgrade flags.
|
||||
|
||||
Returns:
|
||||
Tuple of (tier_list, tier_order_map)
|
||||
"""
|
||||
tiers = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
SubscriptionTier.is_public == True, # noqa: E712
|
||||
)
|
||||
.order_by(SubscriptionTier.display_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
tier_order = {t.code: t.display_order for t in tiers}
|
||||
current_order = tier_order.get(current_tier, 0)
|
||||
|
||||
tier_list = []
|
||||
for tier in tiers:
|
||||
tier_list.append({
|
||||
"code": tier.code,
|
||||
"name": tier.name,
|
||||
"description": tier.description,
|
||||
"price_monthly_cents": tier.price_monthly_cents,
|
||||
"price_annual_cents": tier.price_annual_cents,
|
||||
"orders_per_month": tier.orders_per_month,
|
||||
"products_limit": tier.products_limit,
|
||||
"team_members": tier.team_members,
|
||||
"features": tier.features or [],
|
||||
"is_current": tier.code == current_tier,
|
||||
"can_upgrade": tier.display_order > current_order,
|
||||
"can_downgrade": tier.display_order < current_order,
|
||||
})
|
||||
|
||||
return tier_list, tier_order
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||
"""
|
||||
Get a tier by its code.
|
||||
|
||||
Raises:
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
"""
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.code == tier_code,
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tier:
|
||||
raise TierNotFoundError(tier_code)
|
||||
|
||||
return tier
|
||||
|
||||
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""
|
||||
Get vendor by ID.
|
||||
|
||||
Raises:
|
||||
VendorNotFoundException from app.exceptions
|
||||
"""
|
||||
from app.exceptions import VendorNotFoundException
|
||||
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
|
||||
|
||||
return vendor
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
tier_code: str,
|
||||
is_annual: bool,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe checkout session.
|
||||
|
||||
Returns:
|
||||
Dict with checkout_url and session_id
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
StripePriceNotConfiguredError: If price not configured
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
vendor = self.get_vendor(db, vendor_id)
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
price_id = (
|
||||
tier.stripe_price_annual_id
|
||||
if is_annual and tier.stripe_price_annual_id
|
||||
else tier.stripe_price_monthly_id
|
||||
)
|
||||
|
||||
if not price_id:
|
||||
raise StripePriceNotConfiguredError(tier_code)
|
||||
|
||||
# Check if this is a new subscription (for trial)
|
||||
existing_sub = subscription_service.get_subscription(db, vendor_id)
|
||||
trial_days = None
|
||||
if not existing_sub or not existing_sub.stripe_subscription_id:
|
||||
from app.core.config import settings
|
||||
trial_days = settings.stripe_trial_days
|
||||
|
||||
session = stripe_service.create_checkout_session(
|
||||
db=db,
|
||||
vendor=vendor,
|
||||
price_id=price_id,
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
trial_days=trial_days,
|
||||
)
|
||||
|
||||
# Update subscription with tier info
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
subscription.tier = tier_code
|
||||
subscription.is_annual = is_annual
|
||||
|
||||
return {
|
||||
"checkout_url": session.url,
|
||||
"session_id": session.id,
|
||||
}
|
||||
|
||||
def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session.
|
||||
|
||||
Returns:
|
||||
Dict with portal_url
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
NoActiveSubscriptionError: If no subscription with customer ID
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_customer_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
session = stripe_service.create_portal_session(
|
||||
customer_id=subscription.stripe_customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
|
||||
return {"portal_url": session.url}
|
||||
|
||||
def get_invoices(
|
||||
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20
|
||||
) -> tuple[list[BillingHistory], int]:
|
||||
"""
|
||||
Get invoice history for a vendor.
|
||||
|
||||
Returns:
|
||||
Tuple of (invoices, total_count)
|
||||
"""
|
||||
query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id)
|
||||
|
||||
total = query.count()
|
||||
|
||||
invoices = (
|
||||
query.order_by(BillingHistory.invoice_date.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return invoices, total
|
||||
|
||||
def get_available_addons(
|
||||
self, db: Session, category: str | None = None
|
||||
) -> list[AddOnProduct]:
|
||||
"""Get available add-on products."""
|
||||
query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712
|
||||
|
||||
if category:
|
||||
query = query.filter(AddOnProduct.category == category)
|
||||
|
||||
return query.order_by(AddOnProduct.display_order).all()
|
||||
|
||||
def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]:
|
||||
"""Get vendor's purchased add-ons."""
|
||||
return (
|
||||
db.query(VendorAddOn)
|
||||
.filter(VendorAddOn.vendor_id == vendor_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
def cancel_subscription(
|
||||
self, db: Session, vendor_id: int, reason: str | None, immediately: bool
|
||||
) -> dict:
|
||||
"""
|
||||
Cancel a subscription.
|
||||
|
||||
Returns:
|
||||
Dict with message and effective_date
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription to cancel
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.cancel_subscription(
|
||||
subscription_id=subscription.stripe_subscription_id,
|
||||
immediately=immediately,
|
||||
cancellation_reason=reason,
|
||||
)
|
||||
|
||||
subscription.cancelled_at = datetime.utcnow()
|
||||
subscription.cancellation_reason = reason
|
||||
|
||||
effective_date = (
|
||||
datetime.utcnow().isoformat()
|
||||
if immediately
|
||||
else subscription.period_end.isoformat()
|
||||
if subscription.period_end
|
||||
else datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Subscription cancelled successfully",
|
||||
"effective_date": effective_date,
|
||||
}
|
||||
|
||||
def reactivate_subscription(self, db: Session, vendor_id: int) -> dict:
|
||||
"""
|
||||
Reactivate a cancelled subscription.
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription
|
||||
SubscriptionNotCancelledError: If not cancelled
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if not subscription.cancelled_at:
|
||||
raise SubscriptionNotCancelledError()
|
||||
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.reactivate_subscription(subscription.stripe_subscription_id)
|
||||
|
||||
subscription.cancelled_at = None
|
||||
subscription.cancellation_reason = None
|
||||
|
||||
return {"message": "Subscription reactivated successfully"}
|
||||
|
||||
def get_upcoming_invoice(self, db: Session, vendor_id: int) -> dict:
|
||||
"""
|
||||
Get upcoming invoice preview.
|
||||
|
||||
Returns:
|
||||
Dict with amount_due_cents, currency, next_payment_date, line_items
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription with customer ID
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_customer_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if not stripe_service.is_configured:
|
||||
# Return empty preview if Stripe not configured
|
||||
return {
|
||||
"amount_due_cents": 0,
|
||||
"currency": "EUR",
|
||||
"next_payment_date": None,
|
||||
"line_items": [],
|
||||
}
|
||||
|
||||
invoice = stripe_service.get_upcoming_invoice(subscription.stripe_customer_id)
|
||||
|
||||
if not invoice:
|
||||
return {
|
||||
"amount_due_cents": 0,
|
||||
"currency": "EUR",
|
||||
"next_payment_date": None,
|
||||
"line_items": [],
|
||||
}
|
||||
|
||||
line_items = []
|
||||
if invoice.lines and invoice.lines.data:
|
||||
for line in invoice.lines.data:
|
||||
line_items.append({
|
||||
"description": line.description or "",
|
||||
"amount_cents": line.amount,
|
||||
"quantity": line.quantity or 1,
|
||||
})
|
||||
|
||||
return {
|
||||
"amount_due_cents": invoice.amount_due,
|
||||
"currency": invoice.currency.upper(),
|
||||
"next_payment_date": datetime.fromtimestamp(invoice.next_payment_attempt).isoformat()
|
||||
if invoice.next_payment_attempt
|
||||
else None,
|
||||
"line_items": line_items,
|
||||
}
|
||||
|
||||
def change_tier(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
new_tier_code: str,
|
||||
is_annual: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Change subscription tier (upgrade/downgrade).
|
||||
|
||||
Returns:
|
||||
Dict with message, new_tier, effective_immediately
|
||||
|
||||
Raises:
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
NoActiveSubscriptionError: If no subscription
|
||||
StripePriceNotConfiguredError: If price not configured
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
tier = self.get_tier_by_code(db, new_tier_code)
|
||||
|
||||
price_id = (
|
||||
tier.stripe_price_annual_id
|
||||
if is_annual and tier.stripe_price_annual_id
|
||||
else tier.stripe_price_monthly_id
|
||||
)
|
||||
|
||||
if not price_id:
|
||||
raise StripePriceNotConfiguredError(new_tier_code)
|
||||
|
||||
# Update in Stripe
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.update_subscription(
|
||||
subscription_id=subscription.stripe_subscription_id,
|
||||
new_price_id=price_id,
|
||||
)
|
||||
|
||||
# Update local subscription
|
||||
old_tier = subscription.tier
|
||||
subscription.tier = new_tier_code
|
||||
subscription.tier_id = tier.id
|
||||
subscription.is_annual = is_annual
|
||||
subscription.updated_at = datetime.utcnow()
|
||||
|
||||
is_upgrade = self._is_upgrade(db, old_tier, new_tier_code)
|
||||
|
||||
return {
|
||||
"message": f"Subscription {'upgraded' if is_upgrade else 'changed'} to {tier.name}",
|
||||
"new_tier": new_tier_code,
|
||||
"effective_immediately": True,
|
||||
}
|
||||
|
||||
def _is_upgrade(self, db: Session, old_tier: str, new_tier: str) -> bool:
|
||||
"""Check if tier change is an upgrade."""
|
||||
old = db.query(SubscriptionTier).filter(SubscriptionTier.code == old_tier).first()
|
||||
new = db.query(SubscriptionTier).filter(SubscriptionTier.code == new_tier).first()
|
||||
|
||||
if not old or not new:
|
||||
return False
|
||||
|
||||
return new.display_order > old.display_order
|
||||
|
||||
def purchase_addon(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
addon_code: str,
|
||||
domain_name: str | None,
|
||||
quantity: int,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Create checkout session for add-on purchase.
|
||||
|
||||
Returns:
|
||||
Dict with checkout_url and session_id
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
AddonNotFoundError: If addon doesn't exist
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
addon = (
|
||||
db.query(AddOnProduct)
|
||||
.filter(
|
||||
AddOnProduct.code == addon_code,
|
||||
AddOnProduct.is_active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not addon:
|
||||
raise BillingServiceError(f"Add-on '{addon_code}' not found")
|
||||
|
||||
if not addon.stripe_price_id:
|
||||
raise BillingServiceError(f"Stripe price not configured for add-on '{addon_code}'")
|
||||
|
||||
vendor = self.get_vendor(db, vendor_id)
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Create checkout session for add-on
|
||||
session = stripe_service.create_checkout_session(
|
||||
db=db,
|
||||
vendor=vendor,
|
||||
price_id=addon.stripe_price_id,
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
quantity=quantity,
|
||||
metadata={
|
||||
"addon_code": addon_code,
|
||||
"domain_name": domain_name or "",
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"checkout_url": session.url,
|
||||
"session_id": session.id,
|
||||
}
|
||||
|
||||
def cancel_addon(self, db: Session, vendor_id: int, addon_id: int) -> dict:
|
||||
"""
|
||||
Cancel a purchased add-on.
|
||||
|
||||
Returns:
|
||||
Dict with message and addon_code
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If addon not found or not owned by vendor
|
||||
"""
|
||||
vendor_addon = (
|
||||
db.query(VendorAddOn)
|
||||
.filter(
|
||||
VendorAddOn.id == addon_id,
|
||||
VendorAddOn.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not vendor_addon:
|
||||
raise BillingServiceError("Add-on not found")
|
||||
|
||||
addon_code = vendor_addon.addon_product.code
|
||||
|
||||
# Cancel in Stripe if applicable
|
||||
if stripe_service.is_configured and vendor_addon.stripe_subscription_item_id:
|
||||
try:
|
||||
stripe_service.cancel_subscription_item(vendor_addon.stripe_subscription_item_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cancel addon in Stripe: {e}")
|
||||
|
||||
# Mark as cancelled
|
||||
vendor_addon.status = "cancelled"
|
||||
vendor_addon.cancelled_at = datetime.utcnow()
|
||||
|
||||
return {
|
||||
"message": "Add-on cancelled successfully",
|
||||
"addon_code": addon_code,
|
||||
}
|
||||
|
||||
|
||||
# Create service instance
|
||||
billing_service = BillingService()
|
||||
@@ -25,7 +25,7 @@ from app.modules.cms.schemas import (
|
||||
CMSUsageResponse,
|
||||
)
|
||||
from app.modules.cms.services import content_page_service
|
||||
from app.services.vendor_service import VendorService
|
||||
from app.services.vendor_service import VendorService # noqa: MOD-004 - shared platform service
|
||||
from models.database.user import User
|
||||
|
||||
vendor_service = VendorService()
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.deps import get_current_vendor_from_cookie_or_header, get_db
|
||||
from app.modules.cms.services import content_page_service
|
||||
from app.services.platform_settings_service import platform_settings_service
|
||||
from app.services.platform_settings_service import platform_settings_service # noqa: MOD-004 - shared platform service
|
||||
from app.templates_config import templates
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
24
app/modules/contracts/definition.py
Normal file
24
app/modules/contracts/definition.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# app/modules/contracts/definition.py
|
||||
"""
|
||||
Contracts module definition.
|
||||
|
||||
Cross-module contracts and Protocol interfaces.
|
||||
Infrastructure module - cannot be disabled.
|
||||
"""
|
||||
|
||||
from app.modules.base import ModuleDefinition
|
||||
|
||||
contracts_module = ModuleDefinition(
|
||||
code="contracts",
|
||||
name="Module Contracts",
|
||||
description="Cross-module contracts using Protocol pattern for type-safe inter-module communication.",
|
||||
version="1.0.0",
|
||||
is_core=True,
|
||||
features=[
|
||||
"service_protocols",
|
||||
"cross_module_interfaces",
|
||||
],
|
||||
menu_items={}, # Infrastructure module - no UI
|
||||
)
|
||||
|
||||
__all__ = ["contracts_module"]
|
||||
@@ -2,14 +2,23 @@
|
||||
"""
|
||||
Customers module database models.
|
||||
|
||||
Re-exports customer-related models from their source locations.
|
||||
This is the canonical location for customer models. Module models are
|
||||
automatically discovered and registered with SQLAlchemy's Base.metadata
|
||||
at startup.
|
||||
|
||||
Usage:
|
||||
from app.modules.customers.models import (
|
||||
Customer,
|
||||
CustomerAddress,
|
||||
PasswordResetToken,
|
||||
)
|
||||
"""
|
||||
|
||||
from models.database.customer import (
|
||||
from app.modules.customers.models.customer import (
|
||||
Customer,
|
||||
CustomerAddress,
|
||||
)
|
||||
from models.database.password_reset_token import PasswordResetToken
|
||||
from app.modules.customers.models.password_reset_token import PasswordResetToken
|
||||
|
||||
__all__ = [
|
||||
"Customer",
|
||||
|
||||
93
app/modules/customers/models/customer.py
Normal file
93
app/modules/customers/models/customer.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# app/modules/customers/models/customer.py
|
||||
"""
|
||||
Customer database models.
|
||||
|
||||
Provides Customer and CustomerAddress models for vendor-scoped
|
||||
customer management.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
from models.database.base import TimestampMixin
|
||||
|
||||
|
||||
class Customer(Base, TimestampMixin):
|
||||
"""Customer model with vendor isolation."""
|
||||
|
||||
__tablename__ = "customers"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False)
|
||||
email = Column(
|
||||
String(255), nullable=False, index=True
|
||||
) # Unique within vendor scope
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
first_name = Column(String(100))
|
||||
last_name = Column(String(100))
|
||||
phone = Column(String(50))
|
||||
customer_number = Column(
|
||||
String(100), nullable=False, index=True
|
||||
) # Vendor-specific ID
|
||||
preferences = Column(JSON, default=dict)
|
||||
marketing_consent = Column(Boolean, default=False)
|
||||
last_order_date = Column(DateTime)
|
||||
total_orders = Column(Integer, default=0)
|
||||
total_spent = Column(Numeric(10, 2), default=0)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Language preference (NULL = use vendor storefront_language default)
|
||||
# Supported: en, fr, de, lb
|
||||
preferred_language = Column(String(5), nullable=True)
|
||||
|
||||
# Relationships
|
||||
vendor = relationship("Vendor", back_populates="customers")
|
||||
addresses = relationship("CustomerAddress", back_populates="customer")
|
||||
orders = relationship("Order", back_populates="customer")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Customer(id={self.id}, vendor_id={self.vendor_id}, email='{self.email}')>"
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
if self.first_name and self.last_name:
|
||||
return f"{self.first_name} {self.last_name}"
|
||||
return self.email
|
||||
|
||||
|
||||
class CustomerAddress(Base, TimestampMixin):
|
||||
"""Customer address model for shipping and billing addresses."""
|
||||
|
||||
__tablename__ = "customer_addresses"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False)
|
||||
customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False)
|
||||
address_type = Column(String(50), nullable=False) # 'billing', 'shipping'
|
||||
first_name = Column(String(100), nullable=False)
|
||||
last_name = Column(String(100), nullable=False)
|
||||
company = Column(String(200))
|
||||
address_line_1 = Column(String(255), nullable=False)
|
||||
address_line_2 = Column(String(255))
|
||||
city = Column(String(100), nullable=False)
|
||||
postal_code = Column(String(20), nullable=False)
|
||||
country_name = Column(String(100), nullable=False)
|
||||
country_iso = Column(String(5), nullable=False)
|
||||
is_default = Column(Boolean, default=False)
|
||||
|
||||
# Relationships
|
||||
vendor = relationship("Vendor")
|
||||
customer = relationship("Customer", back_populates="addresses")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CustomerAddress(id={self.id}, customer_id={self.customer_id}, type='{self.address_type}')>"
|
||||
91
app/modules/customers/models/password_reset_token.py
Normal file
91
app/modules/customers/models/password_reset_token.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# app/modules/customers/models/password_reset_token.py
|
||||
"""
|
||||
Password reset token model for customer accounts.
|
||||
|
||||
Security features:
|
||||
- Tokens are stored as SHA256 hashes, not plaintext
|
||||
- Tokens expire after 1 hour
|
||||
- Only one active token per customer (old tokens invalidated on new request)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class PasswordResetToken(Base):
|
||||
"""Password reset token for customer accounts."""
|
||||
|
||||
__tablename__ = "password_reset_tokens"
|
||||
|
||||
# Token expiry in hours
|
||||
TOKEN_EXPIRY_HOURS = 1
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
customer_id = Column(
|
||||
Integer, ForeignKey("customers.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
token_hash = Column(String(64), nullable=False, index=True)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
used_at = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
customer = relationship("Customer")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PasswordResetToken(id={self.id}, customer_id={self.customer_id}, expires_at={self.expires_at})>"
|
||||
|
||||
@staticmethod
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA256."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def create_for_customer(cls, db: Session, customer_id: int) -> str:
|
||||
"""Create a new password reset token for a customer.
|
||||
|
||||
Invalidates any existing tokens for the customer.
|
||||
Returns the plaintext token (to be sent via email).
|
||||
"""
|
||||
# Invalidate existing tokens for this customer
|
||||
db.query(cls).filter(
|
||||
cls.customer_id == customer_id,
|
||||
cls.used_at.is_(None),
|
||||
).delete()
|
||||
|
||||
# Generate new token
|
||||
plaintext_token = secrets.token_urlsafe(32)
|
||||
token_hash = cls.hash_token(plaintext_token)
|
||||
|
||||
# Create token record
|
||||
token = cls(
|
||||
customer_id=customer_id,
|
||||
token_hash=token_hash,
|
||||
expires_at=datetime.utcnow() + timedelta(hours=cls.TOKEN_EXPIRY_HOURS),
|
||||
)
|
||||
db.add(token)
|
||||
db.flush()
|
||||
|
||||
return plaintext_token
|
||||
|
||||
@classmethod
|
||||
def find_valid_token(cls, db: Session, plaintext_token: str) -> "PasswordResetToken | None":
|
||||
"""Find a valid (not expired, not used) token."""
|
||||
token_hash = cls.hash_token(plaintext_token)
|
||||
|
||||
return db.query(cls).filter(
|
||||
cls.token_hash == token_hash,
|
||||
cls.expires_at > datetime.utcnow(),
|
||||
cls.used_at.is_(None),
|
||||
).first()
|
||||
|
||||
def mark_used(self, db: Session) -> None:
|
||||
"""Mark this token as used."""
|
||||
self.used_at = datetime.utcnow()
|
||||
db.flush()
|
||||
@@ -2,27 +2,68 @@
|
||||
"""
|
||||
Customers module Pydantic schemas.
|
||||
|
||||
Re-exports customer-related schemas from their source locations.
|
||||
This is the canonical location for customer schemas.
|
||||
|
||||
Usage:
|
||||
from app.modules.customers.schemas import (
|
||||
CustomerRegister,
|
||||
CustomerUpdate,
|
||||
CustomerResponse,
|
||||
)
|
||||
"""
|
||||
|
||||
from models.schema.customer import (
|
||||
from app.modules.customers.schemas.customer import (
|
||||
# Registration & Authentication
|
||||
CustomerRegister,
|
||||
CustomerUpdate,
|
||||
CustomerLogin,
|
||||
CustomerPasswordChange,
|
||||
# Customer Response
|
||||
CustomerResponse,
|
||||
CustomerListResponse,
|
||||
# Address
|
||||
CustomerAddressCreate,
|
||||
CustomerAddressUpdate,
|
||||
CustomerAddressResponse,
|
||||
CustomerAddressListResponse,
|
||||
# Preferences
|
||||
CustomerPreferencesUpdate,
|
||||
# Vendor Management
|
||||
CustomerMessageResponse,
|
||||
VendorCustomerListResponse,
|
||||
CustomerDetailResponse,
|
||||
CustomerOrderInfo,
|
||||
CustomerOrdersResponse,
|
||||
CustomerStatisticsResponse,
|
||||
# Admin Management
|
||||
AdminCustomerItem,
|
||||
AdminCustomerListResponse,
|
||||
AdminCustomerDetailResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registration & Authentication
|
||||
"CustomerRegister",
|
||||
"CustomerUpdate",
|
||||
"CustomerLogin",
|
||||
"CustomerPasswordChange",
|
||||
# Customer Response
|
||||
"CustomerResponse",
|
||||
"CustomerListResponse",
|
||||
# Address
|
||||
"CustomerAddressCreate",
|
||||
"CustomerAddressUpdate",
|
||||
"CustomerAddressResponse",
|
||||
"CustomerAddressListResponse",
|
||||
# Preferences
|
||||
"CustomerPreferencesUpdate",
|
||||
# Vendor Management
|
||||
"CustomerMessageResponse",
|
||||
"VendorCustomerListResponse",
|
||||
"CustomerDetailResponse",
|
||||
"CustomerOrderInfo",
|
||||
"CustomerOrdersResponse",
|
||||
"CustomerStatisticsResponse",
|
||||
# Admin Management
|
||||
"AdminCustomerItem",
|
||||
"AdminCustomerListResponse",
|
||||
"AdminCustomerDetailResponse",
|
||||
]
|
||||
|
||||
340
app/modules/customers/schemas/customer.py
Normal file
340
app/modules/customers/schemas/customer.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# app/modules/customers/schemas/customer.py
|
||||
"""
|
||||
Pydantic schemas for customer-related operations.
|
||||
|
||||
Provides schemas for:
|
||||
- Customer registration and authentication
|
||||
- Customer profile management
|
||||
- Customer addresses
|
||||
- Admin customer management
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Customer Registration & Authentication
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CustomerRegister(BaseModel):
|
||||
"""Schema for customer registration."""
|
||||
|
||||
email: EmailStr = Field(..., description="Customer email address")
|
||||
password: str = Field(
|
||||
..., min_length=8, description="Password (minimum 8 characters)"
|
||||
)
|
||||
first_name: str = Field(..., min_length=1, max_length=100)
|
||||
last_name: str = Field(..., min_length=1, max_length=100)
|
||||
phone: str | None = Field(None, max_length=50)
|
||||
marketing_consent: bool = Field(default=False)
|
||||
preferred_language: str | None = Field(
|
||||
None, description="Preferred language (en, fr, de, lb)"
|
||||
)
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def email_lowercase(cls, v: str) -> str:
|
||||
"""Convert email to lowercase."""
|
||||
return v.lower()
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Validate password strength."""
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters")
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError("Password must contain at least one digit")
|
||||
if not any(char.isalpha() for char in v):
|
||||
raise ValueError("Password must contain at least one letter")
|
||||
return v
|
||||
|
||||
|
||||
class CustomerUpdate(BaseModel):
|
||||
"""Schema for updating customer profile."""
|
||||
|
||||
email: EmailStr | None = None
|
||||
first_name: str | None = Field(None, min_length=1, max_length=100)
|
||||
last_name: str | None = Field(None, min_length=1, max_length=100)
|
||||
phone: str | None = Field(None, max_length=50)
|
||||
marketing_consent: bool | None = None
|
||||
preferred_language: str | None = Field(
|
||||
None, description="Preferred language (en, fr, de, lb)"
|
||||
)
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def email_lowercase(cls, v: str | None) -> str | None:
|
||||
"""Convert email to lowercase."""
|
||||
return v.lower() if v else None
|
||||
|
||||
|
||||
class CustomerPasswordChange(BaseModel):
|
||||
"""Schema for customer password change."""
|
||||
|
||||
current_password: str = Field(..., description="Current password")
|
||||
new_password: str = Field(
|
||||
..., min_length=8, description="New password (minimum 8 characters)"
|
||||
)
|
||||
confirm_password: str = Field(..., description="Confirm new password")
|
||||
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Validate password strength."""
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters")
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError("Password must contain at least one digit")
|
||||
if not any(char.isalpha() for char in v):
|
||||
raise ValueError("Password must contain at least one letter")
|
||||
return v
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Customer Response
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CustomerResponse(BaseModel):
|
||||
"""Schema for customer response (excludes password)."""
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
email: str
|
||||
first_name: str | None
|
||||
last_name: str | None
|
||||
phone: str | None
|
||||
customer_number: str
|
||||
marketing_consent: bool
|
||||
preferred_language: str | None
|
||||
last_order_date: datetime | None
|
||||
total_orders: int
|
||||
total_spent: Decimal
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
"""Get customer full name."""
|
||||
if self.first_name and self.last_name:
|
||||
return f"{self.first_name} {self.last_name}"
|
||||
return self.email
|
||||
|
||||
|
||||
class CustomerListResponse(BaseModel):
|
||||
"""Schema for paginated customer list."""
|
||||
|
||||
customers: list[CustomerResponse]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Customer Address
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CustomerAddressCreate(BaseModel):
|
||||
"""Schema for creating customer address."""
|
||||
|
||||
address_type: str = Field(..., pattern="^(billing|shipping)$")
|
||||
first_name: str = Field(..., min_length=1, max_length=100)
|
||||
last_name: str = Field(..., min_length=1, max_length=100)
|
||||
company: str | None = Field(None, max_length=200)
|
||||
address_line_1: str = Field(..., min_length=1, max_length=255)
|
||||
address_line_2: str | None = Field(None, max_length=255)
|
||||
city: str = Field(..., min_length=1, max_length=100)
|
||||
postal_code: str = Field(..., min_length=1, max_length=20)
|
||||
country_name: str = Field(..., min_length=2, max_length=100)
|
||||
country_iso: str = Field(..., min_length=2, max_length=5)
|
||||
is_default: bool = Field(default=False)
|
||||
|
||||
|
||||
class CustomerAddressUpdate(BaseModel):
|
||||
"""Schema for updating customer address."""
|
||||
|
||||
address_type: str | None = Field(None, pattern="^(billing|shipping)$")
|
||||
first_name: str | None = Field(None, min_length=1, max_length=100)
|
||||
last_name: str | None = Field(None, min_length=1, max_length=100)
|
||||
company: str | None = Field(None, max_length=200)
|
||||
address_line_1: str | None = Field(None, min_length=1, max_length=255)
|
||||
address_line_2: str | None = Field(None, max_length=255)
|
||||
city: str | None = Field(None, min_length=1, max_length=100)
|
||||
postal_code: str | None = Field(None, min_length=1, max_length=20)
|
||||
country_name: str | None = Field(None, min_length=2, max_length=100)
|
||||
country_iso: str | None = Field(None, min_length=2, max_length=5)
|
||||
is_default: bool | None = None
|
||||
|
||||
|
||||
class CustomerAddressResponse(BaseModel):
|
||||
"""Schema for customer address response."""
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
customer_id: int
|
||||
address_type: str
|
||||
first_name: str
|
||||
last_name: str
|
||||
company: str | None
|
||||
address_line_1: str
|
||||
address_line_2: str | None
|
||||
city: str
|
||||
postal_code: str
|
||||
country_name: str
|
||||
country_iso: str
|
||||
is_default: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CustomerAddressListResponse(BaseModel):
|
||||
"""Schema for customer address list response."""
|
||||
|
||||
addresses: list[CustomerAddressResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Customer Preferences
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CustomerPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating customer preferences."""
|
||||
|
||||
marketing_consent: bool | None = None
|
||||
preferred_language: str | None = Field(
|
||||
None, description="Preferred language (en, fr, de, lb)"
|
||||
)
|
||||
currency: str | None = Field(None, max_length=3)
|
||||
notification_preferences: dict[str, bool] | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Vendor Customer Management Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CustomerMessageResponse(BaseModel):
|
||||
"""Simple message response for customer operations."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class VendorCustomerListResponse(BaseModel):
|
||||
"""Schema for vendor customer list with skip/limit pagination."""
|
||||
|
||||
customers: list[CustomerResponse] = []
|
||||
total: int = 0
|
||||
skip: int = 0
|
||||
limit: int = 100
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class CustomerDetailResponse(BaseModel):
|
||||
"""Detailed customer response for vendor management."""
|
||||
|
||||
id: int | None = None
|
||||
vendor_id: int | None = None
|
||||
email: str | None = None
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
phone: str | None = None
|
||||
customer_number: str | None = None
|
||||
marketing_consent: bool | None = None
|
||||
preferred_language: str | None = None
|
||||
last_order_date: datetime | None = None
|
||||
total_orders: int | None = None
|
||||
total_spent: Decimal | None = None
|
||||
is_active: bool | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
message: str | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CustomerOrderInfo(BaseModel):
|
||||
"""Basic order info for customer order history."""
|
||||
|
||||
id: int
|
||||
order_number: str
|
||||
status: str
|
||||
total: Decimal
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class CustomerOrdersResponse(BaseModel):
|
||||
"""Response for customer order history."""
|
||||
|
||||
orders: list[CustomerOrderInfo] = []
|
||||
total: int = 0
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class CustomerStatisticsResponse(BaseModel):
|
||||
"""Response for customer statistics."""
|
||||
|
||||
total: int = 0
|
||||
active: int = 0
|
||||
inactive: int = 0
|
||||
with_orders: int = 0
|
||||
total_spent: float = 0.0
|
||||
total_orders: int = 0
|
||||
avg_order_value: float = 0.0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Admin Customer Management Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdminCustomerItem(BaseModel):
|
||||
"""Admin customer list item with vendor info."""
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
email: str
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
phone: str | None = None
|
||||
customer_number: str
|
||||
marketing_consent: bool = False
|
||||
preferred_language: str | None = None
|
||||
last_order_date: datetime | None = None
|
||||
total_orders: int = 0
|
||||
total_spent: float = 0.0
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
vendor_name: str | None = None
|
||||
vendor_code: str | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class AdminCustomerListResponse(BaseModel):
|
||||
"""Admin paginated customer list with skip/limit."""
|
||||
|
||||
customers: list[AdminCustomerItem] = []
|
||||
total: int = 0
|
||||
skip: int = 0
|
||||
limit: int = 20
|
||||
|
||||
|
||||
class AdminCustomerDetailResponse(AdminCustomerItem):
|
||||
"""Detailed customer response for admin."""
|
||||
|
||||
pass
|
||||
@@ -2,18 +2,25 @@
|
||||
"""
|
||||
Customers module services.
|
||||
|
||||
Re-exports customer-related services from their source locations.
|
||||
This is the canonical location for customer services.
|
||||
|
||||
Usage:
|
||||
from app.modules.customers.services import (
|
||||
customer_service,
|
||||
admin_customer_service,
|
||||
customer_address_service,
|
||||
)
|
||||
"""
|
||||
|
||||
from app.services.customer_service import (
|
||||
from app.modules.customers.services.customer_service import (
|
||||
customer_service,
|
||||
CustomerService,
|
||||
)
|
||||
from app.services.admin_customer_service import (
|
||||
from app.modules.customers.services.admin_customer_service import (
|
||||
admin_customer_service,
|
||||
AdminCustomerService,
|
||||
)
|
||||
from app.services.customer_address_service import (
|
||||
from app.modules.customers.services.customer_address_service import (
|
||||
customer_address_service,
|
||||
CustomerAddressService,
|
||||
)
|
||||
|
||||
242
app/modules/customers/services/admin_customer_service.py
Normal file
242
app/modules/customers/services/admin_customer_service.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# app/modules/customers/services/admin_customer_service.py
|
||||
"""
|
||||
Admin customer management service.
|
||||
|
||||
Handles customer operations for admin users across all vendors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions.customer import CustomerNotFoundException
|
||||
from app.modules.customers.models import Customer
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminCustomerService:
|
||||
"""Service for admin-level customer management across vendors."""
|
||||
|
||||
def list_customers(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
search: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get paginated list of customers across all vendors.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter
|
||||
search: Search by email, name, or customer number
|
||||
is_active: Filter by active status
|
||||
skip: Number of records to skip
|
||||
limit: Maximum records to return
|
||||
|
||||
Returns:
|
||||
Tuple of (customers list, total count)
|
||||
"""
|
||||
# Build query
|
||||
query = db.query(Customer).join(Vendor, Customer.vendor_id == Vendor.id)
|
||||
|
||||
# Apply filters
|
||||
if vendor_id:
|
||||
query = query.filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.filter(
|
||||
(Customer.email.ilike(search_term))
|
||||
| (Customer.first_name.ilike(search_term))
|
||||
| (Customer.last_name.ilike(search_term))
|
||||
| (Customer.customer_number.ilike(search_term))
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Customer.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get paginated results with vendor info
|
||||
customers = (
|
||||
query.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
|
||||
.order_by(Customer.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Format response
|
||||
result = []
|
||||
for row in customers:
|
||||
customer = row[0]
|
||||
vendor_name = row[1]
|
||||
vendor_code = row[2]
|
||||
|
||||
customer_dict = {
|
||||
"id": customer.id,
|
||||
"vendor_id": customer.vendor_id,
|
||||
"email": customer.email,
|
||||
"first_name": customer.first_name,
|
||||
"last_name": customer.last_name,
|
||||
"phone": customer.phone,
|
||||
"customer_number": customer.customer_number,
|
||||
"marketing_consent": customer.marketing_consent,
|
||||
"preferred_language": customer.preferred_language,
|
||||
"last_order_date": customer.last_order_date,
|
||||
"total_orders": customer.total_orders,
|
||||
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
|
||||
"is_active": customer.is_active,
|
||||
"created_at": customer.created_at,
|
||||
"updated_at": customer.updated_at,
|
||||
"vendor_name": vendor_name,
|
||||
"vendor_code": vendor_code,
|
||||
}
|
||||
result.append(customer_dict)
|
||||
|
||||
return result, total
|
||||
|
||||
def get_customer_stats(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get customer statistics.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter
|
||||
|
||||
Returns:
|
||||
Dict with customer statistics
|
||||
"""
|
||||
query = db.query(Customer)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
total = query.count()
|
||||
active = query.filter(Customer.is_active == True).count() # noqa: E712
|
||||
inactive = query.filter(Customer.is_active == False).count() # noqa: E712
|
||||
with_orders = query.filter(Customer.total_orders > 0).count()
|
||||
|
||||
# Total spent across all customers
|
||||
total_spent_result = query.with_entities(func.sum(Customer.total_spent)).scalar()
|
||||
total_spent = float(total_spent_result) if total_spent_result else 0
|
||||
|
||||
# Average order value
|
||||
total_orders_result = query.with_entities(func.sum(Customer.total_orders)).scalar()
|
||||
total_orders = int(total_orders_result) if total_orders_result else 0
|
||||
avg_order_value = total_spent / total_orders if total_orders > 0 else 0
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"active": active,
|
||||
"inactive": inactive,
|
||||
"with_orders": with_orders,
|
||||
"total_spent": total_spent,
|
||||
"total_orders": total_orders,
|
||||
"avg_order_value": round(avg_order_value, 2),
|
||||
}
|
||||
|
||||
def get_customer(
|
||||
self,
|
||||
db: Session,
|
||||
customer_id: int,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get customer details by ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer dict with vendor info
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
result = (
|
||||
db.query(Customer)
|
||||
.join(Vendor, Customer.vendor_id == Vendor.id)
|
||||
.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
|
||||
.filter(Customer.id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise CustomerNotFoundException(str(customer_id))
|
||||
|
||||
customer = result[0]
|
||||
return {
|
||||
"id": customer.id,
|
||||
"vendor_id": customer.vendor_id,
|
||||
"email": customer.email,
|
||||
"first_name": customer.first_name,
|
||||
"last_name": customer.last_name,
|
||||
"phone": customer.phone,
|
||||
"customer_number": customer.customer_number,
|
||||
"marketing_consent": customer.marketing_consent,
|
||||
"preferred_language": customer.preferred_language,
|
||||
"last_order_date": customer.last_order_date,
|
||||
"total_orders": customer.total_orders,
|
||||
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
|
||||
"is_active": customer.is_active,
|
||||
"created_at": customer.created_at,
|
||||
"updated_at": customer.updated_at,
|
||||
"vendor_name": result[1],
|
||||
"vendor_code": result[2],
|
||||
}
|
||||
|
||||
def toggle_customer_status(
|
||||
self,
|
||||
db: Session,
|
||||
customer_id: int,
|
||||
admin_email: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Toggle customer active status.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
customer_id: Customer ID
|
||||
admin_email: Admin user email for logging
|
||||
|
||||
Returns:
|
||||
Dict with customer ID, new status, and message
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
customer = db.query(Customer).filter(Customer.id == customer_id).first()
|
||||
|
||||
if not customer:
|
||||
raise CustomerNotFoundException(str(customer_id))
|
||||
|
||||
customer.is_active = not customer.is_active
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
status = "activated" if customer.is_active else "deactivated"
|
||||
logger.info(f"Customer {customer.email} {status} by admin {admin_email}")
|
||||
|
||||
return {
|
||||
"id": customer.id,
|
||||
"is_active": customer.is_active,
|
||||
"message": f"Customer {status} successfully",
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
admin_customer_service = AdminCustomerService()
|
||||
314
app/modules/customers/services/customer_address_service.py
Normal file
314
app/modules/customers/services/customer_address_service.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# app/modules/customers/services/customer_address_service.py
|
||||
"""
|
||||
Customer Address Service
|
||||
|
||||
Business logic for managing customer addresses with vendor isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
AddressLimitExceededException,
|
||||
AddressNotFoundException,
|
||||
)
|
||||
from app.modules.customers.models import CustomerAddress
|
||||
from app.modules.customers.schemas import CustomerAddressCreate, CustomerAddressUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomerAddressService:
|
||||
"""Service for managing customer addresses with vendor isolation."""
|
||||
|
||||
MAX_ADDRESSES_PER_CUSTOMER = 10
|
||||
|
||||
def list_addresses(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> list[CustomerAddress]:
|
||||
"""
|
||||
Get all addresses for a customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
List of customer addresses
|
||||
"""
|
||||
return (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
)
|
||||
.order_by(CustomerAddress.is_default.desc(), CustomerAddress.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_address(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_id: int
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Get a specific address with ownership validation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
|
||||
Returns:
|
||||
Customer address
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found or doesn't belong to customer
|
||||
"""
|
||||
address = (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.id == address_id,
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not address:
|
||||
raise AddressNotFoundException(address_id)
|
||||
|
||||
return address
|
||||
|
||||
def get_default_address(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_type: str
|
||||
) -> CustomerAddress | None:
|
||||
"""
|
||||
Get the default address for a specific type.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_type: 'shipping' or 'billing'
|
||||
|
||||
Returns:
|
||||
Default address or None if not set
|
||||
"""
|
||||
return (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
CustomerAddress.address_type == address_type,
|
||||
CustomerAddress.is_default == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def create_address(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
address_data: CustomerAddressCreate,
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Create a new address for a customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_data: Address creation data
|
||||
|
||||
Returns:
|
||||
Created customer address
|
||||
|
||||
Raises:
|
||||
AddressLimitExceededException: If max addresses reached
|
||||
"""
|
||||
# Check address limit
|
||||
current_count = (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if current_count >= self.MAX_ADDRESSES_PER_CUSTOMER:
|
||||
raise AddressLimitExceededException(self.MAX_ADDRESSES_PER_CUSTOMER)
|
||||
|
||||
# If setting as default, clear other defaults of same type
|
||||
if address_data.is_default:
|
||||
self._clear_other_defaults(
|
||||
db, vendor_id, customer_id, address_data.address_type
|
||||
)
|
||||
|
||||
# Create the address
|
||||
address = CustomerAddress(
|
||||
vendor_id=vendor_id,
|
||||
customer_id=customer_id,
|
||||
address_type=address_data.address_type,
|
||||
first_name=address_data.first_name,
|
||||
last_name=address_data.last_name,
|
||||
company=address_data.company,
|
||||
address_line_1=address_data.address_line_1,
|
||||
address_line_2=address_data.address_line_2,
|
||||
city=address_data.city,
|
||||
postal_code=address_data.postal_code,
|
||||
country_name=address_data.country_name,
|
||||
country_iso=address_data.country_iso,
|
||||
is_default=address_data.is_default,
|
||||
)
|
||||
|
||||
db.add(address)
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Created address {address.id} for customer {customer_id} "
|
||||
f"(type={address_data.address_type}, default={address_data.is_default})"
|
||||
)
|
||||
|
||||
return address
|
||||
|
||||
def update_address(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
address_id: int,
|
||||
address_data: CustomerAddressUpdate,
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Update an existing address.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
address_data: Address update data
|
||||
|
||||
Returns:
|
||||
Updated customer address
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found
|
||||
"""
|
||||
address = self.get_address(db, vendor_id, customer_id, address_id)
|
||||
|
||||
# Update only provided fields
|
||||
update_data = address_data.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle default flag - clear others if setting to default
|
||||
if update_data.get("is_default") is True:
|
||||
# Use updated type if provided, otherwise current type
|
||||
address_type = update_data.get("address_type", address.address_type)
|
||||
self._clear_other_defaults(
|
||||
db, vendor_id, customer_id, address_type, exclude_id=address_id
|
||||
)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(address, field, value)
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Updated address {address_id} for customer {customer_id}")
|
||||
|
||||
return address
|
||||
|
||||
def delete_address(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_id: int
|
||||
) -> None:
|
||||
"""
|
||||
Delete an address.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found
|
||||
"""
|
||||
address = self.get_address(db, vendor_id, customer_id, address_id)
|
||||
|
||||
db.delete(address)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Deleted address {address_id} for customer {customer_id}")
|
||||
|
||||
def set_default(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_id: int
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Set an address as the default for its type.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
|
||||
Returns:
|
||||
Updated customer address
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found
|
||||
"""
|
||||
address = self.get_address(db, vendor_id, customer_id, address_id)
|
||||
|
||||
# Clear other defaults of same type
|
||||
self._clear_other_defaults(
|
||||
db, vendor_id, customer_id, address.address_type, exclude_id=address_id
|
||||
)
|
||||
|
||||
# Set this one as default
|
||||
address.is_default = True
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Set address {address_id} as default {address.address_type} "
|
||||
f"for customer {customer_id}"
|
||||
)
|
||||
|
||||
return address
|
||||
|
||||
def _clear_other_defaults(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
address_type: str,
|
||||
exclude_id: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Clear the default flag on other addresses of the same type.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_type: 'shipping' or 'billing'
|
||||
exclude_id: Address ID to exclude from clearing
|
||||
"""
|
||||
query = db.query(CustomerAddress).filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
CustomerAddress.address_type == address_type,
|
||||
CustomerAddress.is_default == True, # noqa: E712
|
||||
)
|
||||
|
||||
if exclude_id:
|
||||
query = query.filter(CustomerAddress.id != exclude_id)
|
||||
|
||||
query.update({"is_default": False}, synchronize_session=False)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
customer_address_service = CustomerAddressService()
|
||||
659
app/modules/customers/services/customer_service.py
Normal file
659
app/modules/customers/services/customer_service.py
Normal file
@@ -0,0 +1,659 @@
|
||||
# app/modules/customers/services/customer_service.py
|
||||
"""
|
||||
Customer management service.
|
||||
|
||||
Handles customer registration, authentication, and profile management
|
||||
with complete vendor isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions.customer import (
|
||||
CustomerNotActiveException,
|
||||
CustomerNotFoundException,
|
||||
CustomerValidationException,
|
||||
DuplicateCustomerEmailException,
|
||||
InvalidCustomerCredentialsException,
|
||||
InvalidPasswordResetTokenException,
|
||||
PasswordTooShortException,
|
||||
)
|
||||
from app.exceptions.vendor import VendorNotActiveException, VendorNotFoundException
|
||||
from app.services.auth_service import AuthService
|
||||
from app.modules.customers.models import Customer, PasswordResetToken
|
||||
from app.modules.customers.schemas import CustomerRegister, CustomerUpdate
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomerService:
|
||||
"""Service for managing vendor-scoped customers."""
|
||||
|
||||
def __init__(self):
|
||||
self.auth_service = AuthService()
|
||||
|
||||
def register_customer(
|
||||
self, db: Session, vendor_id: int, customer_data: CustomerRegister
|
||||
) -> Customer:
|
||||
"""
|
||||
Register a new customer for a specific vendor.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_data: Customer registration data
|
||||
|
||||
Returns:
|
||||
Customer: Created customer object
|
||||
|
||||
Raises:
|
||||
VendorNotFoundException: If vendor doesn't exist
|
||||
VendorNotActiveException: If vendor is not active
|
||||
DuplicateCustomerEmailException: If email already exists for this vendor
|
||||
CustomerValidationException: If customer data is invalid
|
||||
"""
|
||||
# Verify vendor exists and is active
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
|
||||
|
||||
if not vendor.is_active:
|
||||
raise VendorNotActiveException(vendor.vendor_code)
|
||||
|
||||
# Check if email already exists for this vendor
|
||||
existing_customer = (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == customer_data.email.lower(),
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_customer:
|
||||
raise DuplicateCustomerEmailException(
|
||||
customer_data.email, vendor.vendor_code
|
||||
)
|
||||
|
||||
# Generate unique customer number for this vendor
|
||||
customer_number = self._generate_customer_number(
|
||||
db, vendor_id, vendor.vendor_code
|
||||
)
|
||||
|
||||
# Hash password
|
||||
hashed_password = self.auth_service.hash_password(customer_data.password)
|
||||
|
||||
# Create customer
|
||||
customer = Customer(
|
||||
vendor_id=vendor_id,
|
||||
email=customer_data.email.lower(),
|
||||
hashed_password=hashed_password,
|
||||
first_name=customer_data.first_name,
|
||||
last_name=customer_data.last_name,
|
||||
phone=customer_data.phone,
|
||||
customer_number=customer_number,
|
||||
marketing_consent=(
|
||||
customer_data.marketing_consent
|
||||
if hasattr(customer_data, "marketing_consent")
|
||||
else False
|
||||
),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
try:
|
||||
db.add(customer)
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
logger.info(
|
||||
f"Customer registered successfully: {customer.email} "
|
||||
f"(ID: {customer.id}, Number: {customer.customer_number}) "
|
||||
f"for vendor {vendor.vendor_code}"
|
||||
)
|
||||
|
||||
return customer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering customer: {str(e)}")
|
||||
raise CustomerValidationException(
|
||||
message="Failed to register customer", details={"error": str(e)}
|
||||
)
|
||||
|
||||
def login_customer(
|
||||
self, db: Session, vendor_id: int, credentials
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Authenticate customer and generate JWT token.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
credentials: Login credentials (UserLogin schema)
|
||||
|
||||
Returns:
|
||||
Dict containing customer and token data
|
||||
|
||||
Raises:
|
||||
VendorNotFoundException: If vendor doesn't exist
|
||||
InvalidCustomerCredentialsException: If credentials are invalid
|
||||
CustomerNotActiveException: If customer account is inactive
|
||||
"""
|
||||
# Verify vendor exists
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
|
||||
|
||||
# Find customer by email (vendor-scoped)
|
||||
customer = (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == credentials.email_or_username.lower(),
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not customer:
|
||||
raise InvalidCustomerCredentialsException()
|
||||
|
||||
# Verify password using auth_manager directly
|
||||
if not self.auth_service.auth_manager.verify_password(
|
||||
credentials.password, customer.hashed_password
|
||||
):
|
||||
raise InvalidCustomerCredentialsException()
|
||||
|
||||
# Check if customer is active
|
||||
if not customer.is_active:
|
||||
raise CustomerNotActiveException(customer.email)
|
||||
|
||||
# Generate JWT token with customer context
|
||||
from jose import jwt
|
||||
|
||||
auth_manager = self.auth_service.auth_manager
|
||||
expires_delta = timedelta(minutes=auth_manager.token_expire_minutes)
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
|
||||
payload = {
|
||||
"sub": str(customer.id),
|
||||
"email": customer.email,
|
||||
"vendor_id": vendor_id,
|
||||
"type": "customer",
|
||||
"exp": expire,
|
||||
"iat": datetime.now(UTC),
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
token_data = {
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": auth_manager.token_expire_minutes * 60,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Customer login successful: {customer.email} "
|
||||
f"for vendor {vendor.vendor_code}"
|
||||
)
|
||||
|
||||
return {"customer": customer, "token_data": token_data}
|
||||
|
||||
def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
|
||||
"""
|
||||
Get customer by ID with vendor isolation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer: Customer object
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
customer = (
|
||||
db.query(Customer)
|
||||
.filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not customer:
|
||||
raise CustomerNotFoundException(str(customer_id))
|
||||
|
||||
return customer
|
||||
|
||||
def get_customer_by_email(
|
||||
self, db: Session, vendor_id: int, email: str
|
||||
) -> Customer | None:
|
||||
"""
|
||||
Get customer by email (vendor-scoped).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
email: Customer email
|
||||
|
||||
Returns:
|
||||
Optional[Customer]: Customer object or None
|
||||
"""
|
||||
return (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_vendor_customers(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
) -> tuple[list[Customer], int]:
|
||||
"""
|
||||
Get all customers for a vendor with filtering and pagination.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
search: Search in name/email
|
||||
is_active: Filter by active status
|
||||
|
||||
Returns:
|
||||
Tuple of (customers, total_count)
|
||||
"""
|
||||
from sqlalchemy import or_
|
||||
|
||||
query = db.query(Customer).filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
Customer.email.ilike(search_pattern),
|
||||
Customer.first_name.ilike(search_pattern),
|
||||
Customer.last_name.ilike(search_pattern),
|
||||
Customer.customer_number.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Customer.is_active == is_active)
|
||||
|
||||
# Order by most recent first
|
||||
query = query.order_by(Customer.created_at.desc())
|
||||
|
||||
total = query.count()
|
||||
customers = query.offset(skip).limit(limit).all()
|
||||
|
||||
return customers, total
|
||||
|
||||
def get_customer_orders(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list, int]:
|
||||
"""
|
||||
Get orders for a specific customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
|
||||
Returns:
|
||||
Tuple of (orders, total_count)
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
from models.database.order import Order
|
||||
|
||||
# Verify customer belongs to vendor
|
||||
self.get_customer(db, vendor_id, customer_id)
|
||||
|
||||
# Get customer orders
|
||||
query = (
|
||||
db.query(Order)
|
||||
.filter(
|
||||
Order.customer_id == customer_id,
|
||||
Order.vendor_id == vendor_id,
|
||||
)
|
||||
.order_by(Order.created_at.desc())
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
orders = query.offset(skip).limit(limit).all()
|
||||
|
||||
return orders, total
|
||||
|
||||
def get_customer_statistics(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> dict:
|
||||
"""
|
||||
Get detailed statistics for a customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Dict with customer statistics
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
from models.database.order import Order
|
||||
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
|
||||
# Get order statistics
|
||||
order_stats = (
|
||||
db.query(
|
||||
func.count(Order.id).label("total_orders"),
|
||||
func.sum(Order.total_cents).label("total_spent_cents"),
|
||||
func.avg(Order.total_cents).label("avg_order_cents"),
|
||||
func.max(Order.created_at).label("last_order_date"),
|
||||
)
|
||||
.filter(Order.customer_id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
total_orders = order_stats.total_orders or 0
|
||||
total_spent_cents = order_stats.total_spent_cents or 0
|
||||
avg_order_cents = order_stats.avg_order_cents or 0
|
||||
|
||||
return {
|
||||
"customer_id": customer_id,
|
||||
"total_orders": total_orders,
|
||||
"total_spent": total_spent_cents / 100, # Convert to euros
|
||||
"average_order_value": avg_order_cents / 100 if avg_order_cents else 0.0,
|
||||
"last_order_date": order_stats.last_order_date,
|
||||
"member_since": customer.created_at,
|
||||
"is_active": customer.is_active,
|
||||
}
|
||||
|
||||
def toggle_customer_status(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> Customer:
|
||||
"""
|
||||
Toggle customer active status.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer: Updated customer
|
||||
"""
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
customer.is_active = not customer.is_active
|
||||
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
action = "activated" if customer.is_active else "deactivated"
|
||||
logger.info(f"Customer {action}: {customer.email} (ID: {customer.id})")
|
||||
|
||||
return customer
|
||||
|
||||
def update_customer(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
customer_data: CustomerUpdate,
|
||||
) -> Customer:
|
||||
"""
|
||||
Update customer profile.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
customer_data: Updated customer data
|
||||
|
||||
Returns:
|
||||
Customer: Updated customer object
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
CustomerValidationException: If update data is invalid
|
||||
"""
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
|
||||
# Update fields
|
||||
update_data = customer_data.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == "email" and value:
|
||||
# Check if new email already exists for this vendor
|
||||
existing = (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == value.lower(),
|
||||
Customer.id != customer_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
raise DuplicateCustomerEmailException(value, "vendor")
|
||||
|
||||
setattr(customer, field, value.lower())
|
||||
elif hasattr(customer, field):
|
||||
setattr(customer, field, value)
|
||||
|
||||
try:
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
logger.info(f"Customer updated: {customer.email} (ID: {customer.id})")
|
||||
|
||||
return customer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating customer: {str(e)}")
|
||||
raise CustomerValidationException(
|
||||
message="Failed to update customer", details={"error": str(e)}
|
||||
)
|
||||
|
||||
def deactivate_customer(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> Customer:
|
||||
"""
|
||||
Deactivate customer account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer: Deactivated customer object
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
customer.is_active = False
|
||||
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
logger.info(f"Customer deactivated: {customer.email} (ID: {customer.id})")
|
||||
|
||||
return customer
|
||||
|
||||
def update_customer_stats(
|
||||
self, db: Session, customer_id: int, order_total: float
|
||||
) -> None:
|
||||
"""
|
||||
Update customer statistics after order.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
customer_id: Customer ID
|
||||
order_total: Order total amount
|
||||
"""
|
||||
customer = db.query(Customer).filter(Customer.id == customer_id).first()
|
||||
|
||||
if customer:
|
||||
customer.total_orders += 1
|
||||
customer.total_spent += order_total
|
||||
customer.last_order_date = datetime.utcnow()
|
||||
|
||||
logger.debug(f"Updated stats for customer {customer.email}")
|
||||
|
||||
def _generate_customer_number(
|
||||
self, db: Session, vendor_id: int, vendor_code: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate unique customer number for vendor.
|
||||
|
||||
Format: {VENDOR_CODE}-CUST-{SEQUENCE}
|
||||
Example: VENDORA-CUST-00001
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
vendor_code: Vendor code
|
||||
|
||||
Returns:
|
||||
str: Unique customer number
|
||||
"""
|
||||
# Get count of customers for this vendor
|
||||
count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
|
||||
|
||||
# Generate number with padding
|
||||
sequence = str(count + 1).zfill(5)
|
||||
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
|
||||
|
||||
# Ensure uniqueness (in case of deletions)
|
||||
while (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.customer_number == customer_number,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
):
|
||||
count += 1
|
||||
sequence = str(count + 1).zfill(5)
|
||||
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
|
||||
|
||||
return customer_number
|
||||
|
||||
def get_customer_for_password_reset(
|
||||
self, db: Session, vendor_id: int, email: str
|
||||
) -> Customer | None:
|
||||
"""
|
||||
Get active customer by email for password reset.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
email: Customer email
|
||||
|
||||
Returns:
|
||||
Customer if found and active, None otherwise
|
||||
"""
|
||||
return (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == email.lower(),
|
||||
Customer.is_active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def validate_and_reset_password(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
reset_token: str,
|
||||
new_password: str,
|
||||
) -> Customer:
|
||||
"""
|
||||
Validate reset token and update customer password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reset_token: Password reset token from email
|
||||
new_password: New password
|
||||
|
||||
Returns:
|
||||
Customer: Updated customer
|
||||
|
||||
Raises:
|
||||
PasswordTooShortException: If password too short
|
||||
InvalidPasswordResetTokenException: If token invalid/expired
|
||||
CustomerNotActiveException: If customer not active
|
||||
"""
|
||||
# Validate password length
|
||||
if len(new_password) < 8:
|
||||
raise PasswordTooShortException(min_length=8)
|
||||
|
||||
# Find valid token
|
||||
token_record = PasswordResetToken.find_valid_token(db, reset_token)
|
||||
|
||||
if not token_record:
|
||||
raise InvalidPasswordResetTokenException()
|
||||
|
||||
# Get the customer and verify they belong to this vendor
|
||||
customer = (
|
||||
db.query(Customer)
|
||||
.filter(Customer.id == token_record.customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not customer or customer.vendor_id != vendor_id:
|
||||
raise InvalidPasswordResetTokenException()
|
||||
|
||||
if not customer.is_active:
|
||||
raise CustomerNotActiveException(customer.email)
|
||||
|
||||
# Hash the new password and update customer
|
||||
hashed_password = self.auth_service.hash_password(new_password)
|
||||
customer.hashed_password = hashed_password
|
||||
|
||||
# Mark token as used
|
||||
token_record.mark_used(db)
|
||||
|
||||
logger.info(f"Password reset completed for customer {customer.id}")
|
||||
|
||||
return customer
|
||||
|
||||
|
||||
# Singleton instance
|
||||
customer_service = CustomerService()
|
||||
@@ -2,16 +2,14 @@
|
||||
"""
|
||||
Dev-Tools module services.
|
||||
|
||||
This module re-exports services from their current locations.
|
||||
In future cleanup phases, the actual service implementations
|
||||
may be moved here.
|
||||
Provides code quality scanning and test running functionality.
|
||||
|
||||
Services:
|
||||
- code_quality_service: Code quality scanning and violation management
|
||||
- test_runner_service: Test execution and results management
|
||||
"""
|
||||
|
||||
from app.services.code_quality_service import (
|
||||
from app.modules.dev_tools.services.code_quality_service import (
|
||||
code_quality_service,
|
||||
CodeQualityService,
|
||||
VALIDATOR_ARCHITECTURE,
|
||||
@@ -21,7 +19,7 @@ from app.services.code_quality_service import (
|
||||
VALIDATOR_SCRIPTS,
|
||||
VALIDATOR_NAMES,
|
||||
)
|
||||
from app.services.test_runner_service import (
|
||||
from app.modules.dev_tools.services.test_runner_service import (
|
||||
test_runner_service,
|
||||
TestRunnerService,
|
||||
)
|
||||
|
||||
820
app/modules/dev_tools/services/code_quality_service.py
Normal file
820
app/modules/dev_tools/services/code_quality_service.py
Normal file
@@ -0,0 +1,820 @@
|
||||
"""
|
||||
Code Quality Service
|
||||
Business logic for managing code quality scans and violations
|
||||
Supports multiple validator types: architecture, security, performance
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
ScanParseException,
|
||||
ScanTimeoutException,
|
||||
ViolationNotFoundException,
|
||||
)
|
||||
from app.modules.dev_tools.models import (
|
||||
ArchitectureScan,
|
||||
ArchitectureViolation,
|
||||
ViolationAssignment,
|
||||
ViolationComment,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Validator type constants
|
||||
VALIDATOR_ARCHITECTURE = "architecture"
|
||||
VALIDATOR_SECURITY = "security"
|
||||
VALIDATOR_PERFORMANCE = "performance"
|
||||
|
||||
VALID_VALIDATOR_TYPES = [VALIDATOR_ARCHITECTURE, VALIDATOR_SECURITY, VALIDATOR_PERFORMANCE]
|
||||
|
||||
# Map validator types to their scripts
|
||||
VALIDATOR_SCRIPTS = {
|
||||
VALIDATOR_ARCHITECTURE: "scripts/validate_architecture.py",
|
||||
VALIDATOR_SECURITY: "scripts/validate_security.py",
|
||||
VALIDATOR_PERFORMANCE: "scripts/validate_performance.py",
|
||||
}
|
||||
|
||||
# Human-readable names
|
||||
VALIDATOR_NAMES = {
|
||||
VALIDATOR_ARCHITECTURE: "Architecture",
|
||||
VALIDATOR_SECURITY: "Security",
|
||||
VALIDATOR_PERFORMANCE: "Performance",
|
||||
}
|
||||
|
||||
|
||||
class CodeQualityService:
|
||||
"""Service for managing code quality scans and violations"""
|
||||
|
||||
def run_scan(
|
||||
self,
|
||||
db: Session,
|
||||
triggered_by: str = "manual",
|
||||
validator_type: str = VALIDATOR_ARCHITECTURE,
|
||||
) -> ArchitectureScan:
|
||||
"""
|
||||
Run a code quality validator and store results in database
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
triggered_by: Who/what triggered the scan ('manual', 'scheduled', 'ci/cd')
|
||||
validator_type: Type of validator ('architecture', 'security', 'performance')
|
||||
|
||||
Returns:
|
||||
ArchitectureScan object with results
|
||||
|
||||
Raises:
|
||||
ValueError: If validator_type is invalid
|
||||
ScanTimeoutException: If validator times out
|
||||
ScanParseException: If validator output cannot be parsed
|
||||
"""
|
||||
if validator_type not in VALID_VALIDATOR_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid validator type: {validator_type}. "
|
||||
f"Must be one of: {VALID_VALIDATOR_TYPES}"
|
||||
)
|
||||
|
||||
script_path = VALIDATOR_SCRIPTS[validator_type]
|
||||
validator_name = VALIDATOR_NAMES[validator_type]
|
||||
|
||||
logger.info(
|
||||
f"Starting {validator_name} scan (triggered by: {triggered_by})"
|
||||
)
|
||||
|
||||
# Get git commit hash
|
||||
git_commit = self._get_git_commit_hash()
|
||||
|
||||
# Run validator with JSON output
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python", script_path, "--json"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minute timeout
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"{validator_name} scan timed out after 5 minutes")
|
||||
raise ScanTimeoutException(timeout_seconds=300)
|
||||
|
||||
duration = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
# Parse JSON output (get only the JSON part, skip progress messages)
|
||||
try:
|
||||
# Find the JSON part in stdout
|
||||
lines = result.stdout.strip().split("\n")
|
||||
json_start = -1
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith("{"):
|
||||
json_start = i
|
||||
break
|
||||
|
||||
if json_start == -1:
|
||||
raise ValueError("No JSON output found")
|
||||
|
||||
json_output = "\n".join(lines[json_start:])
|
||||
data = json.loads(json_output)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Failed to parse {validator_name} validator output: {e}")
|
||||
logger.error(f"Stdout: {result.stdout}")
|
||||
logger.error(f"Stderr: {result.stderr}")
|
||||
raise ScanParseException(reason=str(e))
|
||||
|
||||
# Create scan record
|
||||
scan = ArchitectureScan(
|
||||
timestamp=datetime.now(),
|
||||
validator_type=validator_type,
|
||||
total_files=data.get("files_checked", 0),
|
||||
total_violations=data.get("total_violations", 0),
|
||||
errors=data.get("errors", 0),
|
||||
warnings=data.get("warnings", 0),
|
||||
duration_seconds=duration,
|
||||
triggered_by=triggered_by,
|
||||
git_commit_hash=git_commit,
|
||||
)
|
||||
db.add(scan)
|
||||
db.flush() # Get scan.id
|
||||
|
||||
# Create violation records
|
||||
violations_data = data.get("violations", [])
|
||||
logger.info(f"Creating {len(violations_data)} {validator_name} violation records")
|
||||
|
||||
for v in violations_data:
|
||||
violation = ArchitectureViolation(
|
||||
scan_id=scan.id,
|
||||
validator_type=validator_type,
|
||||
rule_id=v["rule_id"],
|
||||
rule_name=v["rule_name"],
|
||||
severity=v["severity"],
|
||||
file_path=v["file_path"],
|
||||
line_number=v["line_number"],
|
||||
message=v["message"],
|
||||
context=v.get("context", ""),
|
||||
suggestion=v.get("suggestion", ""),
|
||||
status="open",
|
||||
)
|
||||
db.add(violation)
|
||||
|
||||
db.flush()
|
||||
db.refresh(scan)
|
||||
|
||||
logger.info(
|
||||
f"{validator_name} scan completed: {scan.total_violations} violations found"
|
||||
)
|
||||
return scan
|
||||
|
||||
def run_all_scans(
|
||||
self, db: Session, triggered_by: str = "manual"
|
||||
) -> list[ArchitectureScan]:
|
||||
"""
|
||||
Run all validators and return list of scans
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
triggered_by: Who/what triggered the scan
|
||||
|
||||
Returns:
|
||||
List of ArchitectureScan objects (one per validator)
|
||||
"""
|
||||
results = []
|
||||
for validator_type in VALID_VALIDATOR_TYPES:
|
||||
try:
|
||||
scan = self.run_scan(db, triggered_by, validator_type)
|
||||
results.append(scan)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run {validator_type} scan: {e}")
|
||||
# Continue with other validators even if one fails
|
||||
return results
|
||||
|
||||
def get_latest_scan(
|
||||
self, db: Session, validator_type: str = None
|
||||
) -> ArchitectureScan | None:
|
||||
"""
|
||||
Get the most recent scan
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
validator_type: Optional filter by validator type
|
||||
|
||||
Returns:
|
||||
Most recent ArchitectureScan or None
|
||||
"""
|
||||
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
|
||||
|
||||
if validator_type:
|
||||
query = query.filter(ArchitectureScan.validator_type == validator_type)
|
||||
|
||||
return query.first()
|
||||
|
||||
def get_latest_scans_by_type(self, db: Session) -> dict[str, ArchitectureScan]:
|
||||
"""
|
||||
Get the most recent scan for each validator type
|
||||
|
||||
Returns:
|
||||
Dictionary mapping validator_type to its latest scan
|
||||
"""
|
||||
result = {}
|
||||
for vtype in VALID_VALIDATOR_TYPES:
|
||||
scan = self.get_latest_scan(db, validator_type=vtype)
|
||||
if scan:
|
||||
result[vtype] = scan
|
||||
return result
|
||||
|
||||
def get_scan_by_id(self, db: Session, scan_id: int) -> ArchitectureScan | None:
|
||||
"""Get scan by ID"""
|
||||
return db.query(ArchitectureScan).filter(ArchitectureScan.id == scan_id).first()
|
||||
|
||||
def create_pending_scan(
|
||||
self, db: Session, validator_type: str, triggered_by: str
|
||||
) -> ArchitectureScan:
|
||||
"""
|
||||
Create a new scan record with pending status.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
validator_type: Type of validator (architecture, security, performance)
|
||||
triggered_by: Who triggered the scan (e.g., "manual:username")
|
||||
|
||||
Returns:
|
||||
The created ArchitectureScan record with ID populated
|
||||
"""
|
||||
scan = ArchitectureScan(
|
||||
timestamp=datetime.now(UTC),
|
||||
validator_type=validator_type,
|
||||
status="pending",
|
||||
triggered_by=triggered_by,
|
||||
)
|
||||
db.add(scan)
|
||||
db.flush() # Get scan.id
|
||||
return scan
|
||||
|
||||
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
|
||||
"""
|
||||
Get all currently running scans (pending or running status).
|
||||
|
||||
Returns:
|
||||
List of scans with status 'pending' or 'running', newest first
|
||||
"""
|
||||
return (
|
||||
db.query(ArchitectureScan)
|
||||
.filter(ArchitectureScan.status.in_(["pending", "running"]))
|
||||
.order_by(ArchitectureScan.timestamp.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_scan_history(
|
||||
self, db: Session, limit: int = 30, validator_type: str = None
|
||||
) -> list[ArchitectureScan]:
|
||||
"""
|
||||
Get scan history for trend graphs
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
limit: Maximum number of scans to return
|
||||
validator_type: Optional filter by validator type
|
||||
|
||||
Returns:
|
||||
List of ArchitectureScan objects, newest first
|
||||
"""
|
||||
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
|
||||
|
||||
if validator_type:
|
||||
query = query.filter(ArchitectureScan.validator_type == validator_type)
|
||||
|
||||
return query.limit(limit).all()
|
||||
|
||||
def get_violations(
|
||||
self,
|
||||
db: Session,
|
||||
scan_id: int = None,
|
||||
validator_type: str = None,
|
||||
severity: str = None,
|
||||
status: str = None,
|
||||
rule_id: str = None,
|
||||
file_path: str = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ArchitectureViolation], int]:
|
||||
"""
|
||||
Get violations with filtering and pagination
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
scan_id: Filter by scan ID (if None, use latest scan(s))
|
||||
validator_type: Filter by validator type
|
||||
severity: Filter by severity ('error', 'warning')
|
||||
status: Filter by status ('open', 'assigned', 'resolved', etc.)
|
||||
rule_id: Filter by rule ID
|
||||
file_path: Filter by file path (partial match)
|
||||
limit: Page size
|
||||
offset: Page offset
|
||||
|
||||
Returns:
|
||||
Tuple of (violations list, total count)
|
||||
"""
|
||||
# Build query
|
||||
query = db.query(ArchitectureViolation)
|
||||
|
||||
# If scan_id specified, filter by it
|
||||
if scan_id is not None:
|
||||
query = query.filter(ArchitectureViolation.scan_id == scan_id)
|
||||
else:
|
||||
# If no scan_id, get violations from latest scan(s)
|
||||
if validator_type:
|
||||
# Get latest scan for specific validator type
|
||||
latest_scan = self.get_latest_scan(db, validator_type)
|
||||
if not latest_scan:
|
||||
return [], 0
|
||||
query = query.filter(ArchitectureViolation.scan_id == latest_scan.id)
|
||||
else:
|
||||
# Get violations from latest scans of all types
|
||||
latest_scans = self.get_latest_scans_by_type(db)
|
||||
if not latest_scans:
|
||||
return [], 0
|
||||
scan_ids = [s.id for s in latest_scans.values()]
|
||||
query = query.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
|
||||
# Apply validator_type filter if specified (for scan_id queries)
|
||||
if validator_type and scan_id is not None:
|
||||
query = query.filter(ArchitectureViolation.validator_type == validator_type)
|
||||
|
||||
# Apply other filters
|
||||
if severity:
|
||||
query = query.filter(ArchitectureViolation.severity == severity)
|
||||
|
||||
if status:
|
||||
query = query.filter(ArchitectureViolation.status == status)
|
||||
|
||||
if rule_id:
|
||||
query = query.filter(ArchitectureViolation.rule_id == rule_id)
|
||||
|
||||
if file_path:
|
||||
query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get page of results
|
||||
violations = (
|
||||
query.order_by(
|
||||
ArchitectureViolation.severity.desc(),
|
||||
ArchitectureViolation.validator_type,
|
||||
ArchitectureViolation.file_path,
|
||||
)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
return violations, total
|
||||
|
||||
def get_violation_by_id(
|
||||
self, db: Session, violation_id: int
|
||||
) -> ArchitectureViolation | None:
|
||||
"""Get single violation with details"""
|
||||
return (
|
||||
db.query(ArchitectureViolation)
|
||||
.filter(ArchitectureViolation.id == violation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def assign_violation(
|
||||
self,
|
||||
db: Session,
|
||||
violation_id: int,
|
||||
user_id: int,
|
||||
assigned_by: int,
|
||||
due_date: datetime = None,
|
||||
priority: str = "medium",
|
||||
) -> ViolationAssignment:
|
||||
"""
|
||||
Assign violation to a developer
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
user_id: User to assign to
|
||||
assigned_by: User who is assigning
|
||||
due_date: Due date (optional)
|
||||
priority: Priority level ('low', 'medium', 'high', 'critical')
|
||||
|
||||
Returns:
|
||||
ViolationAssignment object
|
||||
"""
|
||||
# Update violation status
|
||||
violation = self.get_violation_by_id(db, violation_id)
|
||||
if violation:
|
||||
violation.status = "assigned"
|
||||
violation.assigned_to = user_id
|
||||
|
||||
# Create assignment record
|
||||
assignment = ViolationAssignment(
|
||||
violation_id=violation_id,
|
||||
user_id=user_id,
|
||||
assigned_by=assigned_by,
|
||||
due_date=due_date,
|
||||
priority=priority,
|
||||
)
|
||||
db.add(assignment)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Violation {violation_id} assigned to user {user_id}")
|
||||
return assignment
|
||||
|
||||
def resolve_violation(
|
||||
self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
|
||||
) -> ArchitectureViolation:
|
||||
"""
|
||||
Mark violation as resolved
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
resolved_by: User who resolved it
|
||||
resolution_note: Note about resolution
|
||||
|
||||
Returns:
|
||||
Updated ArchitectureViolation object
|
||||
"""
|
||||
violation = self.get_violation_by_id(db, violation_id)
|
||||
if not violation:
|
||||
raise ViolationNotFoundException(violation_id)
|
||||
|
||||
violation.status = "resolved"
|
||||
violation.resolved_at = datetime.now()
|
||||
violation.resolved_by = resolved_by
|
||||
violation.resolution_note = resolution_note
|
||||
|
||||
db.flush()
|
||||
logger.info(f"Violation {violation_id} resolved by user {resolved_by}")
|
||||
return violation
|
||||
|
||||
def ignore_violation(
|
||||
self, db: Session, violation_id: int, ignored_by: int, reason: str
|
||||
) -> ArchitectureViolation:
|
||||
"""
|
||||
Mark violation as ignored/won't fix
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
ignored_by: User who ignored it
|
||||
reason: Reason for ignoring
|
||||
|
||||
Returns:
|
||||
Updated ArchitectureViolation object
|
||||
"""
|
||||
violation = self.get_violation_by_id(db, violation_id)
|
||||
if not violation:
|
||||
raise ViolationNotFoundException(violation_id)
|
||||
|
||||
violation.status = "ignored"
|
||||
violation.resolved_at = datetime.now()
|
||||
violation.resolved_by = ignored_by
|
||||
violation.resolution_note = f"Ignored: {reason}"
|
||||
|
||||
db.flush()
|
||||
logger.info(f"Violation {violation_id} ignored by user {ignored_by}")
|
||||
return violation
|
||||
|
||||
def add_comment(
|
||||
self, db: Session, violation_id: int, user_id: int, comment: str
|
||||
) -> ViolationComment:
|
||||
"""
|
||||
Add comment to violation
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
user_id: User posting comment
|
||||
comment: Comment text
|
||||
|
||||
Returns:
|
||||
ViolationComment object
|
||||
"""
|
||||
comment_obj = ViolationComment(
|
||||
violation_id=violation_id, user_id=user_id, comment=comment
|
||||
)
|
||||
db.add(comment_obj)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Comment added to violation {violation_id} by user {user_id}")
|
||||
return comment_obj
|
||||
|
||||
def get_dashboard_stats(
|
||||
self, db: Session, validator_type: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
Get statistics for dashboard
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
validator_type: Optional filter by validator type. If None, returns combined stats.
|
||||
|
||||
Returns:
|
||||
Dictionary with various statistics including per-validator breakdown
|
||||
"""
|
||||
# Get latest scans by type
|
||||
latest_scans = self.get_latest_scans_by_type(db)
|
||||
|
||||
if not latest_scans:
|
||||
return self._empty_dashboard_stats()
|
||||
|
||||
# If specific validator type requested
|
||||
if validator_type and validator_type in latest_scans:
|
||||
scan = latest_scans[validator_type]
|
||||
return self._get_stats_for_scan(db, scan, validator_type)
|
||||
|
||||
# Combined stats across all validators
|
||||
return self._get_combined_stats(db, latest_scans)
|
||||
|
||||
def _empty_dashboard_stats(self) -> dict:
|
||||
"""Return empty dashboard stats structure"""
|
||||
return {
|
||||
"total_violations": 0,
|
||||
"errors": 0,
|
||||
"warnings": 0,
|
||||
"info": 0,
|
||||
"open": 0,
|
||||
"assigned": 0,
|
||||
"resolved": 0,
|
||||
"ignored": 0,
|
||||
"technical_debt_score": 100,
|
||||
"trend": [],
|
||||
"by_severity": {},
|
||||
"by_rule": {},
|
||||
"by_module": {},
|
||||
"top_files": [],
|
||||
"last_scan": None,
|
||||
"by_validator": {},
|
||||
}
|
||||
|
||||
def _get_stats_for_scan(
|
||||
self, db: Session, scan: ArchitectureScan, validator_type: str
|
||||
) -> dict:
|
||||
"""Get stats for a single scan/validator type"""
|
||||
# Get violation counts by status
|
||||
status_counts = (
|
||||
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.status)
|
||||
.all()
|
||||
)
|
||||
status_dict = {status: count for status, count in status_counts}
|
||||
|
||||
# Get violations by severity
|
||||
severity_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.severity)
|
||||
.all()
|
||||
)
|
||||
by_severity = {sev: count for sev, count in severity_counts}
|
||||
|
||||
# Get violations by rule
|
||||
rule_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.rule_id)
|
||||
.all()
|
||||
)
|
||||
by_rule = {
|
||||
rule: count
|
||||
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
|
||||
}
|
||||
|
||||
# Get top violating files
|
||||
file_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.file_path,
|
||||
func.count(ArchitectureViolation.id).label("count"),
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.file_path)
|
||||
.order_by(desc("count"))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
top_files = [{"file": file, "count": count} for file, count in file_counts]
|
||||
|
||||
# Get violations by module
|
||||
by_module = self._get_violations_by_module(db, scan.id)
|
||||
|
||||
# Get trend for this validator type
|
||||
trend_scans = self.get_scan_history(db, limit=7, validator_type=validator_type)
|
||||
trend = [
|
||||
{
|
||||
"timestamp": s.timestamp.isoformat(),
|
||||
"violations": s.total_violations,
|
||||
"errors": s.errors,
|
||||
"warnings": s.warnings,
|
||||
}
|
||||
for s in reversed(trend_scans)
|
||||
]
|
||||
|
||||
return {
|
||||
"total_violations": scan.total_violations,
|
||||
"errors": scan.errors,
|
||||
"warnings": scan.warnings,
|
||||
"info": by_severity.get("info", 0),
|
||||
"open": status_dict.get("open", 0),
|
||||
"assigned": status_dict.get("assigned", 0),
|
||||
"resolved": status_dict.get("resolved", 0),
|
||||
"ignored": status_dict.get("ignored", 0),
|
||||
"technical_debt_score": self._calculate_score(scan.errors, scan.warnings),
|
||||
"trend": trend,
|
||||
"by_severity": by_severity,
|
||||
"by_rule": by_rule,
|
||||
"by_module": by_module,
|
||||
"top_files": top_files,
|
||||
"last_scan": scan.timestamp.isoformat(),
|
||||
"validator_type": validator_type,
|
||||
"by_validator": {
|
||||
validator_type: {
|
||||
"total_violations": scan.total_violations,
|
||||
"errors": scan.errors,
|
||||
"warnings": scan.warnings,
|
||||
"last_scan": scan.timestamp.isoformat(),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _get_combined_stats(
|
||||
self, db: Session, latest_scans: dict[str, ArchitectureScan]
|
||||
) -> dict:
|
||||
"""Get combined stats across all validators"""
|
||||
# Aggregate totals
|
||||
total_violations = sum(s.total_violations for s in latest_scans.values())
|
||||
total_errors = sum(s.errors for s in latest_scans.values())
|
||||
total_warnings = sum(s.warnings for s in latest_scans.values())
|
||||
|
||||
# Get all scan IDs
|
||||
scan_ids = [s.id for s in latest_scans.values()]
|
||||
|
||||
# Get violation counts by status
|
||||
status_counts = (
|
||||
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.status)
|
||||
.all()
|
||||
)
|
||||
status_dict = {status: count for status, count in status_counts}
|
||||
|
||||
# Get violations by severity
|
||||
severity_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.severity)
|
||||
.all()
|
||||
)
|
||||
by_severity = {sev: count for sev, count in severity_counts}
|
||||
|
||||
# Get violations by rule (across all validators)
|
||||
rule_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.rule_id)
|
||||
.all()
|
||||
)
|
||||
by_rule = {
|
||||
rule: count
|
||||
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
|
||||
}
|
||||
|
||||
# Get top violating files
|
||||
file_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.file_path,
|
||||
func.count(ArchitectureViolation.id).label("count"),
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.file_path)
|
||||
.order_by(desc("count"))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
top_files = [{"file": file, "count": count} for file, count in file_counts]
|
||||
|
||||
# Get violations by module
|
||||
by_module = {}
|
||||
for scan_id in scan_ids:
|
||||
module_counts = self._get_violations_by_module(db, scan_id)
|
||||
for module, count in module_counts.items():
|
||||
by_module[module] = by_module.get(module, 0) + count
|
||||
by_module = dict(
|
||||
sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
)
|
||||
|
||||
# Per-validator breakdown
|
||||
by_validator = {}
|
||||
for vtype, scan in latest_scans.items():
|
||||
by_validator[vtype] = {
|
||||
"total_violations": scan.total_violations,
|
||||
"errors": scan.errors,
|
||||
"warnings": scan.warnings,
|
||||
"last_scan": scan.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
# Get most recent scan timestamp
|
||||
most_recent = max(latest_scans.values(), key=lambda s: s.timestamp)
|
||||
|
||||
return {
|
||||
"total_violations": total_violations,
|
||||
"errors": total_errors,
|
||||
"warnings": total_warnings,
|
||||
"info": by_severity.get("info", 0),
|
||||
"open": status_dict.get("open", 0),
|
||||
"assigned": status_dict.get("assigned", 0),
|
||||
"resolved": status_dict.get("resolved", 0),
|
||||
"ignored": status_dict.get("ignored", 0),
|
||||
"technical_debt_score": self._calculate_score(total_errors, total_warnings),
|
||||
"trend": [], # Combined trend would need special handling
|
||||
"by_severity": by_severity,
|
||||
"by_rule": by_rule,
|
||||
"by_module": by_module,
|
||||
"top_files": top_files,
|
||||
"last_scan": most_recent.timestamp.isoformat(),
|
||||
"by_validator": by_validator,
|
||||
}
|
||||
|
||||
def _get_violations_by_module(self, db: Session, scan_id: int) -> dict[str, int]:
|
||||
"""Extract module from file paths and count violations"""
|
||||
by_module = {}
|
||||
violations = (
|
||||
db.query(ArchitectureViolation.file_path)
|
||||
.filter(ArchitectureViolation.scan_id == scan_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for v in violations:
|
||||
path_parts = v.file_path.split("/")
|
||||
if len(path_parts) >= 2:
|
||||
module = "/".join(path_parts[:2])
|
||||
else:
|
||||
module = path_parts[0]
|
||||
by_module[module] = by_module.get(module, 0) + 1
|
||||
|
||||
return dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10])
|
||||
|
||||
def _calculate_score(self, errors: int, warnings: int) -> int:
|
||||
"""Calculate technical debt score (0-100)"""
|
||||
score = 100 - (errors * 0.5 + warnings * 0.05)
|
||||
return max(0, min(100, int(score)))
|
||||
|
||||
def calculate_technical_debt_score(
|
||||
self, db: Session, scan_id: int = None, validator_type: str = None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate technical debt score (0-100)
|
||||
|
||||
Formula: 100 - (errors * 0.5 + warnings * 0.05)
|
||||
Capped at 0 minimum
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
scan_id: Scan ID (if None, use latest)
|
||||
validator_type: Filter by validator type
|
||||
|
||||
Returns:
|
||||
Score from 0-100
|
||||
"""
|
||||
if scan_id is None:
|
||||
latest_scan = self.get_latest_scan(db, validator_type)
|
||||
if not latest_scan:
|
||||
return 100
|
||||
scan_id = latest_scan.id
|
||||
|
||||
scan = self.get_scan_by_id(db, scan_id)
|
||||
if not scan:
|
||||
return 100
|
||||
|
||||
return self._calculate_score(scan.errors, scan.warnings)
|
||||
|
||||
def _get_git_commit_hash(self) -> str | None:
|
||||
"""Get current git commit hash"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()[:40]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
code_quality_service = CodeQualityService()
|
||||
507
app/modules/dev_tools/services/test_runner_service.py
Normal file
507
app/modules/dev_tools/services/test_runner_service.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Test Runner Service
|
||||
Service for running pytest and storing results
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.dev_tools.models import TestCollection, TestResult, TestRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestRunnerService:
|
||||
"""Service for managing pytest test runs"""
|
||||
|
||||
def __init__(self):
|
||||
self.project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
|
||||
def create_test_run(
|
||||
self,
|
||||
db: Session,
|
||||
test_path: str = "tests",
|
||||
triggered_by: str = "manual",
|
||||
) -> TestRun:
|
||||
"""Create a test run record without executing tests"""
|
||||
test_run = TestRun(
|
||||
timestamp=datetime.now(UTC),
|
||||
triggered_by=triggered_by,
|
||||
test_path=test_path,
|
||||
status="running",
|
||||
git_commit_hash=self._get_git_commit(),
|
||||
git_branch=self._get_git_branch(),
|
||||
)
|
||||
db.add(test_run)
|
||||
db.flush()
|
||||
return test_run
|
||||
|
||||
def run_tests(
|
||||
self,
|
||||
db: Session,
|
||||
test_path: str = "tests",
|
||||
triggered_by: str = "manual",
|
||||
extra_args: list[str] | None = None,
|
||||
) -> TestRun:
|
||||
"""
|
||||
Run pytest synchronously and store results in database
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
test_path: Path to tests (relative to project root)
|
||||
triggered_by: Who triggered the run
|
||||
extra_args: Additional pytest arguments
|
||||
|
||||
Returns:
|
||||
TestRun object with results
|
||||
"""
|
||||
test_run = self.create_test_run(db, test_path, triggered_by)
|
||||
self._execute_tests(db, test_run, test_path, extra_args)
|
||||
return test_run
|
||||
|
||||
def _execute_tests(
|
||||
self,
|
||||
db: Session,
|
||||
test_run: TestRun,
|
||||
test_path: str,
|
||||
extra_args: list[str] | None,
|
||||
) -> None:
|
||||
"""Execute pytest and update the test run record"""
|
||||
try:
|
||||
# Build pytest command with JSON output
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as f:
|
||||
json_report_path = f.name
|
||||
|
||||
pytest_args = [
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
test_path,
|
||||
"--json-report",
|
||||
f"--json-report-file={json_report_path}",
|
||||
"-v",
|
||||
"--tb=short",
|
||||
]
|
||||
|
||||
if extra_args:
|
||||
pytest_args.extend(extra_args)
|
||||
|
||||
test_run.pytest_args = " ".join(pytest_args)
|
||||
|
||||
# Run pytest
|
||||
start_time = datetime.now(UTC)
|
||||
result = subprocess.run(
|
||||
pytest_args,
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600, # 10 minute timeout
|
||||
)
|
||||
end_time = datetime.now(UTC)
|
||||
|
||||
test_run.duration_seconds = (end_time - start_time).total_seconds()
|
||||
|
||||
# Parse JSON report
|
||||
try:
|
||||
with open(json_report_path) as f:
|
||||
report = json.load(f)
|
||||
|
||||
self._process_json_report(db, test_run, report)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
# Fallback to parsing stdout if JSON report failed
|
||||
logger.warning(f"JSON report unavailable ({e}), parsing stdout")
|
||||
self._parse_pytest_output(test_run, result.stdout, result.stderr)
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
Path(json_report_path).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Set final status
|
||||
if test_run.failed > 0 or test_run.errors > 0:
|
||||
test_run.status = "failed"
|
||||
else:
|
||||
test_run.status = "passed"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
test_run.status = "error"
|
||||
logger.error("Pytest run timed out")
|
||||
except Exception as e:
|
||||
test_run.status = "error"
|
||||
logger.error(f"Error running tests: {e}")
|
||||
|
||||
def _process_json_report(self, db: Session, test_run: TestRun, report: dict):
|
||||
"""Process pytest-json-report output"""
|
||||
summary = report.get("summary", {})
|
||||
|
||||
test_run.total_tests = summary.get("total", 0)
|
||||
test_run.passed = summary.get("passed", 0)
|
||||
test_run.failed = summary.get("failed", 0)
|
||||
test_run.errors = summary.get("error", 0)
|
||||
test_run.skipped = summary.get("skipped", 0)
|
||||
test_run.xfailed = summary.get("xfailed", 0)
|
||||
test_run.xpassed = summary.get("xpassed", 0)
|
||||
|
||||
# Process individual test results
|
||||
tests = report.get("tests", [])
|
||||
for test in tests:
|
||||
node_id = test.get("nodeid", "")
|
||||
outcome = test.get("outcome", "unknown")
|
||||
|
||||
# Parse node_id to get file, class, function
|
||||
test_file, test_class, test_name = self._parse_node_id(node_id)
|
||||
|
||||
# Get failure details
|
||||
error_message = None
|
||||
traceback = None
|
||||
if outcome in ("failed", "error"):
|
||||
call_info = test.get("call", {})
|
||||
if "longrepr" in call_info:
|
||||
traceback = call_info["longrepr"]
|
||||
# Extract error message from traceback
|
||||
if isinstance(traceback, str):
|
||||
lines = traceback.strip().split("\n")
|
||||
if lines:
|
||||
error_message = lines[-1][:500] # Last line, limited length
|
||||
|
||||
test_result = TestResult(
|
||||
run_id=test_run.id,
|
||||
node_id=node_id,
|
||||
test_name=test_name,
|
||||
test_file=test_file,
|
||||
test_class=test_class,
|
||||
outcome=outcome,
|
||||
duration_seconds=test.get("duration", 0.0),
|
||||
error_message=error_message,
|
||||
traceback=traceback,
|
||||
markers=test.get("keywords", []),
|
||||
)
|
||||
db.add(test_result)
|
||||
|
||||
def _parse_node_id(self, node_id: str) -> tuple[str, str | None, str]:
|
||||
"""Parse pytest node_id into file, class, function"""
|
||||
# Format: tests/unit/test_foo.py::TestClass::test_method
|
||||
# or: tests/unit/test_foo.py::test_function
|
||||
parts = node_id.split("::")
|
||||
|
||||
test_file = parts[0] if parts else ""
|
||||
test_class = None
|
||||
test_name = parts[-1] if parts else ""
|
||||
|
||||
if len(parts) == 3:
|
||||
test_class = parts[1]
|
||||
elif len(parts) == 2:
|
||||
# Could be Class::method or file::function
|
||||
if parts[1].startswith("Test"):
|
||||
test_class = parts[1]
|
||||
test_name = parts[1]
|
||||
|
||||
# Handle parametrized tests
|
||||
if "[" in test_name:
|
||||
test_name = test_name.split("[")[0]
|
||||
|
||||
return test_file, test_class, test_name
|
||||
|
||||
def _parse_pytest_output(self, test_run: TestRun, stdout: str, stderr: str):
|
||||
"""Fallback parser for pytest text output"""
|
||||
# Parse summary line like: "10 passed, 2 failed, 1 skipped"
|
||||
summary_pattern = r"(\d+)\s+(passed|failed|error|skipped|xfailed|xpassed)"
|
||||
|
||||
for match in re.finditer(summary_pattern, stdout):
|
||||
count = int(match.group(1))
|
||||
status = match.group(2)
|
||||
|
||||
if status == "passed":
|
||||
test_run.passed = count
|
||||
elif status == "failed":
|
||||
test_run.failed = count
|
||||
elif status == "error":
|
||||
test_run.errors = count
|
||||
elif status == "skipped":
|
||||
test_run.skipped = count
|
||||
elif status == "xfailed":
|
||||
test_run.xfailed = count
|
||||
elif status == "xpassed":
|
||||
test_run.xpassed = count
|
||||
|
||||
test_run.total_tests = (
|
||||
test_run.passed
|
||||
+ test_run.failed
|
||||
+ test_run.errors
|
||||
+ test_run.skipped
|
||||
+ test_run.xfailed
|
||||
+ test_run.xpassed
|
||||
)
|
||||
|
||||
def _get_git_commit(self) -> str | None:
|
||||
"""Get current git commit hash"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"],
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.stdout.strip()[:40] if result.returncode == 0 else None
|
||||
except:
|
||||
return None
|
||||
|
||||
def _get_git_branch(self) -> str | None:
|
||||
"""Get current git branch"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.stdout.strip() if result.returncode == 0 else None
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]:
|
||||
"""Get recent test run history"""
|
||||
return db.query(TestRun).order_by(desc(TestRun.timestamp)).limit(limit).all()
|
||||
|
||||
def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None:
|
||||
"""Get a specific test run with results"""
|
||||
return db.query(TestRun).filter(TestRun.id == run_id).first()
|
||||
|
||||
def get_failed_tests(self, db: Session, run_id: int) -> list[TestResult]:
|
||||
"""Get failed tests from a run"""
|
||||
return (
|
||||
db.query(TestResult)
|
||||
.filter(
|
||||
TestResult.run_id == run_id, TestResult.outcome.in_(["failed", "error"])
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_run_results(
|
||||
self, db: Session, run_id: int, outcome: str | None = None
|
||||
) -> list[TestResult]:
|
||||
"""Get test results for a specific run, optionally filtered by outcome"""
|
||||
query = db.query(TestResult).filter(TestResult.run_id == run_id)
|
||||
|
||||
if outcome:
|
||||
query = query.filter(TestResult.outcome == outcome)
|
||||
|
||||
return query.all()
|
||||
|
||||
def get_dashboard_stats(self, db: Session) -> dict:
|
||||
"""Get statistics for the testing dashboard"""
|
||||
# Get latest run
|
||||
latest_run = (
|
||||
db.query(TestRun)
|
||||
.filter(TestRun.status != "running")
|
||||
.order_by(desc(TestRun.timestamp))
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get test collection info (or calculate from latest run)
|
||||
collection = (
|
||||
db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first()
|
||||
)
|
||||
|
||||
# Get trend data (last 10 runs)
|
||||
trend_runs = (
|
||||
db.query(TestRun)
|
||||
.filter(TestRun.status != "running")
|
||||
.order_by(desc(TestRun.timestamp))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Calculate stats by category from latest run
|
||||
by_category = {}
|
||||
if latest_run:
|
||||
results = (
|
||||
db.query(TestResult).filter(TestResult.run_id == latest_run.id).all()
|
||||
)
|
||||
for result in results:
|
||||
# Categorize by test path
|
||||
if "unit" in result.test_file:
|
||||
category = "Unit Tests"
|
||||
elif "integration" in result.test_file:
|
||||
category = "Integration Tests"
|
||||
elif "performance" in result.test_file:
|
||||
category = "Performance Tests"
|
||||
elif "system" in result.test_file:
|
||||
category = "System Tests"
|
||||
else:
|
||||
category = "Other"
|
||||
|
||||
if category not in by_category:
|
||||
by_category[category] = {"total": 0, "passed": 0, "failed": 0}
|
||||
by_category[category]["total"] += 1
|
||||
if result.outcome == "passed":
|
||||
by_category[category]["passed"] += 1
|
||||
elif result.outcome in ("failed", "error"):
|
||||
by_category[category]["failed"] += 1
|
||||
|
||||
# Get top failing tests (across recent runs)
|
||||
top_failing = (
|
||||
db.query(
|
||||
TestResult.test_name,
|
||||
TestResult.test_file,
|
||||
func.count(TestResult.id).label("failure_count"),
|
||||
)
|
||||
.filter(TestResult.outcome.in_(["failed", "error"]))
|
||||
.group_by(TestResult.test_name, TestResult.test_file)
|
||||
.order_by(desc("failure_count"))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
# Current run stats
|
||||
"total_tests": latest_run.total_tests if latest_run else 0,
|
||||
"passed": latest_run.passed if latest_run else 0,
|
||||
"failed": latest_run.failed if latest_run else 0,
|
||||
"errors": latest_run.errors if latest_run else 0,
|
||||
"skipped": latest_run.skipped if latest_run else 0,
|
||||
"pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0,
|
||||
"duration_seconds": round(latest_run.duration_seconds, 2)
|
||||
if latest_run
|
||||
else 0,
|
||||
"coverage_percent": latest_run.coverage_percent if latest_run else None,
|
||||
"last_run": latest_run.timestamp.isoformat() if latest_run else None,
|
||||
"last_run_status": latest_run.status if latest_run else None,
|
||||
# Collection stats
|
||||
"total_test_files": collection.total_files if collection else 0,
|
||||
"collected_tests": collection.total_tests if collection else 0,
|
||||
"unit_tests": collection.unit_tests if collection else 0,
|
||||
"integration_tests": collection.integration_tests if collection else 0,
|
||||
"performance_tests": collection.performance_tests if collection else 0,
|
||||
"system_tests": collection.system_tests if collection else 0,
|
||||
"last_collected": collection.collected_at.isoformat()
|
||||
if collection
|
||||
else None,
|
||||
# Trend data
|
||||
"trend": [
|
||||
{
|
||||
"timestamp": run.timestamp.isoformat(),
|
||||
"total": run.total_tests,
|
||||
"passed": run.passed,
|
||||
"failed": run.failed,
|
||||
"pass_rate": round(run.pass_rate, 1),
|
||||
"duration": round(run.duration_seconds, 1),
|
||||
}
|
||||
for run in reversed(trend_runs)
|
||||
],
|
||||
# By category
|
||||
"by_category": by_category,
|
||||
# Top failing tests
|
||||
"top_failing": [
|
||||
{
|
||||
"test_name": t.test_name,
|
||||
"test_file": t.test_file,
|
||||
"failure_count": t.failure_count,
|
||||
}
|
||||
for t in top_failing
|
||||
],
|
||||
}
|
||||
|
||||
def collect_tests(self, db: Session) -> TestCollection:
|
||||
"""Collect test information without running tests"""
|
||||
collection = TestCollection(
|
||||
collected_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
try:
|
||||
# Run pytest --collect-only with JSON report
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as f:
|
||||
json_report_path = f.name
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--collect-only",
|
||||
"--json-report",
|
||||
f"--json-report-file={json_report_path}",
|
||||
"tests",
|
||||
],
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
# Parse JSON report
|
||||
json_path = Path(json_report_path)
|
||||
if json_path.exists():
|
||||
with open(json_path) as f:
|
||||
report = json.load(f)
|
||||
|
||||
# Get total from summary
|
||||
collection.total_tests = report.get("summary", {}).get("collected", 0)
|
||||
|
||||
# Parse collectors to get test files and counts
|
||||
test_files = {}
|
||||
for collector in report.get("collectors", []):
|
||||
for item in collector.get("result", []):
|
||||
if item.get("type") == "Function":
|
||||
node_id = item.get("nodeid", "")
|
||||
if "::" in node_id:
|
||||
file_path = node_id.split("::")[0]
|
||||
if file_path not in test_files:
|
||||
test_files[file_path] = 0
|
||||
test_files[file_path] += 1
|
||||
|
||||
# Count files and categorize
|
||||
for file_path, count in test_files.items():
|
||||
collection.total_files += 1
|
||||
|
||||
if "/unit/" in file_path or file_path.startswith("tests/unit"):
|
||||
collection.unit_tests += count
|
||||
elif "/integration/" in file_path or file_path.startswith(
|
||||
"tests/integration"
|
||||
):
|
||||
collection.integration_tests += count
|
||||
elif "/performance/" in file_path or file_path.startswith(
|
||||
"tests/performance"
|
||||
):
|
||||
collection.performance_tests += count
|
||||
elif "/system/" in file_path or file_path.startswith(
|
||||
"tests/system"
|
||||
):
|
||||
collection.system_tests += count
|
||||
|
||||
collection.test_files = [
|
||||
{"file": f, "count": c}
|
||||
for f, c in sorted(test_files.items(), key=lambda x: -x[1])
|
||||
]
|
||||
|
||||
# Cleanup
|
||||
json_path.unlink(missing_ok=True)
|
||||
|
||||
logger.info(
|
||||
f"Collected {collection.total_tests} tests from {collection.total_files} files"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting tests: {e}", exc_info=True)
|
||||
|
||||
db.add(collection)
|
||||
return collection
|
||||
|
||||
|
||||
# Singleton instance
|
||||
test_runner_service = TestRunnerService()
|
||||
@@ -2,11 +2,11 @@
|
||||
"""
|
||||
Inventory module database models.
|
||||
|
||||
Re-exports inventory-related models from their source locations.
|
||||
This module contains the canonical implementations of inventory-related models.
|
||||
"""
|
||||
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.inventory_transaction import (
|
||||
from app.modules.inventory.models.inventory import Inventory
|
||||
from app.modules.inventory.models.inventory_transaction import (
|
||||
InventoryTransaction,
|
||||
TransactionType,
|
||||
)
|
||||
|
||||
61
app/modules/inventory/models/inventory.py
Normal file
61
app/modules/inventory/models/inventory.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# app/modules/inventory/models/inventory.py
|
||||
"""
|
||||
Inventory model for tracking stock at warehouse/bin locations.
|
||||
|
||||
Each entry represents a quantity of a product at a specific bin location
|
||||
within a warehouse. Products can be scattered across multiple bins.
|
||||
|
||||
Example:
|
||||
Warehouse: "strassen"
|
||||
Bin: "SA-10-02"
|
||||
Product: GTIN 4007817144145
|
||||
Quantity: 3
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Index, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
from models.database.base import TimestampMixin
|
||||
|
||||
|
||||
class Inventory(Base, TimestampMixin):
|
||||
__tablename__ = "inventory"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True)
|
||||
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
|
||||
|
||||
# Location: warehouse + bin
|
||||
warehouse = Column(String, nullable=False, default="strassen", index=True)
|
||||
bin_location = Column(String, nullable=False, index=True) # e.g., "SA-10-02"
|
||||
|
||||
# Legacy field - kept for backward compatibility, will be removed
|
||||
location = Column(String, index=True)
|
||||
|
||||
quantity = Column(Integer, nullable=False, default=0)
|
||||
reserved_quantity = Column(Integer, default=0)
|
||||
|
||||
# Keep GTIN for reference/reporting (matches Product.gtin)
|
||||
gtin = Column(String, index=True)
|
||||
|
||||
# Relationships
|
||||
product = relationship("Product", back_populates="inventory_entries")
|
||||
vendor = relationship("Vendor")
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"product_id", "warehouse", "bin_location", name="uq_inventory_product_warehouse_bin"
|
||||
),
|
||||
Index("idx_inventory_vendor_product", "vendor_id", "product_id"),
|
||||
Index("idx_inventory_warehouse_bin", "warehouse", "bin_location"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Inventory(product_id={self.product_id}, location='{self.location}', quantity={self.quantity})>"
|
||||
|
||||
@property
|
||||
def available_quantity(self):
|
||||
"""Calculate available quantity (total - reserved)."""
|
||||
return max(0, self.quantity - self.reserved_quantity)
|
||||
170
app/modules/inventory/models/inventory_transaction.py
Normal file
170
app/modules/inventory/models/inventory_transaction.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# app/modules/inventory/models/inventory_transaction.py
|
||||
"""
|
||||
Inventory Transaction Model - Audit trail for all stock movements.
|
||||
|
||||
This model tracks every change to inventory quantities, providing:
|
||||
- Complete audit trail for compliance and debugging
|
||||
- Order-linked transactions for traceability
|
||||
- Support for different transaction types (reserve, fulfill, adjust, etc.)
|
||||
|
||||
All stock movements should create a transaction record.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
|
||||
class TransactionType(str, Enum):
|
||||
"""Types of inventory transactions."""
|
||||
|
||||
# Order-related
|
||||
RESERVE = "reserve" # Stock reserved for order
|
||||
FULFILL = "fulfill" # Reserved stock consumed (shipped)
|
||||
RELEASE = "release" # Reserved stock released (cancelled)
|
||||
|
||||
# Manual adjustments
|
||||
ADJUST = "adjust" # Manual adjustment (+/-)
|
||||
SET = "set" # Set to exact quantity
|
||||
|
||||
# Imports
|
||||
IMPORT = "import" # Initial import/sync
|
||||
|
||||
# Returns
|
||||
RETURN = "return" # Stock returned from customer
|
||||
|
||||
|
||||
class InventoryTransaction(Base):
|
||||
"""
|
||||
Audit log for inventory movements.
|
||||
|
||||
Every change to inventory quantity creates a transaction record,
|
||||
enabling complete traceability of stock levels over time.
|
||||
"""
|
||||
|
||||
__tablename__ = "inventory_transactions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Core references
|
||||
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
|
||||
product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True)
|
||||
inventory_id = Column(
|
||||
Integer, ForeignKey("inventory.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Transaction details
|
||||
transaction_type = Column(
|
||||
SQLEnum(TransactionType), nullable=False, index=True
|
||||
)
|
||||
quantity_change = Column(Integer, nullable=False) # Positive = add, negative = remove
|
||||
|
||||
# Quantities after transaction (snapshot)
|
||||
quantity_after = Column(Integer, nullable=False)
|
||||
reserved_after = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Location context
|
||||
location = Column(String, nullable=True)
|
||||
warehouse = Column(String, nullable=True)
|
||||
|
||||
# Order reference (for order-related transactions)
|
||||
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True)
|
||||
order_number = Column(String, nullable=True)
|
||||
|
||||
# Audit fields
|
||||
reason = Column(Text, nullable=True) # Human-readable reason
|
||||
created_by = Column(String, nullable=True) # User/system that created
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
vendor = relationship("Vendor")
|
||||
product = relationship("Product")
|
||||
inventory = relationship("Inventory")
|
||||
order = relationship("Order")
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
Index("idx_inv_tx_vendor_product", "vendor_id", "product_id"),
|
||||
Index("idx_inv_tx_vendor_created", "vendor_id", "created_at"),
|
||||
Index("idx_inv_tx_order", "order_id"),
|
||||
Index("idx_inv_tx_type_created", "transaction_type", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<InventoryTransaction {self.id}: "
|
||||
f"{self.transaction_type.value} {self.quantity_change:+d} "
|
||||
f"for product {self.product_id}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_transaction(
|
||||
cls,
|
||||
vendor_id: int,
|
||||
product_id: int,
|
||||
transaction_type: TransactionType,
|
||||
quantity_change: int,
|
||||
quantity_after: int,
|
||||
reserved_after: int = 0,
|
||||
inventory_id: int | None = None,
|
||||
location: str | None = None,
|
||||
warehouse: str | None = None,
|
||||
order_id: int | None = None,
|
||||
order_number: str | None = None,
|
||||
reason: str | None = None,
|
||||
created_by: str | None = None,
|
||||
) -> "InventoryTransaction":
|
||||
"""
|
||||
Factory method to create a transaction record.
|
||||
|
||||
Args:
|
||||
vendor_id: Vendor ID
|
||||
product_id: Product ID
|
||||
transaction_type: Type of transaction
|
||||
quantity_change: Change in quantity (positive = add, negative = remove)
|
||||
quantity_after: Total quantity after this transaction
|
||||
reserved_after: Reserved quantity after this transaction
|
||||
inventory_id: Optional inventory record ID
|
||||
location: Optional location
|
||||
warehouse: Optional warehouse
|
||||
order_id: Optional order ID (for order-related transactions)
|
||||
order_number: Optional order number for display
|
||||
reason: Optional human-readable reason
|
||||
created_by: Optional user/system identifier
|
||||
|
||||
Returns:
|
||||
InventoryTransaction instance (not yet added to session)
|
||||
"""
|
||||
return cls(
|
||||
vendor_id=vendor_id,
|
||||
product_id=product_id,
|
||||
inventory_id=inventory_id,
|
||||
transaction_type=transaction_type,
|
||||
quantity_change=quantity_change,
|
||||
quantity_after=quantity_after,
|
||||
reserved_after=reserved_after,
|
||||
location=location,
|
||||
warehouse=warehouse,
|
||||
order_id=order_id,
|
||||
order_number=order_number,
|
||||
reason=reason,
|
||||
created_by=created_by,
|
||||
)
|
||||
@@ -2,27 +2,77 @@
|
||||
"""
|
||||
Inventory module Pydantic schemas.
|
||||
|
||||
Re-exports inventory-related schemas from their source locations.
|
||||
This module contains the canonical implementations of inventory-related schemas.
|
||||
"""
|
||||
|
||||
from models.schema.inventory import (
|
||||
from app.modules.inventory.schemas.inventory import (
|
||||
# Base schemas
|
||||
InventoryBase,
|
||||
InventoryCreate,
|
||||
InventoryAdjust,
|
||||
InventoryUpdate,
|
||||
InventoryReserve,
|
||||
# Response schemas
|
||||
InventoryResponse,
|
||||
InventoryLocationResponse,
|
||||
ProductInventorySummary,
|
||||
InventoryListResponse,
|
||||
InventoryTransactionResponse,
|
||||
InventoryMessageResponse,
|
||||
InventorySummaryResponse,
|
||||
# Admin schemas
|
||||
AdminInventoryCreate,
|
||||
AdminInventoryAdjust,
|
||||
AdminInventoryItem,
|
||||
AdminInventoryListResponse,
|
||||
AdminInventoryStats,
|
||||
AdminLowStockItem,
|
||||
AdminVendorWithInventory,
|
||||
AdminVendorsWithInventoryResponse,
|
||||
AdminInventoryLocationsResponse,
|
||||
# Transaction schemas
|
||||
InventoryTransactionResponse,
|
||||
InventoryTransactionWithProduct,
|
||||
InventoryTransactionListResponse,
|
||||
ProductTransactionHistoryResponse,
|
||||
OrderTransactionHistoryResponse,
|
||||
# Admin transaction schemas
|
||||
AdminInventoryTransactionItem,
|
||||
AdminInventoryTransactionListResponse,
|
||||
AdminTransactionStatsResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base schemas
|
||||
"InventoryBase",
|
||||
"InventoryCreate",
|
||||
"InventoryAdjust",
|
||||
"InventoryUpdate",
|
||||
"InventoryReserve",
|
||||
# Response schemas
|
||||
"InventoryResponse",
|
||||
"InventoryLocationResponse",
|
||||
"ProductInventorySummary",
|
||||
"InventoryListResponse",
|
||||
"InventoryTransactionResponse",
|
||||
"InventoryMessageResponse",
|
||||
"InventorySummaryResponse",
|
||||
# Admin schemas
|
||||
"AdminInventoryCreate",
|
||||
"AdminInventoryAdjust",
|
||||
"AdminInventoryItem",
|
||||
"AdminInventoryListResponse",
|
||||
"AdminInventoryStats",
|
||||
"AdminLowStockItem",
|
||||
"AdminVendorWithInventory",
|
||||
"AdminVendorsWithInventoryResponse",
|
||||
"AdminInventoryLocationsResponse",
|
||||
# Transaction schemas
|
||||
"InventoryTransactionResponse",
|
||||
"InventoryTransactionWithProduct",
|
||||
"InventoryTransactionListResponse",
|
||||
"ProductTransactionHistoryResponse",
|
||||
"OrderTransactionHistoryResponse",
|
||||
# Admin transaction schemas
|
||||
"AdminInventoryTransactionItem",
|
||||
"AdminInventoryTransactionListResponse",
|
||||
"AdminTransactionStatsResponse",
|
||||
]
|
||||
|
||||
294
app/modules/inventory/schemas/inventory.py
Normal file
294
app/modules/inventory/schemas/inventory.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# app/modules/inventory/schemas/inventory.py
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class InventoryBase(BaseModel):
|
||||
product_id: int = Field(..., description="Product ID in vendor catalog")
|
||||
location: str = Field(..., description="Storage location")
|
||||
|
||||
|
||||
class InventoryCreate(InventoryBase):
|
||||
"""Set exact inventory quantity (replaces existing)."""
|
||||
|
||||
quantity: int = Field(..., description="Exact inventory quantity", ge=0)
|
||||
|
||||
|
||||
class InventoryAdjust(InventoryBase):
|
||||
"""Add or remove inventory quantity."""
|
||||
|
||||
quantity: int = Field(
|
||||
..., description="Quantity to add (positive) or remove (negative)"
|
||||
)
|
||||
|
||||
|
||||
class InventoryUpdate(BaseModel):
|
||||
"""Update inventory fields."""
|
||||
|
||||
quantity: int | None = Field(None, ge=0)
|
||||
reserved_quantity: int | None = Field(None, ge=0)
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class InventoryReserve(BaseModel):
|
||||
"""Reserve inventory for orders."""
|
||||
|
||||
product_id: int
|
||||
location: str
|
||||
quantity: int = Field(..., gt=0)
|
||||
|
||||
|
||||
class InventoryResponse(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
product_id: int
|
||||
vendor_id: int
|
||||
location: str
|
||||
quantity: int
|
||||
reserved_quantity: int
|
||||
gtin: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def available_quantity(self):
|
||||
return max(0, self.quantity - self.reserved_quantity)
|
||||
|
||||
|
||||
class InventoryLocationResponse(BaseModel):
|
||||
location: str
|
||||
quantity: int
|
||||
reserved_quantity: int
|
||||
available_quantity: int
|
||||
|
||||
|
||||
class ProductInventorySummary(BaseModel):
|
||||
"""Inventory summary for a product."""
|
||||
|
||||
product_id: int
|
||||
vendor_id: int
|
||||
product_sku: str | None
|
||||
product_title: str
|
||||
total_quantity: int
|
||||
total_reserved: int
|
||||
total_available: int
|
||||
locations: list[InventoryLocationResponse]
|
||||
|
||||
|
||||
class InventoryListResponse(BaseModel):
|
||||
inventories: list[InventoryResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class InventoryMessageResponse(BaseModel):
|
||||
"""Simple message response for inventory operations."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class InventorySummaryResponse(BaseModel):
|
||||
"""Inventory summary response for marketplace product service."""
|
||||
|
||||
gtin: str
|
||||
total_quantity: int
|
||||
locations: list[InventoryLocationResponse]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Admin Inventory Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdminInventoryCreate(BaseModel):
|
||||
"""Admin version of inventory create - requires explicit vendor_id."""
|
||||
|
||||
vendor_id: int = Field(..., description="Target vendor ID")
|
||||
product_id: int = Field(..., description="Product ID in vendor catalog")
|
||||
location: str = Field(..., description="Storage location")
|
||||
quantity: int = Field(..., description="Exact inventory quantity", ge=0)
|
||||
|
||||
|
||||
class AdminInventoryAdjust(BaseModel):
|
||||
"""Admin version of inventory adjust - requires explicit vendor_id."""
|
||||
|
||||
vendor_id: int = Field(..., description="Target vendor ID")
|
||||
product_id: int = Field(..., description="Product ID in vendor catalog")
|
||||
location: str = Field(..., description="Storage location")
|
||||
quantity: int = Field(
|
||||
..., description="Quantity to add (positive) or remove (negative)"
|
||||
)
|
||||
reason: str | None = Field(None, description="Reason for adjustment")
|
||||
|
||||
|
||||
class AdminInventoryItem(BaseModel):
|
||||
"""Inventory item with vendor info for admin list view."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
product_id: int
|
||||
vendor_id: int
|
||||
vendor_name: str | None = None
|
||||
vendor_code: str | None = None
|
||||
product_title: str | None = None
|
||||
product_sku: str | None = None
|
||||
location: str
|
||||
quantity: int
|
||||
reserved_quantity: int
|
||||
available_quantity: int
|
||||
gtin: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class AdminInventoryListResponse(BaseModel):
|
||||
"""Cross-vendor inventory list for admin."""
|
||||
|
||||
inventories: list[AdminInventoryItem]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
vendor_filter: int | None = None
|
||||
location_filter: str | None = None
|
||||
|
||||
|
||||
class AdminInventoryStats(BaseModel):
|
||||
"""Inventory statistics for admin dashboard."""
|
||||
|
||||
total_entries: int
|
||||
total_quantity: int
|
||||
total_reserved: int
|
||||
total_available: int
|
||||
low_stock_count: int
|
||||
vendors_with_inventory: int
|
||||
unique_locations: int
|
||||
|
||||
|
||||
class AdminLowStockItem(BaseModel):
|
||||
"""Low stock item for admin alerts."""
|
||||
|
||||
id: int
|
||||
product_id: int
|
||||
vendor_id: int
|
||||
vendor_name: str | None = None
|
||||
product_title: str | None = None
|
||||
location: str
|
||||
quantity: int
|
||||
reserved_quantity: int
|
||||
available_quantity: int
|
||||
|
||||
|
||||
class AdminVendorWithInventory(BaseModel):
|
||||
"""Vendor with inventory entries."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
vendor_code: str
|
||||
|
||||
|
||||
class AdminVendorsWithInventoryResponse(BaseModel):
|
||||
"""Response for vendors with inventory list."""
|
||||
|
||||
vendors: list[AdminVendorWithInventory]
|
||||
|
||||
|
||||
class AdminInventoryLocationsResponse(BaseModel):
|
||||
"""Response for unique inventory locations."""
|
||||
|
||||
locations: list[str]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Inventory Transaction Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class InventoryTransactionResponse(BaseModel):
|
||||
"""Single inventory transaction record."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
product_id: int
|
||||
inventory_id: int | None = None
|
||||
transaction_type: str
|
||||
quantity_change: int
|
||||
quantity_after: int
|
||||
reserved_after: int
|
||||
location: str | None = None
|
||||
warehouse: str | None = None
|
||||
order_id: int | None = None
|
||||
order_number: str | None = None
|
||||
reason: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class InventoryTransactionWithProduct(InventoryTransactionResponse):
|
||||
"""Transaction with product details for list views."""
|
||||
|
||||
product_title: str | None = None
|
||||
product_sku: str | None = None
|
||||
|
||||
|
||||
class InventoryTransactionListResponse(BaseModel):
|
||||
"""Paginated list of inventory transactions."""
|
||||
|
||||
transactions: list[InventoryTransactionWithProduct]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class ProductTransactionHistoryResponse(BaseModel):
|
||||
"""Transaction history for a specific product."""
|
||||
|
||||
product_id: int
|
||||
product_title: str | None = None
|
||||
product_sku: str | None = None
|
||||
current_quantity: int
|
||||
current_reserved: int
|
||||
transactions: list[InventoryTransactionResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class OrderTransactionHistoryResponse(BaseModel):
|
||||
"""Transaction history for a specific order."""
|
||||
|
||||
order_id: int
|
||||
order_number: str
|
||||
transactions: list[InventoryTransactionWithProduct]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Admin Inventory Transaction Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdminInventoryTransactionItem(InventoryTransactionWithProduct):
|
||||
"""Transaction with vendor details for admin views."""
|
||||
|
||||
vendor_name: str | None = None
|
||||
vendor_code: str | None = None
|
||||
|
||||
|
||||
class AdminInventoryTransactionListResponse(BaseModel):
|
||||
"""Paginated list of transactions for admin."""
|
||||
|
||||
transactions: list[AdminInventoryTransactionItem]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class AdminTransactionStatsResponse(BaseModel):
|
||||
"""Transaction statistics for admin dashboard."""
|
||||
|
||||
total_transactions: int
|
||||
transactions_today: int
|
||||
by_type: dict[str, int]
|
||||
@@ -2,20 +2,21 @@
|
||||
"""
|
||||
Inventory module services.
|
||||
|
||||
Re-exports inventory-related services from their source locations.
|
||||
This module contains the canonical implementations of inventory-related services.
|
||||
"""
|
||||
|
||||
from app.services.inventory_service import (
|
||||
from app.modules.inventory.services.inventory_service import (
|
||||
inventory_service,
|
||||
InventoryService,
|
||||
)
|
||||
from app.services.inventory_transaction_service import (
|
||||
from app.modules.inventory.services.inventory_transaction_service import (
|
||||
inventory_transaction_service,
|
||||
InventoryTransactionService,
|
||||
)
|
||||
from app.services.inventory_import_service import (
|
||||
from app.modules.inventory.services.inventory_import_service import (
|
||||
inventory_import_service,
|
||||
InventoryImportService,
|
||||
ImportResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -25,4 +26,5 @@ __all__ = [
|
||||
"InventoryTransactionService",
|
||||
"inventory_import_service",
|
||||
"InventoryImportService",
|
||||
"ImportResult",
|
||||
]
|
||||
|
||||
250
app/modules/inventory/services/inventory_import_service.py
Normal file
250
app/modules/inventory/services/inventory_import_service.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# app/modules/inventory/services/inventory_import_service.py
|
||||
"""
|
||||
Inventory import service for bulk importing stock from TSV/CSV files.
|
||||
|
||||
Supports two formats:
|
||||
1. One row per unit (quantity = count of rows):
|
||||
BIN EAN PRODUCT
|
||||
SA-10-02 0810050910101 Product Name
|
||||
SA-10-02 0810050910101 Product Name (2nd unit)
|
||||
|
||||
2. With explicit quantity column:
|
||||
BIN EAN PRODUCT QUANTITY
|
||||
SA-10-02 0810050910101 Product Name 12
|
||||
|
||||
Products are matched by GTIN/EAN to existing vendor products.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.inventory.models.inventory import Inventory
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportResult:
|
||||
"""Result of an inventory import operation."""
|
||||
|
||||
success: bool = True
|
||||
total_rows: int = 0
|
||||
entries_created: int = 0
|
||||
entries_updated: int = 0
|
||||
quantity_imported: int = 0
|
||||
unmatched_gtins: list = field(default_factory=list)
|
||||
errors: list = field(default_factory=list)
|
||||
|
||||
|
||||
class InventoryImportService:
|
||||
"""Service for importing inventory from TSV/CSV files."""
|
||||
|
||||
def import_from_text(
|
||||
self,
|
||||
db: Session,
|
||||
content: str,
|
||||
vendor_id: int,
|
||||
warehouse: str = "strassen",
|
||||
delimiter: str = "\t",
|
||||
clear_existing: bool = False,
|
||||
) -> ImportResult:
|
||||
"""
|
||||
Import inventory from TSV/CSV text content.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
content: TSV/CSV content as string
|
||||
vendor_id: Vendor ID for inventory
|
||||
warehouse: Warehouse name (default: "strassen")
|
||||
delimiter: Column delimiter (default: tab)
|
||||
clear_existing: If True, clear existing inventory before import
|
||||
|
||||
Returns:
|
||||
ImportResult with summary and errors
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
try:
|
||||
# Parse CSV/TSV
|
||||
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
|
||||
|
||||
# Normalize headers (case-insensitive, strip whitespace)
|
||||
if reader.fieldnames:
|
||||
reader.fieldnames = [h.strip().upper() for h in reader.fieldnames]
|
||||
|
||||
# Validate required columns
|
||||
required = {"BIN", "EAN"}
|
||||
if not reader.fieldnames or not required.issubset(set(reader.fieldnames)):
|
||||
result.success = False
|
||||
result.errors.append(
|
||||
f"Missing required columns. Found: {reader.fieldnames}, Required: {required}"
|
||||
)
|
||||
return result
|
||||
|
||||
has_quantity = "QUANTITY" in reader.fieldnames
|
||||
|
||||
# Group entries by (EAN, BIN)
|
||||
# Key: (ean, bin) -> quantity
|
||||
inventory_data: dict[tuple[str, str], int] = defaultdict(int)
|
||||
product_names: dict[str, str] = {} # EAN -> product name (for logging)
|
||||
|
||||
for row in reader:
|
||||
result.total_rows += 1
|
||||
|
||||
ean = row.get("EAN", "").strip()
|
||||
bin_loc = row.get("BIN", "").strip()
|
||||
product_name = row.get("PRODUCT", "").strip()
|
||||
|
||||
if not ean or not bin_loc:
|
||||
result.errors.append(f"Row {result.total_rows}: Missing EAN or BIN")
|
||||
continue
|
||||
|
||||
# Get quantity
|
||||
if has_quantity:
|
||||
try:
|
||||
qty = int(row.get("QUANTITY", "1").strip())
|
||||
except ValueError:
|
||||
result.errors.append(
|
||||
f"Row {result.total_rows}: Invalid quantity '{row.get('QUANTITY')}'"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
qty = 1 # Each row = 1 unit
|
||||
|
||||
inventory_data[(ean, bin_loc)] += qty
|
||||
if product_name:
|
||||
product_names[ean] = product_name
|
||||
|
||||
# Clear existing inventory if requested
|
||||
if clear_existing:
|
||||
db.query(Inventory).filter(
|
||||
Inventory.vendor_id == vendor_id,
|
||||
Inventory.warehouse == warehouse,
|
||||
).delete()
|
||||
db.flush()
|
||||
|
||||
# Build EAN to Product mapping for this vendor
|
||||
products = (
|
||||
db.query(Product)
|
||||
.filter(
|
||||
Product.vendor_id == vendor_id,
|
||||
Product.gtin.isnot(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
ean_to_product: dict[str, Product] = {p.gtin: p for p in products if p.gtin}
|
||||
|
||||
# Track unmatched GTINs
|
||||
unmatched: dict[str, int] = {} # EAN -> total quantity
|
||||
|
||||
# Process inventory entries
|
||||
for (ean, bin_loc), quantity in inventory_data.items():
|
||||
product = ean_to_product.get(ean)
|
||||
|
||||
if not product:
|
||||
# Track unmatched
|
||||
if ean not in unmatched:
|
||||
unmatched[ean] = 0
|
||||
unmatched[ean] += quantity
|
||||
continue
|
||||
|
||||
# Upsert inventory entry
|
||||
existing = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == product.id,
|
||||
Inventory.warehouse == warehouse,
|
||||
Inventory.bin_location == bin_loc,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.quantity = quantity
|
||||
existing.gtin = ean
|
||||
result.entries_updated += 1
|
||||
else:
|
||||
inv = Inventory(
|
||||
product_id=product.id,
|
||||
vendor_id=vendor_id,
|
||||
warehouse=warehouse,
|
||||
bin_location=bin_loc,
|
||||
location=bin_loc, # Legacy field
|
||||
quantity=quantity,
|
||||
gtin=ean,
|
||||
)
|
||||
db.add(inv)
|
||||
result.entries_created += 1
|
||||
|
||||
result.quantity_imported += quantity
|
||||
|
||||
db.flush()
|
||||
|
||||
# Format unmatched GTINs for result
|
||||
for ean, qty in unmatched.items():
|
||||
product_name = product_names.get(ean, "Unknown")
|
||||
result.unmatched_gtins.append(
|
||||
{"gtin": ean, "quantity": qty, "product_name": product_name}
|
||||
)
|
||||
|
||||
if result.unmatched_gtins:
|
||||
logger.warning(
|
||||
f"Import had {len(result.unmatched_gtins)} unmatched GTINs"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Inventory import failed")
|
||||
result.success = False
|
||||
result.errors.append(str(e))
|
||||
|
||||
return result
|
||||
|
||||
def import_from_file(
|
||||
self,
|
||||
db: Session,
|
||||
file_path: str,
|
||||
vendor_id: int,
|
||||
warehouse: str = "strassen",
|
||||
clear_existing: bool = False,
|
||||
) -> ImportResult:
|
||||
"""
|
||||
Import inventory from a TSV/CSV file.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
file_path: Path to TSV/CSV file
|
||||
vendor_id: Vendor ID for inventory
|
||||
warehouse: Warehouse name
|
||||
clear_existing: If True, clear existing inventory before import
|
||||
|
||||
Returns:
|
||||
ImportResult with summary and errors
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
return ImportResult(success=False, errors=[f"Failed to read file: {e}"])
|
||||
|
||||
# Detect delimiter
|
||||
first_line = content.split("\n")[0] if content else ""
|
||||
delimiter = "\t" if "\t" in first_line else ","
|
||||
|
||||
return self.import_from_text(
|
||||
db=db,
|
||||
content=content,
|
||||
vendor_id=vendor_id,
|
||||
warehouse=warehouse,
|
||||
delimiter=delimiter,
|
||||
clear_existing=clear_existing,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
inventory_import_service = InventoryImportService()
|
||||
949
app/modules/inventory/services/inventory_service.py
Normal file
949
app/modules/inventory/services/inventory_service.py
Normal file
@@ -0,0 +1,949 @@
|
||||
# app/modules/inventory/services/inventory_service.py
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
InsufficientInventoryException,
|
||||
InvalidQuantityException,
|
||||
InventoryNotFoundException,
|
||||
InventoryValidationException,
|
||||
ProductNotFoundException,
|
||||
ValidationException,
|
||||
VendorNotFoundException,
|
||||
)
|
||||
from app.modules.inventory.models.inventory import Inventory
|
||||
from app.modules.inventory.schemas.inventory import (
|
||||
AdminInventoryItem,
|
||||
AdminInventoryListResponse,
|
||||
AdminInventoryLocationsResponse,
|
||||
AdminInventoryStats,
|
||||
AdminLowStockItem,
|
||||
AdminVendorsWithInventoryResponse,
|
||||
AdminVendorWithInventory,
|
||||
InventoryAdjust,
|
||||
InventoryCreate,
|
||||
InventoryLocationResponse,
|
||||
InventoryReserve,
|
||||
InventoryUpdate,
|
||||
ProductInventorySummary,
|
||||
)
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InventoryService:
|
||||
"""Service for inventory operations with vendor isolation."""
|
||||
|
||||
def set_inventory(
|
||||
self, db: Session, vendor_id: int, inventory_data: InventoryCreate
|
||||
) -> Inventory:
|
||||
"""
|
||||
Set exact inventory quantity for a product at a location (replaces existing).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID (from middleware)
|
||||
inventory_data: Inventory data
|
||||
|
||||
Returns:
|
||||
Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product belongs to vendor
|
||||
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
|
||||
|
||||
# Validate location
|
||||
location = self._validate_location(inventory_data.location)
|
||||
|
||||
# Validate quantity
|
||||
self._validate_quantity(inventory_data.quantity, allow_zero=True)
|
||||
|
||||
# Check if inventory entry exists
|
||||
existing = self._get_inventory_entry(
|
||||
db, inventory_data.product_id, location
|
||||
)
|
||||
|
||||
if existing:
|
||||
old_qty = existing.quantity
|
||||
existing.quantity = inventory_data.quantity
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(existing)
|
||||
|
||||
logger.info(
|
||||
f"Set inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"{old_qty} → {inventory_data.quantity}"
|
||||
)
|
||||
return existing
|
||||
# Create new inventory entry
|
||||
new_inventory = Inventory(
|
||||
product_id=inventory_data.product_id,
|
||||
vendor_id=vendor_id,
|
||||
warehouse="strassen", # Default warehouse
|
||||
bin_location=location, # Use location as bin location
|
||||
location=location, # Keep for backward compatibility
|
||||
quantity=inventory_data.quantity,
|
||||
gtin=product.marketplace_product.gtin, # Optional reference
|
||||
)
|
||||
db.add(new_inventory)
|
||||
db.flush()
|
||||
db.refresh(new_inventory)
|
||||
|
||||
logger.info(
|
||||
f"Created inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"{inventory_data.quantity}"
|
||||
)
|
||||
return new_inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InvalidQuantityException,
|
||||
InventoryValidationException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error setting inventory: {str(e)}")
|
||||
raise ValidationException("Failed to set inventory")
|
||||
|
||||
def adjust_inventory(
|
||||
self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
|
||||
) -> Inventory:
|
||||
"""
|
||||
Adjust inventory by adding or removing quantity.
|
||||
Positive quantity = add, negative = remove.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
inventory_data: Adjustment data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product belongs to vendor
|
||||
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
|
||||
|
||||
# Validate location
|
||||
location = self._validate_location(inventory_data.location)
|
||||
|
||||
# Check if inventory exists
|
||||
existing = self._get_inventory_entry(
|
||||
db, inventory_data.product_id, location
|
||||
)
|
||||
|
||||
if not existing:
|
||||
# Create new if adding, error if removing
|
||||
if inventory_data.quantity < 0:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {inventory_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Create with positive quantity
|
||||
new_inventory = Inventory(
|
||||
product_id=inventory_data.product_id,
|
||||
vendor_id=vendor_id,
|
||||
warehouse="strassen", # Default warehouse
|
||||
bin_location=location, # Use location as bin location
|
||||
location=location, # Keep for backward compatibility
|
||||
quantity=inventory_data.quantity,
|
||||
gtin=product.marketplace_product.gtin,
|
||||
)
|
||||
db.add(new_inventory)
|
||||
db.flush()
|
||||
db.refresh(new_inventory)
|
||||
|
||||
logger.info(
|
||||
f"Created inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"+{inventory_data.quantity}"
|
||||
)
|
||||
return new_inventory
|
||||
|
||||
# Adjust existing inventory
|
||||
old_qty = existing.quantity
|
||||
new_qty = old_qty + inventory_data.quantity
|
||||
|
||||
# Validate resulting quantity
|
||||
if new_qty < 0:
|
||||
raise InsufficientInventoryException(
|
||||
f"Insufficient inventory. Available: {old_qty}, "
|
||||
f"Requested removal: {abs(inventory_data.quantity)}"
|
||||
)
|
||||
|
||||
existing.quantity = new_qty
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(existing)
|
||||
|
||||
logger.info(
|
||||
f"Adjusted inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"{old_qty} {'+' if inventory_data.quantity >= 0 else ''}{inventory_data.quantity} = {new_qty}"
|
||||
)
|
||||
return existing
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InventoryValidationException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error adjusting inventory: {str(e)}")
|
||||
raise ValidationException("Failed to adjust inventory")
|
||||
|
||||
def reserve_inventory(
|
||||
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
|
||||
) -> Inventory:
|
||||
"""
|
||||
Reserve inventory for an order (increases reserved_quantity).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reserve_data: Reservation data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product
|
||||
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
|
||||
|
||||
# Validate location and quantity
|
||||
location = self._validate_location(reserve_data.location)
|
||||
self._validate_quantity(reserve_data.quantity, allow_zero=False)
|
||||
|
||||
# Get inventory entry
|
||||
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {reserve_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Check available quantity
|
||||
available = inventory.quantity - inventory.reserved_quantity
|
||||
if available < reserve_data.quantity:
|
||||
raise InsufficientInventoryException(
|
||||
f"Insufficient available inventory. Available: {available}, "
|
||||
f"Requested: {reserve_data.quantity}"
|
||||
)
|
||||
|
||||
# Reserve inventory
|
||||
inventory.reserved_quantity += reserve_data.quantity
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(
|
||||
f"Reserved {reserve_data.quantity} units for product {reserve_data.product_id} "
|
||||
f"at {location}"
|
||||
)
|
||||
return inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InvalidQuantityException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error reserving inventory: {str(e)}")
|
||||
raise ValidationException("Failed to reserve inventory")
|
||||
|
||||
def release_reservation(
|
||||
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
|
||||
) -> Inventory:
|
||||
"""
|
||||
Release reserved inventory (decreases reserved_quantity).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reserve_data: Reservation data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product
|
||||
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
|
||||
|
||||
location = self._validate_location(reserve_data.location)
|
||||
self._validate_quantity(reserve_data.quantity, allow_zero=False)
|
||||
|
||||
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {reserve_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Validate reserved quantity
|
||||
if inventory.reserved_quantity < reserve_data.quantity:
|
||||
logger.warning(
|
||||
f"Attempting to release more than reserved. Reserved: {inventory.reserved_quantity}, "
|
||||
f"Requested: {reserve_data.quantity}"
|
||||
)
|
||||
inventory.reserved_quantity = 0
|
||||
else:
|
||||
inventory.reserved_quantity -= reserve_data.quantity
|
||||
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(
|
||||
f"Released {reserve_data.quantity} units for product {reserve_data.product_id} "
|
||||
f"at {location}"
|
||||
)
|
||||
return inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InvalidQuantityException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error releasing reservation: {str(e)}")
|
||||
raise ValidationException("Failed to release reservation")
|
||||
|
||||
def fulfill_reservation(
|
||||
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
|
||||
) -> Inventory:
|
||||
"""
|
||||
Fulfill a reservation (decreases both quantity and reserved_quantity).
|
||||
Use when order is shipped/completed.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reserve_data: Reservation data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
|
||||
location = self._validate_location(reserve_data.location)
|
||||
self._validate_quantity(reserve_data.quantity, allow_zero=False)
|
||||
|
||||
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {reserve_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Validate quantities
|
||||
if inventory.quantity < reserve_data.quantity:
|
||||
raise InsufficientInventoryException(
|
||||
f"Insufficient inventory. Available: {inventory.quantity}, "
|
||||
f"Requested: {reserve_data.quantity}"
|
||||
)
|
||||
|
||||
if inventory.reserved_quantity < reserve_data.quantity:
|
||||
logger.warning(
|
||||
f"Fulfilling more than reserved. Reserved: {inventory.reserved_quantity}, "
|
||||
f"Fulfilling: {reserve_data.quantity}"
|
||||
)
|
||||
|
||||
# Fulfill (remove from both quantity and reserved)
|
||||
inventory.quantity -= reserve_data.quantity
|
||||
inventory.reserved_quantity = max(
|
||||
0, inventory.reserved_quantity - reserve_data.quantity
|
||||
)
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(
|
||||
f"Fulfilled {reserve_data.quantity} units for product {reserve_data.product_id} "
|
||||
f"at {location}"
|
||||
)
|
||||
return inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InvalidQuantityException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error fulfilling reservation: {str(e)}")
|
||||
raise ValidationException("Failed to fulfill reservation")
|
||||
|
||||
def get_product_inventory(
|
||||
self, db: Session, vendor_id: int, product_id: int
|
||||
) -> ProductInventorySummary:
|
||||
"""
|
||||
Get inventory summary for a product across all locations.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
product_id: Product ID
|
||||
|
||||
Returns:
|
||||
ProductInventorySummary
|
||||
"""
|
||||
try:
|
||||
product = self._get_vendor_product(db, vendor_id, product_id)
|
||||
|
||||
inventory_entries = (
|
||||
db.query(Inventory).filter(Inventory.product_id == product_id).all()
|
||||
)
|
||||
|
||||
if not inventory_entries:
|
||||
return ProductInventorySummary(
|
||||
product_id=product_id,
|
||||
vendor_id=vendor_id,
|
||||
product_sku=product.vendor_sku,
|
||||
product_title=product.marketplace_product.get_title() or "",
|
||||
total_quantity=0,
|
||||
total_reserved=0,
|
||||
total_available=0,
|
||||
locations=[],
|
||||
)
|
||||
|
||||
total_qty = sum(inv.quantity for inv in inventory_entries)
|
||||
total_reserved = sum(inv.reserved_quantity for inv in inventory_entries)
|
||||
total_available = sum(inv.available_quantity for inv in inventory_entries)
|
||||
|
||||
locations = [
|
||||
InventoryLocationResponse(
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
)
|
||||
for inv in inventory_entries
|
||||
]
|
||||
|
||||
return ProductInventorySummary(
|
||||
product_id=product_id,
|
||||
vendor_id=vendor_id,
|
||||
product_sku=product.vendor_sku,
|
||||
product_title=product.marketplace_product.get_title() or "",
|
||||
total_quantity=total_qty,
|
||||
total_reserved=total_reserved,
|
||||
total_available=total_available,
|
||||
locations=locations,
|
||||
)
|
||||
|
||||
except ProductNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting product inventory: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve product inventory")
|
||||
|
||||
def get_vendor_inventory(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
location: str | None = None,
|
||||
low_stock_threshold: int | None = None,
|
||||
) -> list[Inventory]:
|
||||
"""
|
||||
Get all inventory for a vendor with filtering.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
location: Filter by location
|
||||
low_stock_threshold: Filter items below threshold
|
||||
|
||||
Returns:
|
||||
List of Inventory objects
|
||||
"""
|
||||
try:
|
||||
query = db.query(Inventory).filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
if location:
|
||||
query = query.filter(Inventory.location.ilike(f"%{location}%"))
|
||||
|
||||
if low_stock_threshold is not None:
|
||||
query = query.filter(Inventory.quantity <= low_stock_threshold)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting vendor inventory: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve vendor inventory")
|
||||
|
||||
def update_inventory(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
inventory_id: int,
|
||||
inventory_update: InventoryUpdate,
|
||||
) -> Inventory:
|
||||
"""Update inventory entry."""
|
||||
try:
|
||||
inventory = self._get_inventory_by_id(db, inventory_id)
|
||||
|
||||
# Verify ownership
|
||||
if inventory.vendor_id != vendor_id:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
|
||||
# Update fields
|
||||
if inventory_update.quantity is not None:
|
||||
self._validate_quantity(inventory_update.quantity, allow_zero=True)
|
||||
inventory.quantity = inventory_update.quantity
|
||||
|
||||
if inventory_update.reserved_quantity is not None:
|
||||
self._validate_quantity(
|
||||
inventory_update.reserved_quantity, allow_zero=True
|
||||
)
|
||||
inventory.reserved_quantity = inventory_update.reserved_quantity
|
||||
|
||||
if inventory_update.location:
|
||||
inventory.location = self._validate_location(inventory_update.location)
|
||||
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(f"Updated inventory {inventory_id}")
|
||||
return inventory
|
||||
|
||||
except (
|
||||
InventoryNotFoundException,
|
||||
InvalidQuantityException,
|
||||
InventoryValidationException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating inventory: {str(e)}")
|
||||
raise ValidationException("Failed to update inventory")
|
||||
|
||||
def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
|
||||
"""Delete inventory entry."""
|
||||
try:
|
||||
inventory = self._get_inventory_by_id(db, inventory_id)
|
||||
|
||||
# Verify ownership
|
||||
if inventory.vendor_id != vendor_id:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
|
||||
db.delete(inventory)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Deleted inventory {inventory_id}")
|
||||
return True
|
||||
|
||||
except InventoryNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting inventory: {str(e)}")
|
||||
raise ValidationException("Failed to delete inventory")
|
||||
|
||||
# =========================================================================
|
||||
# Admin Methods (cross-vendor operations)
|
||||
# =========================================================================
|
||||
|
||||
def get_all_inventory_admin(
|
||||
self,
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
vendor_id: int | None = None,
|
||||
location: str | None = None,
|
||||
low_stock: int | None = None,
|
||||
search: str | None = None,
|
||||
) -> AdminInventoryListResponse:
|
||||
"""
|
||||
Get inventory across all vendors with filtering (admin only).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
vendor_id: Filter by vendor
|
||||
location: Filter by location
|
||||
low_stock: Filter items below threshold
|
||||
search: Search by product title or SKU
|
||||
|
||||
Returns:
|
||||
AdminInventoryListResponse
|
||||
"""
|
||||
query = db.query(Inventory).join(Product).join(Vendor)
|
||||
|
||||
# Apply filters
|
||||
if vendor_id is not None:
|
||||
query = query.filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
if location:
|
||||
query = query.filter(Inventory.location.ilike(f"%{location}%"))
|
||||
|
||||
if low_stock is not None:
|
||||
query = query.filter(Inventory.quantity <= low_stock)
|
||||
|
||||
if search:
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.marketplace_product_translation import (
|
||||
MarketplaceProductTranslation,
|
||||
)
|
||||
|
||||
query = (
|
||||
query.join(MarketplaceProduct)
|
||||
.outerjoin(MarketplaceProductTranslation)
|
||||
.filter(
|
||||
(MarketplaceProductTranslation.title.ilike(f"%{search}%"))
|
||||
| (Product.vendor_sku.ilike(f"%{search}%"))
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count before pagination
|
||||
total = query.count()
|
||||
|
||||
# Apply pagination
|
||||
inventories = query.offset(skip).limit(limit).all()
|
||||
|
||||
# Build response with vendor/product info
|
||||
items = []
|
||||
for inv in inventories:
|
||||
product = inv.product
|
||||
vendor = inv.vendor
|
||||
title = None
|
||||
if product and product.marketplace_product:
|
||||
title = product.marketplace_product.get_title()
|
||||
|
||||
items.append(
|
||||
AdminInventoryItem(
|
||||
id=inv.id,
|
||||
product_id=inv.product_id,
|
||||
vendor_id=inv.vendor_id,
|
||||
vendor_name=vendor.name if vendor else None,
|
||||
vendor_code=vendor.vendor_code if vendor else None,
|
||||
product_title=title,
|
||||
product_sku=product.vendor_sku if product else None,
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
gtin=inv.gtin,
|
||||
created_at=inv.created_at,
|
||||
updated_at=inv.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return AdminInventoryListResponse(
|
||||
inventories=items,
|
||||
total=total,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
vendor_filter=vendor_id,
|
||||
location_filter=location,
|
||||
)
|
||||
|
||||
def get_inventory_stats_admin(self, db: Session) -> AdminInventoryStats:
|
||||
"""Get platform-wide inventory statistics (admin only)."""
|
||||
# Total entries
|
||||
total_entries = db.query(func.count(Inventory.id)).scalar() or 0
|
||||
|
||||
# Aggregate quantities
|
||||
totals = db.query(
|
||||
func.sum(Inventory.quantity).label("total_qty"),
|
||||
func.sum(Inventory.reserved_quantity).label("total_reserved"),
|
||||
).first()
|
||||
|
||||
total_quantity = totals.total_qty or 0
|
||||
total_reserved = totals.total_reserved or 0
|
||||
total_available = total_quantity - total_reserved
|
||||
|
||||
# Low stock count (default threshold: 10)
|
||||
low_stock_count = (
|
||||
db.query(func.count(Inventory.id))
|
||||
.filter(Inventory.quantity <= 10)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Vendors with inventory
|
||||
vendors_with_inventory = (
|
||||
db.query(func.count(func.distinct(Inventory.vendor_id))).scalar() or 0
|
||||
)
|
||||
|
||||
# Unique locations
|
||||
unique_locations = (
|
||||
db.query(func.count(func.distinct(Inventory.location))).scalar() or 0
|
||||
)
|
||||
|
||||
return AdminInventoryStats(
|
||||
total_entries=total_entries,
|
||||
total_quantity=total_quantity,
|
||||
total_reserved=total_reserved,
|
||||
total_available=total_available,
|
||||
low_stock_count=low_stock_count,
|
||||
vendors_with_inventory=vendors_with_inventory,
|
||||
unique_locations=unique_locations,
|
||||
)
|
||||
|
||||
def get_low_stock_items_admin(
|
||||
self,
|
||||
db: Session,
|
||||
threshold: int = 10,
|
||||
vendor_id: int | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[AdminLowStockItem]:
|
||||
"""Get items with low stock levels (admin only)."""
|
||||
query = (
|
||||
db.query(Inventory)
|
||||
.join(Product)
|
||||
.join(Vendor)
|
||||
.filter(Inventory.quantity <= threshold)
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
# Order by quantity ascending (most critical first)
|
||||
query = query.order_by(Inventory.quantity.asc())
|
||||
|
||||
inventories = query.limit(limit).all()
|
||||
|
||||
items = []
|
||||
for inv in inventories:
|
||||
product = inv.product
|
||||
vendor = inv.vendor
|
||||
title = None
|
||||
if product and product.marketplace_product:
|
||||
title = product.marketplace_product.get_title()
|
||||
|
||||
items.append(
|
||||
AdminLowStockItem(
|
||||
id=inv.id,
|
||||
product_id=inv.product_id,
|
||||
vendor_id=inv.vendor_id,
|
||||
vendor_name=vendor.name if vendor else None,
|
||||
product_title=title,
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
def get_vendors_with_inventory_admin(
|
||||
self, db: Session
|
||||
) -> AdminVendorsWithInventoryResponse:
|
||||
"""Get list of vendors that have inventory entries (admin only)."""
|
||||
# noqa: SVC-005 - Admin function, intentionally cross-vendor
|
||||
# Use subquery to avoid DISTINCT on JSON columns (PostgreSQL can't compare JSON)
|
||||
vendor_ids_subquery = (
|
||||
db.query(Inventory.vendor_id)
|
||||
.distinct()
|
||||
.subquery()
|
||||
)
|
||||
vendors = (
|
||||
db.query(Vendor)
|
||||
.filter(Vendor.id.in_(db.query(vendor_ids_subquery.c.vendor_id)))
|
||||
.order_by(Vendor.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
return AdminVendorsWithInventoryResponse(
|
||||
vendors=[
|
||||
AdminVendorWithInventory(
|
||||
id=v.id, name=v.name, vendor_code=v.vendor_code
|
||||
)
|
||||
for v in vendors
|
||||
]
|
||||
)
|
||||
|
||||
def get_inventory_locations_admin(
|
||||
self, db: Session, vendor_id: int | None = None
|
||||
) -> AdminInventoryLocationsResponse:
|
||||
"""Get list of unique inventory locations (admin only)."""
|
||||
query = db.query(func.distinct(Inventory.location))
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
locations = [loc[0] for loc in query.all()]
|
||||
|
||||
return AdminInventoryLocationsResponse(locations=sorted(locations))
|
||||
|
||||
def get_vendor_inventory_admin(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
location: str | None = None,
|
||||
low_stock: int | None = None,
|
||||
) -> AdminInventoryListResponse:
|
||||
"""Get inventory for a specific vendor (admin only)."""
|
||||
# Verify vendor exists
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
|
||||
|
||||
# Use the existing method
|
||||
inventories = self.get_vendor_inventory(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
location=location,
|
||||
low_stock_threshold=low_stock,
|
||||
)
|
||||
|
||||
# Build response with product info
|
||||
items = []
|
||||
for inv in inventories:
|
||||
product = inv.product
|
||||
title = None
|
||||
if product and product.marketplace_product:
|
||||
title = product.marketplace_product.get_title()
|
||||
|
||||
items.append(
|
||||
AdminInventoryItem(
|
||||
id=inv.id,
|
||||
product_id=inv.product_id,
|
||||
vendor_id=inv.vendor_id,
|
||||
vendor_name=vendor.name,
|
||||
vendor_code=vendor.vendor_code,
|
||||
product_title=title,
|
||||
product_sku=product.vendor_sku if product else None,
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
gtin=inv.gtin,
|
||||
created_at=inv.created_at,
|
||||
updated_at=inv.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_query = db.query(func.count(Inventory.id)).filter(
|
||||
Inventory.vendor_id == vendor_id
|
||||
)
|
||||
if location:
|
||||
total_query = total_query.filter(Inventory.location.ilike(f"%{location}%"))
|
||||
if low_stock is not None:
|
||||
total_query = total_query.filter(Inventory.quantity <= low_stock)
|
||||
total = total_query.scalar() or 0
|
||||
|
||||
return AdminInventoryListResponse(
|
||||
inventories=items,
|
||||
total=total,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
vendor_filter=vendor_id,
|
||||
location_filter=location,
|
||||
)
|
||||
|
||||
def get_product_inventory_admin(
|
||||
self, db: Session, product_id: int
|
||||
) -> ProductInventorySummary:
|
||||
"""Get inventory summary for a product (admin only - no vendor check)."""
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
raise ProductNotFoundException(f"Product {product_id} not found")
|
||||
|
||||
# Use existing method with the product's vendor_id
|
||||
return self.get_product_inventory(db, product.vendor_id, product_id)
|
||||
|
||||
def verify_vendor_exists(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""Verify vendor exists and return it."""
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
|
||||
return vendor
|
||||
|
||||
def get_inventory_by_id_admin(self, db: Session, inventory_id: int) -> Inventory:
|
||||
"""Get inventory by ID (admin only - returns inventory with vendor_id)."""
|
||||
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
return inventory
|
||||
|
||||
# =========================================================================
|
||||
# Private helper methods
|
||||
# =========================================================================
|
||||
|
||||
def _get_vendor_product(
|
||||
self, db: Session, vendor_id: int, product_id: int
|
||||
) -> Product:
|
||||
"""Get product and verify it belongs to vendor."""
|
||||
product = (
|
||||
db.query(Product)
|
||||
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not product:
|
||||
raise ProductNotFoundException(
|
||||
f"Product {product_id} not found in your catalog"
|
||||
)
|
||||
|
||||
return product
|
||||
|
||||
def _get_inventory_entry(
|
||||
self, db: Session, product_id: int, location: str
|
||||
) -> Inventory | None:
|
||||
"""Get inventory entry by product and location."""
|
||||
return (
|
||||
db.query(Inventory)
|
||||
.filter(Inventory.product_id == product_id, Inventory.location == location)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _get_inventory_by_id(self, db: Session, inventory_id: int) -> Inventory:
|
||||
"""Get inventory by ID or raise exception."""
|
||||
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
return inventory
|
||||
|
||||
def _validate_location(self, location: str) -> str:
|
||||
"""Validate and normalize location."""
|
||||
if not location or not location.strip():
|
||||
raise InventoryValidationException("Location is required")
|
||||
return location.strip().upper()
|
||||
|
||||
def _validate_quantity(self, quantity: int, allow_zero: bool = True) -> None:
|
||||
"""Validate quantity value."""
|
||||
if quantity is None:
|
||||
raise InvalidQuantityException("Quantity is required")
|
||||
|
||||
if not isinstance(quantity, int):
|
||||
raise InvalidQuantityException("Quantity must be an integer")
|
||||
|
||||
if quantity < 0:
|
||||
raise InvalidQuantityException("Quantity cannot be negative")
|
||||
|
||||
if not allow_zero and quantity == 0:
|
||||
raise InvalidQuantityException("Quantity must be positive")
|
||||
|
||||
|
||||
# Create service instance
|
||||
inventory_service = InventoryService()
|
||||
431
app/modules/inventory/services/inventory_transaction_service.py
Normal file
431
app/modules/inventory/services/inventory_transaction_service.py
Normal file
@@ -0,0 +1,431 @@
|
||||
# app/modules/inventory/services/inventory_transaction_service.py
|
||||
"""
|
||||
Inventory Transaction Service.
|
||||
|
||||
Provides query operations for inventory transaction history.
|
||||
All transaction WRITES are handled by OrderInventoryService.
|
||||
This service handles transaction READS for reporting and auditing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import OrderNotFoundException, ProductNotFoundException
|
||||
from app.modules.inventory.models.inventory import Inventory
|
||||
from app.modules.inventory.models.inventory_transaction import InventoryTransaction
|
||||
from models.database.order import Order
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InventoryTransactionService:
|
||||
"""Service for querying inventory transaction history."""
|
||||
|
||||
def get_vendor_transactions(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
product_id: int | None = None,
|
||||
transaction_type: str | None = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get inventory transactions for a vendor with optional filters.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
product_id: Optional product filter
|
||||
transaction_type: Optional transaction type filter
|
||||
|
||||
Returns:
|
||||
Tuple of (transactions with product details, total count)
|
||||
"""
|
||||
# Build query
|
||||
query = db.query(InventoryTransaction).filter(
|
||||
InventoryTransaction.vendor_id == vendor_id
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if product_id:
|
||||
query = query.filter(InventoryTransaction.product_id == product_id)
|
||||
if transaction_type:
|
||||
query = query.filter(
|
||||
InventoryTransaction.transaction_type == transaction_type
|
||||
)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get transactions with pagination (newest first)
|
||||
transactions = (
|
||||
query.order_by(InventoryTransaction.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Build result with product details
|
||||
result = []
|
||||
for tx in transactions:
|
||||
product = db.query(Product).filter(Product.id == tx.product_id).first()
|
||||
product_title = None
|
||||
product_sku = None
|
||||
if product:
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
}
|
||||
)
|
||||
|
||||
return result, total
|
||||
|
||||
def get_product_history(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
product_id: int,
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""
|
||||
Get transaction history for a specific product.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
product_id: Product ID
|
||||
limit: Max transactions to return
|
||||
|
||||
Returns:
|
||||
Dict with product info, current inventory, and transactions
|
||||
|
||||
Raises:
|
||||
ProductNotFoundException: If product not found or doesn't belong to vendor
|
||||
"""
|
||||
# Get product details
|
||||
product = (
|
||||
db.query(Product)
|
||||
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not product:
|
||||
raise ProductNotFoundException(
|
||||
f"Product {product_id} not found in vendor catalog"
|
||||
)
|
||||
|
||||
product_title = None
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
# Get current inventory
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(Inventory.product_id == product_id, Inventory.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
current_quantity = inventory.quantity if inventory else 0
|
||||
current_reserved = inventory.reserved_quantity if inventory else 0
|
||||
|
||||
# Get transactions
|
||||
transactions = (
|
||||
db.query(InventoryTransaction)
|
||||
.filter(
|
||||
InventoryTransaction.vendor_id == vendor_id,
|
||||
InventoryTransaction.product_id == product_id,
|
||||
)
|
||||
.order_by(InventoryTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
total = (
|
||||
db.query(func.count(InventoryTransaction.id))
|
||||
.filter(
|
||||
InventoryTransaction.vendor_id == vendor_id,
|
||||
InventoryTransaction.product_id == product_id,
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
"current_quantity": current_quantity,
|
||||
"current_reserved": current_reserved,
|
||||
"transactions": [
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
}
|
||||
for tx in transactions
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
|
||||
def get_order_history(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Get all inventory transactions for a specific order.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
Dict with order info and transactions
|
||||
|
||||
Raises:
|
||||
OrderNotFoundException: If order not found or doesn't belong to vendor
|
||||
"""
|
||||
# Verify order belongs to vendor
|
||||
order = (
|
||||
db.query(Order)
|
||||
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not order:
|
||||
raise OrderNotFoundException(f"Order {order_id} not found")
|
||||
|
||||
# Get transactions for this order
|
||||
transactions = (
|
||||
db.query(InventoryTransaction)
|
||||
.filter(InventoryTransaction.order_id == order_id)
|
||||
.order_by(InventoryTransaction.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# Build result with product details
|
||||
result = []
|
||||
for tx in transactions:
|
||||
product = db.query(Product).filter(Product.id == tx.product_id).first()
|
||||
product_title = None
|
||||
product_sku = None
|
||||
if product:
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"transactions": result,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Admin Methods (cross-vendor operations)
|
||||
# =========================================================================
|
||||
|
||||
def get_all_transactions_admin(
|
||||
self,
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
vendor_id: int | None = None,
|
||||
product_id: int | None = None,
|
||||
transaction_type: str | None = None,
|
||||
order_id: int | None = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get inventory transactions across all vendors (admin only).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
vendor_id: Optional vendor filter
|
||||
product_id: Optional product filter
|
||||
transaction_type: Optional transaction type filter
|
||||
order_id: Optional order filter
|
||||
|
||||
Returns:
|
||||
Tuple of (transactions with details, total count)
|
||||
"""
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
# Build query
|
||||
query = db.query(InventoryTransaction)
|
||||
|
||||
# Apply filters
|
||||
if vendor_id:
|
||||
query = query.filter(InventoryTransaction.vendor_id == vendor_id)
|
||||
if product_id:
|
||||
query = query.filter(InventoryTransaction.product_id == product_id)
|
||||
if transaction_type:
|
||||
query = query.filter(
|
||||
InventoryTransaction.transaction_type == transaction_type
|
||||
)
|
||||
if order_id:
|
||||
query = query.filter(InventoryTransaction.order_id == order_id)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get transactions with pagination (newest first)
|
||||
transactions = (
|
||||
query.order_by(InventoryTransaction.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Build result with vendor and product details
|
||||
result = []
|
||||
for tx in transactions:
|
||||
vendor = db.query(Vendor).filter(Vendor.id == tx.vendor_id).first()
|
||||
product = db.query(Product).filter(Product.id == tx.product_id).first()
|
||||
|
||||
product_title = None
|
||||
product_sku = None
|
||||
if product:
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"vendor_name": vendor.name if vendor else None,
|
||||
"vendor_code": vendor.vendor_code if vendor else None,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
}
|
||||
)
|
||||
|
||||
return result, total
|
||||
|
||||
def get_transaction_stats_admin(self, db: Session) -> dict:
|
||||
"""
|
||||
Get transaction statistics across the platform (admin only).
|
||||
|
||||
Returns:
|
||||
Dict with transaction counts by type
|
||||
"""
|
||||
from sqlalchemy import func as sql_func
|
||||
|
||||
# Count by transaction type
|
||||
type_counts = (
|
||||
db.query(
|
||||
InventoryTransaction.transaction_type,
|
||||
sql_func.count(InventoryTransaction.id).label("count"),
|
||||
)
|
||||
.group_by(InventoryTransaction.transaction_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Total transactions
|
||||
total = db.query(sql_func.count(InventoryTransaction.id)).scalar() or 0
|
||||
|
||||
# Transactions today
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_count = (
|
||||
db.query(sql_func.count(InventoryTransaction.id))
|
||||
.filter(InventoryTransaction.created_at >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_transactions": total,
|
||||
"transactions_today": today_count,
|
||||
"by_type": {tc.transaction_type.value: tc.count for tc in type_counts},
|
||||
}
|
||||
|
||||
|
||||
# Create service instance
|
||||
inventory_transaction_service = InventoryTransactionService()
|
||||
@@ -2,45 +2,35 @@
|
||||
"""
|
||||
Marketplace module services.
|
||||
|
||||
Re-exports Letzshop and marketplace services from their current locations.
|
||||
Services remain in app/services/ for now to avoid breaking existing imports.
|
||||
|
||||
Usage:
|
||||
from app.modules.marketplace.services import (
|
||||
letzshop_export_service,
|
||||
marketplace_import_job_service,
|
||||
marketplace_product_service,
|
||||
)
|
||||
from app.modules.marketplace.services.letzshop import (
|
||||
LetzshopClient,
|
||||
LetzshopCredentialsService,
|
||||
LetzshopOrderService,
|
||||
LetzshopVendorSyncService,
|
||||
)
|
||||
This module contains the canonical implementations of marketplace-related services.
|
||||
"""
|
||||
|
||||
# Re-export from existing locations for convenience
|
||||
from app.services.letzshop_export_service import (
|
||||
# Main marketplace services
|
||||
from app.modules.marketplace.services.letzshop_export_service import (
|
||||
LetzshopExportService,
|
||||
letzshop_export_service,
|
||||
)
|
||||
from app.services.marketplace_import_job_service import (
|
||||
from app.modules.marketplace.services.marketplace_import_job_service import (
|
||||
MarketplaceImportJobService,
|
||||
marketplace_import_job_service,
|
||||
)
|
||||
from app.services.marketplace_product_service import (
|
||||
from app.modules.marketplace.services.marketplace_product_service import (
|
||||
MarketplaceProductService,
|
||||
marketplace_product_service,
|
||||
)
|
||||
|
||||
# Letzshop submodule re-exports
|
||||
from app.services.letzshop import (
|
||||
# Letzshop submodule services
|
||||
from app.modules.marketplace.services.letzshop import (
|
||||
LetzshopClient,
|
||||
LetzshopClientError,
|
||||
)
|
||||
from app.services.letzshop.credentials_service import LetzshopCredentialsService
|
||||
from app.services.letzshop.order_service import LetzshopOrderService
|
||||
from app.services.letzshop.vendor_sync_service import LetzshopVendorSyncService
|
||||
from app.modules.marketplace.services.letzshop.credentials_service import (
|
||||
LetzshopCredentialsService,
|
||||
)
|
||||
from app.modules.marketplace.services.letzshop.order_service import LetzshopOrderService
|
||||
from app.modules.marketplace.services.letzshop.vendor_sync_service import (
|
||||
LetzshopVendorSyncService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Export service
|
||||
|
||||
53
app/modules/marketplace/services/letzshop/__init__.py
Normal file
53
app/modules/marketplace/services/letzshop/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# app/modules/marketplace/services/letzshop/__init__.py
|
||||
"""
|
||||
Letzshop marketplace integration services.
|
||||
|
||||
Provides:
|
||||
- GraphQL client for API communication
|
||||
- Credential management service
|
||||
- Order import service
|
||||
- Fulfillment sync service
|
||||
- Vendor directory sync service
|
||||
"""
|
||||
|
||||
from .client_service import (
|
||||
LetzshopAPIError,
|
||||
LetzshopAuthError,
|
||||
LetzshopClient,
|
||||
LetzshopClientError,
|
||||
LetzshopConnectionError,
|
||||
)
|
||||
from .credentials_service import (
|
||||
CredentialsError,
|
||||
CredentialsNotFoundError,
|
||||
LetzshopCredentialsService,
|
||||
)
|
||||
from .order_service import (
|
||||
LetzshopOrderService,
|
||||
OrderNotFoundError,
|
||||
VendorNotFoundError,
|
||||
)
|
||||
from .vendor_sync_service import (
|
||||
LetzshopVendorSyncService,
|
||||
get_vendor_sync_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"LetzshopClient",
|
||||
"LetzshopClientError",
|
||||
"LetzshopAuthError",
|
||||
"LetzshopAPIError",
|
||||
"LetzshopConnectionError",
|
||||
# Credentials
|
||||
"LetzshopCredentialsService",
|
||||
"CredentialsError",
|
||||
"CredentialsNotFoundError",
|
||||
# Order Service
|
||||
"LetzshopOrderService",
|
||||
"OrderNotFoundError",
|
||||
"VendorNotFoundError",
|
||||
# Vendor Sync Service
|
||||
"LetzshopVendorSyncService",
|
||||
"get_vendor_sync_service",
|
||||
]
|
||||
1015
app/modules/marketplace/services/letzshop/client_service.py
Normal file
1015
app/modules/marketplace/services/letzshop/client_service.py
Normal file
File diff suppressed because it is too large
Load Diff
400
app/modules/marketplace/services/letzshop/credentials_service.py
Normal file
400
app/modules/marketplace/services/letzshop/credentials_service.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# app/services/letzshop/credentials_service.py
|
||||
"""
|
||||
Letzshop credentials management service.
|
||||
|
||||
Handles secure storage and retrieval of per-vendor Letzshop API credentials.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.encryption import decrypt_value, encrypt_value, mask_api_key
|
||||
from models.database.letzshop import VendorLetzshopCredentials
|
||||
|
||||
from .client_service import LetzshopClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default Letzshop GraphQL endpoint
|
||||
DEFAULT_ENDPOINT = "https://letzshop.lu/graphql"
|
||||
|
||||
|
||||
class CredentialsError(Exception):
|
||||
"""Base exception for credentials errors."""
|
||||
|
||||
|
||||
class CredentialsNotFoundError(CredentialsError):
|
||||
"""Raised when credentials are not found for a vendor."""
|
||||
|
||||
|
||||
class LetzshopCredentialsService:
|
||||
"""
|
||||
Service for managing Letzshop API credentials.
|
||||
|
||||
Provides secure storage and retrieval of encrypted API keys,
|
||||
connection testing, and sync status updates.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize the credentials service.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session.
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
# ========================================================================
|
||||
# CRUD Operations
|
||||
# ========================================================================
|
||||
|
||||
def get_credentials(self, vendor_id: int) -> VendorLetzshopCredentials | None:
|
||||
"""
|
||||
Get Letzshop credentials for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
VendorLetzshopCredentials or None if not found.
|
||||
"""
|
||||
return (
|
||||
self.db.query(VendorLetzshopCredentials)
|
||||
.filter(VendorLetzshopCredentials.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_credentials_or_raise(self, vendor_id: int) -> VendorLetzshopCredentials:
|
||||
"""
|
||||
Get Letzshop credentials for a vendor or raise an exception.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
VendorLetzshopCredentials.
|
||||
|
||||
Raises:
|
||||
CredentialsNotFoundError: If credentials are not found.
|
||||
"""
|
||||
credentials = self.get_credentials(vendor_id)
|
||||
if credentials is None:
|
||||
raise CredentialsNotFoundError(
|
||||
f"Letzshop credentials not found for vendor {vendor_id}"
|
||||
)
|
||||
return credentials
|
||||
|
||||
def create_credentials(
|
||||
self,
|
||||
vendor_id: int,
|
||||
api_key: str,
|
||||
api_endpoint: str | None = None,
|
||||
auto_sync_enabled: bool = False,
|
||||
sync_interval_minutes: int = 15,
|
||||
) -> VendorLetzshopCredentials:
|
||||
"""
|
||||
Create Letzshop credentials for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
api_key: The Letzshop API key (will be encrypted).
|
||||
api_endpoint: Custom API endpoint (optional).
|
||||
auto_sync_enabled: Whether to enable automatic sync.
|
||||
sync_interval_minutes: Sync interval in minutes.
|
||||
|
||||
Returns:
|
||||
Created VendorLetzshopCredentials.
|
||||
"""
|
||||
# Encrypt the API key
|
||||
encrypted_key = encrypt_value(api_key)
|
||||
|
||||
credentials = VendorLetzshopCredentials(
|
||||
vendor_id=vendor_id,
|
||||
api_key_encrypted=encrypted_key,
|
||||
api_endpoint=api_endpoint or DEFAULT_ENDPOINT,
|
||||
auto_sync_enabled=auto_sync_enabled,
|
||||
sync_interval_minutes=sync_interval_minutes,
|
||||
)
|
||||
|
||||
self.db.add(credentials)
|
||||
self.db.flush()
|
||||
|
||||
logger.info(f"Created Letzshop credentials for vendor {vendor_id}")
|
||||
return credentials
|
||||
|
||||
def update_credentials(
|
||||
self,
|
||||
vendor_id: int,
|
||||
api_key: str | None = None,
|
||||
api_endpoint: str | None = None,
|
||||
auto_sync_enabled: bool | None = None,
|
||||
sync_interval_minutes: int | None = None,
|
||||
) -> VendorLetzshopCredentials:
|
||||
"""
|
||||
Update Letzshop credentials for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
api_key: New API key (optional, will be encrypted if provided).
|
||||
api_endpoint: New API endpoint (optional).
|
||||
auto_sync_enabled: New auto-sync setting (optional).
|
||||
sync_interval_minutes: New sync interval (optional).
|
||||
|
||||
Returns:
|
||||
Updated VendorLetzshopCredentials.
|
||||
|
||||
Raises:
|
||||
CredentialsNotFoundError: If credentials are not found.
|
||||
"""
|
||||
credentials = self.get_credentials_or_raise(vendor_id)
|
||||
|
||||
if api_key is not None:
|
||||
credentials.api_key_encrypted = encrypt_value(api_key)
|
||||
if api_endpoint is not None:
|
||||
credentials.api_endpoint = api_endpoint
|
||||
if auto_sync_enabled is not None:
|
||||
credentials.auto_sync_enabled = auto_sync_enabled
|
||||
if sync_interval_minutes is not None:
|
||||
credentials.sync_interval_minutes = sync_interval_minutes
|
||||
|
||||
self.db.flush()
|
||||
|
||||
logger.info(f"Updated Letzshop credentials for vendor {vendor_id}")
|
||||
return credentials
|
||||
|
||||
def delete_credentials(self, vendor_id: int) -> bool:
|
||||
"""
|
||||
Delete Letzshop credentials for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found.
|
||||
"""
|
||||
credentials = self.get_credentials(vendor_id)
|
||||
if credentials is None:
|
||||
return False
|
||||
|
||||
self.db.delete(credentials)
|
||||
self.db.flush()
|
||||
|
||||
logger.info(f"Deleted Letzshop credentials for vendor {vendor_id}")
|
||||
return True
|
||||
|
||||
def upsert_credentials(
|
||||
self,
|
||||
vendor_id: int,
|
||||
api_key: str,
|
||||
api_endpoint: str | None = None,
|
||||
auto_sync_enabled: bool = False,
|
||||
sync_interval_minutes: int = 15,
|
||||
) -> VendorLetzshopCredentials:
|
||||
"""
|
||||
Create or update Letzshop credentials for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
api_key: The Letzshop API key (will be encrypted).
|
||||
api_endpoint: Custom API endpoint (optional).
|
||||
auto_sync_enabled: Whether to enable automatic sync.
|
||||
sync_interval_minutes: Sync interval in minutes.
|
||||
|
||||
Returns:
|
||||
Created or updated VendorLetzshopCredentials.
|
||||
"""
|
||||
existing = self.get_credentials(vendor_id)
|
||||
|
||||
if existing:
|
||||
return self.update_credentials(
|
||||
vendor_id=vendor_id,
|
||||
api_key=api_key,
|
||||
api_endpoint=api_endpoint,
|
||||
auto_sync_enabled=auto_sync_enabled,
|
||||
sync_interval_minutes=sync_interval_minutes,
|
||||
)
|
||||
|
||||
return self.create_credentials(
|
||||
vendor_id=vendor_id,
|
||||
api_key=api_key,
|
||||
api_endpoint=api_endpoint,
|
||||
auto_sync_enabled=auto_sync_enabled,
|
||||
sync_interval_minutes=sync_interval_minutes,
|
||||
)
|
||||
|
||||
# ========================================================================
|
||||
# Key Decryption and Client Creation
|
||||
# ========================================================================
|
||||
|
||||
def get_decrypted_api_key(self, vendor_id: int) -> str:
|
||||
"""
|
||||
Get the decrypted API key for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
Decrypted API key.
|
||||
|
||||
Raises:
|
||||
CredentialsNotFoundError: If credentials are not found.
|
||||
"""
|
||||
credentials = self.get_credentials_or_raise(vendor_id)
|
||||
return decrypt_value(credentials.api_key_encrypted)
|
||||
|
||||
def get_masked_api_key(self, vendor_id: int) -> str:
|
||||
"""
|
||||
Get a masked version of the API key for display.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
Masked API key (e.g., "sk-a***************").
|
||||
|
||||
Raises:
|
||||
CredentialsNotFoundError: If credentials are not found.
|
||||
"""
|
||||
api_key = self.get_decrypted_api_key(vendor_id)
|
||||
return mask_api_key(api_key)
|
||||
|
||||
def create_client(self, vendor_id: int) -> LetzshopClient:
|
||||
"""
|
||||
Create a Letzshop client for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
Configured LetzshopClient.
|
||||
|
||||
Raises:
|
||||
CredentialsNotFoundError: If credentials are not found.
|
||||
"""
|
||||
credentials = self.get_credentials_or_raise(vendor_id)
|
||||
api_key = decrypt_value(credentials.api_key_encrypted)
|
||||
|
||||
return LetzshopClient(
|
||||
api_key=api_key,
|
||||
endpoint=credentials.api_endpoint,
|
||||
)
|
||||
|
||||
# ========================================================================
|
||||
# Connection Testing
|
||||
# ========================================================================
|
||||
|
||||
def test_connection(self, vendor_id: int) -> tuple[bool, float | None, str | None]:
|
||||
"""
|
||||
Test the connection for a vendor's credentials.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, response_time_ms, error_message).
|
||||
"""
|
||||
try:
|
||||
with self.create_client(vendor_id) as client:
|
||||
return client.test_connection()
|
||||
except CredentialsNotFoundError:
|
||||
return False, None, "Letzshop credentials not configured"
|
||||
except Exception as e:
|
||||
logger.error(f"Connection test failed for vendor {vendor_id}: {e}")
|
||||
return False, None, str(e)
|
||||
|
||||
def test_api_key(
|
||||
self,
|
||||
api_key: str,
|
||||
api_endpoint: str | None = None,
|
||||
) -> tuple[bool, float | None, str | None]:
|
||||
"""
|
||||
Test an API key without saving it.
|
||||
|
||||
Args:
|
||||
api_key: The API key to test.
|
||||
api_endpoint: Optional custom endpoint.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, response_time_ms, error_message).
|
||||
"""
|
||||
try:
|
||||
with LetzshopClient(
|
||||
api_key=api_key,
|
||||
endpoint=api_endpoint or DEFAULT_ENDPOINT,
|
||||
) as client:
|
||||
return client.test_connection()
|
||||
except Exception as e:
|
||||
logger.error(f"API key test failed: {e}")
|
||||
return False, None, str(e)
|
||||
|
||||
# ========================================================================
|
||||
# Sync Status Updates
|
||||
# ========================================================================
|
||||
|
||||
def update_sync_status(
|
||||
self,
|
||||
vendor_id: int,
|
||||
status: str,
|
||||
error: str | None = None,
|
||||
) -> VendorLetzshopCredentials | None:
|
||||
"""
|
||||
Update the last sync status for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
status: Sync status (success, failed, partial).
|
||||
error: Error message if sync failed.
|
||||
|
||||
Returns:
|
||||
Updated credentials or None if not found.
|
||||
"""
|
||||
credentials = self.get_credentials(vendor_id)
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
credentials.last_sync_at = datetime.now(UTC)
|
||||
credentials.last_sync_status = status
|
||||
credentials.last_sync_error = error
|
||||
|
||||
self.db.flush()
|
||||
|
||||
return credentials
|
||||
|
||||
# ========================================================================
|
||||
# Status Helpers
|
||||
# ========================================================================
|
||||
|
||||
def is_configured(self, vendor_id: int) -> bool:
|
||||
"""Check if Letzshop is configured for a vendor."""
|
||||
return self.get_credentials(vendor_id) is not None
|
||||
|
||||
def get_status(self, vendor_id: int) -> dict:
|
||||
"""
|
||||
Get the Letzshop integration status for a vendor.
|
||||
|
||||
Args:
|
||||
vendor_id: The vendor ID.
|
||||
|
||||
Returns:
|
||||
Status dictionary with configuration and sync info.
|
||||
"""
|
||||
credentials = self.get_credentials(vendor_id)
|
||||
|
||||
if credentials is None:
|
||||
return {
|
||||
"is_configured": False,
|
||||
"is_connected": False,
|
||||
"last_sync_at": None,
|
||||
"last_sync_status": None,
|
||||
"auto_sync_enabled": False,
|
||||
}
|
||||
|
||||
return {
|
||||
"is_configured": True,
|
||||
"is_connected": credentials.last_sync_status == "success",
|
||||
"last_sync_at": credentials.last_sync_at,
|
||||
"last_sync_status": credentials.last_sync_status,
|
||||
"auto_sync_enabled": credentials.auto_sync_enabled,
|
||||
}
|
||||
1136
app/modules/marketplace/services/letzshop/order_service.py
Normal file
1136
app/modules/marketplace/services/letzshop/order_service.py
Normal file
File diff suppressed because it is too large
Load Diff
521
app/modules/marketplace/services/letzshop/vendor_sync_service.py
Normal file
521
app/modules/marketplace/services/letzshop/vendor_sync_service.py
Normal file
@@ -0,0 +1,521 @@
|
||||
# app/services/letzshop/vendor_sync_service.py
|
||||
"""
|
||||
Service for syncing Letzshop vendor directory to local cache.
|
||||
|
||||
Fetches vendor data from Letzshop's public GraphQL API and stores it
|
||||
in the letzshop_vendor_cache table for fast lookups during signup.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.letzshop.client_service import LetzshopClient
|
||||
from models.database.letzshop import LetzshopVendorCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LetzshopVendorSyncService:
|
||||
"""
|
||||
Service for syncing Letzshop vendor directory.
|
||||
|
||||
Usage:
|
||||
service = LetzshopVendorSyncService(db)
|
||||
stats = service.sync_all_vendors()
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize the sync service."""
|
||||
self.db = db
|
||||
|
||||
def sync_all_vendors(
|
||||
self,
|
||||
progress_callback: Callable[[int, int, int], None] | None = None,
|
||||
max_pages: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Sync all vendors from Letzshop to local cache.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback(page, fetched, total) for progress.
|
||||
|
||||
Returns:
|
||||
Dictionary with sync statistics.
|
||||
"""
|
||||
stats = {
|
||||
"started_at": datetime.now(UTC),
|
||||
"total_fetched": 0,
|
||||
"created": 0,
|
||||
"updated": 0,
|
||||
"errors": 0,
|
||||
"error_details": [],
|
||||
}
|
||||
|
||||
logger.info("Starting Letzshop vendor directory sync...")
|
||||
|
||||
# Create client (no API key needed for public vendor data)
|
||||
client = LetzshopClient(api_key="")
|
||||
|
||||
try:
|
||||
# Fetch all vendors
|
||||
vendors = client.get_all_vendors_paginated(
|
||||
page_size=50,
|
||||
max_pages=max_pages,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
stats["total_fetched"] = len(vendors)
|
||||
logger.info(f"Fetched {len(vendors)} vendors from Letzshop")
|
||||
|
||||
# Process each vendor
|
||||
for vendor_data in vendors:
|
||||
try:
|
||||
result = self._upsert_vendor(vendor_data)
|
||||
if result == "created":
|
||||
stats["created"] += 1
|
||||
elif result == "updated":
|
||||
stats["updated"] += 1
|
||||
except Exception as e:
|
||||
stats["errors"] += 1
|
||||
error_info = {
|
||||
"vendor_id": vendor_data.get("id"),
|
||||
"slug": vendor_data.get("slug"),
|
||||
"error": str(e),
|
||||
}
|
||||
stats["error_details"].append(error_info)
|
||||
logger.error(f"Error processing vendor {vendor_data.get('slug')}: {e}")
|
||||
|
||||
# Commit all changes
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Sync complete: {stats['created']} created, "
|
||||
f"{stats['updated']} updated, {stats['errors']} errors"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Vendor sync failed: {e}")
|
||||
stats["error"] = str(e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
stats["completed_at"] = datetime.now(UTC)
|
||||
stats["duration_seconds"] = (
|
||||
stats["completed_at"] - stats["started_at"]
|
||||
).total_seconds()
|
||||
|
||||
return stats
|
||||
|
||||
def _upsert_vendor(self, vendor_data: dict[str, Any]) -> str:
|
||||
"""
|
||||
Insert or update a vendor in the cache.
|
||||
|
||||
Args:
|
||||
vendor_data: Raw vendor data from Letzshop API.
|
||||
|
||||
Returns:
|
||||
"created" or "updated" indicating the operation performed.
|
||||
"""
|
||||
letzshop_id = vendor_data.get("id")
|
||||
slug = vendor_data.get("slug")
|
||||
|
||||
if not letzshop_id or not slug:
|
||||
raise ValueError("Vendor missing required id or slug")
|
||||
|
||||
# Parse the vendor data
|
||||
parsed = self._parse_vendor_data(vendor_data)
|
||||
|
||||
# Check if exists
|
||||
existing = (
|
||||
self.db.query(LetzshopVendorCache)
|
||||
.filter(LetzshopVendorCache.letzshop_id == letzshop_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
# Update existing record (preserve claimed status)
|
||||
for key, value in parsed.items():
|
||||
if key not in ("claimed_by_vendor_id", "claimed_at"):
|
||||
setattr(existing, key, value)
|
||||
existing.last_synced_at = datetime.now(UTC)
|
||||
return "updated"
|
||||
else:
|
||||
# Create new record
|
||||
cache_entry = LetzshopVendorCache(
|
||||
**parsed,
|
||||
last_synced_at=datetime.now(UTC),
|
||||
)
|
||||
self.db.add(cache_entry)
|
||||
return "created"
|
||||
|
||||
def _parse_vendor_data(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Parse raw Letzshop vendor data into cache model fields.
|
||||
|
||||
Args:
|
||||
data: Raw vendor data from Letzshop API.
|
||||
|
||||
Returns:
|
||||
Dictionary of parsed fields for LetzshopVendorCache.
|
||||
"""
|
||||
# Extract location
|
||||
location = data.get("location") or {}
|
||||
country = location.get("country") or {}
|
||||
|
||||
# Extract descriptions
|
||||
description = data.get("description") or {}
|
||||
|
||||
# Extract opening hours
|
||||
opening_hours = data.get("openingHours") or {}
|
||||
|
||||
# Extract categories (list of translated name objects)
|
||||
categories = []
|
||||
for cat in data.get("vendorCategories") or []:
|
||||
cat_name = cat.get("name") or {}
|
||||
# Prefer English, fallback to French or German
|
||||
name = cat_name.get("en") or cat_name.get("fr") or cat_name.get("de")
|
||||
if name:
|
||||
categories.append(name)
|
||||
|
||||
# Extract social media URLs
|
||||
social_links = []
|
||||
for link in data.get("socialMediaLinks") or []:
|
||||
url = link.get("url")
|
||||
if url:
|
||||
social_links.append(url)
|
||||
|
||||
# Extract background image
|
||||
bg_image = data.get("backgroundImage") or {}
|
||||
|
||||
return {
|
||||
"letzshop_id": data.get("id"),
|
||||
"slug": data.get("slug"),
|
||||
"name": data.get("name"),
|
||||
"company_name": data.get("companyName") or data.get("legalName"),
|
||||
"is_active": data.get("active", True),
|
||||
# Descriptions
|
||||
"description_en": description.get("en"),
|
||||
"description_fr": description.get("fr"),
|
||||
"description_de": description.get("de"),
|
||||
# Contact
|
||||
"email": data.get("email"),
|
||||
"phone": data.get("phone"),
|
||||
"fax": data.get("fax"),
|
||||
"website": data.get("homepage"),
|
||||
# Location
|
||||
"street": location.get("street"),
|
||||
"street_number": location.get("number"),
|
||||
"city": location.get("city"),
|
||||
"zipcode": location.get("zipcode"),
|
||||
"country_iso": country.get("iso", "LU"),
|
||||
"latitude": str(data.get("lat")) if data.get("lat") else None,
|
||||
"longitude": str(data.get("lng")) if data.get("lng") else None,
|
||||
# Categories and media
|
||||
"categories": categories,
|
||||
"background_image_url": bg_image.get("url"),
|
||||
"social_media_links": social_links,
|
||||
# Opening hours
|
||||
"opening_hours_en": opening_hours.get("en"),
|
||||
"opening_hours_fr": opening_hours.get("fr"),
|
||||
"opening_hours_de": opening_hours.get("de"),
|
||||
# Representative
|
||||
"representative_name": data.get("representative"),
|
||||
"representative_title": data.get("representativeTitle"),
|
||||
# Raw data for reference
|
||||
"raw_data": data,
|
||||
}
|
||||
|
||||
def sync_single_vendor(self, slug: str) -> LetzshopVendorCache | None:
|
||||
"""
|
||||
Sync a single vendor by slug.
|
||||
|
||||
Useful for on-demand refresh when a user looks up a vendor.
|
||||
|
||||
Args:
|
||||
slug: The vendor's URL slug.
|
||||
|
||||
Returns:
|
||||
The updated/created cache entry, or None if not found.
|
||||
"""
|
||||
client = LetzshopClient(api_key="")
|
||||
|
||||
try:
|
||||
vendor_data = client.get_vendor_by_slug(slug)
|
||||
|
||||
if not vendor_data:
|
||||
logger.warning(f"Vendor not found on Letzshop: {slug}")
|
||||
return None
|
||||
|
||||
result = self._upsert_vendor(vendor_data)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Single vendor sync: {slug} ({result})")
|
||||
|
||||
return (
|
||||
self.db.query(LetzshopVendorCache)
|
||||
.filter(LetzshopVendorCache.slug == slug)
|
||||
.first()
|
||||
)
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
def get_cached_vendor(self, slug: str) -> LetzshopVendorCache | None:
|
||||
"""
|
||||
Get a vendor from cache by slug.
|
||||
|
||||
Args:
|
||||
slug: The vendor's URL slug.
|
||||
|
||||
Returns:
|
||||
Cache entry or None if not found.
|
||||
"""
|
||||
return (
|
||||
self.db.query(LetzshopVendorCache)
|
||||
.filter(LetzshopVendorCache.slug == slug.lower())
|
||||
.first()
|
||||
)
|
||||
|
||||
def search_cached_vendors(
|
||||
self,
|
||||
search: str | None = None,
|
||||
city: str | None = None,
|
||||
category: str | None = None,
|
||||
only_unclaimed: bool = False,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
) -> tuple[list[LetzshopVendorCache], int]:
|
||||
"""
|
||||
Search cached vendors with filters.
|
||||
|
||||
Args:
|
||||
search: Search term for name.
|
||||
city: Filter by city.
|
||||
category: Filter by category.
|
||||
only_unclaimed: Only return vendors not yet claimed.
|
||||
page: Page number (1-indexed).
|
||||
limit: Items per page.
|
||||
|
||||
Returns:
|
||||
Tuple of (vendors list, total count).
|
||||
"""
|
||||
query = self.db.query(LetzshopVendorCache).filter(
|
||||
LetzshopVendorCache.is_active == True # noqa: E712
|
||||
)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search.lower()}%"
|
||||
query = query.filter(
|
||||
func.lower(LetzshopVendorCache.name).like(search_term)
|
||||
)
|
||||
|
||||
if city:
|
||||
query = query.filter(
|
||||
func.lower(LetzshopVendorCache.city) == city.lower()
|
||||
)
|
||||
|
||||
if category:
|
||||
# Search in JSON array
|
||||
query = query.filter(
|
||||
LetzshopVendorCache.categories.contains([category])
|
||||
)
|
||||
|
||||
if only_unclaimed:
|
||||
query = query.filter(
|
||||
LetzshopVendorCache.claimed_by_vendor_id.is_(None)
|
||||
)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Apply pagination
|
||||
offset = (page - 1) * limit
|
||||
vendors = (
|
||||
query.order_by(LetzshopVendorCache.name)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return vendors, total
|
||||
|
||||
def get_sync_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about the vendor cache.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics.
|
||||
"""
|
||||
total = self.db.query(LetzshopVendorCache).count()
|
||||
active = (
|
||||
self.db.query(LetzshopVendorCache)
|
||||
.filter(LetzshopVendorCache.is_active == True) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
claimed = (
|
||||
self.db.query(LetzshopVendorCache)
|
||||
.filter(LetzshopVendorCache.claimed_by_vendor_id.isnot(None))
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get last sync time
|
||||
last_synced = (
|
||||
self.db.query(func.max(LetzshopVendorCache.last_synced_at)).scalar()
|
||||
)
|
||||
|
||||
# Get unique cities
|
||||
cities = (
|
||||
self.db.query(LetzshopVendorCache.city)
|
||||
.filter(LetzshopVendorCache.city.isnot(None))
|
||||
.distinct()
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_vendors": total,
|
||||
"active_vendors": active,
|
||||
"claimed_vendors": claimed,
|
||||
"unclaimed_vendors": active - claimed,
|
||||
"unique_cities": cities,
|
||||
"last_synced_at": last_synced.isoformat() if last_synced else None,
|
||||
}
|
||||
|
||||
def mark_vendor_claimed(
|
||||
self,
|
||||
letzshop_slug: str,
|
||||
vendor_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Mark a Letzshop vendor as claimed by a platform vendor.
|
||||
|
||||
Args:
|
||||
letzshop_slug: The Letzshop vendor slug.
|
||||
vendor_id: The platform vendor ID that claimed it.
|
||||
|
||||
Returns:
|
||||
True if successful, False if vendor not found.
|
||||
"""
|
||||
cache_entry = self.get_cached_vendor(letzshop_slug)
|
||||
|
||||
if not cache_entry:
|
||||
return False
|
||||
|
||||
cache_entry.claimed_by_vendor_id = vendor_id
|
||||
cache_entry.claimed_at = datetime.now(UTC)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Vendor {letzshop_slug} claimed by vendor_id={vendor_id}")
|
||||
return True
|
||||
|
||||
def create_vendor_from_cache(
|
||||
self,
|
||||
letzshop_slug: str,
|
||||
company_id: int,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a platform vendor from a cached Letzshop vendor.
|
||||
|
||||
Args:
|
||||
letzshop_slug: The Letzshop vendor slug.
|
||||
company_id: The company ID to create the vendor under.
|
||||
|
||||
Returns:
|
||||
Dictionary with created vendor info.
|
||||
|
||||
Raises:
|
||||
ValueError: If vendor not found, already claimed, or company not found.
|
||||
"""
|
||||
import random
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.services.admin_service import admin_service
|
||||
from models.database.company import Company
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.vendor import VendorCreate
|
||||
|
||||
# Get cache entry
|
||||
cache_entry = self.get_cached_vendor(letzshop_slug)
|
||||
if not cache_entry:
|
||||
raise ValueError(f"Letzshop vendor '{letzshop_slug}' not found in cache")
|
||||
|
||||
if cache_entry.is_claimed:
|
||||
raise ValueError(
|
||||
f"Letzshop vendor '{cache_entry.name}' is already claimed "
|
||||
f"by vendor ID {cache_entry.claimed_by_vendor_id}"
|
||||
)
|
||||
|
||||
# Verify company exists
|
||||
company = self.db.query(Company).filter(Company.id == company_id).first()
|
||||
if not company:
|
||||
raise ValueError(f"Company with ID {company_id} not found")
|
||||
|
||||
# Generate vendor code from slug
|
||||
vendor_code = letzshop_slug.upper().replace("-", "_")[:20]
|
||||
|
||||
# Check if vendor code already exists
|
||||
existing = (
|
||||
self.db.query(Vendor)
|
||||
.filter(func.upper(Vendor.vendor_code) == vendor_code)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
vendor_code = f"{vendor_code[:16]}_{random.randint(100, 999)}"
|
||||
|
||||
# Generate subdomain from slug
|
||||
subdomain = letzshop_slug.lower().replace("_", "-")[:30]
|
||||
existing_subdomain = (
|
||||
self.db.query(Vendor)
|
||||
.filter(func.lower(Vendor.subdomain) == subdomain)
|
||||
.first()
|
||||
)
|
||||
if existing_subdomain:
|
||||
subdomain = f"{subdomain[:26]}-{random.randint(100, 999)}"
|
||||
|
||||
# Create vendor data from cache
|
||||
address = f"{cache_entry.street or ''} {cache_entry.street_number or ''}".strip()
|
||||
vendor_data = VendorCreate(
|
||||
name=cache_entry.name,
|
||||
vendor_code=vendor_code,
|
||||
subdomain=subdomain,
|
||||
company_id=company_id,
|
||||
email=cache_entry.email or company.email,
|
||||
phone=cache_entry.phone,
|
||||
description=cache_entry.description_en or cache_entry.description_fr or "",
|
||||
city=cache_entry.city,
|
||||
country=cache_entry.country_iso or "LU",
|
||||
website=cache_entry.website,
|
||||
address_line_1=address or None,
|
||||
postal_code=cache_entry.zipcode,
|
||||
)
|
||||
|
||||
# Create vendor
|
||||
vendor = admin_service.create_vendor(self.db, vendor_data)
|
||||
|
||||
# Mark the Letzshop vendor as claimed (commits internally) # noqa: SVC-006
|
||||
self.mark_vendor_claimed(letzshop_slug, vendor.id)
|
||||
|
||||
logger.info(
|
||||
f"Created vendor {vendor.vendor_code} from Letzshop vendor {letzshop_slug}"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": vendor.id,
|
||||
"vendor_code": vendor.vendor_code,
|
||||
"name": vendor.name,
|
||||
"subdomain": vendor.subdomain,
|
||||
"company_id": vendor.company_id,
|
||||
}
|
||||
|
||||
|
||||
# Singleton-style function for easy access
|
||||
def get_vendor_sync_service(db: Session) -> LetzshopVendorSyncService:
|
||||
"""Get a vendor sync service instance."""
|
||||
return LetzshopVendorSyncService(db)
|
||||
338
app/modules/marketplace/services/letzshop_export_service.py
Normal file
338
app/modules/marketplace/services/letzshop_export_service.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# app/services/letzshop_export_service.py
|
||||
"""
|
||||
Service for exporting products to Letzshop CSV format.
|
||||
|
||||
Generates Google Shopping compatible CSV files for Letzshop marketplace.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from models.database.letzshop import LetzshopSyncLog
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Letzshop CSV columns in order
|
||||
LETZSHOP_CSV_COLUMNS = [
|
||||
"id",
|
||||
"title",
|
||||
"description",
|
||||
"link",
|
||||
"image_link",
|
||||
"additional_image_link",
|
||||
"availability",
|
||||
"price",
|
||||
"sale_price",
|
||||
"brand",
|
||||
"gtin",
|
||||
"mpn",
|
||||
"google_product_category",
|
||||
"product_type",
|
||||
"condition",
|
||||
"adult",
|
||||
"multipack",
|
||||
"is_bundle",
|
||||
"age_group",
|
||||
"color",
|
||||
"gender",
|
||||
"material",
|
||||
"pattern",
|
||||
"size",
|
||||
"size_type",
|
||||
"size_system",
|
||||
"item_group_id",
|
||||
"custom_label_0",
|
||||
"custom_label_1",
|
||||
"custom_label_2",
|
||||
"custom_label_3",
|
||||
"custom_label_4",
|
||||
"identifier_exists",
|
||||
"unit_pricing_measure",
|
||||
"unit_pricing_base_measure",
|
||||
"shipping",
|
||||
"atalanda:tax_rate",
|
||||
"atalanda:quantity",
|
||||
"atalanda:boost_sort",
|
||||
"atalanda:delivery_method",
|
||||
]
|
||||
|
||||
|
||||
class LetzshopExportService:
|
||||
"""Service for exporting products to Letzshop CSV format."""
|
||||
|
||||
def __init__(self, default_tax_rate: float = 17.0):
|
||||
"""
|
||||
Initialize the export service.
|
||||
|
||||
Args:
|
||||
default_tax_rate: Default VAT rate for Luxembourg (17%)
|
||||
"""
|
||||
self.default_tax_rate = default_tax_rate
|
||||
|
||||
def export_vendor_products(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
language: str = "en",
|
||||
include_inactive: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Export all products for a vendor in Letzshop CSV format.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID to export products for
|
||||
language: Language for title/description (en, fr, de)
|
||||
include_inactive: Whether to include inactive products
|
||||
|
||||
Returns:
|
||||
CSV string content
|
||||
"""
|
||||
# Query products for this vendor with their marketplace product data
|
||||
query = (
|
||||
db.query(Product)
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.options(
|
||||
joinedload(Product.marketplace_product).joinedload(
|
||||
MarketplaceProduct.translations
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not include_inactive:
|
||||
query = query.filter(Product.is_active == True)
|
||||
|
||||
products = query.all()
|
||||
|
||||
logger.info(
|
||||
f"Exporting {len(products)} products for vendor {vendor_id} in {language}"
|
||||
)
|
||||
|
||||
return self._generate_csv(products, language)
|
||||
|
||||
def export_marketplace_products(
|
||||
self,
|
||||
db: Session,
|
||||
marketplace: str = "Letzshop",
|
||||
language: str = "en",
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Export marketplace products directly (admin use).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
marketplace: Filter by marketplace source
|
||||
language: Language for title/description
|
||||
limit: Optional limit on number of products
|
||||
|
||||
Returns:
|
||||
CSV string content
|
||||
"""
|
||||
query = (
|
||||
db.query(MarketplaceProduct)
|
||||
.filter(MarketplaceProduct.is_active == True)
|
||||
.options(joinedload(MarketplaceProduct.translations))
|
||||
)
|
||||
|
||||
if marketplace:
|
||||
query = query.filter(
|
||||
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
|
||||
)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
products = query.all()
|
||||
|
||||
logger.info(
|
||||
f"Exporting {len(products)} marketplace products for {marketplace} in {language}"
|
||||
)
|
||||
|
||||
return self._generate_csv_from_marketplace_products(products, language)
|
||||
|
||||
def _generate_csv(self, products: list[Product], language: str) -> str:
|
||||
"""Generate CSV from vendor Product objects."""
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
output,
|
||||
fieldnames=LETZSHOP_CSV_COLUMNS,
|
||||
delimiter="\t",
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
for product in products:
|
||||
if product.marketplace_product:
|
||||
row = self._product_to_row(product, language)
|
||||
writer.writerow(row)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def _generate_csv_from_marketplace_products(
|
||||
self, products: list[MarketplaceProduct], language: str
|
||||
) -> str:
|
||||
"""Generate CSV from MarketplaceProduct objects directly."""
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
output,
|
||||
fieldnames=LETZSHOP_CSV_COLUMNS,
|
||||
delimiter="\t",
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
for mp in products:
|
||||
row = self._marketplace_product_to_row(mp, language)
|
||||
writer.writerow(row)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def _product_to_row(self, product: Product, language: str) -> dict:
|
||||
"""Convert a Product (with MarketplaceProduct) to a CSV row."""
|
||||
mp = product.marketplace_product
|
||||
return self._marketplace_product_to_row(
|
||||
mp, language, vendor_sku=product.vendor_sku
|
||||
)
|
||||
|
||||
def _marketplace_product_to_row(
|
||||
self,
|
||||
mp: MarketplaceProduct,
|
||||
language: str,
|
||||
vendor_sku: str | None = None,
|
||||
) -> dict:
|
||||
"""Convert a MarketplaceProduct to a CSV row dict."""
|
||||
# Get localized title and description
|
||||
title = mp.get_title(language) or ""
|
||||
description = mp.get_description(language) or ""
|
||||
|
||||
# Format price with currency
|
||||
price = ""
|
||||
if mp.price_numeric:
|
||||
price = f"{mp.price_numeric:.2f} {mp.currency or 'EUR'}"
|
||||
elif mp.price:
|
||||
price = mp.price
|
||||
|
||||
# Format sale price
|
||||
sale_price = ""
|
||||
if mp.sale_price_numeric:
|
||||
sale_price = f"{mp.sale_price_numeric:.2f} {mp.currency or 'EUR'}"
|
||||
elif mp.sale_price:
|
||||
sale_price = mp.sale_price
|
||||
|
||||
# Additional images - join with comma if multiple
|
||||
additional_images = ""
|
||||
if mp.additional_images:
|
||||
additional_images = ",".join(mp.additional_images)
|
||||
elif mp.additional_image_link:
|
||||
additional_images = mp.additional_image_link
|
||||
|
||||
# Determine identifier_exists
|
||||
identifier_exists = mp.identifier_exists
|
||||
if not identifier_exists:
|
||||
identifier_exists = "yes" if (mp.gtin or mp.mpn) else "no"
|
||||
|
||||
return {
|
||||
"id": vendor_sku or mp.marketplace_product_id,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"link": mp.link or mp.source_url or "",
|
||||
"image_link": mp.image_link or "",
|
||||
"additional_image_link": additional_images,
|
||||
"availability": mp.availability or "in stock",
|
||||
"price": price,
|
||||
"sale_price": sale_price,
|
||||
"brand": mp.brand or "",
|
||||
"gtin": mp.gtin or "",
|
||||
"mpn": mp.mpn or "",
|
||||
"google_product_category": mp.google_product_category or "",
|
||||
"product_type": mp.product_type_raw or "",
|
||||
"condition": mp.condition or "new",
|
||||
"adult": mp.adult or "no",
|
||||
"multipack": str(mp.multipack) if mp.multipack else "",
|
||||
"is_bundle": mp.is_bundle or "no",
|
||||
"age_group": mp.age_group or "",
|
||||
"color": mp.color or "",
|
||||
"gender": mp.gender or "",
|
||||
"material": mp.material or "",
|
||||
"pattern": mp.pattern or "",
|
||||
"size": mp.size or "",
|
||||
"size_type": mp.size_type or "",
|
||||
"size_system": mp.size_system or "",
|
||||
"item_group_id": mp.item_group_id or "",
|
||||
"custom_label_0": mp.custom_label_0 or "",
|
||||
"custom_label_1": mp.custom_label_1 or "",
|
||||
"custom_label_2": mp.custom_label_2 or "",
|
||||
"custom_label_3": mp.custom_label_3 or "",
|
||||
"custom_label_4": mp.custom_label_4 or "",
|
||||
"identifier_exists": identifier_exists,
|
||||
"unit_pricing_measure": mp.unit_pricing_measure or "",
|
||||
"unit_pricing_base_measure": mp.unit_pricing_base_measure or "",
|
||||
"shipping": mp.shipping or "",
|
||||
"atalanda:tax_rate": str(self.default_tax_rate),
|
||||
"atalanda:quantity": "", # Would need inventory data
|
||||
"atalanda:boost_sort": "",
|
||||
"atalanda:delivery_method": "",
|
||||
}
|
||||
|
||||
def log_export(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
started_at: datetime,
|
||||
completed_at: datetime,
|
||||
files_processed: int,
|
||||
files_succeeded: int,
|
||||
files_failed: int,
|
||||
products_exported: int,
|
||||
triggered_by: str,
|
||||
error_details: dict | None = None,
|
||||
) -> LetzshopSyncLog:
|
||||
"""
|
||||
Log an export operation to the sync log.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
started_at: When the export started
|
||||
completed_at: When the export completed
|
||||
files_processed: Number of language files to export (e.g., 3)
|
||||
files_succeeded: Number of files successfully exported
|
||||
files_failed: Number of files that failed
|
||||
products_exported: Total products in the export
|
||||
triggered_by: Who triggered the export (e.g., "admin:123")
|
||||
error_details: Optional error details if any failures
|
||||
|
||||
Returns:
|
||||
Created LetzshopSyncLog entry
|
||||
"""
|
||||
sync_log = LetzshopSyncLog(
|
||||
vendor_id=vendor_id,
|
||||
operation_type="product_export",
|
||||
direction="outbound",
|
||||
status="completed" if files_failed == 0 else "partial",
|
||||
records_processed=files_processed,
|
||||
records_succeeded=files_succeeded,
|
||||
records_failed=files_failed,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
duration_seconds=int((completed_at - started_at).total_seconds()),
|
||||
triggered_by=triggered_by,
|
||||
error_details={
|
||||
"products_exported": products_exported,
|
||||
**(error_details or {}),
|
||||
} if products_exported or error_details else None,
|
||||
)
|
||||
db.add(sync_log)
|
||||
db.flush()
|
||||
return sync_log
|
||||
|
||||
|
||||
# Singleton instance
|
||||
letzshop_export_service = LetzshopExportService()
|
||||
@@ -0,0 +1,334 @@
|
||||
# app/services/marketplace_import_job_service.py
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
ImportJobNotFoundException,
|
||||
ImportJobNotOwnedException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.database.marketplace_import_job import (
|
||||
MarketplaceImportError,
|
||||
MarketplaceImportJob,
|
||||
)
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.marketplace_import_job import (
|
||||
AdminMarketplaceImportJobResponse,
|
||||
MarketplaceImportJobRequest,
|
||||
MarketplaceImportJobResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketplaceImportJobService:
|
||||
"""Service class for Marketplace operations."""
|
||||
|
||||
def create_import_job(
|
||||
self,
|
||||
db: Session,
|
||||
request: MarketplaceImportJobRequest,
|
||||
vendor: Vendor, # CHANGED: Vendor object from middleware
|
||||
user: User,
|
||||
) -> MarketplaceImportJob:
|
||||
"""
|
||||
Create a new marketplace import job.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
request: Import request data
|
||||
vendor: Vendor object (from middleware)
|
||||
user: User creating the job
|
||||
|
||||
Returns:
|
||||
Created MarketplaceImportJob object
|
||||
"""
|
||||
try:
|
||||
# Create marketplace import job record
|
||||
import_job = MarketplaceImportJob(
|
||||
status="pending",
|
||||
source_url=request.source_url,
|
||||
marketplace=request.marketplace,
|
||||
language=request.language,
|
||||
vendor_id=vendor.id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
db.add(import_job)
|
||||
db.flush()
|
||||
db.refresh(import_job)
|
||||
|
||||
logger.info(
|
||||
f"Created marketplace import job {import_job.id}: "
|
||||
f"{request.marketplace} -> {vendor.name} (code: {vendor.vendor_code}) "
|
||||
f"by user {user.username}"
|
||||
)
|
||||
|
||||
return import_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating import job: {str(e)}")
|
||||
raise ValidationException("Failed to create import job")
|
||||
|
||||
def get_import_job_by_id(
|
||||
self, db: Session, job_id: int, user: User
|
||||
) -> MarketplaceImportJob:
|
||||
"""Get a marketplace import job by ID with access control."""
|
||||
try:
|
||||
job = (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.id == job_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not job:
|
||||
raise ImportJobNotFoundException(job_id)
|
||||
|
||||
# Users can only see their own jobs, admins can see all
|
||||
if user.role != "admin" and job.user_id != user.id:
|
||||
raise ImportJobNotOwnedException(job_id, user.id)
|
||||
|
||||
return job
|
||||
|
||||
except (ImportJobNotFoundException, ImportJobNotOwnedException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting import job {job_id}: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import job")
|
||||
|
||||
def get_import_job_for_vendor(
|
||||
self, db: Session, job_id: int, vendor_id: int
|
||||
) -> MarketplaceImportJob:
|
||||
"""
|
||||
Get a marketplace import job by ID with vendor access control.
|
||||
|
||||
Validates that the job belongs to the specified vendor.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
job_id: Import job ID
|
||||
vendor_id: Vendor ID from token (to verify ownership)
|
||||
|
||||
Raises:
|
||||
ImportJobNotFoundException: If job not found
|
||||
UnauthorizedVendorAccessException: If job doesn't belong to vendor
|
||||
"""
|
||||
from app.exceptions import UnauthorizedVendorAccessException
|
||||
|
||||
try:
|
||||
job = (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.id == job_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not job:
|
||||
raise ImportJobNotFoundException(job_id)
|
||||
|
||||
# Verify job belongs to vendor (service layer validation)
|
||||
if job.vendor_id != vendor_id:
|
||||
raise UnauthorizedVendorAccessException(
|
||||
vendor_code=str(vendor_id),
|
||||
user_id=0, # Not user-specific, but vendor mismatch
|
||||
)
|
||||
|
||||
return job
|
||||
|
||||
except (ImportJobNotFoundException, UnauthorizedVendorAccessException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}"
|
||||
)
|
||||
raise ValidationException("Failed to retrieve import job")
|
||||
|
||||
def get_import_jobs(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor, # ADDED: Vendor filter
|
||||
user: User,
|
||||
marketplace: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[MarketplaceImportJob]:
|
||||
"""Get marketplace import jobs for a specific vendor."""
|
||||
try:
|
||||
query = db.query(MarketplaceImportJob).filter(
|
||||
MarketplaceImportJob.vendor_id == vendor.id
|
||||
)
|
||||
|
||||
# Users can only see their own jobs, admins can see all vendor jobs
|
||||
if user.role != "admin":
|
||||
query = query.filter(MarketplaceImportJob.user_id == user.id)
|
||||
|
||||
# Apply marketplace filter
|
||||
if marketplace:
|
||||
query = query.filter(
|
||||
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
|
||||
)
|
||||
|
||||
# Order by creation date (newest first) and apply pagination
|
||||
jobs = (
|
||||
query.order_by(
|
||||
MarketplaceImportJob.created_at.desc(),
|
||||
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting import jobs: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import jobs")
|
||||
|
||||
def convert_to_response_model(
|
||||
self, job: MarketplaceImportJob
|
||||
) -> MarketplaceImportJobResponse:
|
||||
"""Convert database model to API response model."""
|
||||
return MarketplaceImportJobResponse(
|
||||
job_id=job.id,
|
||||
status=job.status,
|
||||
marketplace=job.marketplace,
|
||||
language=job.language,
|
||||
vendor_id=job.vendor_id,
|
||||
vendor_code=job.vendor.vendor_code if job.vendor else None,
|
||||
vendor_name=job.vendor.name if job.vendor else None,
|
||||
source_url=job.source_url,
|
||||
imported=job.imported_count or 0,
|
||||
updated=job.updated_count or 0,
|
||||
total_processed=job.total_processed or 0,
|
||||
error_count=job.error_count or 0,
|
||||
error_message=job.error_message,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
)
|
||||
|
||||
def convert_to_admin_response_model(
|
||||
self, job: MarketplaceImportJob
|
||||
) -> AdminMarketplaceImportJobResponse:
|
||||
"""Convert database model to admin API response model with extra fields."""
|
||||
return AdminMarketplaceImportJobResponse(
|
||||
id=job.id,
|
||||
job_id=job.id,
|
||||
status=job.status,
|
||||
marketplace=job.marketplace,
|
||||
language=job.language,
|
||||
vendor_id=job.vendor_id,
|
||||
vendor_code=job.vendor.vendor_code if job.vendor else None,
|
||||
vendor_name=job.vendor.name if job.vendor else None,
|
||||
source_url=job.source_url,
|
||||
imported=job.imported_count or 0,
|
||||
updated=job.updated_count or 0,
|
||||
total_processed=job.total_processed or 0,
|
||||
error_count=job.error_count or 0,
|
||||
error_message=job.error_message,
|
||||
error_details=[],
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
created_by_name=job.user.username if job.user else None,
|
||||
)
|
||||
|
||||
def get_all_import_jobs_paginated(
|
||||
self,
|
||||
db: Session,
|
||||
marketplace: str | None = None,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[MarketplaceImportJob], int]:
|
||||
"""Get all marketplace import jobs with pagination (for admin)."""
|
||||
try:
|
||||
query = db.query(MarketplaceImportJob)
|
||||
|
||||
if marketplace:
|
||||
query = query.filter(
|
||||
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
|
||||
)
|
||||
if status:
|
||||
query = query.filter(MarketplaceImportJob.status == status)
|
||||
|
||||
total = query.count()
|
||||
skip = (page - 1) * limit
|
||||
jobs = (
|
||||
query.order_by(
|
||||
MarketplaceImportJob.created_at.desc(),
|
||||
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return jobs, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all import jobs: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import jobs")
|
||||
|
||||
def get_import_job_by_id_admin(
|
||||
self, db: Session, job_id: int
|
||||
) -> MarketplaceImportJob:
|
||||
"""Get a marketplace import job by ID (admin - no access control)."""
|
||||
job = (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.id == job_id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise ImportJobNotFoundException(job_id)
|
||||
return job
|
||||
|
||||
def get_import_job_errors(
|
||||
self,
|
||||
db: Session,
|
||||
job_id: int,
|
||||
error_type: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[MarketplaceImportError], int]:
|
||||
"""
|
||||
Get import errors for a specific job with pagination.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
job_id: Import job ID
|
||||
error_type: Optional filter by error type
|
||||
page: Page number (1-indexed)
|
||||
limit: Number of items per page
|
||||
|
||||
Returns:
|
||||
Tuple of (list of errors, total count)
|
||||
"""
|
||||
try:
|
||||
query = db.query(MarketplaceImportError).filter(
|
||||
MarketplaceImportError.import_job_id == job_id
|
||||
)
|
||||
|
||||
if error_type:
|
||||
query = query.filter(MarketplaceImportError.error_type == error_type)
|
||||
|
||||
total = query.count()
|
||||
|
||||
offset = (page - 1) * limit
|
||||
errors = (
|
||||
query.order_by(MarketplaceImportError.row_number)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return errors, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting import job errors for job {job_id}: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import errors")
|
||||
|
||||
|
||||
marketplace_import_job_service = MarketplaceImportJobService()
|
||||
1075
app/modules/marketplace/services/marketplace_product_service.py
Normal file
1075
app/modules/marketplace/services/marketplace_product_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,10 +2,10 @@
|
||||
"""
|
||||
Messaging module database models.
|
||||
|
||||
Re-exports messaging-related models from their source locations.
|
||||
This module contains the canonical implementations of messaging-related models.
|
||||
"""
|
||||
|
||||
from models.database.message import (
|
||||
from app.modules.messaging.models.message import (
|
||||
Conversation,
|
||||
ConversationParticipant,
|
||||
ConversationType,
|
||||
@@ -13,7 +13,7 @@ from models.database.message import (
|
||||
MessageAttachment,
|
||||
ParticipantType,
|
||||
)
|
||||
from models.database.admin import AdminNotification
|
||||
from app.modules.messaging.models.admin_notification import AdminNotification
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
|
||||
54
app/modules/messaging/models/admin_notification.py
Normal file
54
app/modules/messaging/models/admin_notification.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# app/modules/messaging/models/admin_notification.py
|
||||
"""
|
||||
Admin notification database model.
|
||||
|
||||
This model handles admin-specific notifications for system alerts and warnings.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
from models.database.base import TimestampMixin
|
||||
|
||||
|
||||
class AdminNotification(Base, TimestampMixin):
|
||||
"""
|
||||
Admin-specific notifications for system alerts and warnings.
|
||||
|
||||
Different from vendor/customer notifications - these are for platform
|
||||
administrators to track system health and issues requiring attention.
|
||||
"""
|
||||
|
||||
__tablename__ = "admin_notifications"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
type = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # system_alert, vendor_issue, import_failure
|
||||
priority = Column(
|
||||
String(20), default="normal", index=True
|
||||
) # low, normal, high, critical
|
||||
title = Column(String(200), nullable=False)
|
||||
message = Column(Text, nullable=False)
|
||||
is_read = Column(Boolean, default=False, index=True)
|
||||
read_at = Column(DateTime, nullable=True)
|
||||
read_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
action_required = Column(Boolean, default=False, index=True)
|
||||
action_url = Column(String(500)) # Link to relevant admin page
|
||||
notification_metadata = Column(JSON) # Additional contextual data
|
||||
|
||||
# Relationships
|
||||
read_by = relationship("User", foreign_keys=[read_by_user_id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AdminNotification(id={self.id}, type='{self.type}', priority='{self.priority}')>"
|
||||
272
app/modules/messaging/models/message.py
Normal file
272
app/modules/messaging/models/message.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# app/modules/messaging/models/message.py
|
||||
"""
|
||||
Messaging system database models.
|
||||
|
||||
Supports three communication channels:
|
||||
- Admin <-> Vendor
|
||||
- Vendor <-> Customer
|
||||
- Admin <-> Customer
|
||||
|
||||
Multi-tenant isolation is enforced via vendor_id for conversations
|
||||
involving customers.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
from models.database.base import TimestampMixin
|
||||
|
||||
|
||||
class ConversationType(str, enum.Enum):
|
||||
"""Defines the three supported conversation channels."""
|
||||
|
||||
ADMIN_VENDOR = "admin_vendor"
|
||||
VENDOR_CUSTOMER = "vendor_customer"
|
||||
ADMIN_CUSTOMER = "admin_customer"
|
||||
|
||||
|
||||
class ParticipantType(str, enum.Enum):
|
||||
"""Type of participant in a conversation."""
|
||||
|
||||
ADMIN = "admin" # User with role="admin"
|
||||
VENDOR = "vendor" # User with role="vendor" (via VendorUser)
|
||||
CUSTOMER = "customer" # Customer model
|
||||
|
||||
|
||||
def _enum_values(enum_class):
|
||||
"""Extract enum values for SQLAlchemy Enum column."""
|
||||
return [e.value for e in enum_class]
|
||||
|
||||
|
||||
class Conversation(Base, TimestampMixin):
|
||||
"""
|
||||
Represents a threaded conversation between participants.
|
||||
|
||||
Multi-tenancy: vendor_id is required for vendor_customer and admin_customer
|
||||
conversations to ensure customer data isolation.
|
||||
"""
|
||||
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Conversation type determines participant structure
|
||||
conversation_type = Column(
|
||||
Enum(ConversationType, values_callable=_enum_values),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Subject line for the conversation thread
|
||||
subject = Column(String(500), nullable=False)
|
||||
|
||||
# For vendor_customer and admin_customer conversations
|
||||
# Required for multi-tenant data isolation
|
||||
vendor_id = Column(
|
||||
Integer,
|
||||
ForeignKey("vendors.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Status flags
|
||||
is_closed = Column(Boolean, default=False, nullable=False)
|
||||
closed_at = Column(DateTime, nullable=True)
|
||||
closed_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True)
|
||||
closed_by_id = Column(Integer, nullable=True)
|
||||
|
||||
# Last activity tracking for sorting
|
||||
last_message_at = Column(DateTime, nullable=True, index=True)
|
||||
message_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
# Relationships
|
||||
vendor = relationship("Vendor", foreign_keys=[vendor_id])
|
||||
participants = relationship(
|
||||
"ConversationParticipant",
|
||||
back_populates="conversation",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
messages = relationship(
|
||||
"Message",
|
||||
back_populates="conversation",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="Message.created_at",
|
||||
)
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
Index("ix_conversations_type_vendor", "conversation_type", "vendor_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Conversation(id={self.id}, type='{self.conversation_type.value}', "
|
||||
f"subject='{self.subject[:30]}...')>"
|
||||
)
|
||||
|
||||
|
||||
class ConversationParticipant(Base, TimestampMixin):
|
||||
"""
|
||||
Links participants (users or customers) to conversations.
|
||||
|
||||
Polymorphic relationship:
|
||||
- participant_type="admin" or "vendor": references users.id
|
||||
- participant_type="customer": references customers.id
|
||||
"""
|
||||
|
||||
__tablename__ = "conversation_participants"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
conversation_id = Column(
|
||||
Integer,
|
||||
ForeignKey("conversations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Polymorphic participant reference
|
||||
participant_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False)
|
||||
participant_id = Column(Integer, nullable=False, index=True)
|
||||
|
||||
# For vendor participants, track which vendor they represent
|
||||
vendor_id = Column(
|
||||
Integer,
|
||||
ForeignKey("vendors.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Unread tracking per participant
|
||||
unread_count = Column(Integer, default=0, nullable=False)
|
||||
last_read_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Notification preferences for this conversation
|
||||
email_notifications = Column(Boolean, default=True, nullable=False)
|
||||
muted = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships
|
||||
conversation = relationship("Conversation", back_populates="participants")
|
||||
vendor = relationship("Vendor", foreign_keys=[vendor_id])
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conversation_id",
|
||||
"participant_type",
|
||||
"participant_id",
|
||||
name="uq_conversation_participant",
|
||||
),
|
||||
Index(
|
||||
"ix_participant_lookup",
|
||||
"participant_type",
|
||||
"participant_id",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<ConversationParticipant(conversation_id={self.conversation_id}, "
|
||||
f"type='{self.participant_type.value}', id={self.participant_id})>"
|
||||
)
|
||||
|
||||
|
||||
class Message(Base, TimestampMixin):
|
||||
"""
|
||||
Individual message within a conversation thread.
|
||||
|
||||
Sender polymorphism follows same pattern as participant.
|
||||
"""
|
||||
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
conversation_id = Column(
|
||||
Integer,
|
||||
ForeignKey("conversations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Polymorphic sender reference
|
||||
sender_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False)
|
||||
sender_id = Column(Integer, nullable=False, index=True)
|
||||
|
||||
# Message content
|
||||
content = Column(Text, nullable=False)
|
||||
|
||||
# System messages (e.g., "conversation closed")
|
||||
is_system_message = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Soft delete for moderation
|
||||
is_deleted = Column(Boolean, default=False, nullable=False)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
deleted_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True)
|
||||
deleted_by_id = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
conversation = relationship("Conversation", back_populates="messages")
|
||||
attachments = relationship(
|
||||
"MessageAttachment",
|
||||
back_populates="message",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_messages_conversation_created", "conversation_id", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Message(id={self.id}, conversation_id={self.conversation_id}, "
|
||||
f"sender={self.sender_type.value}:{self.sender_id})>"
|
||||
)
|
||||
|
||||
|
||||
class MessageAttachment(Base, TimestampMixin):
|
||||
"""
|
||||
File attachments for messages.
|
||||
|
||||
Files are stored in platform storage (local/S3) with references here.
|
||||
"""
|
||||
|
||||
__tablename__ = "message_attachments"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message_id = Column(
|
||||
Integer,
|
||||
ForeignKey("messages.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# File metadata
|
||||
filename = Column(String(255), nullable=False)
|
||||
original_filename = Column(String(255), nullable=False)
|
||||
file_path = Column(String(1000), nullable=False) # Storage path
|
||||
file_size = Column(Integer, nullable=False) # Size in bytes
|
||||
mime_type = Column(String(100), nullable=False)
|
||||
|
||||
# For image attachments
|
||||
is_image = Column(Boolean, default=False, nullable=False)
|
||||
image_width = Column(Integer, nullable=True)
|
||||
image_height = Column(Integer, nullable=True)
|
||||
thumbnail_path = Column(String(1000), nullable=True)
|
||||
|
||||
# Relationships
|
||||
message = relationship("Message", back_populates="attachments")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<MessageAttachment(id={self.id}, filename='{self.original_filename}')>"
|
||||
@@ -2,27 +2,106 @@
|
||||
"""
|
||||
Messaging module Pydantic schemas.
|
||||
|
||||
Re-exports messaging-related schemas from their source locations.
|
||||
This module contains the canonical implementations of messaging-related schemas.
|
||||
"""
|
||||
|
||||
from models.schema.message import (
|
||||
ConversationResponse,
|
||||
ConversationListResponse,
|
||||
MessageResponse,
|
||||
MessageCreate,
|
||||
from app.modules.messaging.schemas.message import (
|
||||
# Attachment schemas
|
||||
AttachmentResponse,
|
||||
# Message schemas
|
||||
MessageCreate,
|
||||
MessageResponse,
|
||||
# Participant schemas
|
||||
ParticipantInfo,
|
||||
ParticipantResponse,
|
||||
# Conversation schemas
|
||||
ConversationCreate,
|
||||
ConversationSummary,
|
||||
ConversationDetailResponse,
|
||||
ConversationListResponse,
|
||||
ConversationResponse,
|
||||
# Unread count
|
||||
UnreadCountResponse,
|
||||
# Notification preferences
|
||||
NotificationPreferencesUpdate,
|
||||
# Conversation actions
|
||||
CloseConversationResponse,
|
||||
ReopenConversationResponse,
|
||||
MarkReadResponse,
|
||||
# Recipient selection
|
||||
RecipientOption,
|
||||
RecipientListResponse,
|
||||
# Admin schemas
|
||||
AdminConversationSummary,
|
||||
AdminConversationListResponse,
|
||||
AdminMessageStats,
|
||||
)
|
||||
from models.schema.notification import (
|
||||
|
||||
from app.modules.messaging.schemas.notification import (
|
||||
# Response schemas
|
||||
MessageResponse as NotificationMessageResponse,
|
||||
UnreadCountResponse as NotificationUnreadCountResponse,
|
||||
# Notification schemas
|
||||
NotificationResponse,
|
||||
NotificationListResponse,
|
||||
# Settings schemas
|
||||
NotificationSettingsResponse,
|
||||
NotificationSettingsUpdate,
|
||||
# Template schemas
|
||||
NotificationTemplateResponse,
|
||||
NotificationTemplateListResponse,
|
||||
NotificationTemplateUpdate,
|
||||
# Test notification
|
||||
TestNotificationRequest,
|
||||
# Alert statistics
|
||||
AlertStatisticsResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConversationResponse",
|
||||
"ConversationListResponse",
|
||||
"MessageResponse",
|
||||
"MessageCreate",
|
||||
# Attachment schemas
|
||||
"AttachmentResponse",
|
||||
# Message schemas
|
||||
"MessageCreate",
|
||||
"MessageResponse",
|
||||
# Participant schemas
|
||||
"ParticipantInfo",
|
||||
"ParticipantResponse",
|
||||
# Conversation schemas
|
||||
"ConversationCreate",
|
||||
"ConversationSummary",
|
||||
"ConversationDetailResponse",
|
||||
"ConversationListResponse",
|
||||
"ConversationResponse",
|
||||
# Unread count
|
||||
"UnreadCountResponse",
|
||||
# Notification preferences
|
||||
"NotificationPreferencesUpdate",
|
||||
# Conversation actions
|
||||
"CloseConversationResponse",
|
||||
"ReopenConversationResponse",
|
||||
"MarkReadResponse",
|
||||
# Recipient selection
|
||||
"RecipientOption",
|
||||
"RecipientListResponse",
|
||||
# Admin schemas
|
||||
"AdminConversationSummary",
|
||||
"AdminConversationListResponse",
|
||||
"AdminMessageStats",
|
||||
# Notification response schemas
|
||||
"NotificationMessageResponse",
|
||||
"NotificationUnreadCountResponse",
|
||||
# Notification schemas
|
||||
"NotificationResponse",
|
||||
"NotificationListResponse",
|
||||
# Settings schemas
|
||||
"NotificationSettingsResponse",
|
||||
"NotificationSettingsUpdate",
|
||||
# Template schemas
|
||||
"NotificationTemplateResponse",
|
||||
"NotificationTemplateListResponse",
|
||||
"NotificationTemplateUpdate",
|
||||
# Test notification
|
||||
"TestNotificationRequest",
|
||||
# Alert statistics
|
||||
"AlertStatisticsResponse",
|
||||
]
|
||||
|
||||
312
app/modules/messaging/schemas/message.py
Normal file
312
app/modules/messaging/schemas/message.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# app/modules/messaging/schemas/message.py
|
||||
"""
|
||||
Pydantic schemas for the messaging system.
|
||||
|
||||
Supports three communication channels:
|
||||
- Admin <-> Vendor
|
||||
- Vendor <-> Customer
|
||||
- Admin <-> Customer
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.modules.messaging.models.message import ConversationType, ParticipantType
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Attachment Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AttachmentResponse(BaseModel):
|
||||
"""Schema for message attachment in responses."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
filename: str
|
||||
original_filename: str
|
||||
file_size: int
|
||||
mime_type: str
|
||||
is_image: bool
|
||||
image_width: int | None = None
|
||||
image_height: int | None = None
|
||||
download_url: str | None = None
|
||||
thumbnail_url: str | None = None
|
||||
|
||||
@property
|
||||
def file_size_display(self) -> str:
|
||||
"""Human-readable file size."""
|
||||
if self.file_size < 1024:
|
||||
return f"{self.file_size} B"
|
||||
elif self.file_size < 1024 * 1024:
|
||||
return f"{self.file_size / 1024:.1f} KB"
|
||||
else:
|
||||
return f"{self.file_size / 1024 / 1024:.1f} MB"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Message Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
"""Schema for sending a new message."""
|
||||
|
||||
content: str = Field(..., min_length=1, max_length=10000)
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Schema for a single message in responses."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
conversation_id: int
|
||||
sender_type: ParticipantType
|
||||
sender_id: int
|
||||
content: str
|
||||
is_system_message: bool
|
||||
is_deleted: bool
|
||||
created_at: datetime
|
||||
|
||||
# Enriched sender info (populated by API)
|
||||
sender_name: str | None = None
|
||||
sender_email: str | None = None
|
||||
|
||||
# Attachments
|
||||
attachments: list[AttachmentResponse] = []
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Participant Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ParticipantInfo(BaseModel):
|
||||
"""Schema for participant information."""
|
||||
|
||||
id: int
|
||||
type: ParticipantType
|
||||
name: str
|
||||
email: str | None = None
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class ParticipantResponse(BaseModel):
|
||||
"""Schema for conversation participant in responses."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
participant_type: ParticipantType
|
||||
participant_id: int
|
||||
unread_count: int
|
||||
last_read_at: datetime | None
|
||||
email_notifications: bool
|
||||
muted: bool
|
||||
|
||||
# Enriched info (populated by API)
|
||||
participant_info: ParticipantInfo | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversation Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
"""Schema for creating a new conversation."""
|
||||
|
||||
conversation_type: ConversationType
|
||||
subject: str = Field(..., min_length=1, max_length=500)
|
||||
recipient_type: ParticipantType
|
||||
recipient_id: int
|
||||
vendor_id: int | None = None
|
||||
initial_message: str | None = Field(None, min_length=1, max_length=10000)
|
||||
|
||||
|
||||
class ConversationSummary(BaseModel):
|
||||
"""Schema for conversation in list views."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
conversation_type: ConversationType
|
||||
subject: str
|
||||
vendor_id: int | None = None
|
||||
is_closed: bool
|
||||
closed_at: datetime | None
|
||||
last_message_at: datetime | None
|
||||
message_count: int
|
||||
created_at: datetime
|
||||
|
||||
# Unread count for current user (from participant)
|
||||
unread_count: int = 0
|
||||
|
||||
# Other participant info (enriched by API)
|
||||
other_participant: ParticipantInfo | None = None
|
||||
|
||||
# Last message preview
|
||||
last_message_preview: str | None = None
|
||||
|
||||
|
||||
class ConversationDetailResponse(BaseModel):
|
||||
"""Schema for full conversation detail with messages."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
conversation_type: ConversationType
|
||||
subject: str
|
||||
vendor_id: int | None = None
|
||||
is_closed: bool
|
||||
closed_at: datetime | None
|
||||
closed_by_type: ParticipantType | None = None
|
||||
closed_by_id: int | None = None
|
||||
last_message_at: datetime | None
|
||||
message_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Participants with enriched info
|
||||
participants: list[ParticipantResponse] = []
|
||||
|
||||
# Messages ordered by created_at
|
||||
messages: list[MessageResponse] = []
|
||||
|
||||
# Current user's unread count
|
||||
unread_count: int = 0
|
||||
|
||||
# Vendor info if applicable
|
||||
vendor_name: str | None = None
|
||||
|
||||
|
||||
class ConversationListResponse(BaseModel):
|
||||
"""Schema for paginated conversation list."""
|
||||
|
||||
conversations: list[ConversationSummary]
|
||||
total: int
|
||||
total_unread: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
ConversationResponse = ConversationDetailResponse
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unread Count Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class UnreadCountResponse(BaseModel):
|
||||
"""Schema for unread message count (for header badge)."""
|
||||
|
||||
total_unread: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Notification Preferences Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class NotificationPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating notification preferences."""
|
||||
|
||||
email_notifications: bool | None = None
|
||||
muted: bool | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversation Action Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CloseConversationResponse(BaseModel):
|
||||
"""Response after closing a conversation."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
conversation_id: int
|
||||
|
||||
|
||||
class ReopenConversationResponse(BaseModel):
|
||||
"""Response after reopening a conversation."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
conversation_id: int
|
||||
|
||||
|
||||
class MarkReadResponse(BaseModel):
|
||||
"""Response after marking conversation as read."""
|
||||
|
||||
success: bool
|
||||
conversation_id: int
|
||||
unread_count: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recipient Selection Schemas (for compose modal)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RecipientOption(BaseModel):
|
||||
"""Schema for a selectable recipient in compose modal."""
|
||||
|
||||
id: int
|
||||
type: ParticipantType
|
||||
name: str
|
||||
email: str | None = None
|
||||
vendor_id: int | None = None # For vendor users
|
||||
vendor_name: str | None = None
|
||||
|
||||
|
||||
class RecipientListResponse(BaseModel):
|
||||
"""Schema for list of available recipients."""
|
||||
|
||||
recipients: list[RecipientOption]
|
||||
total: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Admin-specific Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdminConversationSummary(ConversationSummary):
|
||||
"""Extended conversation summary with vendor info for admin views."""
|
||||
|
||||
vendor_name: str | None = None
|
||||
vendor_code: str | None = None
|
||||
|
||||
|
||||
class AdminConversationListResponse(BaseModel):
|
||||
"""Schema for admin conversation list with vendor info."""
|
||||
|
||||
conversations: list[AdminConversationSummary]
|
||||
total: int
|
||||
total_unread: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class AdminMessageStats(BaseModel):
|
||||
"""Messaging statistics for admin dashboard."""
|
||||
|
||||
total_conversations: int = 0
|
||||
open_conversations: int = 0
|
||||
closed_conversations: int = 0
|
||||
total_messages: int = 0
|
||||
|
||||
# By type
|
||||
admin_vendor_conversations: int = 0
|
||||
vendor_customer_conversations: int = 0
|
||||
admin_customer_conversations: int = 0
|
||||
|
||||
# Unread
|
||||
unread_admin: int = 0
|
||||
152
app/modules/messaging/schemas/notification.py
Normal file
152
app/modules/messaging/schemas/notification.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# app/modules/messaging/schemas/notification.py
|
||||
"""
|
||||
Notification Pydantic schemas for API validation and responses.
|
||||
|
||||
This module provides schemas for:
|
||||
- Vendor notifications (list, read, delete)
|
||||
- Notification settings management
|
||||
- Notification email templates
|
||||
- Unread counts and statistics
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ============================================================================
|
||||
# SHARED RESPONSE SCHEMAS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response for simple operations."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
class UnreadCountResponse(BaseModel):
|
||||
"""Response for unread notification count."""
|
||||
|
||||
unread_count: int
|
||||
message: str | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NOTIFICATION SCHEMAS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class NotificationResponse(BaseModel):
|
||||
"""Single notification response."""
|
||||
|
||||
id: int
|
||||
type: str
|
||||
title: str
|
||||
message: str
|
||||
is_read: bool
|
||||
read_at: datetime | None = None
|
||||
priority: str = "normal"
|
||||
action_url: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class NotificationListResponse(BaseModel):
|
||||
"""Paginated list of notifications."""
|
||||
|
||||
notifications: list[NotificationResponse] = []
|
||||
total: int = 0
|
||||
unread_count: int = 0
|
||||
message: str | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NOTIFICATION SETTINGS SCHEMAS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class NotificationSettingsResponse(BaseModel):
|
||||
"""Notification preferences response."""
|
||||
|
||||
email_notifications: bool = True
|
||||
in_app_notifications: bool = True
|
||||
notification_types: dict[str, bool] = Field(default_factory=dict)
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class NotificationSettingsUpdate(BaseModel):
|
||||
"""Request model for updating notification settings."""
|
||||
|
||||
email_notifications: bool | None = None
|
||||
in_app_notifications: bool | None = None
|
||||
notification_types: dict[str, bool] | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NOTIFICATION TEMPLATE SCHEMAS
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class NotificationTemplateResponse(BaseModel):
|
||||
"""Single notification template response."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
type: str
|
||||
subject: str
|
||||
body_html: str | None = None
|
||||
body_text: str | None = None
|
||||
variables: list[str] = Field(default_factory=list)
|
||||
is_active: bool = True
|
||||
created_at: datetime
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class NotificationTemplateListResponse(BaseModel):
|
||||
"""List of notification templates."""
|
||||
|
||||
templates: list[NotificationTemplateResponse] = []
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class NotificationTemplateUpdate(BaseModel):
|
||||
"""Request model for updating notification template."""
|
||||
|
||||
subject: str | None = Field(None, max_length=200)
|
||||
body_html: str | None = None
|
||||
body_text: str | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TEST NOTIFICATION SCHEMA
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestNotificationRequest(BaseModel):
|
||||
"""Request model for sending test notification."""
|
||||
|
||||
template_id: int | None = Field(None, description="Template to use")
|
||||
email: str | None = Field(None, description="Override recipient email")
|
||||
notification_type: str = Field(
|
||||
default="test", description="Type of notification to send"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ADMIN ALERT STATISTICS SCHEMA
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AlertStatisticsResponse(BaseModel):
|
||||
"""Response for alert statistics."""
|
||||
|
||||
total_alerts: int = 0
|
||||
active_alerts: int = 0
|
||||
critical_alerts: int = 0
|
||||
resolved_today: int = 0
|
||||
@@ -2,24 +2,29 @@
|
||||
"""
|
||||
Messaging module services.
|
||||
|
||||
Re-exports messaging-related services from their source locations.
|
||||
This module contains the canonical implementations of messaging-related services.
|
||||
"""
|
||||
|
||||
from app.services.messaging_service import (
|
||||
from app.modules.messaging.services.messaging_service import (
|
||||
messaging_service,
|
||||
MessagingService,
|
||||
)
|
||||
from app.services.message_attachment_service import (
|
||||
from app.modules.messaging.services.message_attachment_service import (
|
||||
message_attachment_service,
|
||||
MessageAttachmentService,
|
||||
)
|
||||
from app.services.admin_notification_service import (
|
||||
from app.modules.messaging.services.admin_notification_service import (
|
||||
admin_notification_service,
|
||||
AdminNotificationService,
|
||||
platform_alert_service,
|
||||
PlatformAlertService,
|
||||
# Constants
|
||||
NotificationType,
|
||||
Priority,
|
||||
AlertType,
|
||||
Severity,
|
||||
)
|
||||
|
||||
# Note: notification_service is a placeholder - not yet implemented
|
||||
|
||||
__all__ = [
|
||||
"messaging_service",
|
||||
"MessagingService",
|
||||
@@ -27,4 +32,11 @@ __all__ = [
|
||||
"MessageAttachmentService",
|
||||
"admin_notification_service",
|
||||
"AdminNotificationService",
|
||||
"platform_alert_service",
|
||||
"PlatformAlertService",
|
||||
# Constants
|
||||
"NotificationType",
|
||||
"Priority",
|
||||
"AlertType",
|
||||
"Severity",
|
||||
]
|
||||
|
||||
702
app/modules/messaging/services/admin_notification_service.py
Normal file
702
app/modules/messaging/services/admin_notification_service.py
Normal file
@@ -0,0 +1,702 @@
|
||||
# app/modules/messaging/services/admin_notification_service.py
|
||||
"""
|
||||
Admin notification service.
|
||||
|
||||
Provides functionality for:
|
||||
- Creating and managing admin notifications
|
||||
- Managing platform alerts
|
||||
- Notification statistics and queries
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, case, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.messaging.models.admin_notification import AdminNotification
|
||||
from models.database.admin import PlatformAlert
|
||||
from models.schema.admin import AdminNotificationCreate, PlatformAlertCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NOTIFICATION TYPES
|
||||
# ============================================================================
|
||||
|
||||
class NotificationType:
|
||||
"""Notification type constants."""
|
||||
|
||||
SYSTEM_ALERT = "system_alert"
|
||||
IMPORT_FAILURE = "import_failure"
|
||||
EXPORT_FAILURE = "export_failure"
|
||||
ORDER_SYNC_FAILURE = "order_sync_failure"
|
||||
VENDOR_ISSUE = "vendor_issue"
|
||||
CUSTOMER_MESSAGE = "customer_message"
|
||||
VENDOR_MESSAGE = "vendor_message"
|
||||
SECURITY_ALERT = "security_alert"
|
||||
PERFORMANCE_ALERT = "performance_alert"
|
||||
ORDER_EXCEPTION = "order_exception"
|
||||
CRITICAL_ERROR = "critical_error"
|
||||
|
||||
|
||||
class Priority:
|
||||
"""Priority level constants."""
|
||||
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class AlertType:
|
||||
"""Platform alert type constants."""
|
||||
|
||||
SECURITY = "security"
|
||||
PERFORMANCE = "performance"
|
||||
CAPACITY = "capacity"
|
||||
INTEGRATION = "integration"
|
||||
DATABASE = "database"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class Severity:
|
||||
"""Alert severity constants."""
|
||||
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ADMIN NOTIFICATION SERVICE
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdminNotificationService:
|
||||
"""Service for managing admin notifications."""
|
||||
|
||||
def create_notification(
|
||||
self,
|
||||
db: Session,
|
||||
notification_type: str,
|
||||
title: str,
|
||||
message: str,
|
||||
priority: str = Priority.NORMAL,
|
||||
action_required: bool = False,
|
||||
action_url: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> AdminNotification:
|
||||
"""
|
||||
Create a new admin notification.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
notification_type: Type of notification
|
||||
title: Notification title
|
||||
message: Notification message
|
||||
priority: Priority level (low, normal, high, critical)
|
||||
action_required: Whether action is required
|
||||
action_url: URL to relevant admin page
|
||||
metadata: Additional contextual data
|
||||
|
||||
Returns:
|
||||
Created AdminNotification
|
||||
"""
|
||||
notification = AdminNotification(
|
||||
type=notification_type,
|
||||
title=title,
|
||||
message=message,
|
||||
priority=priority,
|
||||
action_required=action_required,
|
||||
action_url=action_url,
|
||||
notification_metadata=metadata,
|
||||
)
|
||||
db.add(notification)
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Created notification: {notification_type} - {title} (priority: {priority})"
|
||||
)
|
||||
|
||||
return notification
|
||||
|
||||
def create_from_schema(
|
||||
self,
|
||||
db: Session,
|
||||
data: AdminNotificationCreate,
|
||||
) -> AdminNotification:
|
||||
"""Create notification from Pydantic schema."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=data.type,
|
||||
title=data.title,
|
||||
message=data.message,
|
||||
priority=data.priority,
|
||||
action_required=data.action_required,
|
||||
action_url=data.action_url,
|
||||
metadata=data.metadata,
|
||||
)
|
||||
|
||||
def get_notifications(
|
||||
self,
|
||||
db: Session,
|
||||
priority: str | None = None,
|
||||
is_read: bool | None = None,
|
||||
notification_type: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[AdminNotification], int, int]:
|
||||
"""
|
||||
Get paginated admin notifications.
|
||||
|
||||
Returns:
|
||||
Tuple of (notifications, total_count, unread_count)
|
||||
"""
|
||||
query = db.query(AdminNotification)
|
||||
|
||||
# Apply filters
|
||||
if priority:
|
||||
query = query.filter(AdminNotification.priority == priority)
|
||||
|
||||
if is_read is not None:
|
||||
query = query.filter(AdminNotification.is_read == is_read)
|
||||
|
||||
if notification_type:
|
||||
query = query.filter(AdminNotification.type == notification_type)
|
||||
|
||||
# Get counts
|
||||
total = query.count()
|
||||
unread_count = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get paginated results ordered by priority and date
|
||||
priority_order = case(
|
||||
(AdminNotification.priority == "critical", 1),
|
||||
(AdminNotification.priority == "high", 2),
|
||||
(AdminNotification.priority == "normal", 3),
|
||||
(AdminNotification.priority == "low", 4),
|
||||
else_=5,
|
||||
)
|
||||
|
||||
notifications = (
|
||||
query.order_by(
|
||||
AdminNotification.is_read, # Unread first
|
||||
priority_order,
|
||||
AdminNotification.created_at.desc(),
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return notifications, total, unread_count
|
||||
|
||||
def get_unread_count(self, db: Session) -> int:
|
||||
"""Get count of unread notifications."""
|
||||
return (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
|
||||
def get_recent_notifications(
|
||||
self,
|
||||
db: Session,
|
||||
limit: int = 5,
|
||||
) -> list[AdminNotification]:
|
||||
"""Get recent unread notifications for header dropdown."""
|
||||
priority_order = case(
|
||||
(AdminNotification.priority == "critical", 1),
|
||||
(AdminNotification.priority == "high", 2),
|
||||
(AdminNotification.priority == "normal", 3),
|
||||
(AdminNotification.priority == "low", 4),
|
||||
else_=5,
|
||||
)
|
||||
|
||||
return (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.order_by(priority_order, AdminNotification.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def mark_as_read(
|
||||
self,
|
||||
db: Session,
|
||||
notification_id: int,
|
||||
user_id: int,
|
||||
) -> AdminNotification | None:
|
||||
"""Mark a notification as read."""
|
||||
notification = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.id == notification_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if notification and not notification.is_read:
|
||||
notification.is_read = True
|
||||
notification.read_at = datetime.utcnow()
|
||||
notification.read_by_user_id = user_id
|
||||
db.flush()
|
||||
|
||||
return notification
|
||||
|
||||
def mark_all_as_read(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: int,
|
||||
) -> int:
|
||||
"""Mark all unread notifications as read. Returns count of updated."""
|
||||
now = datetime.utcnow()
|
||||
count = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.update(
|
||||
{
|
||||
AdminNotification.is_read: True,
|
||||
AdminNotification.read_at: now,
|
||||
AdminNotification.read_by_user_id: user_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
db.flush()
|
||||
return count
|
||||
|
||||
def delete_old_notifications(
|
||||
self,
|
||||
db: Session,
|
||||
days: int = 30,
|
||||
) -> int:
|
||||
"""Delete notifications older than specified days."""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
count = (
|
||||
db.query(AdminNotification)
|
||||
.filter(
|
||||
and_(
|
||||
AdminNotification.is_read == True, # noqa: E712
|
||||
AdminNotification.created_at < cutoff,
|
||||
)
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
db.flush()
|
||||
return count
|
||||
|
||||
def delete_notification(
|
||||
self,
|
||||
db: Session,
|
||||
notification_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a notification by ID.
|
||||
|
||||
Returns:
|
||||
True if notification was deleted, False if not found.
|
||||
"""
|
||||
notification = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.id == notification_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if notification:
|
||||
db.delete(notification)
|
||||
db.flush()
|
||||
logger.info(f"Deleted notification {notification_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# CONVENIENCE METHODS FOR CREATING SPECIFIC NOTIFICATION TYPES
|
||||
# =========================================================================
|
||||
|
||||
def notify_import_failure(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
job_id: int,
|
||||
error_message: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for import job failure."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.IMPORT_FAILURE,
|
||||
title=f"Import Failed: {vendor_name}",
|
||||
message=error_message,
|
||||
priority=Priority.HIGH,
|
||||
action_required=True,
|
||||
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
|
||||
if vendor_id
|
||||
else "/admin/marketplace",
|
||||
metadata={"vendor_name": vendor_name, "job_id": job_id, "vendor_id": vendor_id},
|
||||
)
|
||||
|
||||
def notify_order_sync_failure(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
error_message: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for order sync failure."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.ORDER_SYNC_FAILURE,
|
||||
title=f"Order Sync Failed: {vendor_name}",
|
||||
message=error_message,
|
||||
priority=Priority.HIGH,
|
||||
action_required=True,
|
||||
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
|
||||
if vendor_id
|
||||
else "/admin/marketplace/letzshop",
|
||||
metadata={"vendor_name": vendor_name, "vendor_id": vendor_id},
|
||||
)
|
||||
|
||||
def notify_order_exception(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
order_number: str,
|
||||
exception_count: int,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for order item exceptions."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.ORDER_EXCEPTION,
|
||||
title=f"Order Exception: {order_number}",
|
||||
message=f"{exception_count} item(s) need attention for order {order_number} ({vendor_name})",
|
||||
priority=Priority.NORMAL,
|
||||
action_required=True,
|
||||
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=exceptions"
|
||||
if vendor_id
|
||||
else "/admin/marketplace/letzshop",
|
||||
metadata={
|
||||
"vendor_name": vendor_name,
|
||||
"order_number": order_number,
|
||||
"exception_count": exception_count,
|
||||
"vendor_id": vendor_id,
|
||||
},
|
||||
)
|
||||
|
||||
def notify_critical_error(
|
||||
self,
|
||||
db: Session,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for critical application errors."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.CRITICAL_ERROR,
|
||||
title=f"Critical Error: {error_type}",
|
||||
message=error_message,
|
||||
priority=Priority.CRITICAL,
|
||||
action_required=True,
|
||||
action_url="/admin/logs",
|
||||
metadata=details,
|
||||
)
|
||||
|
||||
def notify_vendor_issue(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
issue_type: str,
|
||||
message: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for vendor-related issues."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.VENDOR_ISSUE,
|
||||
title=f"Vendor Issue: {vendor_name}",
|
||||
message=message,
|
||||
priority=Priority.HIGH,
|
||||
action_required=True,
|
||||
action_url=f"/admin/vendors/{vendor_id}" if vendor_id else "/admin/vendors",
|
||||
metadata={
|
||||
"vendor_name": vendor_name,
|
||||
"issue_type": issue_type,
|
||||
"vendor_id": vendor_id,
|
||||
},
|
||||
)
|
||||
|
||||
def notify_security_alert(
|
||||
self,
|
||||
db: Session,
|
||||
title: str,
|
||||
message: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for security-related alerts."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.SECURITY_ALERT,
|
||||
title=title,
|
||||
message=message,
|
||||
priority=Priority.CRITICAL,
|
||||
action_required=True,
|
||||
action_url="/admin/audit",
|
||||
metadata=details,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PLATFORM ALERT SERVICE
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PlatformAlertService:
|
||||
"""Service for managing platform-wide alerts."""
|
||||
|
||||
def create_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_type: str,
|
||||
severity: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
affected_vendors: list[int] | None = None,
|
||||
affected_systems: list[str] | None = None,
|
||||
auto_generated: bool = True,
|
||||
) -> PlatformAlert:
|
||||
"""Create a new platform alert."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
alert = PlatformAlert(
|
||||
alert_type=alert_type,
|
||||
severity=severity,
|
||||
title=title,
|
||||
description=description,
|
||||
affected_vendors=affected_vendors,
|
||||
affected_systems=affected_systems,
|
||||
auto_generated=auto_generated,
|
||||
first_occurred_at=now,
|
||||
last_occurred_at=now,
|
||||
)
|
||||
db.add(alert)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Created platform alert: {alert_type} - {title} ({severity})")
|
||||
|
||||
return alert
|
||||
|
||||
def create_from_schema(
|
||||
self,
|
||||
db: Session,
|
||||
data: PlatformAlertCreate,
|
||||
) -> PlatformAlert:
|
||||
"""Create alert from Pydantic schema."""
|
||||
return self.create_alert(
|
||||
db=db,
|
||||
alert_type=data.alert_type,
|
||||
severity=data.severity,
|
||||
title=data.title,
|
||||
description=data.description,
|
||||
affected_vendors=data.affected_vendors,
|
||||
affected_systems=data.affected_systems,
|
||||
auto_generated=data.auto_generated,
|
||||
)
|
||||
|
||||
def get_alerts(
|
||||
self,
|
||||
db: Session,
|
||||
severity: str | None = None,
|
||||
alert_type: str | None = None,
|
||||
is_resolved: bool | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[PlatformAlert], int, int, int]:
|
||||
"""
|
||||
Get paginated platform alerts.
|
||||
|
||||
Returns:
|
||||
Tuple of (alerts, total_count, active_count, critical_count)
|
||||
"""
|
||||
query = db.query(PlatformAlert)
|
||||
|
||||
# Apply filters
|
||||
if severity:
|
||||
query = query.filter(PlatformAlert.severity == severity)
|
||||
|
||||
if alert_type:
|
||||
query = query.filter(PlatformAlert.alert_type == alert_type)
|
||||
|
||||
if is_resolved is not None:
|
||||
query = query.filter(PlatformAlert.is_resolved == is_resolved)
|
||||
|
||||
# Get counts
|
||||
total = query.count()
|
||||
active_count = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(PlatformAlert.is_resolved == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
critical_count = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.is_resolved == False, # noqa: E712
|
||||
PlatformAlert.severity == Severity.CRITICAL,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get paginated results
|
||||
severity_order = case(
|
||||
(PlatformAlert.severity == "critical", 1),
|
||||
(PlatformAlert.severity == "error", 2),
|
||||
(PlatformAlert.severity == "warning", 3),
|
||||
(PlatformAlert.severity == "info", 4),
|
||||
else_=5,
|
||||
)
|
||||
|
||||
alerts = (
|
||||
query.order_by(
|
||||
PlatformAlert.is_resolved, # Unresolved first
|
||||
severity_order,
|
||||
PlatformAlert.last_occurred_at.desc(),
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return alerts, total, active_count, critical_count
|
||||
|
||||
def resolve_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_id: int,
|
||||
user_id: int,
|
||||
resolution_notes: str | None = None,
|
||||
) -> PlatformAlert | None:
|
||||
"""Resolve a platform alert."""
|
||||
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
|
||||
|
||||
if alert and not alert.is_resolved:
|
||||
alert.is_resolved = True
|
||||
alert.resolved_at = datetime.utcnow()
|
||||
alert.resolved_by_user_id = user_id
|
||||
alert.resolution_notes = resolution_notes
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Resolved platform alert {alert_id}")
|
||||
|
||||
return alert
|
||||
|
||||
def get_statistics(self, db: Session) -> dict[str, int]:
|
||||
"""Get alert statistics."""
|
||||
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
total = db.query(PlatformAlert).count()
|
||||
active = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(PlatformAlert.is_resolved == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
critical = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.is_resolved == False, # noqa: E712
|
||||
PlatformAlert.severity == Severity.CRITICAL,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
resolved_today = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.is_resolved == True, # noqa: E712
|
||||
PlatformAlert.resolved_at >= today_start,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_alerts": total,
|
||||
"active_alerts": active,
|
||||
"critical_alerts": critical,
|
||||
"resolved_today": resolved_today,
|
||||
}
|
||||
|
||||
def increment_occurrence(
|
||||
self,
|
||||
db: Session,
|
||||
alert_id: int,
|
||||
) -> PlatformAlert | None:
|
||||
"""Increment occurrence count for repeated alert."""
|
||||
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
|
||||
|
||||
if alert:
|
||||
alert.occurrence_count += 1
|
||||
alert.last_occurred_at = datetime.utcnow()
|
||||
db.flush()
|
||||
|
||||
return alert
|
||||
|
||||
def find_similar_active_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_type: str,
|
||||
title: str,
|
||||
) -> PlatformAlert | None:
|
||||
"""Find an active alert with same type and title."""
|
||||
return (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.alert_type == alert_type,
|
||||
PlatformAlert.title == title,
|
||||
PlatformAlert.is_resolved == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def create_or_increment_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_type: str,
|
||||
severity: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
affected_vendors: list[int] | None = None,
|
||||
affected_systems: list[str] | None = None,
|
||||
) -> PlatformAlert:
|
||||
"""Create alert or increment occurrence if similar exists."""
|
||||
existing = self.find_similar_active_alert(db, alert_type, title)
|
||||
|
||||
if existing:
|
||||
self.increment_occurrence(db, existing.id)
|
||||
return existing
|
||||
|
||||
return self.create_alert(
|
||||
db=db,
|
||||
alert_type=alert_type,
|
||||
severity=severity,
|
||||
title=title,
|
||||
description=description,
|
||||
affected_vendors=affected_vendors,
|
||||
affected_systems=affected_systems,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instances
|
||||
admin_notification_service = AdminNotificationService()
|
||||
platform_alert_service = PlatformAlertService()
|
||||
225
app/modules/messaging/services/message_attachment_service.py
Normal file
225
app/modules/messaging/services/message_attachment_service.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# app/modules/messaging/services/message_attachment_service.py
|
||||
"""
|
||||
Attachment handling service for messaging system.
|
||||
|
||||
Handles file upload, validation, storage, and retrieval.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.admin_settings_service import admin_settings_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed MIME types for attachments
|
||||
ALLOWED_MIME_TYPES = {
|
||||
# Images
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
# Documents
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
# Archives
|
||||
"application/zip",
|
||||
# Text
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
}
|
||||
|
||||
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
|
||||
|
||||
# Default max file size in MB
|
||||
DEFAULT_MAX_FILE_SIZE_MB = 10
|
||||
|
||||
|
||||
class MessageAttachmentService:
|
||||
"""Service for handling message attachments."""
|
||||
|
||||
def __init__(self, storage_base: str = "uploads/messages"):
|
||||
self.storage_base = storage_base
|
||||
|
||||
def get_max_file_size_bytes(self, db: Session) -> int:
|
||||
"""Get maximum file size from platform settings."""
|
||||
max_mb = admin_settings_service.get_setting_value(
|
||||
db,
|
||||
"message_attachment_max_size_mb",
|
||||
default=DEFAULT_MAX_FILE_SIZE_MB,
|
||||
)
|
||||
try:
|
||||
max_mb = int(max_mb)
|
||||
except (TypeError, ValueError):
|
||||
max_mb = DEFAULT_MAX_FILE_SIZE_MB
|
||||
return max_mb * 1024 * 1024 # Convert to bytes
|
||||
|
||||
def validate_file_type(self, mime_type: str) -> bool:
|
||||
"""Check if file type is allowed."""
|
||||
return mime_type in ALLOWED_MIME_TYPES
|
||||
|
||||
def is_image(self, mime_type: str) -> bool:
|
||||
"""Check if file is an image."""
|
||||
return mime_type in IMAGE_MIME_TYPES
|
||||
|
||||
async def validate_and_store(
|
||||
self,
|
||||
db: Session,
|
||||
file: UploadFile,
|
||||
conversation_id: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and store an uploaded file.
|
||||
|
||||
Returns dict with file metadata for MessageAttachment creation.
|
||||
|
||||
Raises:
|
||||
ValueError: If file type or size is invalid
|
||||
"""
|
||||
# Validate MIME type
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
if not self.validate_file_type(content_type):
|
||||
raise ValueError(
|
||||
f"File type '{content_type}' not allowed. "
|
||||
"Allowed types: images (JPEG, PNG, GIF, WebP), "
|
||||
"PDF, Office documents, ZIP, text files."
|
||||
)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Validate file size
|
||||
max_size = self.get_max_file_size_bytes(db)
|
||||
if file_size > max_size:
|
||||
raise ValueError(
|
||||
f"File size {file_size / 1024 / 1024:.1f}MB exceeds "
|
||||
f"maximum allowed size of {max_size / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
|
||||
# Generate unique filename
|
||||
original_filename = file.filename or "attachment"
|
||||
ext = Path(original_filename).suffix.lower()
|
||||
unique_filename = f"{uuid.uuid4().hex}{ext}"
|
||||
|
||||
# Create storage path: uploads/messages/YYYY/MM/conversation_id/filename
|
||||
now = datetime.utcnow()
|
||||
relative_path = os.path.join(
|
||||
self.storage_base,
|
||||
str(now.year),
|
||||
f"{now.month:02d}",
|
||||
str(conversation_id),
|
||||
)
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(relative_path, exist_ok=True)
|
||||
|
||||
# Full file path
|
||||
file_path = os.path.join(relative_path, unique_filename)
|
||||
|
||||
# Write file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Prepare metadata
|
||||
is_image = self.is_image(content_type)
|
||||
metadata = {
|
||||
"filename": unique_filename,
|
||||
"original_filename": original_filename,
|
||||
"file_path": file_path,
|
||||
"file_size": file_size,
|
||||
"mime_type": content_type,
|
||||
"is_image": is_image,
|
||||
}
|
||||
|
||||
# Generate thumbnail for images
|
||||
if is_image:
|
||||
thumbnail_data = self._create_thumbnail(content, file_path)
|
||||
metadata.update(thumbnail_data)
|
||||
|
||||
logger.info(
|
||||
f"Stored attachment {unique_filename} for conversation {conversation_id} "
|
||||
f"({file_size} bytes, type: {content_type})"
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def _create_thumbnail(self, content: bytes, original_path: str) -> dict:
|
||||
"""Create thumbnail for image attachments."""
|
||||
try:
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
img = Image.open(io.BytesIO(content))
|
||||
width, height = img.size
|
||||
|
||||
# Create thumbnail
|
||||
img.thumbnail((200, 200))
|
||||
|
||||
thumb_path = original_path.replace(".", "_thumb.")
|
||||
img.save(thumb_path)
|
||||
|
||||
return {
|
||||
"image_width": width,
|
||||
"image_height": height,
|
||||
"thumbnail_path": thumb_path,
|
||||
}
|
||||
except ImportError:
|
||||
logger.warning("PIL not installed, skipping thumbnail generation")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create thumbnail: {e}")
|
||||
return {}
|
||||
|
||||
def delete_attachment(
|
||||
self, file_path: str, thumbnail_path: str | None = None
|
||||
) -> bool:
|
||||
"""Delete attachment files from storage."""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logger.info(f"Deleted attachment file: {file_path}")
|
||||
|
||||
if thumbnail_path and os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
logger.info(f"Deleted thumbnail: {thumbnail_path}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete attachment {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def get_download_url(self, file_path: str) -> str:
|
||||
"""
|
||||
Get download URL for an attachment.
|
||||
|
||||
For local storage, returns a relative path that can be served
|
||||
by the static file handler or a dedicated download endpoint.
|
||||
"""
|
||||
# Convert local path to URL path
|
||||
# Assumes files are served from /static/uploads or similar
|
||||
return f"/static/{file_path}"
|
||||
|
||||
def get_file_content(self, file_path: str) -> bytes | None:
|
||||
"""Read file content from storage."""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
message_attachment_service = MessageAttachmentService()
|
||||
684
app/modules/messaging/services/messaging_service.py
Normal file
684
app/modules/messaging/services/messaging_service.py
Normal file
@@ -0,0 +1,684 @@
|
||||
# app/modules/messaging/services/messaging_service.py
|
||||
"""
|
||||
Messaging service for conversation and message management.
|
||||
|
||||
Provides functionality for:
|
||||
- Creating conversations between different participant types
|
||||
- Sending messages with attachments
|
||||
- Managing read status and unread counts
|
||||
- Conversation listing with filters
|
||||
- Multi-tenant data isolation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.modules.messaging.models.message import (
|
||||
Conversation,
|
||||
ConversationParticipant,
|
||||
ConversationType,
|
||||
Message,
|
||||
MessageAttachment,
|
||||
ParticipantType,
|
||||
)
|
||||
from models.database.customer import Customer
|
||||
from models.database.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagingService:
|
||||
"""Service for managing conversations and messages."""
|
||||
|
||||
# =========================================================================
|
||||
# CONVERSATION MANAGEMENT
|
||||
# =========================================================================
|
||||
|
||||
def create_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_type: ConversationType,
|
||||
subject: str,
|
||||
initiator_type: ParticipantType,
|
||||
initiator_id: int,
|
||||
recipient_type: ParticipantType,
|
||||
recipient_id: int,
|
||||
vendor_id: int | None = None,
|
||||
initial_message: str | None = None,
|
||||
) -> Conversation:
|
||||
"""
|
||||
Create a new conversation between two participants.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
conversation_type: Type of conversation channel
|
||||
subject: Conversation subject line
|
||||
initiator_type: Type of initiating participant
|
||||
initiator_id: ID of initiating participant
|
||||
recipient_type: Type of receiving participant
|
||||
recipient_id: ID of receiving participant
|
||||
vendor_id: Required for vendor_customer/admin_customer types
|
||||
initial_message: Optional first message content
|
||||
|
||||
Returns:
|
||||
Created Conversation object
|
||||
"""
|
||||
# Validate vendor_id requirement
|
||||
if conversation_type in [
|
||||
ConversationType.VENDOR_CUSTOMER,
|
||||
ConversationType.ADMIN_CUSTOMER,
|
||||
]:
|
||||
if not vendor_id:
|
||||
raise ValueError(
|
||||
f"vendor_id required for {conversation_type.value} conversations"
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation = Conversation(
|
||||
conversation_type=conversation_type,
|
||||
subject=subject,
|
||||
vendor_id=vendor_id,
|
||||
)
|
||||
db.add(conversation)
|
||||
db.flush()
|
||||
|
||||
# Add participants
|
||||
initiator_vendor_id = (
|
||||
vendor_id if initiator_type == ParticipantType.VENDOR else None
|
||||
)
|
||||
recipient_vendor_id = (
|
||||
vendor_id if recipient_type == ParticipantType.VENDOR else None
|
||||
)
|
||||
|
||||
initiator = ConversationParticipant(
|
||||
conversation_id=conversation.id,
|
||||
participant_type=initiator_type,
|
||||
participant_id=initiator_id,
|
||||
vendor_id=initiator_vendor_id,
|
||||
unread_count=0, # Initiator has read their own message
|
||||
)
|
||||
recipient = ConversationParticipant(
|
||||
conversation_id=conversation.id,
|
||||
participant_type=recipient_type,
|
||||
participant_id=recipient_id,
|
||||
vendor_id=recipient_vendor_id,
|
||||
unread_count=1 if initial_message else 0,
|
||||
)
|
||||
|
||||
db.add(initiator)
|
||||
db.add(recipient)
|
||||
db.flush()
|
||||
|
||||
# Add initial message if provided
|
||||
if initial_message:
|
||||
self.send_message(
|
||||
db=db,
|
||||
conversation_id=conversation.id,
|
||||
sender_type=initiator_type,
|
||||
sender_id=initiator_id,
|
||||
content=initial_message,
|
||||
_skip_unread_update=True, # Already set above
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {conversation_type.value} conversation {conversation.id}: "
|
||||
f"{initiator_type.value}:{initiator_id} -> {recipient_type.value}:{recipient_id}"
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
def get_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
) -> Conversation | None:
|
||||
"""
|
||||
Get conversation if participant has access.
|
||||
|
||||
Validates that the requester is a participant.
|
||||
"""
|
||||
conversation = (
|
||||
db.query(Conversation)
|
||||
.options(
|
||||
joinedload(Conversation.participants),
|
||||
joinedload(Conversation.messages).joinedload(Message.attachments),
|
||||
)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
# Verify participant access
|
||||
has_access = any(
|
||||
p.participant_type == participant_type
|
||||
and p.participant_id == participant_id
|
||||
for p in conversation.participants
|
||||
)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
f"Access denied to conversation {conversation_id} for "
|
||||
f"{participant_type.value}:{participant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
return conversation
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
db: Session,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
vendor_id: int | None = None,
|
||||
conversation_type: ConversationType | None = None,
|
||||
is_closed: bool | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
) -> tuple[list[Conversation], int, int]:
|
||||
"""
|
||||
List conversations for a participant with filters.
|
||||
|
||||
Returns:
|
||||
Tuple of (conversations, total_count, total_unread)
|
||||
"""
|
||||
# Base query: conversations where user is a participant
|
||||
query = (
|
||||
db.query(Conversation)
|
||||
.join(ConversationParticipant)
|
||||
.filter(
|
||||
and_(
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Multi-tenant filter for vendor users
|
||||
if participant_type == ParticipantType.VENDOR and vendor_id:
|
||||
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
|
||||
|
||||
# Customer vendor isolation
|
||||
if participant_type == ParticipantType.CUSTOMER and vendor_id:
|
||||
query = query.filter(Conversation.vendor_id == vendor_id)
|
||||
|
||||
# Type filter
|
||||
if conversation_type:
|
||||
query = query.filter(Conversation.conversation_type == conversation_type)
|
||||
|
||||
# Status filter
|
||||
if is_closed is not None:
|
||||
query = query.filter(Conversation.is_closed == is_closed)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get total unread across all conversations
|
||||
unread_query = db.query(
|
||||
func.sum(ConversationParticipant.unread_count)
|
||||
).filter(
|
||||
and_(
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if participant_type == ParticipantType.VENDOR and vendor_id:
|
||||
unread_query = unread_query.filter(
|
||||
ConversationParticipant.vendor_id == vendor_id
|
||||
)
|
||||
|
||||
total_unread = unread_query.scalar() or 0
|
||||
|
||||
# Get paginated results, ordered by last activity
|
||||
conversations = (
|
||||
query.options(joinedload(Conversation.participants))
|
||||
.order_by(Conversation.last_message_at.desc().nullslast())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return conversations, total, total_unread
|
||||
|
||||
def close_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
closer_type: ParticipantType,
|
||||
closer_id: int,
|
||||
) -> Conversation | None:
|
||||
"""Close a conversation thread."""
|
||||
conversation = self.get_conversation(
|
||||
db, conversation_id, closer_type, closer_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
conversation.is_closed = True
|
||||
conversation.closed_at = datetime.now(UTC)
|
||||
conversation.closed_by_type = closer_type
|
||||
conversation.closed_by_id = closer_id
|
||||
|
||||
# Add system message
|
||||
self.send_message(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
sender_type=closer_type,
|
||||
sender_id=closer_id,
|
||||
content="Conversation closed",
|
||||
is_system_message=True,
|
||||
)
|
||||
|
||||
db.flush()
|
||||
return conversation
|
||||
|
||||
def reopen_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
opener_type: ParticipantType,
|
||||
opener_id: int,
|
||||
) -> Conversation | None:
|
||||
"""Reopen a closed conversation."""
|
||||
conversation = self.get_conversation(
|
||||
db, conversation_id, opener_type, opener_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
conversation.is_closed = False
|
||||
conversation.closed_at = None
|
||||
conversation.closed_by_type = None
|
||||
conversation.closed_by_id = None
|
||||
|
||||
# Add system message
|
||||
self.send_message(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
sender_type=opener_type,
|
||||
sender_id=opener_id,
|
||||
content="Conversation reopened",
|
||||
is_system_message=True,
|
||||
)
|
||||
|
||||
db.flush()
|
||||
return conversation
|
||||
|
||||
# =========================================================================
|
||||
# MESSAGE MANAGEMENT
|
||||
# =========================================================================
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
sender_type: ParticipantType,
|
||||
sender_id: int,
|
||||
content: str,
|
||||
attachments: list[dict[str, Any]] | None = None,
|
||||
is_system_message: bool = False,
|
||||
_skip_unread_update: bool = False,
|
||||
) -> Message:
|
||||
"""
|
||||
Send a message in a conversation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
conversation_id: Target conversation ID
|
||||
sender_type: Type of sender
|
||||
sender_id: ID of sender
|
||||
content: Message text content
|
||||
attachments: List of attachment dicts with file metadata
|
||||
is_system_message: Whether this is a system-generated message
|
||||
_skip_unread_update: Internal flag to skip unread increment
|
||||
|
||||
Returns:
|
||||
Created Message object
|
||||
"""
|
||||
# Create message
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
sender_type=sender_type,
|
||||
sender_id=sender_id,
|
||||
content=content,
|
||||
is_system_message=is_system_message,
|
||||
)
|
||||
db.add(message)
|
||||
db.flush()
|
||||
|
||||
# Add attachments if any
|
||||
if attachments:
|
||||
for att_data in attachments:
|
||||
attachment = MessageAttachment(
|
||||
message_id=message.id,
|
||||
filename=att_data["filename"],
|
||||
original_filename=att_data["original_filename"],
|
||||
file_path=att_data["file_path"],
|
||||
file_size=att_data["file_size"],
|
||||
mime_type=att_data["mime_type"],
|
||||
is_image=att_data.get("is_image", False),
|
||||
image_width=att_data.get("image_width"),
|
||||
image_height=att_data.get("image_height"),
|
||||
thumbnail_path=att_data.get("thumbnail_path"),
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
# Update conversation metadata
|
||||
conversation = (
|
||||
db.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
)
|
||||
|
||||
if conversation:
|
||||
conversation.last_message_at = datetime.now(UTC)
|
||||
conversation.message_count += 1
|
||||
|
||||
# Update unread counts for other participants
|
||||
if not _skip_unread_update:
|
||||
db.query(ConversationParticipant).filter(
|
||||
and_(
|
||||
ConversationParticipant.conversation_id == conversation_id,
|
||||
or_(
|
||||
ConversationParticipant.participant_type != sender_type,
|
||||
ConversationParticipant.participant_id != sender_id,
|
||||
),
|
||||
)
|
||||
).update(
|
||||
{
|
||||
ConversationParticipant.unread_count: ConversationParticipant.unread_count
|
||||
+ 1
|
||||
}
|
||||
)
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Message {message.id} sent in conversation {conversation_id} "
|
||||
f"by {sender_type.value}:{sender_id}"
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
def delete_message(
|
||||
self,
|
||||
db: Session,
|
||||
message_id: int,
|
||||
deleter_type: ParticipantType,
|
||||
deleter_id: int,
|
||||
) -> Message | None:
|
||||
"""Soft delete a message (for moderation)."""
|
||||
message = db.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
if not message:
|
||||
return None
|
||||
|
||||
# Verify deleter has access to conversation
|
||||
conversation = self.get_conversation(
|
||||
db, message.conversation_id, deleter_type, deleter_id
|
||||
)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
message.is_deleted = True
|
||||
message.deleted_at = datetime.now(UTC)
|
||||
message.deleted_by_type = deleter_type
|
||||
message.deleted_by_id = deleter_id
|
||||
|
||||
db.flush()
|
||||
return message
|
||||
|
||||
def mark_conversation_read(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
reader_type: ParticipantType,
|
||||
reader_id: int,
|
||||
) -> bool:
|
||||
"""Mark all messages in conversation as read for participant."""
|
||||
result = (
|
||||
db.query(ConversationParticipant)
|
||||
.filter(
|
||||
and_(
|
||||
ConversationParticipant.conversation_id == conversation_id,
|
||||
ConversationParticipant.participant_type == reader_type,
|
||||
ConversationParticipant.participant_id == reader_id,
|
||||
)
|
||||
)
|
||||
.update(
|
||||
{
|
||||
ConversationParticipant.unread_count: 0,
|
||||
ConversationParticipant.last_read_at: datetime.now(UTC),
|
||||
}
|
||||
)
|
||||
)
|
||||
db.flush()
|
||||
return result > 0
|
||||
|
||||
def get_unread_count(
|
||||
self,
|
||||
db: Session,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
vendor_id: int | None = None,
|
||||
) -> int:
|
||||
"""Get total unread message count for a participant."""
|
||||
query = db.query(func.sum(ConversationParticipant.unread_count)).filter(
|
||||
and_(
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
|
||||
|
||||
return query.scalar() or 0
|
||||
|
||||
# =========================================================================
|
||||
# PARTICIPANT HELPERS
|
||||
# =========================================================================
|
||||
|
||||
def get_participant_info(
|
||||
self,
|
||||
db: Session,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get display info for a participant (name, email, avatar)."""
|
||||
if participant_type in [ParticipantType.ADMIN, ParticipantType.VENDOR]:
|
||||
user = db.query(User).filter(User.id == participant_id).first()
|
||||
if user:
|
||||
return {
|
||||
"id": user.id,
|
||||
"type": participant_type.value,
|
||||
"name": f"{user.first_name or ''} {user.last_name or ''}".strip()
|
||||
or user.username,
|
||||
"email": user.email,
|
||||
"avatar_url": None, # Could add avatar support later
|
||||
}
|
||||
elif participant_type == ParticipantType.CUSTOMER:
|
||||
customer = db.query(Customer).filter(Customer.id == participant_id).first()
|
||||
if customer:
|
||||
return {
|
||||
"id": customer.id,
|
||||
"type": participant_type.value,
|
||||
"name": f"{customer.first_name or ''} {customer.last_name or ''}".strip()
|
||||
or customer.email,
|
||||
"email": customer.email,
|
||||
"avatar_url": None,
|
||||
}
|
||||
return None
|
||||
|
||||
def get_other_participant(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
my_type: ParticipantType,
|
||||
my_id: int,
|
||||
) -> ConversationParticipant | None:
|
||||
"""Get the other participant in a conversation."""
|
||||
for p in conversation.participants:
|
||||
if p.participant_type != my_type or p.participant_id != my_id:
|
||||
return p
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# NOTIFICATION PREFERENCES
|
||||
# =========================================================================
|
||||
|
||||
def update_notification_preferences(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
email_notifications: bool | None = None,
|
||||
muted: bool | None = None,
|
||||
) -> bool:
|
||||
"""Update notification preferences for a participant in a conversation."""
|
||||
updates = {}
|
||||
if email_notifications is not None:
|
||||
updates[ConversationParticipant.email_notifications] = email_notifications
|
||||
if muted is not None:
|
||||
updates[ConversationParticipant.muted] = muted
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
result = (
|
||||
db.query(ConversationParticipant)
|
||||
.filter(
|
||||
and_(
|
||||
ConversationParticipant.conversation_id == conversation_id,
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
.update(updates)
|
||||
)
|
||||
db.flush()
|
||||
return result > 0
|
||||
|
||||
# =========================================================================
|
||||
# RECIPIENT QUERIES
|
||||
# =========================================================================
|
||||
|
||||
def get_vendor_recipients(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get list of vendor users as potential recipients.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter
|
||||
search: Search term for name/email
|
||||
skip: Pagination offset
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Tuple of (recipients list, total count)
|
||||
"""
|
||||
from models.database.vendor import VendorUser
|
||||
|
||||
query = (
|
||||
db.query(User, VendorUser)
|
||||
.join(VendorUser, User.id == VendorUser.user_id)
|
||||
.filter(User.is_active == True) # noqa: E712
|
||||
)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(VendorUser.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
(User.username.ilike(search_pattern))
|
||||
| (User.email.ilike(search_pattern))
|
||||
| (User.first_name.ilike(search_pattern))
|
||||
| (User.last_name.ilike(search_pattern))
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
results = query.offset(skip).limit(limit).all()
|
||||
|
||||
recipients = []
|
||||
for user, vendor_user in results:
|
||||
name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username
|
||||
recipients.append({
|
||||
"id": user.id,
|
||||
"type": ParticipantType.VENDOR,
|
||||
"name": name,
|
||||
"email": user.email,
|
||||
"vendor_id": vendor_user.vendor_id,
|
||||
"vendor_name": vendor_user.vendor.name if vendor_user.vendor else None,
|
||||
})
|
||||
|
||||
return recipients, total
|
||||
|
||||
def get_customer_recipients(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get list of customers as potential recipients.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter (required for vendor users)
|
||||
search: Search term for name/email
|
||||
skip: Pagination offset
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Tuple of (recipients list, total count)
|
||||
"""
|
||||
query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
(Customer.email.ilike(search_pattern))
|
||||
| (Customer.first_name.ilike(search_pattern))
|
||||
| (Customer.last_name.ilike(search_pattern))
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
results = query.offset(skip).limit(limit).all()
|
||||
|
||||
recipients = []
|
||||
for customer in results:
|
||||
name = f"{customer.first_name or ''} {customer.last_name or ''}".strip()
|
||||
recipients.append({
|
||||
"id": customer.id,
|
||||
"type": ParticipantType.CUSTOMER,
|
||||
"name": name or customer.email,
|
||||
"email": customer.email,
|
||||
"vendor_id": customer.vendor_id,
|
||||
})
|
||||
|
||||
return recipients, total
|
||||
|
||||
|
||||
# Singleton instance
|
||||
messaging_service = MessagingService()
|
||||
@@ -2,10 +2,10 @@
|
||||
"""
|
||||
Monitoring module services.
|
||||
|
||||
Re-exports monitoring-related services from their source locations.
|
||||
This module contains the canonical implementations of monitoring-related services.
|
||||
"""
|
||||
|
||||
from app.services.background_tasks_service import (
|
||||
from app.modules.monitoring.services.background_tasks_service import (
|
||||
background_tasks_service,
|
||||
BackgroundTasksService,
|
||||
)
|
||||
|
||||
194
app/modules/monitoring/services/background_tasks_service.py
Normal file
194
app/modules/monitoring/services/background_tasks_service.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# app/modules/monitoring/services/background_tasks_service.py
|
||||
"""
|
||||
Background Tasks Service
|
||||
Service for monitoring background tasks across the system
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import case, desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.architecture_scan import ArchitectureScan
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.test_run import TestRun
|
||||
|
||||
|
||||
class BackgroundTasksService:
|
||||
"""Service for monitoring background tasks"""
|
||||
|
||||
def get_import_jobs(
|
||||
self, db: Session, status: str | None = None, limit: int = 50
|
||||
) -> list[MarketplaceImportJob]:
|
||||
"""Get import jobs with optional status filter"""
|
||||
query = db.query(MarketplaceImportJob)
|
||||
if status:
|
||||
query = query.filter(MarketplaceImportJob.status == status)
|
||||
return query.order_by(desc(MarketplaceImportJob.created_at)).limit(limit).all()
|
||||
|
||||
def get_test_runs(
|
||||
self, db: Session, status: str | None = None, limit: int = 50
|
||||
) -> list[TestRun]:
|
||||
"""Get test runs with optional status filter"""
|
||||
query = db.query(TestRun)
|
||||
if status:
|
||||
query = query.filter(TestRun.status == status)
|
||||
return query.order_by(desc(TestRun.timestamp)).limit(limit).all()
|
||||
|
||||
def get_running_imports(self, db: Session) -> list[MarketplaceImportJob]:
|
||||
"""Get currently running import jobs"""
|
||||
return (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.status == "processing")
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_running_test_runs(self, db: Session) -> list[TestRun]:
|
||||
"""Get currently running test runs"""
|
||||
# noqa: SVC-005 - Platform-level, TestRuns not vendor-scoped
|
||||
return db.query(TestRun).filter(TestRun.status == "running").all()
|
||||
|
||||
def get_import_stats(self, db: Session) -> dict:
|
||||
"""Get import job statistics"""
|
||||
today_start = datetime.now(UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
stats = db.query(
|
||||
func.count(MarketplaceImportJob.id).label("total"),
|
||||
func.sum(
|
||||
case((MarketplaceImportJob.status == "processing", 1), else_=0)
|
||||
).label("running"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
MarketplaceImportJob.status.in_(
|
||||
["completed", "completed_with_errors"]
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("completed"),
|
||||
func.sum(
|
||||
case((MarketplaceImportJob.status == "failed", 1), else_=0)
|
||||
).label("failed"),
|
||||
).first()
|
||||
|
||||
today_count = (
|
||||
db.query(func.count(MarketplaceImportJob.id))
|
||||
.filter(MarketplaceImportJob.created_at >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total": stats.total or 0,
|
||||
"running": stats.running or 0,
|
||||
"completed": stats.completed or 0,
|
||||
"failed": stats.failed or 0,
|
||||
"today": today_count,
|
||||
}
|
||||
|
||||
def get_test_run_stats(self, db: Session) -> dict:
|
||||
"""Get test run statistics"""
|
||||
today_start = datetime.now(UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
stats = db.query(
|
||||
func.count(TestRun.id).label("total"),
|
||||
func.sum(case((TestRun.status == "running", 1), else_=0)).label(
|
||||
"running"
|
||||
),
|
||||
func.sum(case((TestRun.status == "passed", 1), else_=0)).label(
|
||||
"completed"
|
||||
),
|
||||
func.sum(
|
||||
case((TestRun.status.in_(["failed", "error"]), 1), else_=0)
|
||||
).label("failed"),
|
||||
func.avg(TestRun.duration_seconds).label("avg_duration"),
|
||||
).first()
|
||||
|
||||
today_count = (
|
||||
db.query(func.count(TestRun.id))
|
||||
.filter(TestRun.timestamp >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total": stats.total or 0,
|
||||
"running": stats.running or 0,
|
||||
"completed": stats.completed or 0,
|
||||
"failed": stats.failed or 0,
|
||||
"today": today_count,
|
||||
"avg_duration": round(stats.avg_duration or 0, 1),
|
||||
}
|
||||
|
||||
def get_code_quality_scans(
|
||||
self, db: Session, status: str | None = None, limit: int = 50
|
||||
) -> list[ArchitectureScan]:
|
||||
"""Get code quality scans with optional status filter"""
|
||||
query = db.query(ArchitectureScan)
|
||||
if status:
|
||||
query = query.filter(ArchitectureScan.status == status)
|
||||
return query.order_by(desc(ArchitectureScan.timestamp)).limit(limit).all()
|
||||
|
||||
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
|
||||
"""Get currently running code quality scans"""
|
||||
return (
|
||||
db.query(ArchitectureScan)
|
||||
.filter(ArchitectureScan.status.in_(["pending", "running"]))
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_scan_stats(self, db: Session) -> dict:
|
||||
"""Get code quality scan statistics"""
|
||||
today_start = datetime.now(UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
stats = db.query(
|
||||
func.count(ArchitectureScan.id).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(ArchitectureScan.status.in_(["pending", "running"]), 1), else_=0
|
||||
)
|
||||
).label("running"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
ArchitectureScan.status.in_(
|
||||
["completed", "completed_with_warnings"]
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("completed"),
|
||||
func.sum(
|
||||
case((ArchitectureScan.status == "failed", 1), else_=0)
|
||||
).label("failed"),
|
||||
func.avg(ArchitectureScan.duration_seconds).label("avg_duration"),
|
||||
).first()
|
||||
|
||||
today_count = (
|
||||
db.query(func.count(ArchitectureScan.id))
|
||||
.filter(ArchitectureScan.timestamp >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total": stats.total or 0,
|
||||
"running": stats.running or 0,
|
||||
"completed": stats.completed or 0,
|
||||
"failed": stats.failed or 0,
|
||||
"today": today_count,
|
||||
"avg_duration": round(stats.avg_duration or 0, 1),
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
background_tasks_service = BackgroundTasksService()
|
||||
@@ -2,15 +2,12 @@
|
||||
"""
|
||||
Orders module database models.
|
||||
|
||||
Re-exports order-related models from their source locations.
|
||||
This module contains the canonical implementations of order-related models.
|
||||
"""
|
||||
|
||||
from models.database.order import (
|
||||
Order,
|
||||
OrderItem,
|
||||
)
|
||||
from models.database.order_item_exception import OrderItemException
|
||||
from models.database.invoice import (
|
||||
from app.modules.orders.models.order import Order, OrderItem
|
||||
from app.modules.orders.models.order_item_exception import OrderItemException
|
||||
from app.modules.orders.models.invoice import (
|
||||
Invoice,
|
||||
InvoiceStatus,
|
||||
VATRegime,
|
||||
|
||||
215
app/modules/orders/models/invoice.py
Normal file
215
app/modules/orders/models/invoice.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# app/modules/orders/models/invoice.py
|
||||
"""
|
||||
Invoice database models for the OMS.
|
||||
|
||||
Provides models for:
|
||||
- VendorInvoiceSettings: Per-vendor invoice configuration (company details, VAT, numbering)
|
||||
- Invoice: Invoice records with snapshots of seller/buyer details
|
||||
"""
|
||||
|
||||
import enum
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.sqlite import JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
from models.database.base import TimestampMixin
|
||||
|
||||
|
||||
class VendorInvoiceSettings(Base, TimestampMixin):
|
||||
"""
|
||||
Per-vendor invoice configuration.
|
||||
|
||||
Stores company details, VAT number, invoice numbering preferences,
|
||||
and payment information for invoice generation.
|
||||
|
||||
One-to-one relationship with Vendor.
|
||||
"""
|
||||
|
||||
__tablename__ = "vendor_invoice_settings"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
vendor_id = Column(
|
||||
Integer, ForeignKey("vendors.id"), unique=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
# Legal company details for invoice header
|
||||
company_name = Column(String(255), nullable=False) # Legal name for invoices
|
||||
company_address = Column(String(255), nullable=True) # Street address
|
||||
company_city = Column(String(100), nullable=True)
|
||||
company_postal_code = Column(String(20), nullable=True)
|
||||
company_country = Column(String(2), nullable=False, default="LU") # ISO country code
|
||||
|
||||
# VAT information
|
||||
vat_number = Column(String(50), nullable=True) # e.g., "LU12345678"
|
||||
is_vat_registered = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# OSS (One-Stop-Shop) for EU VAT
|
||||
is_oss_registered = Column(Boolean, default=False, nullable=False)
|
||||
oss_registration_country = Column(String(2), nullable=True) # ISO country code
|
||||
|
||||
# Invoice numbering
|
||||
invoice_prefix = Column(String(20), default="INV", nullable=False)
|
||||
invoice_next_number = Column(Integer, default=1, nullable=False)
|
||||
invoice_number_padding = Column(Integer, default=5, nullable=False) # e.g., INV00001
|
||||
|
||||
# Payment information
|
||||
payment_terms = Column(Text, nullable=True) # e.g., "Payment due within 30 days"
|
||||
bank_name = Column(String(255), nullable=True)
|
||||
bank_iban = Column(String(50), nullable=True)
|
||||
bank_bic = Column(String(20), nullable=True)
|
||||
|
||||
# Invoice footer
|
||||
footer_text = Column(Text, nullable=True) # Custom footer text
|
||||
|
||||
# Default VAT rate for Luxembourg invoices (17% standard)
|
||||
default_vat_rate = Column(Numeric(5, 2), default=17.00, nullable=False)
|
||||
|
||||
# Relationships
|
||||
vendor = relationship("Vendor", back_populates="invoice_settings")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<VendorInvoiceSettings(vendor_id={self.vendor_id}, company='{self.company_name}')>"
|
||||
|
||||
def get_next_invoice_number(self) -> str:
|
||||
"""Generate the next invoice number and increment counter."""
|
||||
number = str(self.invoice_next_number).zfill(self.invoice_number_padding)
|
||||
return f"{self.invoice_prefix}{number}"
|
||||
|
||||
|
||||
class InvoiceStatus(str, enum.Enum):
|
||||
"""Invoice status enumeration."""
|
||||
|
||||
DRAFT = "draft"
|
||||
ISSUED = "issued"
|
||||
PAID = "paid"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class VATRegime(str, enum.Enum):
|
||||
"""VAT regime for invoice calculation."""
|
||||
|
||||
DOMESTIC = "domestic" # Same country as seller
|
||||
OSS = "oss" # EU cross-border with OSS registration
|
||||
REVERSE_CHARGE = "reverse_charge" # B2B with valid VAT number
|
||||
ORIGIN = "origin" # Cross-border without OSS (use origin VAT)
|
||||
EXEMPT = "exempt" # VAT exempt
|
||||
|
||||
|
||||
class Invoice(Base, TimestampMixin):
|
||||
"""
|
||||
Invoice record with snapshots of seller/buyer details.
|
||||
|
||||
Stores complete invoice data including snapshots of seller and buyer
|
||||
details at time of creation for audit purposes.
|
||||
"""
|
||||
|
||||
__tablename__ = "invoices"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
|
||||
order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True)
|
||||
|
||||
# Invoice identification
|
||||
invoice_number = Column(String(50), nullable=False)
|
||||
invoice_date = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Status
|
||||
status = Column(String(20), default=InvoiceStatus.DRAFT.value, nullable=False)
|
||||
|
||||
# Seller details snapshot (captured at invoice creation)
|
||||
seller_details = Column(JSON, nullable=False)
|
||||
# Structure: {
|
||||
# "company_name": str,
|
||||
# "address": str,
|
||||
# "city": str,
|
||||
# "postal_code": str,
|
||||
# "country": str,
|
||||
# "vat_number": str | None
|
||||
# }
|
||||
|
||||
# Buyer details snapshot (captured at invoice creation)
|
||||
buyer_details = Column(JSON, nullable=False)
|
||||
# Structure: {
|
||||
# "name": str,
|
||||
# "email": str,
|
||||
# "address": str,
|
||||
# "city": str,
|
||||
# "postal_code": str,
|
||||
# "country": str,
|
||||
# "vat_number": str | None (for B2B)
|
||||
# }
|
||||
|
||||
# Line items snapshot
|
||||
line_items = Column(JSON, nullable=False)
|
||||
# Structure: [{
|
||||
# "description": str,
|
||||
# "quantity": int,
|
||||
# "unit_price_cents": int,
|
||||
# "total_cents": int,
|
||||
# "sku": str | None,
|
||||
# "ean": str | None
|
||||
# }]
|
||||
|
||||
# VAT information
|
||||
vat_regime = Column(String(20), default=VATRegime.DOMESTIC.value, nullable=False)
|
||||
destination_country = Column(String(2), nullable=True) # For OSS invoices
|
||||
vat_rate = Column(Numeric(5, 2), nullable=False) # e.g., 17.00 for 17%
|
||||
vat_rate_label = Column(String(50), nullable=True) # e.g., "Luxembourg Standard VAT"
|
||||
|
||||
# Amounts (stored in cents for precision)
|
||||
currency = Column(String(3), default="EUR", nullable=False)
|
||||
subtotal_cents = Column(Integer, nullable=False) # Before VAT
|
||||
vat_amount_cents = Column(Integer, nullable=False) # VAT amount
|
||||
total_cents = Column(Integer, nullable=False) # After VAT
|
||||
|
||||
# Payment information
|
||||
payment_terms = Column(Text, nullable=True)
|
||||
bank_details = Column(JSON, nullable=True) # IBAN, BIC snapshot
|
||||
footer_text = Column(Text, nullable=True)
|
||||
|
||||
# PDF storage
|
||||
pdf_generated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
pdf_path = Column(String(500), nullable=True) # Path to stored PDF
|
||||
|
||||
# Notes
|
||||
notes = Column(Text, nullable=True) # Internal notes
|
||||
|
||||
# Relationships
|
||||
vendor = relationship("Vendor", back_populates="invoices")
|
||||
order = relationship("Order", back_populates="invoices")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_invoice_vendor_number", "vendor_id", "invoice_number", unique=True),
|
||||
Index("idx_invoice_vendor_date", "vendor_id", "invoice_date"),
|
||||
Index("idx_invoice_status", "vendor_id", "status"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Invoice(id={self.id}, number='{self.invoice_number}', status='{self.status}')>"
|
||||
|
||||
@property
|
||||
def subtotal(self) -> float:
|
||||
"""Get subtotal in EUR."""
|
||||
return self.subtotal_cents / 100
|
||||
|
||||
@property
|
||||
def vat_amount(self) -> float:
|
||||
"""Get VAT amount in EUR."""
|
||||
return self.vat_amount_cents / 100
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
"""Get total in EUR."""
|
||||
return self.total_cents / 100
|
||||
406
app/modules/orders/models/order.py
Normal file
406
app/modules/orders/models/order.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# app/modules/orders/models/order.py
|
||||
"""
|
||||
Unified Order model for all sales channels.
|
||||
|
||||
Supports:
|
||||
- Direct orders (from vendor's own storefront)
|
||||
- Marketplace orders (Letzshop, etc.)
|
||||
|
||||
Design principles:
|
||||
- Customer/address data is snapshotted at order time (preserves history)
|
||||
- customer_id FK links to Customer record (may be inactive for marketplace imports)
|
||||
- channel field distinguishes order source
|
||||
- external_* fields store marketplace-specific references
|
||||
|
||||
Money values are stored as integer cents (e.g., €105.91 = 10591).
|
||||
See docs/architecture/money-handling.md for details.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.modules.orders.models.order_item_exception import OrderItemException
|
||||
from sqlalchemy.dialects.sqlite import JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
from app.utils.money import cents_to_euros, euros_to_cents
|
||||
from models.database.base import TimestampMixin
|
||||
|
||||
|
||||
class Order(Base, TimestampMixin):
|
||||
"""
|
||||
Unified order model for all sales channels.
|
||||
|
||||
Stores orders from direct sales and marketplaces (Letzshop, etc.)
|
||||
with snapshotted customer and address data.
|
||||
|
||||
All monetary amounts are stored as integer cents for precision.
|
||||
"""
|
||||
|
||||
__tablename__ = "orders"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True)
|
||||
customer_id = Column(
|
||||
Integer, ForeignKey("customers.id"), nullable=False, index=True
|
||||
)
|
||||
order_number = Column(String(100), nullable=False, unique=True, index=True)
|
||||
|
||||
# === Channel/Source ===
|
||||
channel = Column(
|
||||
String(50), default="direct", nullable=False, index=True
|
||||
) # direct, letzshop
|
||||
|
||||
# External references (for marketplace orders)
|
||||
external_order_id = Column(
|
||||
String(100), nullable=True, index=True
|
||||
) # Marketplace order ID
|
||||
external_shipment_id = Column(
|
||||
String(100), nullable=True, index=True
|
||||
) # Marketplace shipment ID
|
||||
external_order_number = Column(String(100), nullable=True) # Marketplace order #
|
||||
external_data = Column(JSON, nullable=True) # Raw marketplace data for debugging
|
||||
|
||||
# === Status ===
|
||||
# pending: awaiting confirmation
|
||||
# processing: confirmed, being prepared
|
||||
# shipped: shipped with tracking
|
||||
# delivered: delivered to customer
|
||||
# cancelled: order cancelled/declined
|
||||
# refunded: order refunded
|
||||
status = Column(String(50), nullable=False, default="pending", index=True)
|
||||
|
||||
# === Financials (stored as integer cents) ===
|
||||
subtotal_cents = Column(Integer, nullable=True) # May not be available from marketplace
|
||||
tax_amount_cents = Column(Integer, nullable=True)
|
||||
shipping_amount_cents = Column(Integer, nullable=True)
|
||||
discount_amount_cents = Column(Integer, nullable=True)
|
||||
total_amount_cents = Column(Integer, nullable=False)
|
||||
currency = Column(String(10), default="EUR")
|
||||
|
||||
# === VAT Information ===
|
||||
# VAT regime: domestic, oss, reverse_charge, origin, exempt
|
||||
vat_regime = Column(String(20), nullable=True)
|
||||
# VAT rate as percentage (e.g., 17.00 for 17%)
|
||||
vat_rate = Column(Numeric(5, 2), nullable=True)
|
||||
# Human-readable VAT label (e.g., "Luxembourg VAT 17%")
|
||||
vat_rate_label = Column(String(100), nullable=True)
|
||||
# Destination country for cross-border sales (ISO code)
|
||||
vat_destination_country = Column(String(2), nullable=True)
|
||||
|
||||
# === Customer Snapshot (preserved at order time) ===
|
||||
customer_first_name = Column(String(100), nullable=False)
|
||||
customer_last_name = Column(String(100), nullable=False)
|
||||
customer_email = Column(String(255), nullable=False)
|
||||
customer_phone = Column(String(50), nullable=True)
|
||||
customer_locale = Column(String(10), nullable=True) # en, fr, de, lb
|
||||
|
||||
# === Shipping Address Snapshot ===
|
||||
ship_first_name = Column(String(100), nullable=False)
|
||||
ship_last_name = Column(String(100), nullable=False)
|
||||
ship_company = Column(String(200), nullable=True)
|
||||
ship_address_line_1 = Column(String(255), nullable=False)
|
||||
ship_address_line_2 = Column(String(255), nullable=True)
|
||||
ship_city = Column(String(100), nullable=False)
|
||||
ship_postal_code = Column(String(20), nullable=False)
|
||||
ship_country_iso = Column(String(5), nullable=False)
|
||||
|
||||
# === Billing Address Snapshot ===
|
||||
bill_first_name = Column(String(100), nullable=False)
|
||||
bill_last_name = Column(String(100), nullable=False)
|
||||
bill_company = Column(String(200), nullable=True)
|
||||
bill_address_line_1 = Column(String(255), nullable=False)
|
||||
bill_address_line_2 = Column(String(255), nullable=True)
|
||||
bill_city = Column(String(100), nullable=False)
|
||||
bill_postal_code = Column(String(20), nullable=False)
|
||||
bill_country_iso = Column(String(5), nullable=False)
|
||||
|
||||
# === Tracking ===
|
||||
shipping_method = Column(String(100), nullable=True)
|
||||
tracking_number = Column(String(100), nullable=True)
|
||||
tracking_provider = Column(String(100), nullable=True)
|
||||
tracking_url = Column(String(500), nullable=True) # Full tracking URL
|
||||
shipment_number = Column(String(100), nullable=True) # Carrier shipment number (e.g., H74683403433)
|
||||
shipping_carrier = Column(String(50), nullable=True) # Carrier code (greco, colissimo, etc.)
|
||||
|
||||
# === Notes ===
|
||||
customer_notes = Column(Text, nullable=True)
|
||||
internal_notes = Column(Text, nullable=True)
|
||||
|
||||
# === Timestamps ===
|
||||
order_date = Column(
|
||||
DateTime(timezone=True), nullable=False
|
||||
) # When customer placed order
|
||||
confirmed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
shipped_at = Column(DateTime(timezone=True), nullable=True)
|
||||
delivered_at = Column(DateTime(timezone=True), nullable=True)
|
||||
cancelled_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# === Relationships ===
|
||||
vendor = relationship("Vendor")
|
||||
customer = relationship("Customer", back_populates="orders")
|
||||
items = relationship(
|
||||
"OrderItem", back_populates="order", cascade="all, delete-orphan"
|
||||
)
|
||||
invoices = relationship(
|
||||
"Invoice", back_populates="order", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Composite indexes for common queries
|
||||
__table_args__ = (
|
||||
Index("idx_order_vendor_status", "vendor_id", "status"),
|
||||
Index("idx_order_vendor_channel", "vendor_id", "channel"),
|
||||
Index("idx_order_vendor_date", "vendor_id", "order_date"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Order(id={self.id}, order_number='{self.order_number}', channel='{self.channel}', status='{self.status}')>"
|
||||
|
||||
# === PRICE PROPERTIES (Euro convenience accessors) ===
|
||||
|
||||
@property
|
||||
def subtotal(self) -> float | None:
|
||||
"""Get subtotal in euros."""
|
||||
if self.subtotal_cents is not None:
|
||||
return cents_to_euros(self.subtotal_cents)
|
||||
return None
|
||||
|
||||
@subtotal.setter
|
||||
def subtotal(self, value: float | None):
|
||||
"""Set subtotal from euros."""
|
||||
self.subtotal_cents = euros_to_cents(value) if value is not None else None
|
||||
|
||||
@property
|
||||
def tax_amount(self) -> float | None:
|
||||
"""Get tax amount in euros."""
|
||||
if self.tax_amount_cents is not None:
|
||||
return cents_to_euros(self.tax_amount_cents)
|
||||
return None
|
||||
|
||||
@tax_amount.setter
|
||||
def tax_amount(self, value: float | None):
|
||||
"""Set tax amount from euros."""
|
||||
self.tax_amount_cents = euros_to_cents(value) if value is not None else None
|
||||
|
||||
@property
|
||||
def shipping_amount(self) -> float | None:
|
||||
"""Get shipping amount in euros."""
|
||||
if self.shipping_amount_cents is not None:
|
||||
return cents_to_euros(self.shipping_amount_cents)
|
||||
return None
|
||||
|
||||
@shipping_amount.setter
|
||||
def shipping_amount(self, value: float | None):
|
||||
"""Set shipping amount from euros."""
|
||||
self.shipping_amount_cents = euros_to_cents(value) if value is not None else None
|
||||
|
||||
@property
|
||||
def discount_amount(self) -> float | None:
|
||||
"""Get discount amount in euros."""
|
||||
if self.discount_amount_cents is not None:
|
||||
return cents_to_euros(self.discount_amount_cents)
|
||||
return None
|
||||
|
||||
@discount_amount.setter
|
||||
def discount_amount(self, value: float | None):
|
||||
"""Set discount amount from euros."""
|
||||
self.discount_amount_cents = euros_to_cents(value) if value is not None else None
|
||||
|
||||
@property
|
||||
def total_amount(self) -> float:
|
||||
"""Get total amount in euros."""
|
||||
return cents_to_euros(self.total_amount_cents)
|
||||
|
||||
@total_amount.setter
|
||||
def total_amount(self, value: float):
|
||||
"""Set total amount from euros."""
|
||||
self.total_amount_cents = euros_to_cents(value)
|
||||
|
||||
# === NAME PROPERTIES ===
|
||||
|
||||
@property
|
||||
def customer_full_name(self) -> str:
|
||||
"""Customer full name from snapshot."""
|
||||
return f"{self.customer_first_name} {self.customer_last_name}".strip()
|
||||
|
||||
@property
|
||||
def ship_full_name(self) -> str:
|
||||
"""Shipping address full name."""
|
||||
return f"{self.ship_first_name} {self.ship_last_name}".strip()
|
||||
|
||||
@property
|
||||
def bill_full_name(self) -> str:
|
||||
"""Billing address full name."""
|
||||
return f"{self.bill_first_name} {self.bill_last_name}".strip()
|
||||
|
||||
@property
|
||||
def is_marketplace_order(self) -> bool:
|
||||
"""Check if this is a marketplace order."""
|
||||
return self.channel != "direct"
|
||||
|
||||
@property
|
||||
def is_fully_shipped(self) -> bool:
|
||||
"""Check if all items are fully shipped."""
|
||||
if not self.items:
|
||||
return False
|
||||
return all(item.is_fully_shipped for item in self.items)
|
||||
|
||||
@property
|
||||
def is_partially_shipped(self) -> bool:
|
||||
"""Check if some items are shipped but not all."""
|
||||
if not self.items:
|
||||
return False
|
||||
has_shipped = any(item.shipped_quantity > 0 for item in self.items)
|
||||
all_shipped = all(item.is_fully_shipped for item in self.items)
|
||||
return has_shipped and not all_shipped
|
||||
|
||||
@property
|
||||
def shipped_item_count(self) -> int:
|
||||
"""Count of fully shipped items."""
|
||||
return sum(1 for item in self.items if item.is_fully_shipped)
|
||||
|
||||
@property
|
||||
def total_shipped_units(self) -> int:
|
||||
"""Total quantity shipped across all items."""
|
||||
return sum(item.shipped_quantity for item in self.items)
|
||||
|
||||
@property
|
||||
def total_ordered_units(self) -> int:
|
||||
"""Total quantity ordered across all items."""
|
||||
return sum(item.quantity for item in self.items)
|
||||
|
||||
|
||||
class OrderItem(Base, TimestampMixin):
|
||||
"""
|
||||
Individual items in an order.
|
||||
|
||||
Stores product snapshot at time of order plus external references
|
||||
for marketplace items.
|
||||
|
||||
All monetary amounts are stored as integer cents for precision.
|
||||
"""
|
||||
|
||||
__tablename__ = "order_items"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
order_id = Column(Integer, ForeignKey("orders.id"), nullable=False, index=True)
|
||||
product_id = Column(Integer, ForeignKey("products.id"), nullable=False)
|
||||
|
||||
# === Product Snapshot (preserved at order time) ===
|
||||
product_name = Column(String(255), nullable=False)
|
||||
product_sku = Column(String(100), nullable=True)
|
||||
gtin = Column(String(50), nullable=True) # EAN/UPC/ISBN etc.
|
||||
gtin_type = Column(String(20), nullable=True) # ean13, upc, isbn, etc.
|
||||
|
||||
# === Pricing (stored as integer cents) ===
|
||||
quantity = Column(Integer, nullable=False)
|
||||
unit_price_cents = Column(Integer, nullable=False)
|
||||
total_price_cents = Column(Integer, nullable=False)
|
||||
|
||||
# === External References (for marketplace items) ===
|
||||
external_item_id = Column(String(100), nullable=True) # e.g., Letzshop inventory unit ID
|
||||
external_variant_id = Column(String(100), nullable=True) # e.g., Letzshop variant ID
|
||||
|
||||
# === Item State (for marketplace confirmation flow) ===
|
||||
# confirmed_available: item confirmed and available
|
||||
# confirmed_unavailable: item confirmed but not available (declined)
|
||||
item_state = Column(String(50), nullable=True)
|
||||
|
||||
# === Inventory Tracking ===
|
||||
inventory_reserved = Column(Boolean, default=False)
|
||||
inventory_fulfilled = Column(Boolean, default=False)
|
||||
|
||||
# === Shipment Tracking ===
|
||||
shipped_quantity = Column(Integer, default=0, nullable=False) # Units shipped so far
|
||||
|
||||
# === Exception Tracking ===
|
||||
# True if product was not found by GTIN during import (linked to placeholder)
|
||||
needs_product_match = Column(Boolean, default=False, index=True)
|
||||
|
||||
# === Relationships ===
|
||||
order = relationship("Order", back_populates="items")
|
||||
product = relationship("Product")
|
||||
exception = relationship(
|
||||
"OrderItemException",
|
||||
back_populates="order_item",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OrderItem(id={self.id}, order_id={self.order_id}, product_id={self.product_id}, gtin='{self.gtin}')>"
|
||||
|
||||
# === PRICE PROPERTIES (Euro convenience accessors) ===
|
||||
|
||||
@property
|
||||
def unit_price(self) -> float:
|
||||
"""Get unit price in euros."""
|
||||
return cents_to_euros(self.unit_price_cents)
|
||||
|
||||
@unit_price.setter
|
||||
def unit_price(self, value: float):
|
||||
"""Set unit price from euros."""
|
||||
self.unit_price_cents = euros_to_cents(value)
|
||||
|
||||
@property
|
||||
def total_price(self) -> float:
|
||||
"""Get total price in euros."""
|
||||
return cents_to_euros(self.total_price_cents)
|
||||
|
||||
@total_price.setter
|
||||
def total_price(self, value: float):
|
||||
"""Set total price from euros."""
|
||||
self.total_price_cents = euros_to_cents(value)
|
||||
|
||||
# === STATUS PROPERTIES ===
|
||||
|
||||
@property
|
||||
def is_confirmed(self) -> bool:
|
||||
"""Check if item has been confirmed (available or unavailable)."""
|
||||
return self.item_state in ("confirmed_available", "confirmed_unavailable")
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if item is confirmed as available."""
|
||||
return self.item_state == "confirmed_available"
|
||||
|
||||
@property
|
||||
def is_declined(self) -> bool:
|
||||
"""Check if item was declined (unavailable)."""
|
||||
return self.item_state == "confirmed_unavailable"
|
||||
|
||||
@property
|
||||
def has_unresolved_exception(self) -> bool:
|
||||
"""Check if item has an unresolved exception blocking confirmation."""
|
||||
if not self.exception:
|
||||
return False
|
||||
return self.exception.blocks_confirmation
|
||||
|
||||
# === SHIPMENT PROPERTIES ===
|
||||
|
||||
@property
|
||||
def remaining_quantity(self) -> int:
|
||||
"""Quantity not yet shipped."""
|
||||
return max(0, self.quantity - self.shipped_quantity)
|
||||
|
||||
@property
|
||||
def is_fully_shipped(self) -> bool:
|
||||
"""Check if all units have been shipped."""
|
||||
return self.shipped_quantity >= self.quantity
|
||||
|
||||
@property
|
||||
def is_partially_shipped(self) -> bool:
|
||||
"""Check if some but not all units have been shipped."""
|
||||
return 0 < self.shipped_quantity < self.quantity
|
||||
117
app/modules/orders/models/order_item_exception.py
Normal file
117
app/modules/orders/models/order_item_exception.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# app/modules/orders/models/order_item_exception.py
|
||||
"""
|
||||
Order Item Exception model for tracking unmatched products during marketplace imports.
|
||||
|
||||
When a marketplace order contains a GTIN that doesn't match any product in the
|
||||
vendor's catalog, the order is still imported but the item is linked to a
|
||||
placeholder product and an exception is recorded here for resolution.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.database import Base
|
||||
from models.database.base import TimestampMixin
|
||||
|
||||
|
||||
class OrderItemException(Base, TimestampMixin):
|
||||
"""
|
||||
Tracks unmatched order items requiring admin/vendor resolution.
|
||||
|
||||
When a marketplace order is imported and a product cannot be found by GTIN,
|
||||
the order item is linked to a placeholder product and this exception record
|
||||
is created. The order cannot be confirmed until all exceptions are resolved.
|
||||
"""
|
||||
|
||||
__tablename__ = "order_item_exceptions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Link to the order item (one-to-one)
|
||||
order_item_id = Column(
|
||||
Integer,
|
||||
ForeignKey("order_items.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
# Vendor ID for efficient querying (denormalized from order)
|
||||
vendor_id = Column(
|
||||
Integer, ForeignKey("vendors.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Original data from marketplace (preserved for matching)
|
||||
original_gtin = Column(String(50), nullable=True, index=True)
|
||||
original_product_name = Column(String(500), nullable=True)
|
||||
original_sku = Column(String(100), nullable=True)
|
||||
|
||||
# Exception classification
|
||||
# product_not_found: GTIN not in vendor catalog
|
||||
# gtin_mismatch: GTIN format issue
|
||||
# duplicate_gtin: Multiple products with same GTIN
|
||||
exception_type = Column(
|
||||
String(50), nullable=False, default="product_not_found"
|
||||
)
|
||||
|
||||
# Resolution status
|
||||
# pending: Awaiting resolution
|
||||
# resolved: Product has been assigned
|
||||
# ignored: Marked as ignored (still blocks confirmation)
|
||||
status = Column(String(50), nullable=False, default="pending", index=True)
|
||||
|
||||
# Resolution details (populated when resolved)
|
||||
resolved_product_id = Column(
|
||||
Integer, ForeignKey("products.id"), nullable=True
|
||||
)
|
||||
resolved_at = Column(DateTime(timezone=True), nullable=True)
|
||||
resolved_by = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||||
resolution_notes = Column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
order_item = relationship("OrderItem", back_populates="exception")
|
||||
vendor = relationship("Vendor")
|
||||
resolved_product = relationship("Product")
|
||||
resolver = relationship("User")
|
||||
|
||||
# Composite indexes for common queries
|
||||
__table_args__ = (
|
||||
Index("idx_exception_vendor_status", "vendor_id", "status"),
|
||||
Index("idx_exception_gtin", "vendor_id", "original_gtin"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<OrderItemException(id={self.id}, "
|
||||
f"order_item_id={self.order_item_id}, "
|
||||
f"gtin='{self.original_gtin}', "
|
||||
f"status='{self.status}')>"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
"""Check if exception is pending resolution."""
|
||||
return self.status == "pending"
|
||||
|
||||
@property
|
||||
def is_resolved(self) -> bool:
|
||||
"""Check if exception has been resolved."""
|
||||
return self.status == "resolved"
|
||||
|
||||
@property
|
||||
def is_ignored(self) -> bool:
|
||||
"""Check if exception has been ignored."""
|
||||
return self.status == "ignored"
|
||||
|
||||
@property
|
||||
def blocks_confirmation(self) -> bool:
|
||||
"""Check if this exception blocks order confirmation."""
|
||||
# Both pending and ignored exceptions block confirmation
|
||||
return self.status in ("pending", "ignored")
|
||||
@@ -2,34 +2,133 @@
|
||||
"""
|
||||
Orders module Pydantic schemas.
|
||||
|
||||
Re-exports order-related schemas from their source locations.
|
||||
This module contains the canonical implementations of order-related schemas.
|
||||
"""
|
||||
|
||||
from models.schema.order import (
|
||||
OrderCreate,
|
||||
OrderItemCreate,
|
||||
OrderResponse,
|
||||
OrderItemResponse,
|
||||
OrderListResponse,
|
||||
from app.modules.orders.schemas.order import (
|
||||
# Address schemas
|
||||
AddressSnapshot,
|
||||
AddressSnapshotResponse,
|
||||
# Order item schemas
|
||||
OrderItemCreate,
|
||||
OrderItemExceptionBrief,
|
||||
OrderItemResponse,
|
||||
# Customer schemas
|
||||
CustomerSnapshot,
|
||||
CustomerSnapshotResponse,
|
||||
# Order CRUD schemas
|
||||
OrderCreate,
|
||||
OrderUpdate,
|
||||
OrderTrackingUpdate,
|
||||
OrderItemStateUpdate,
|
||||
# Order response schemas
|
||||
OrderResponse,
|
||||
OrderDetailResponse,
|
||||
OrderListResponse,
|
||||
OrderListItem,
|
||||
# Admin schemas
|
||||
AdminOrderItem,
|
||||
AdminOrderListResponse,
|
||||
AdminOrderStats,
|
||||
AdminOrderStatusUpdate,
|
||||
AdminVendorWithOrders,
|
||||
AdminVendorsWithOrdersResponse,
|
||||
# Letzshop schemas
|
||||
LetzshopOrderImport,
|
||||
LetzshopShippingInfo,
|
||||
LetzshopOrderConfirmItem,
|
||||
LetzshopOrderConfirmRequest,
|
||||
# Shipping schemas
|
||||
MarkAsShippedRequest,
|
||||
ShippingLabelInfo,
|
||||
)
|
||||
from models.schema.invoice import (
|
||||
|
||||
from app.modules.orders.schemas.invoice import (
|
||||
# Invoice settings schemas
|
||||
VendorInvoiceSettingsCreate,
|
||||
VendorInvoiceSettingsUpdate,
|
||||
VendorInvoiceSettingsResponse,
|
||||
# Line item schemas
|
||||
InvoiceLineItem,
|
||||
InvoiceLineItemResponse,
|
||||
# Address schemas
|
||||
InvoiceSellerDetails,
|
||||
InvoiceBuyerDetails,
|
||||
# Invoice CRUD schemas
|
||||
InvoiceCreate,
|
||||
InvoiceManualCreate,
|
||||
InvoiceResponse,
|
||||
InvoiceListResponse,
|
||||
InvoiceStatusUpdate,
|
||||
# Pagination
|
||||
InvoiceListPaginatedResponse,
|
||||
# PDF
|
||||
InvoicePDFGeneratedResponse,
|
||||
InvoiceStatsResponse,
|
||||
# Backward compatibility
|
||||
InvoiceSettingsCreate,
|
||||
InvoiceSettingsUpdate,
|
||||
InvoiceSettingsResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OrderCreate",
|
||||
"OrderItemCreate",
|
||||
"OrderResponse",
|
||||
"OrderItemResponse",
|
||||
"OrderListResponse",
|
||||
# Address schemas
|
||||
"AddressSnapshot",
|
||||
"AddressSnapshotResponse",
|
||||
# Order item schemas
|
||||
"OrderItemCreate",
|
||||
"OrderItemExceptionBrief",
|
||||
"OrderItemResponse",
|
||||
# Customer schemas
|
||||
"CustomerSnapshot",
|
||||
"CustomerSnapshotResponse",
|
||||
# Order CRUD schemas
|
||||
"OrderCreate",
|
||||
"OrderUpdate",
|
||||
"OrderTrackingUpdate",
|
||||
"OrderItemStateUpdate",
|
||||
# Order response schemas
|
||||
"OrderResponse",
|
||||
"OrderDetailResponse",
|
||||
"OrderListResponse",
|
||||
"OrderListItem",
|
||||
# Admin schemas
|
||||
"AdminOrderItem",
|
||||
"AdminOrderListResponse",
|
||||
"AdminOrderStats",
|
||||
"AdminOrderStatusUpdate",
|
||||
"AdminVendorWithOrders",
|
||||
"AdminVendorsWithOrdersResponse",
|
||||
# Letzshop schemas
|
||||
"LetzshopOrderImport",
|
||||
"LetzshopShippingInfo",
|
||||
"LetzshopOrderConfirmItem",
|
||||
"LetzshopOrderConfirmRequest",
|
||||
# Shipping schemas
|
||||
"MarkAsShippedRequest",
|
||||
"ShippingLabelInfo",
|
||||
# Invoice settings schemas
|
||||
"VendorInvoiceSettingsCreate",
|
||||
"VendorInvoiceSettingsUpdate",
|
||||
"VendorInvoiceSettingsResponse",
|
||||
# Line item schemas
|
||||
"InvoiceLineItem",
|
||||
"InvoiceLineItemResponse",
|
||||
# Invoice address schemas
|
||||
"InvoiceSellerDetails",
|
||||
"InvoiceBuyerDetails",
|
||||
# Invoice CRUD schemas
|
||||
"InvoiceCreate",
|
||||
"InvoiceManualCreate",
|
||||
"InvoiceResponse",
|
||||
"InvoiceListResponse",
|
||||
"InvoiceStatusUpdate",
|
||||
# Pagination
|
||||
"InvoiceListPaginatedResponse",
|
||||
# PDF
|
||||
"InvoicePDFGeneratedResponse",
|
||||
"InvoiceStatsResponse",
|
||||
# Backward compatibility
|
||||
"InvoiceSettingsCreate",
|
||||
"InvoiceSettingsUpdate",
|
||||
"InvoiceSettingsResponse",
|
||||
|
||||
316
app/modules/orders/schemas/invoice.py
Normal file
316
app/modules/orders/schemas/invoice.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# app/modules/orders/schemas/invoice.py
|
||||
"""
|
||||
Pydantic schemas for invoice operations.
|
||||
|
||||
Supports invoice settings management and invoice generation.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ============================================================================
|
||||
# Invoice Settings Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VendorInvoiceSettingsCreate(BaseModel):
|
||||
"""Schema for creating vendor invoice settings."""
|
||||
|
||||
company_name: str = Field(..., min_length=1, max_length=255)
|
||||
company_address: str | None = Field(None, max_length=255)
|
||||
company_city: str | None = Field(None, max_length=100)
|
||||
company_postal_code: str | None = Field(None, max_length=20)
|
||||
company_country: str = Field(default="LU", min_length=2, max_length=2)
|
||||
|
||||
vat_number: str | None = Field(None, max_length=50)
|
||||
is_vat_registered: bool = True
|
||||
|
||||
is_oss_registered: bool = False
|
||||
oss_registration_country: str | None = Field(None, min_length=2, max_length=2)
|
||||
|
||||
invoice_prefix: str = Field(default="INV", max_length=20)
|
||||
invoice_number_padding: int = Field(default=5, ge=1, le=10)
|
||||
|
||||
payment_terms: str | None = None
|
||||
bank_name: str | None = Field(None, max_length=255)
|
||||
bank_iban: str | None = Field(None, max_length=50)
|
||||
bank_bic: str | None = Field(None, max_length=20)
|
||||
|
||||
footer_text: str | None = None
|
||||
default_vat_rate: Decimal = Field(default=Decimal("17.00"), ge=0, le=100)
|
||||
|
||||
|
||||
class VendorInvoiceSettingsUpdate(BaseModel):
|
||||
"""Schema for updating vendor invoice settings."""
|
||||
|
||||
company_name: str | None = Field(None, min_length=1, max_length=255)
|
||||
company_address: str | None = Field(None, max_length=255)
|
||||
company_city: str | None = Field(None, max_length=100)
|
||||
company_postal_code: str | None = Field(None, max_length=20)
|
||||
company_country: str | None = Field(None, min_length=2, max_length=2)
|
||||
|
||||
vat_number: str | None = None
|
||||
is_vat_registered: bool | None = None
|
||||
|
||||
is_oss_registered: bool | None = None
|
||||
oss_registration_country: str | None = None
|
||||
|
||||
invoice_prefix: str | None = Field(None, max_length=20)
|
||||
invoice_number_padding: int | None = Field(None, ge=1, le=10)
|
||||
|
||||
payment_terms: str | None = None
|
||||
bank_name: str | None = Field(None, max_length=255)
|
||||
bank_iban: str | None = Field(None, max_length=50)
|
||||
bank_bic: str | None = Field(None, max_length=20)
|
||||
|
||||
footer_text: str | None = None
|
||||
default_vat_rate: Decimal | None = Field(None, ge=0, le=100)
|
||||
|
||||
|
||||
class VendorInvoiceSettingsResponse(BaseModel):
|
||||
"""Schema for vendor invoice settings response."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
|
||||
company_name: str
|
||||
company_address: str | None
|
||||
company_city: str | None
|
||||
company_postal_code: str | None
|
||||
company_country: str
|
||||
|
||||
vat_number: str | None
|
||||
is_vat_registered: bool
|
||||
|
||||
is_oss_registered: bool
|
||||
oss_registration_country: str | None
|
||||
|
||||
invoice_prefix: str
|
||||
invoice_next_number: int
|
||||
invoice_number_padding: int
|
||||
|
||||
payment_terms: str | None
|
||||
bank_name: str | None
|
||||
bank_iban: str | None
|
||||
bank_bic: str | None
|
||||
|
||||
footer_text: str | None
|
||||
default_vat_rate: Decimal
|
||||
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Invoice Line Item Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class InvoiceLineItem(BaseModel):
|
||||
"""Schema for invoice line item."""
|
||||
|
||||
description: str
|
||||
quantity: int = Field(..., ge=1)
|
||||
unit_price_cents: int
|
||||
total_cents: int
|
||||
sku: str | None = None
|
||||
ean: str | None = None
|
||||
|
||||
|
||||
class InvoiceLineItemResponse(BaseModel):
|
||||
"""Schema for invoice line item in response."""
|
||||
|
||||
description: str
|
||||
quantity: int
|
||||
unit_price_cents: int
|
||||
total_cents: int
|
||||
sku: str | None = None
|
||||
ean: str | None = None
|
||||
|
||||
@property
|
||||
def unit_price(self) -> float:
|
||||
return self.unit_price_cents / 100
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return self.total_cents / 100
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Invoice Address Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class InvoiceSellerDetails(BaseModel):
|
||||
"""Seller details for invoice."""
|
||||
|
||||
company_name: str
|
||||
address: str | None = None
|
||||
city: str | None = None
|
||||
postal_code: str | None = None
|
||||
country: str
|
||||
vat_number: str | None = None
|
||||
|
||||
|
||||
class InvoiceBuyerDetails(BaseModel):
|
||||
"""Buyer details for invoice."""
|
||||
|
||||
name: str
|
||||
email: str | None = None
|
||||
address: str | None = None
|
||||
city: str | None = None
|
||||
postal_code: str | None = None
|
||||
country: str
|
||||
vat_number: str | None = None # For B2B
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Invoice Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class InvoiceCreate(BaseModel):
|
||||
"""Schema for creating an invoice from an order."""
|
||||
|
||||
order_id: int
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class InvoiceManualCreate(BaseModel):
|
||||
"""Schema for creating a manual invoice (without order)."""
|
||||
|
||||
buyer_details: InvoiceBuyerDetails
|
||||
line_items: list[InvoiceLineItem]
|
||||
notes: str | None = None
|
||||
payment_terms: str | None = None
|
||||
|
||||
|
||||
class InvoiceResponse(BaseModel):
|
||||
"""Schema for invoice response."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
order_id: int | None
|
||||
|
||||
invoice_number: str
|
||||
invoice_date: datetime
|
||||
status: str
|
||||
|
||||
seller_details: dict
|
||||
buyer_details: dict
|
||||
line_items: list[dict]
|
||||
|
||||
vat_regime: str
|
||||
destination_country: str | None
|
||||
vat_rate: Decimal
|
||||
vat_rate_label: str | None
|
||||
|
||||
currency: str
|
||||
subtotal_cents: int
|
||||
vat_amount_cents: int
|
||||
total_cents: int
|
||||
|
||||
payment_terms: str | None
|
||||
bank_details: dict | None
|
||||
footer_text: str | None
|
||||
|
||||
pdf_generated_at: datetime | None
|
||||
pdf_path: str | None
|
||||
|
||||
notes: str | None
|
||||
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def subtotal(self) -> float:
|
||||
return self.subtotal_cents / 100
|
||||
|
||||
@property
|
||||
def vat_amount(self) -> float:
|
||||
return self.vat_amount_cents / 100
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return self.total_cents / 100
|
||||
|
||||
|
||||
class InvoiceListResponse(BaseModel):
|
||||
"""Schema for invoice list response (summary)."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
invoice_number: str
|
||||
invoice_date: datetime
|
||||
status: str
|
||||
currency: str
|
||||
total_cents: int
|
||||
order_id: int | None
|
||||
|
||||
# Buyer name for display
|
||||
buyer_name: str | None = None
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return self.total_cents / 100
|
||||
|
||||
|
||||
class InvoiceStatusUpdate(BaseModel):
|
||||
"""Schema for updating invoice status."""
|
||||
|
||||
status: str = Field(..., pattern="^(draft|issued|paid|cancelled)$")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Paginated Response
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class InvoiceListPaginatedResponse(BaseModel):
|
||||
"""Paginated invoice list response."""
|
||||
|
||||
items: list[InvoiceListResponse]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
pages: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PDF Response
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class InvoicePDFGeneratedResponse(BaseModel):
|
||||
"""Response for PDF generation."""
|
||||
|
||||
pdf_path: str
|
||||
message: str = "PDF generated successfully"
|
||||
|
||||
|
||||
class InvoiceStatsResponse(BaseModel):
|
||||
"""Invoice statistics response."""
|
||||
|
||||
total_invoices: int
|
||||
total_revenue_cents: int
|
||||
draft_count: int
|
||||
issued_count: int
|
||||
paid_count: int
|
||||
cancelled_count: int
|
||||
|
||||
@property
|
||||
def total_revenue(self) -> float:
|
||||
return self.total_revenue_cents / 100
|
||||
|
||||
|
||||
# Backward compatibility re-exports
|
||||
InvoiceSettingsCreate = VendorInvoiceSettingsCreate
|
||||
InvoiceSettingsUpdate = VendorInvoiceSettingsUpdate
|
||||
InvoiceSettingsResponse = VendorInvoiceSettingsResponse
|
||||
584
app/modules/orders/schemas/order.py
Normal file
584
app/modules/orders/schemas/order.py
Normal file
@@ -0,0 +1,584 @@
|
||||
# app/modules/orders/schemas/order.py
|
||||
"""
|
||||
Pydantic schemas for unified order operations.
|
||||
|
||||
Supports both direct orders and marketplace orders (Letzshop, etc.)
|
||||
with snapshotted customer and address data.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ============================================================================
|
||||
# Address Snapshot Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AddressSnapshot(BaseModel):
|
||||
"""Address snapshot for order creation."""
|
||||
|
||||
first_name: str = Field(..., min_length=1, max_length=100)
|
||||
last_name: str = Field(..., min_length=1, max_length=100)
|
||||
company: str | None = Field(None, max_length=200)
|
||||
address_line_1: str = Field(..., min_length=1, max_length=255)
|
||||
address_line_2: str | None = Field(None, max_length=255)
|
||||
city: str = Field(..., min_length=1, max_length=100)
|
||||
postal_code: str = Field(..., min_length=1, max_length=20)
|
||||
country_iso: str = Field(..., min_length=2, max_length=5)
|
||||
|
||||
|
||||
class AddressSnapshotResponse(BaseModel):
|
||||
"""Address snapshot in order response."""
|
||||
|
||||
first_name: str
|
||||
last_name: str
|
||||
company: str | None
|
||||
address_line_1: str
|
||||
address_line_2: str | None
|
||||
city: str
|
||||
postal_code: str
|
||||
country_iso: str
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
return f"{self.first_name} {self.last_name}".strip()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Order Item Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OrderItemCreate(BaseModel):
|
||||
"""Schema for creating an order item."""
|
||||
|
||||
product_id: int
|
||||
quantity: int = Field(..., ge=1)
|
||||
|
||||
|
||||
class OrderItemExceptionBrief(BaseModel):
|
||||
"""Brief exception info for embedding in order item responses."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
original_gtin: str | None
|
||||
original_product_name: str | None
|
||||
exception_type: str
|
||||
status: str
|
||||
resolved_product_id: int | None
|
||||
|
||||
|
||||
class OrderItemResponse(BaseModel):
|
||||
"""Schema for order item response."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
order_id: int
|
||||
product_id: int
|
||||
product_name: str
|
||||
product_sku: str | None
|
||||
gtin: str | None
|
||||
gtin_type: str | None
|
||||
quantity: int
|
||||
unit_price: float
|
||||
total_price: float
|
||||
|
||||
# External references (for marketplace items)
|
||||
external_item_id: str | None = None
|
||||
external_variant_id: str | None = None
|
||||
|
||||
# Item state (for marketplace confirmation flow)
|
||||
item_state: str | None = None
|
||||
|
||||
# Inventory tracking
|
||||
inventory_reserved: bool
|
||||
inventory_fulfilled: bool
|
||||
|
||||
# Exception tracking
|
||||
needs_product_match: bool = False
|
||||
exception: OrderItemExceptionBrief | None = None
|
||||
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def is_confirmed(self) -> bool:
|
||||
"""Check if item has been confirmed (available or unavailable)."""
|
||||
return self.item_state in ("confirmed_available", "confirmed_unavailable")
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if item is confirmed as available."""
|
||||
return self.item_state == "confirmed_available"
|
||||
|
||||
@property
|
||||
def is_declined(self) -> bool:
|
||||
"""Check if item was declined (unavailable)."""
|
||||
return self.item_state == "confirmed_unavailable"
|
||||
|
||||
@property
|
||||
def has_unresolved_exception(self) -> bool:
|
||||
"""Check if item has an unresolved exception blocking confirmation."""
|
||||
if not self.exception:
|
||||
return False
|
||||
return self.exception.status in ("pending", "ignored")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Customer Snapshot Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CustomerSnapshot(BaseModel):
|
||||
"""Customer snapshot for order creation."""
|
||||
|
||||
first_name: str = Field(..., min_length=1, max_length=100)
|
||||
last_name: str = Field(..., min_length=1, max_length=100)
|
||||
email: str = Field(..., max_length=255)
|
||||
phone: str | None = Field(None, max_length=50)
|
||||
locale: str | None = Field(None, max_length=10)
|
||||
|
||||
|
||||
class CustomerSnapshotResponse(BaseModel):
|
||||
"""Customer snapshot in order response."""
|
||||
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: str
|
||||
phone: str | None
|
||||
locale: str | None
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
return f"{self.first_name} {self.last_name}".strip()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Order Create/Update Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OrderCreate(BaseModel):
|
||||
"""Schema for creating an order (direct channel)."""
|
||||
|
||||
customer_id: int | None = None # Optional for guest checkout
|
||||
items: list[OrderItemCreate] = Field(..., min_length=1)
|
||||
|
||||
# Customer info snapshot
|
||||
customer: CustomerSnapshot
|
||||
|
||||
# Addresses (snapshots)
|
||||
shipping_address: AddressSnapshot
|
||||
billing_address: AddressSnapshot | None = None # Use shipping if not provided
|
||||
|
||||
# Optional fields
|
||||
shipping_method: str | None = None
|
||||
customer_notes: str | None = Field(None, max_length=1000)
|
||||
|
||||
# Cart/session info
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class OrderUpdate(BaseModel):
|
||||
"""Schema for updating order status."""
|
||||
|
||||
status: str | None = Field(
|
||||
None, pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$"
|
||||
)
|
||||
tracking_number: str | None = None
|
||||
tracking_provider: str | None = None
|
||||
internal_notes: str | None = None
|
||||
|
||||
|
||||
class OrderTrackingUpdate(BaseModel):
|
||||
"""Schema for setting tracking information."""
|
||||
|
||||
tracking_number: str = Field(..., min_length=1, max_length=100)
|
||||
tracking_provider: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class OrderItemStateUpdate(BaseModel):
|
||||
"""Schema for updating item state (marketplace confirmation)."""
|
||||
|
||||
item_id: int
|
||||
state: str = Field(..., pattern="^(confirmed_available|confirmed_unavailable)$")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Order Response Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OrderResponse(BaseModel):
|
||||
"""Schema for order response."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
customer_id: int
|
||||
order_number: str
|
||||
|
||||
# Channel/Source
|
||||
channel: str
|
||||
external_order_id: str | None = None
|
||||
external_shipment_id: str | None = None
|
||||
external_order_number: str | None = None
|
||||
|
||||
# Status
|
||||
status: str
|
||||
|
||||
# Financial
|
||||
subtotal: float | None
|
||||
tax_amount: float | None
|
||||
shipping_amount: float | None
|
||||
discount_amount: float | None
|
||||
total_amount: float
|
||||
currency: str
|
||||
|
||||
# VAT information
|
||||
vat_regime: str | None = None
|
||||
vat_rate: float | None = None
|
||||
vat_rate_label: str | None = None
|
||||
vat_destination_country: str | None = None
|
||||
|
||||
# Customer snapshot
|
||||
customer_first_name: str
|
||||
customer_last_name: str
|
||||
customer_email: str
|
||||
customer_phone: str | None
|
||||
customer_locale: str | None
|
||||
|
||||
# Shipping address snapshot
|
||||
ship_first_name: str
|
||||
ship_last_name: str
|
||||
ship_company: str | None
|
||||
ship_address_line_1: str
|
||||
ship_address_line_2: str | None
|
||||
ship_city: str
|
||||
ship_postal_code: str
|
||||
ship_country_iso: str
|
||||
|
||||
# Billing address snapshot
|
||||
bill_first_name: str
|
||||
bill_last_name: str
|
||||
bill_company: str | None
|
||||
bill_address_line_1: str
|
||||
bill_address_line_2: str | None
|
||||
bill_city: str
|
||||
bill_postal_code: str
|
||||
bill_country_iso: str
|
||||
|
||||
# Tracking
|
||||
shipping_method: str | None
|
||||
tracking_number: str | None
|
||||
tracking_provider: str | None
|
||||
tracking_url: str | None = None
|
||||
shipment_number: str | None = None
|
||||
shipping_carrier: str | None = None
|
||||
|
||||
# Notes
|
||||
customer_notes: str | None
|
||||
internal_notes: str | None
|
||||
|
||||
# Timestamps
|
||||
order_date: datetime
|
||||
confirmed_at: datetime | None
|
||||
shipped_at: datetime | None
|
||||
delivered_at: datetime | None
|
||||
cancelled_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def customer_full_name(self) -> str:
|
||||
return f"{self.customer_first_name} {self.customer_last_name}".strip()
|
||||
|
||||
@property
|
||||
def ship_full_name(self) -> str:
|
||||
return f"{self.ship_first_name} {self.ship_last_name}".strip()
|
||||
|
||||
@property
|
||||
def is_marketplace_order(self) -> bool:
|
||||
return self.channel != "direct"
|
||||
|
||||
|
||||
class OrderDetailResponse(OrderResponse):
|
||||
"""Schema for detailed order response with items."""
|
||||
|
||||
items: list[OrderItemResponse] = []
|
||||
|
||||
# Vendor info (enriched by API)
|
||||
vendor_name: str | None = None
|
||||
vendor_code: str | None = None
|
||||
|
||||
|
||||
class OrderListResponse(BaseModel):
|
||||
"""Schema for paginated order list."""
|
||||
|
||||
orders: list[OrderResponse]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Order List Item (Simplified for list views)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OrderListItem(BaseModel):
|
||||
"""Simplified order item for list views."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
order_number: str
|
||||
channel: str
|
||||
status: str
|
||||
|
||||
# External references
|
||||
external_order_number: str | None = None
|
||||
|
||||
# Customer
|
||||
customer_full_name: str
|
||||
customer_email: str
|
||||
|
||||
# Financial
|
||||
total_amount: float
|
||||
currency: str
|
||||
|
||||
# Shipping
|
||||
ship_country_iso: str
|
||||
|
||||
# Tracking
|
||||
tracking_number: str | None
|
||||
tracking_provider: str | None
|
||||
tracking_url: str | None = None
|
||||
shipment_number: str | None = None
|
||||
shipping_carrier: str | None = None
|
||||
|
||||
# Item count
|
||||
item_count: int = 0
|
||||
|
||||
# Timestamps
|
||||
order_date: datetime
|
||||
confirmed_at: datetime | None
|
||||
shipped_at: datetime | None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Admin Order Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdminOrderItem(BaseModel):
|
||||
"""Order item with vendor info for admin list view."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
vendor_id: int
|
||||
vendor_name: str | None = None
|
||||
vendor_code: str | None = None
|
||||
customer_id: int
|
||||
order_number: str
|
||||
channel: str
|
||||
status: str
|
||||
|
||||
# External references
|
||||
external_order_number: str | None = None
|
||||
external_shipment_id: str | None = None
|
||||
|
||||
# Customer snapshot
|
||||
customer_full_name: str
|
||||
customer_email: str
|
||||
|
||||
# Financial
|
||||
subtotal: float | None
|
||||
tax_amount: float | None
|
||||
shipping_amount: float | None
|
||||
discount_amount: float | None
|
||||
total_amount: float
|
||||
currency: str
|
||||
|
||||
# VAT information
|
||||
vat_regime: str | None = None
|
||||
vat_rate: float | None = None
|
||||
vat_rate_label: str | None = None
|
||||
vat_destination_country: str | None = None
|
||||
|
||||
# Shipping
|
||||
ship_country_iso: str
|
||||
tracking_number: str | None
|
||||
tracking_provider: str | None
|
||||
tracking_url: str | None = None
|
||||
shipment_number: str | None = None
|
||||
shipping_carrier: str | None = None
|
||||
|
||||
# Item count
|
||||
item_count: int = 0
|
||||
|
||||
# Timestamps
|
||||
order_date: datetime
|
||||
confirmed_at: datetime | None
|
||||
shipped_at: datetime | None
|
||||
delivered_at: datetime | None
|
||||
cancelled_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class AdminOrderListResponse(BaseModel):
|
||||
"""Cross-vendor order list for admin."""
|
||||
|
||||
orders: list[AdminOrderItem]
|
||||
total: int
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
|
||||
class AdminOrderStats(BaseModel):
|
||||
"""Order statistics for admin dashboard."""
|
||||
|
||||
total_orders: int = 0
|
||||
pending_orders: int = 0
|
||||
processing_orders: int = 0
|
||||
shipped_orders: int = 0
|
||||
delivered_orders: int = 0
|
||||
cancelled_orders: int = 0
|
||||
refunded_orders: int = 0
|
||||
total_revenue: float = 0.0
|
||||
|
||||
# By channel
|
||||
direct_orders: int = 0
|
||||
letzshop_orders: int = 0
|
||||
|
||||
# Vendors
|
||||
vendors_with_orders: int = 0
|
||||
|
||||
|
||||
class AdminOrderStatusUpdate(BaseModel):
|
||||
"""Admin version of status update with reason."""
|
||||
|
||||
status: str = Field(
|
||||
..., pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$"
|
||||
)
|
||||
tracking_number: str | None = None
|
||||
tracking_provider: str | None = None
|
||||
reason: str | None = Field(None, description="Reason for status change")
|
||||
|
||||
|
||||
class AdminVendorWithOrders(BaseModel):
|
||||
"""Vendor with order count."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
vendor_code: str
|
||||
order_count: int = 0
|
||||
|
||||
|
||||
class AdminVendorsWithOrdersResponse(BaseModel):
|
||||
"""Response for vendors with orders list."""
|
||||
|
||||
vendors: list[AdminVendorWithOrders]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Letzshop-specific Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class LetzshopOrderImport(BaseModel):
|
||||
"""Schema for importing a Letzshop order from shipment data."""
|
||||
|
||||
shipment_id: str
|
||||
order_id: str
|
||||
order_number: str
|
||||
order_date: datetime
|
||||
|
||||
# Customer
|
||||
customer_email: str
|
||||
customer_locale: str | None = None
|
||||
|
||||
# Shipping address
|
||||
ship_first_name: str
|
||||
ship_last_name: str
|
||||
ship_company: str | None = None
|
||||
ship_address_line_1: str
|
||||
ship_address_line_2: str | None = None
|
||||
ship_city: str
|
||||
ship_postal_code: str
|
||||
ship_country_iso: str
|
||||
|
||||
# Billing address
|
||||
bill_first_name: str
|
||||
bill_last_name: str
|
||||
bill_company: str | None = None
|
||||
bill_address_line_1: str
|
||||
bill_address_line_2: str | None = None
|
||||
bill_city: str
|
||||
bill_postal_code: str
|
||||
bill_country_iso: str
|
||||
|
||||
# Totals
|
||||
total_amount: float
|
||||
currency: str = "EUR"
|
||||
|
||||
# State
|
||||
letzshop_state: str # unconfirmed, confirmed, declined
|
||||
|
||||
# Items
|
||||
inventory_units: list[dict]
|
||||
|
||||
# Raw data
|
||||
raw_data: dict | None = None
|
||||
|
||||
|
||||
class LetzshopShippingInfo(BaseModel):
|
||||
"""Shipping info retrieved from Letzshop."""
|
||||
|
||||
tracking_number: str
|
||||
tracking_provider: str
|
||||
shipment_id: str
|
||||
|
||||
|
||||
class LetzshopOrderConfirmItem(BaseModel):
|
||||
"""Schema for confirming/declining a single item."""
|
||||
|
||||
item_id: int
|
||||
external_item_id: str
|
||||
action: str = Field(..., pattern="^(confirm|decline)$")
|
||||
|
||||
|
||||
class LetzshopOrderConfirmRequest(BaseModel):
|
||||
"""Schema for confirming/declining order items."""
|
||||
|
||||
items: list[LetzshopOrderConfirmItem]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mark as Shipped Schemas
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class MarkAsShippedRequest(BaseModel):
|
||||
"""Schema for marking an order as shipped with tracking info."""
|
||||
|
||||
tracking_number: str | None = Field(None, max_length=100)
|
||||
tracking_url: str | None = Field(None, max_length=500)
|
||||
shipping_carrier: str | None = Field(None, max_length=50)
|
||||
|
||||
|
||||
class ShippingLabelInfo(BaseModel):
|
||||
"""Shipping label information for an order."""
|
||||
|
||||
shipment_number: str | None = None
|
||||
shipping_carrier: str | None = None
|
||||
label_url: str | None = None
|
||||
tracking_number: str | None = None
|
||||
tracking_url: str | None = None
|
||||
@@ -2,26 +2,26 @@
|
||||
"""
|
||||
Orders module services.
|
||||
|
||||
Re-exports order-related services from their source locations.
|
||||
This module contains the canonical implementations of order-related services.
|
||||
"""
|
||||
|
||||
from app.services.order_service import (
|
||||
from app.modules.orders.services.order_service import (
|
||||
order_service,
|
||||
OrderService,
|
||||
)
|
||||
from app.services.order_inventory_service import (
|
||||
from app.modules.orders.services.order_inventory_service import (
|
||||
order_inventory_service,
|
||||
OrderInventoryService,
|
||||
)
|
||||
from app.services.order_item_exception_service import (
|
||||
from app.modules.orders.services.order_item_exception_service import (
|
||||
order_item_exception_service,
|
||||
OrderItemExceptionService,
|
||||
)
|
||||
from app.services.invoice_service import (
|
||||
from app.modules.orders.services.invoice_service import (
|
||||
invoice_service,
|
||||
InvoiceService,
|
||||
)
|
||||
from app.services.invoice_pdf_service import (
|
||||
from app.modules.orders.services.invoice_pdf_service import (
|
||||
invoice_pdf_service,
|
||||
InvoicePDFService,
|
||||
)
|
||||
|
||||
150
app/modules/orders/services/invoice_pdf_service.py
Normal file
150
app/modules/orders/services/invoice_pdf_service.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# app/modules/orders/services/invoice_pdf_service.py
|
||||
"""
|
||||
Invoice PDF generation service using WeasyPrint.
|
||||
|
||||
Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint.
|
||||
Stores generated PDFs in the configured storage location.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.orders.models.invoice import Invoice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Template directory
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices"
|
||||
|
||||
# PDF storage directory (relative to project root)
|
||||
PDF_STORAGE_DIR = Path("storage") / "invoices"
|
||||
|
||||
|
||||
class InvoicePDFService:
|
||||
"""Service for generating invoice PDFs."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the PDF service with Jinja2 environment."""
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
def _ensure_storage_dir(self, vendor_id: int) -> Path:
|
||||
"""Ensure the storage directory exists for a vendor."""
|
||||
storage_path = PDF_STORAGE_DIR / str(vendor_id)
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
return storage_path
|
||||
|
||||
def _get_pdf_filename(self, invoice: Invoice) -> str:
|
||||
"""Generate PDF filename for an invoice."""
|
||||
safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-")
|
||||
return f"{safe_number}.pdf"
|
||||
|
||||
def generate_pdf(
|
||||
self,
|
||||
db: Session,
|
||||
invoice: Invoice,
|
||||
force_regenerate: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate PDF for an invoice.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
invoice: Invoice to generate PDF for
|
||||
force_regenerate: If True, regenerate even if PDF already exists
|
||||
|
||||
Returns:
|
||||
Path to the generated PDF file
|
||||
"""
|
||||
# Check if PDF already exists
|
||||
if invoice.pdf_path and not force_regenerate:
|
||||
if Path(invoice.pdf_path).exists():
|
||||
logger.debug(f"PDF already exists for invoice {invoice.invoice_number}")
|
||||
return invoice.pdf_path
|
||||
|
||||
# Ensure storage directory exists
|
||||
storage_dir = self._ensure_storage_dir(invoice.vendor_id)
|
||||
pdf_filename = self._get_pdf_filename(invoice)
|
||||
pdf_path = storage_dir / pdf_filename
|
||||
|
||||
# Render HTML template
|
||||
html_content = self._render_html(invoice)
|
||||
|
||||
# Generate PDF using WeasyPrint
|
||||
try:
|
||||
from weasyprint import HTML
|
||||
|
||||
html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR))
|
||||
html_doc.write_pdf(str(pdf_path))
|
||||
|
||||
logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}")
|
||||
except ImportError:
|
||||
logger.error("WeasyPrint not installed. Install with: pip install weasyprint")
|
||||
raise RuntimeError("WeasyPrint not installed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}")
|
||||
raise
|
||||
|
||||
# Update invoice record with PDF path and timestamp
|
||||
invoice.pdf_path = str(pdf_path)
|
||||
invoice.pdf_generated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
def _render_html(self, invoice: Invoice) -> str:
|
||||
"""Render the invoice HTML template."""
|
||||
template = self.env.get_template("invoice.html")
|
||||
|
||||
context = {
|
||||
"invoice": invoice,
|
||||
"seller": invoice.seller_details,
|
||||
"buyer": invoice.buyer_details,
|
||||
"line_items": invoice.line_items,
|
||||
"bank_details": invoice.bank_details,
|
||||
"payment_terms": invoice.payment_terms,
|
||||
"footer_text": invoice.footer_text,
|
||||
"now": datetime.now(UTC),
|
||||
}
|
||||
|
||||
return template.render(**context)
|
||||
|
||||
def get_pdf_path(self, invoice: Invoice) -> str | None:
|
||||
"""Get the PDF path for an invoice if it exists."""
|
||||
if invoice.pdf_path and Path(invoice.pdf_path).exists():
|
||||
return invoice.pdf_path
|
||||
return None
|
||||
|
||||
def delete_pdf(self, invoice: Invoice, db: Session) -> bool:
|
||||
"""Delete the PDF file for an invoice."""
|
||||
if not invoice.pdf_path:
|
||||
return False
|
||||
|
||||
pdf_path = Path(invoice.pdf_path)
|
||||
if pdf_path.exists():
|
||||
try:
|
||||
pdf_path.unlink()
|
||||
logger.info(f"Deleted PDF for invoice {invoice.invoice_number}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete PDF {pdf_path}: {e}")
|
||||
return False
|
||||
|
||||
invoice.pdf_path = None
|
||||
invoice.pdf_generated_at = None
|
||||
db.flush()
|
||||
|
||||
return True
|
||||
|
||||
def regenerate_pdf(self, db: Session, invoice: Invoice) -> str:
|
||||
"""Force regenerate PDF for an invoice."""
|
||||
return self.generate_pdf(db, invoice, force_regenerate=True)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
invoice_pdf_service = InvoicePDFService()
|
||||
587
app/modules/orders/services/invoice_service.py
Normal file
587
app/modules/orders/services/invoice_service.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# app/modules/orders/services/invoice_service.py
|
||||
"""
|
||||
Invoice service for generating and managing invoices.
|
||||
|
||||
Handles:
|
||||
- Vendor invoice settings management
|
||||
- Invoice generation from orders
|
||||
- VAT calculation (Luxembourg, EU, B2B reverse charge)
|
||||
- Invoice number sequencing
|
||||
- PDF generation (via separate module)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import ValidationException
|
||||
from app.exceptions.invoice import (
|
||||
InvoiceNotFoundException,
|
||||
InvoiceSettingsNotFoundException,
|
||||
OrderNotFoundException,
|
||||
)
|
||||
from app.modules.orders.models.invoice import (
|
||||
Invoice,
|
||||
InvoiceStatus,
|
||||
VATRegime,
|
||||
VendorInvoiceSettings,
|
||||
)
|
||||
from app.modules.orders.models.order import Order
|
||||
from app.modules.orders.schemas.invoice import (
|
||||
VendorInvoiceSettingsCreate,
|
||||
VendorInvoiceSettingsUpdate,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# EU VAT rates by country code (2024 standard rates)
|
||||
EU_VAT_RATES: dict[str, Decimal] = {
|
||||
"AT": Decimal("20.00"),
|
||||
"BE": Decimal("21.00"),
|
||||
"BG": Decimal("20.00"),
|
||||
"HR": Decimal("25.00"),
|
||||
"CY": Decimal("19.00"),
|
||||
"CZ": Decimal("21.00"),
|
||||
"DK": Decimal("25.00"),
|
||||
"EE": Decimal("22.00"),
|
||||
"FI": Decimal("24.00"),
|
||||
"FR": Decimal("20.00"),
|
||||
"DE": Decimal("19.00"),
|
||||
"GR": Decimal("24.00"),
|
||||
"HU": Decimal("27.00"),
|
||||
"IE": Decimal("23.00"),
|
||||
"IT": Decimal("22.00"),
|
||||
"LV": Decimal("21.00"),
|
||||
"LT": Decimal("21.00"),
|
||||
"LU": Decimal("17.00"),
|
||||
"MT": Decimal("18.00"),
|
||||
"NL": Decimal("21.00"),
|
||||
"PL": Decimal("23.00"),
|
||||
"PT": Decimal("23.00"),
|
||||
"RO": Decimal("19.00"),
|
||||
"SK": Decimal("20.00"),
|
||||
"SI": Decimal("22.00"),
|
||||
"ES": Decimal("21.00"),
|
||||
"SE": Decimal("25.00"),
|
||||
}
|
||||
|
||||
LU_VAT_RATES = {
|
||||
"standard": Decimal("17.00"),
|
||||
"intermediate": Decimal("14.00"),
|
||||
"reduced": Decimal("8.00"),
|
||||
"super_reduced": Decimal("3.00"),
|
||||
}
|
||||
|
||||
|
||||
class InvoiceService:
|
||||
"""Service for invoice operations."""
|
||||
|
||||
# =========================================================================
|
||||
# VAT Calculation
|
||||
# =========================================================================
|
||||
|
||||
def get_vat_rate_for_country(self, country_iso: str) -> Decimal:
|
||||
"""Get standard VAT rate for EU country."""
|
||||
return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00"))
|
||||
|
||||
def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str:
|
||||
"""Get human-readable VAT rate label."""
|
||||
country_names = {
|
||||
"AT": "Austria", "BE": "Belgium", "BG": "Bulgaria", "HR": "Croatia",
|
||||
"CY": "Cyprus", "CZ": "Czech Republic", "DK": "Denmark", "EE": "Estonia",
|
||||
"FI": "Finland", "FR": "France", "DE": "Germany", "GR": "Greece",
|
||||
"HU": "Hungary", "IE": "Ireland", "IT": "Italy", "LV": "Latvia",
|
||||
"LT": "Lithuania", "LU": "Luxembourg", "MT": "Malta", "NL": "Netherlands",
|
||||
"PL": "Poland", "PT": "Portugal", "RO": "Romania", "SK": "Slovakia",
|
||||
"SI": "Slovenia", "ES": "Spain", "SE": "Sweden",
|
||||
}
|
||||
country_name = country_names.get(country_iso.upper(), country_iso)
|
||||
return f"{country_name} VAT {vat_rate}%"
|
||||
|
||||
def determine_vat_regime(
|
||||
self,
|
||||
seller_country: str,
|
||||
buyer_country: str,
|
||||
buyer_vat_number: str | None,
|
||||
seller_oss_registered: bool,
|
||||
) -> tuple[VATRegime, Decimal, str | None]:
|
||||
"""Determine VAT regime and rate for invoice."""
|
||||
seller_country = seller_country.upper()
|
||||
buyer_country = buyer_country.upper()
|
||||
|
||||
if seller_country == buyer_country:
|
||||
vat_rate = self.get_vat_rate_for_country(seller_country)
|
||||
return VATRegime.DOMESTIC, vat_rate, None
|
||||
|
||||
if buyer_country in EU_VAT_RATES:
|
||||
if buyer_vat_number:
|
||||
return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country
|
||||
|
||||
if seller_oss_registered:
|
||||
vat_rate = self.get_vat_rate_for_country(buyer_country)
|
||||
return VATRegime.OSS, vat_rate, buyer_country
|
||||
else:
|
||||
vat_rate = self.get_vat_rate_for_country(seller_country)
|
||||
return VATRegime.ORIGIN, vat_rate, buyer_country
|
||||
|
||||
return VATRegime.EXEMPT, Decimal("0.00"), buyer_country
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Settings Management
|
||||
# =========================================================================
|
||||
|
||||
def get_settings(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorInvoiceSettings | None:
|
||||
"""Get vendor invoice settings."""
|
||||
return (
|
||||
db.query(VendorInvoiceSettings)
|
||||
.filter(VendorInvoiceSettings.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_settings_or_raise(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Get vendor invoice settings or raise exception."""
|
||||
settings = self.get_settings(db, vendor_id)
|
||||
if not settings:
|
||||
raise InvoiceSettingsNotFoundException(vendor_id)
|
||||
return settings
|
||||
|
||||
def create_settings(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: VendorInvoiceSettingsCreate,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Create vendor invoice settings."""
|
||||
existing = self.get_settings(db, vendor_id)
|
||||
if existing:
|
||||
raise ValidationException(
|
||||
"Invoice settings already exist for this vendor"
|
||||
)
|
||||
|
||||
settings = VendorInvoiceSettings(
|
||||
vendor_id=vendor_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
db.add(settings)
|
||||
db.flush()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Created invoice settings for vendor {vendor_id}")
|
||||
return settings
|
||||
|
||||
def update_settings(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: VendorInvoiceSettingsUpdate,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Update vendor invoice settings."""
|
||||
settings = self.get_settings_or_raise(db, vendor_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(settings, key, value)
|
||||
|
||||
settings.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Updated invoice settings for vendor {vendor_id}")
|
||||
return settings
|
||||
|
||||
def create_settings_from_vendor(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Create invoice settings from vendor/company info."""
|
||||
company = vendor.company
|
||||
|
||||
settings = VendorInvoiceSettings(
|
||||
vendor_id=vendor.id,
|
||||
company_name=company.legal_name if company else vendor.name,
|
||||
company_address=vendor.effective_business_address,
|
||||
company_city=None,
|
||||
company_postal_code=None,
|
||||
company_country="LU",
|
||||
vat_number=vendor.effective_tax_number,
|
||||
is_vat_registered=bool(vendor.effective_tax_number),
|
||||
)
|
||||
db.add(settings)
|
||||
db.flush()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}")
|
||||
return settings
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Number Generation
|
||||
# =========================================================================
|
||||
|
||||
def _get_next_invoice_number(
|
||||
self, db: Session, settings: VendorInvoiceSettings
|
||||
) -> str:
|
||||
"""Generate next invoice number and increment counter."""
|
||||
number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding)
|
||||
invoice_number = f"{settings.invoice_prefix}{number}"
|
||||
|
||||
settings.invoice_next_number += 1
|
||||
db.flush()
|
||||
|
||||
return invoice_number
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Creation
|
||||
# =========================================================================
|
||||
|
||||
def create_invoice_from_order(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
notes: str | None = None,
|
||||
) -> Invoice:
|
||||
"""Create an invoice from an order."""
|
||||
settings = self.get_settings_or_raise(db, vendor_id)
|
||||
|
||||
order = (
|
||||
db.query(Order)
|
||||
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
if not order:
|
||||
raise OrderNotFoundException(f"Order {order_id} not found")
|
||||
|
||||
existing = (
|
||||
db.query(Invoice)
|
||||
.filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValidationException(f"Invoice already exists for order {order_id}")
|
||||
|
||||
buyer_country = order.bill_country_iso
|
||||
vat_regime, vat_rate, destination_country = self.determine_vat_regime(
|
||||
seller_country=settings.company_country,
|
||||
buyer_country=buyer_country,
|
||||
buyer_vat_number=None,
|
||||
seller_oss_registered=settings.is_oss_registered,
|
||||
)
|
||||
|
||||
seller_details = {
|
||||
"company_name": settings.company_name,
|
||||
"address": settings.company_address,
|
||||
"city": settings.company_city,
|
||||
"postal_code": settings.company_postal_code,
|
||||
"country": settings.company_country,
|
||||
"vat_number": settings.vat_number,
|
||||
}
|
||||
|
||||
buyer_details = {
|
||||
"name": f"{order.bill_first_name} {order.bill_last_name}".strip(),
|
||||
"email": order.customer_email,
|
||||
"address": order.bill_address_line_1,
|
||||
"city": order.bill_city,
|
||||
"postal_code": order.bill_postal_code,
|
||||
"country": order.bill_country_iso,
|
||||
"vat_number": None,
|
||||
}
|
||||
if order.bill_company:
|
||||
buyer_details["company"] = order.bill_company
|
||||
|
||||
line_items = []
|
||||
for item in order.items:
|
||||
line_items.append({
|
||||
"description": item.product_name,
|
||||
"quantity": item.quantity,
|
||||
"unit_price_cents": item.unit_price_cents,
|
||||
"total_cents": item.total_price_cents,
|
||||
"sku": item.product_sku,
|
||||
"ean": item.gtin,
|
||||
})
|
||||
|
||||
subtotal_cents = sum(item["total_cents"] for item in line_items)
|
||||
|
||||
if vat_rate > 0:
|
||||
vat_amount_cents = int(subtotal_cents * float(vat_rate) / 100)
|
||||
else:
|
||||
vat_amount_cents = 0
|
||||
|
||||
total_cents = subtotal_cents + vat_amount_cents
|
||||
|
||||
vat_rate_label = None
|
||||
if vat_rate > 0:
|
||||
if destination_country:
|
||||
vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate)
|
||||
else:
|
||||
vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate)
|
||||
|
||||
invoice_number = self._get_next_invoice_number(db, settings)
|
||||
|
||||
invoice = Invoice(
|
||||
vendor_id=vendor_id,
|
||||
order_id=order_id,
|
||||
invoice_number=invoice_number,
|
||||
invoice_date=datetime.now(UTC),
|
||||
status=InvoiceStatus.DRAFT.value,
|
||||
seller_details=seller_details,
|
||||
buyer_details=buyer_details,
|
||||
line_items=line_items,
|
||||
vat_regime=vat_regime.value,
|
||||
destination_country=destination_country,
|
||||
vat_rate=vat_rate,
|
||||
vat_rate_label=vat_rate_label,
|
||||
currency=order.currency,
|
||||
subtotal_cents=subtotal_cents,
|
||||
vat_amount_cents=vat_amount_cents,
|
||||
total_cents=total_cents,
|
||||
payment_terms=settings.payment_terms,
|
||||
bank_details={
|
||||
"bank_name": settings.bank_name,
|
||||
"iban": settings.bank_iban,
|
||||
"bic": settings.bank_bic,
|
||||
} if settings.bank_iban else None,
|
||||
footer_text=settings.footer_text,
|
||||
notes=notes,
|
||||
)
|
||||
|
||||
db.add(invoice)
|
||||
db.flush()
|
||||
db.refresh(invoice)
|
||||
|
||||
logger.info(
|
||||
f"Created invoice {invoice_number} for order {order_id} "
|
||||
f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})"
|
||||
)
|
||||
|
||||
return invoice
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Retrieval
|
||||
# =========================================================================
|
||||
|
||||
def get_invoice(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by ID."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_invoice_or_raise(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Get invoice by ID or raise exception."""
|
||||
invoice = self.get_invoice(db, vendor_id, invoice_id)
|
||||
if not invoice:
|
||||
raise InvoiceNotFoundException(invoice_id)
|
||||
return invoice
|
||||
|
||||
def get_invoice_by_number(
|
||||
self, db: Session, vendor_id: int, invoice_number: str
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by invoice number."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.invoice_number == invoice_number,
|
||||
Invoice.vendor_id == vendor_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_invoice_by_order_id(
|
||||
self, db: Session, vendor_id: int, order_id: int
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by order ID."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.order_id == order_id,
|
||||
Invoice.vendor_id == vendor_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def list_invoices(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> tuple[list[Invoice], int]:
|
||||
"""List invoices for vendor with pagination."""
|
||||
query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Invoice.status == status)
|
||||
|
||||
total = query.count()
|
||||
|
||||
invoices = (
|
||||
query.order_by(Invoice.invoice_date.desc())
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return invoices, total
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Status Management
|
||||
# =========================================================================
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
new_status: str,
|
||||
) -> Invoice:
|
||||
"""Update invoice status."""
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
|
||||
valid_statuses = [s.value for s in InvoiceStatus]
|
||||
if new_status not in valid_statuses:
|
||||
raise ValidationException(f"Invalid status: {new_status}")
|
||||
|
||||
if invoice.status == InvoiceStatus.CANCELLED.value:
|
||||
raise ValidationException("Cannot change status of cancelled invoice")
|
||||
|
||||
invoice.status = new_status
|
||||
invoice.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(invoice)
|
||||
|
||||
logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}")
|
||||
return invoice
|
||||
|
||||
def mark_as_issued(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Mark invoice as issued."""
|
||||
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value)
|
||||
|
||||
def mark_as_paid(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Mark invoice as paid."""
|
||||
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value)
|
||||
|
||||
def cancel_invoice(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Cancel invoice."""
|
||||
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_invoice_stats(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> dict[str, Any]:
|
||||
"""Get invoice statistics for vendor."""
|
||||
total_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(Invoice.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
total_revenue = (
|
||||
db.query(func.sum(Invoice.total_cents))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status.in_([
|
||||
InvoiceStatus.ISSUED.value,
|
||||
InvoiceStatus.PAID.value,
|
||||
]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
draft_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status == InvoiceStatus.DRAFT.value,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
paid_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status == InvoiceStatus.PAID.value,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_invoices": total_count,
|
||||
"total_revenue_cents": total_revenue,
|
||||
"total_revenue": total_revenue / 100 if total_revenue else 0,
|
||||
"draft_count": draft_count,
|
||||
"paid_count": paid_count,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# PDF Generation
|
||||
# =========================================================================
|
||||
|
||||
def generate_pdf(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
force_regenerate: bool = False,
|
||||
) -> str:
|
||||
"""Generate PDF for an invoice."""
|
||||
from app.modules.orders.services.invoice_pdf_service import invoice_pdf_service
|
||||
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate)
|
||||
|
||||
def get_pdf_path(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
) -> str | None:
|
||||
"""Get PDF path for an invoice if it exists."""
|
||||
from app.modules.orders.services.invoice_pdf_service import invoice_pdf_service
|
||||
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
return invoice_pdf_service.get_pdf_path(invoice)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
invoice_service = InvoiceService()
|
||||
591
app/modules/orders/services/order_inventory_service.py
Normal file
591
app/modules/orders/services/order_inventory_service.py
Normal file
@@ -0,0 +1,591 @@
|
||||
# app/modules/orders/services/order_inventory_service.py
|
||||
"""
|
||||
Order-Inventory Integration Service.
|
||||
|
||||
This service orchestrates inventory operations for orders:
|
||||
- Reserve inventory when orders are confirmed
|
||||
- Fulfill (deduct) inventory when orders are shipped
|
||||
- Release reservations when orders are cancelled
|
||||
|
||||
All operations are logged to the inventory_transactions table for audit trail.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
InsufficientInventoryException,
|
||||
InventoryNotFoundException,
|
||||
OrderNotFoundException,
|
||||
ValidationException,
|
||||
)
|
||||
from app.modules.inventory.models.inventory import Inventory
|
||||
from app.modules.inventory.models.inventory_transaction import (
|
||||
InventoryTransaction,
|
||||
TransactionType,
|
||||
)
|
||||
from app.modules.inventory.schemas.inventory import InventoryReserve
|
||||
from app.modules.inventory.services.inventory_service import inventory_service
|
||||
from app.modules.orders.models.order import Order, OrderItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default location for inventory operations
|
||||
DEFAULT_LOCATION = "DEFAULT"
|
||||
|
||||
|
||||
class OrderInventoryService:
|
||||
"""
|
||||
Orchestrate order and inventory operations together.
|
||||
"""
|
||||
|
||||
def get_order_with_items(
|
||||
self, db: Session, vendor_id: int, order_id: int
|
||||
) -> Order:
|
||||
"""Get order with items or raise OrderNotFoundException."""
|
||||
order = (
|
||||
db.query(Order)
|
||||
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
if not order:
|
||||
raise OrderNotFoundException(f"Order {order_id} not found")
|
||||
return order
|
||||
|
||||
def _find_inventory_location(
|
||||
self, db: Session, product_id: int, vendor_id: int
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the location with available inventory for a product.
|
||||
"""
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
Inventory.quantity > Inventory.reserved_quantity,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return inventory.location if inventory else None
|
||||
|
||||
def _is_placeholder_product(self, order_item: OrderItem) -> bool:
|
||||
"""Check if the order item uses a placeholder product."""
|
||||
if not order_item.product:
|
||||
return True
|
||||
return order_item.product.gtin == "0000000000000"
|
||||
|
||||
def _log_transaction(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
product_id: int,
|
||||
inventory: Inventory,
|
||||
transaction_type: TransactionType,
|
||||
quantity_change: int,
|
||||
order: Order,
|
||||
reason: str | None = None,
|
||||
) -> InventoryTransaction:
|
||||
"""Create an inventory transaction record for audit trail."""
|
||||
transaction = InventoryTransaction.create_transaction(
|
||||
vendor_id=vendor_id,
|
||||
product_id=product_id,
|
||||
inventory_id=inventory.id if inventory else None,
|
||||
transaction_type=transaction_type,
|
||||
quantity_change=quantity_change,
|
||||
quantity_after=inventory.quantity if inventory else 0,
|
||||
reserved_after=inventory.reserved_quantity if inventory else 0,
|
||||
location=inventory.location if inventory else None,
|
||||
warehouse=inventory.warehouse if inventory else None,
|
||||
order_id=order.id,
|
||||
order_number=order.order_number,
|
||||
reason=reason,
|
||||
created_by="system",
|
||||
)
|
||||
db.add(transaction)
|
||||
return transaction
|
||||
|
||||
def reserve_for_order(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""Reserve inventory for all items in an order."""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
reserved_count = 0
|
||||
skipped_items = []
|
||||
|
||||
for item in order.items:
|
||||
if self._is_placeholder_product(item):
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"reason": "placeholder_product",
|
||||
})
|
||||
continue
|
||||
|
||||
location = self._find_inventory_location(db, item.product_id, vendor_id)
|
||||
|
||||
if not location:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "no_inventory",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=location,
|
||||
quantity=item.quantity,
|
||||
)
|
||||
updated_inventory = inventory_service.reserve_inventory(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
reserved_count += 1
|
||||
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.RESERVE,
|
||||
quantity_change=0,
|
||||
order=order,
|
||||
reason=f"Reserved for order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Reserved {item.quantity} units of product {item.product_id} "
|
||||
f"for order {order.order_number}"
|
||||
)
|
||||
except InsufficientInventoryException:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "insufficient_inventory",
|
||||
})
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Order {order.order_number}: reserved {reserved_count} items, "
|
||||
f"skipped {len(skipped_items)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"reserved_count": reserved_count,
|
||||
"skipped_items": skipped_items,
|
||||
}
|
||||
|
||||
def fulfill_order(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""Fulfill (deduct) inventory when an order is shipped."""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
fulfilled_count = 0
|
||||
skipped_items = []
|
||||
|
||||
for item in order.items:
|
||||
if item.is_fully_shipped:
|
||||
continue
|
||||
|
||||
if self._is_placeholder_product(item):
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"reason": "placeholder_product",
|
||||
})
|
||||
continue
|
||||
|
||||
quantity_to_fulfill = item.remaining_quantity
|
||||
|
||||
location = self._find_inventory_location(db, item.product_id, vendor_id)
|
||||
|
||||
if not location:
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == item.product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if inventory:
|
||||
location = inventory.location
|
||||
|
||||
if not location:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "no_inventory",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=location,
|
||||
quantity=quantity_to_fulfill,
|
||||
)
|
||||
updated_inventory = inventory_service.fulfill_reservation(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
fulfilled_count += 1
|
||||
|
||||
item.shipped_quantity = item.quantity
|
||||
item.inventory_fulfilled = True
|
||||
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.FULFILL,
|
||||
quantity_change=-quantity_to_fulfill,
|
||||
order=order,
|
||||
reason=f"Fulfilled for order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Fulfilled {quantity_to_fulfill} units of product {item.product_id} "
|
||||
f"for order {order.order_number}"
|
||||
)
|
||||
except (InsufficientInventoryException, InventoryNotFoundException) as e:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": str(e),
|
||||
})
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Order {order.order_number}: fulfilled {fulfilled_count} items, "
|
||||
f"skipped {len(skipped_items)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"fulfilled_count": fulfilled_count,
|
||||
"skipped_items": skipped_items,
|
||||
}
|
||||
|
||||
def fulfill_item(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
item_id: int,
|
||||
quantity: int | None = None,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""Fulfill (deduct) inventory for a specific order item."""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
item = None
|
||||
for order_item in order.items:
|
||||
if order_item.id == item_id:
|
||||
item = order_item
|
||||
break
|
||||
|
||||
if not item:
|
||||
raise ValidationException(f"Item {item_id} not found in order {order_id}")
|
||||
|
||||
if item.is_fully_shipped:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "Item already fully shipped",
|
||||
}
|
||||
|
||||
quantity_to_fulfill = quantity or item.remaining_quantity
|
||||
|
||||
if quantity_to_fulfill > item.remaining_quantity:
|
||||
raise ValidationException(
|
||||
f"Cannot ship {quantity_to_fulfill} units - only {item.remaining_quantity} remaining"
|
||||
)
|
||||
|
||||
if quantity_to_fulfill <= 0:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "Nothing to fulfill",
|
||||
}
|
||||
|
||||
if self._is_placeholder_product(item):
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "Placeholder product - skipped",
|
||||
}
|
||||
|
||||
location = self._find_inventory_location(db, item.product_id, vendor_id)
|
||||
|
||||
if not location:
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == item.product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if inventory:
|
||||
location = inventory.location
|
||||
|
||||
if not location:
|
||||
if skip_missing:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "No inventory found",
|
||||
}
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=location,
|
||||
quantity=quantity_to_fulfill,
|
||||
)
|
||||
updated_inventory = inventory_service.fulfill_reservation(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
|
||||
item.shipped_quantity += quantity_to_fulfill
|
||||
|
||||
if item.is_fully_shipped:
|
||||
item.inventory_fulfilled = True
|
||||
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.FULFILL,
|
||||
quantity_change=-quantity_to_fulfill,
|
||||
order=order,
|
||||
reason=f"Partial shipment for order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Fulfilled {quantity_to_fulfill} of {item.quantity} units "
|
||||
f"for item {item_id} in order {order.order_number}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": quantity_to_fulfill,
|
||||
"shipped_quantity": item.shipped_quantity,
|
||||
"remaining_quantity": item.remaining_quantity,
|
||||
"is_fully_shipped": item.is_fully_shipped,
|
||||
}
|
||||
|
||||
except (InsufficientInventoryException, InventoryNotFoundException) as e:
|
||||
if skip_missing:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": str(e),
|
||||
}
|
||||
else:
|
||||
raise
|
||||
|
||||
def release_order_reservation(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""Release reserved inventory when an order is cancelled."""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
released_count = 0
|
||||
skipped_items = []
|
||||
|
||||
for item in order.items:
|
||||
if self._is_placeholder_product(item):
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"reason": "placeholder_product",
|
||||
})
|
||||
continue
|
||||
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == item.product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not inventory:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "no_inventory",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=inventory.location,
|
||||
quantity=item.quantity,
|
||||
)
|
||||
updated_inventory = inventory_service.release_reservation(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
released_count += 1
|
||||
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.RELEASE,
|
||||
quantity_change=0,
|
||||
order=order,
|
||||
reason=f"Released for cancelled order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Released {item.quantity} units of product {item.product_id} "
|
||||
f"for cancelled order {order.order_number}"
|
||||
)
|
||||
except Exception as e:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": str(e),
|
||||
})
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Order {order.order_number}: released {released_count} items, "
|
||||
f"skipped {len(skipped_items)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"released_count": released_count,
|
||||
"skipped_items": skipped_items,
|
||||
}
|
||||
|
||||
def handle_status_change(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
old_status: str | None,
|
||||
new_status: str,
|
||||
) -> dict | None:
|
||||
"""Handle inventory operations based on order status changes."""
|
||||
if old_status == new_status:
|
||||
return None
|
||||
|
||||
result = None
|
||||
|
||||
if new_status == "processing":
|
||||
result = self.reserve_for_order(db, vendor_id, order_id, skip_missing=True)
|
||||
logger.info(f"Order {order_id} confirmed: inventory reserved")
|
||||
|
||||
elif new_status == "shipped":
|
||||
result = self.fulfill_order(db, vendor_id, order_id, skip_missing=True)
|
||||
logger.info(f"Order {order_id} shipped: inventory fulfilled")
|
||||
|
||||
elif new_status == "partially_shipped":
|
||||
logger.info(
|
||||
f"Order {order_id} partially shipped: use fulfill_item for item-level fulfillment"
|
||||
)
|
||||
result = {"order_id": order_id, "status": "partially_shipped"}
|
||||
|
||||
elif new_status == "cancelled":
|
||||
if old_status and old_status not in ("cancelled", "refunded"):
|
||||
result = self.release_order_reservation(
|
||||
db, vendor_id, order_id, skip_missing=True
|
||||
)
|
||||
logger.info(f"Order {order_id} cancelled: reservations released")
|
||||
|
||||
return result
|
||||
|
||||
def get_shipment_status(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
) -> dict:
|
||||
"""Get detailed shipment status for an order."""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
items = []
|
||||
for item in order.items:
|
||||
items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"product_name": item.product_name,
|
||||
"quantity": item.quantity,
|
||||
"shipped_quantity": item.shipped_quantity,
|
||||
"remaining_quantity": item.remaining_quantity,
|
||||
"is_fully_shipped": item.is_fully_shipped,
|
||||
"is_partially_shipped": item.is_partially_shipped,
|
||||
})
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"order_status": order.status,
|
||||
"is_fully_shipped": order.is_fully_shipped,
|
||||
"is_partially_shipped": order.is_partially_shipped,
|
||||
"shipped_item_count": order.shipped_item_count,
|
||||
"total_item_count": len(order.items),
|
||||
"total_shipped_units": order.total_shipped_units,
|
||||
"total_ordered_units": order.total_ordered_units,
|
||||
"items": items,
|
||||
}
|
||||
|
||||
|
||||
# Create service instance
|
||||
order_inventory_service = OrderInventoryService()
|
||||
466
app/modules/orders/services/order_item_exception_service.py
Normal file
466
app/modules/orders/services/order_item_exception_service.py
Normal file
@@ -0,0 +1,466 @@
|
||||
# app/modules/orders/services/order_item_exception_service.py
|
||||
"""
|
||||
Service for managing order item exceptions (unmatched products).
|
||||
|
||||
This service handles:
|
||||
- Creating exceptions when products are not found during order import
|
||||
- Resolving exceptions by assigning products
|
||||
- Auto-matching when new products are imported
|
||||
- Querying and statistics for exceptions
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.exceptions import (
|
||||
ExceptionAlreadyResolvedException,
|
||||
InvalidProductForExceptionException,
|
||||
OrderItemExceptionNotFoundException,
|
||||
ProductNotFoundException,
|
||||
)
|
||||
from app.modules.orders.models.order import Order, OrderItem
|
||||
from app.modules.orders.models.order_item_exception import OrderItemException
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrderItemExceptionService:
|
||||
"""Service for order item exception CRUD and resolution workflow."""
|
||||
|
||||
# =========================================================================
|
||||
# Exception Creation
|
||||
# =========================================================================
|
||||
|
||||
def create_exception(
|
||||
self,
|
||||
db: Session,
|
||||
order_item: OrderItem,
|
||||
vendor_id: int,
|
||||
original_gtin: str | None,
|
||||
original_product_name: str | None,
|
||||
original_sku: str | None,
|
||||
exception_type: str = "product_not_found",
|
||||
) -> OrderItemException:
|
||||
"""Create an exception record for an unmatched order item."""
|
||||
exception = OrderItemException(
|
||||
order_item_id=order_item.id,
|
||||
vendor_id=vendor_id,
|
||||
original_gtin=original_gtin,
|
||||
original_product_name=original_product_name,
|
||||
original_sku=original_sku,
|
||||
exception_type=exception_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(exception)
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Created order item exception {exception.id} for order item "
|
||||
f"{order_item.id}, GTIN: {original_gtin}"
|
||||
)
|
||||
|
||||
return exception
|
||||
|
||||
# =========================================================================
|
||||
# Exception Retrieval
|
||||
# =========================================================================
|
||||
|
||||
def get_exception_by_id(
|
||||
self,
|
||||
db: Session,
|
||||
exception_id: int,
|
||||
vendor_id: int | None = None,
|
||||
) -> OrderItemException:
|
||||
"""Get an exception by ID, optionally filtered by vendor."""
|
||||
query = db.query(OrderItemException).filter(
|
||||
OrderItemException.id == exception_id
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(OrderItemException.vendor_id == vendor_id)
|
||||
|
||||
exception = query.first()
|
||||
|
||||
if not exception:
|
||||
raise OrderItemExceptionNotFoundException(exception_id)
|
||||
|
||||
return exception
|
||||
|
||||
def get_pending_exceptions(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[OrderItemException], int]:
|
||||
"""Get exceptions with pagination and filtering."""
|
||||
query = (
|
||||
db.query(OrderItemException)
|
||||
.join(OrderItem)
|
||||
.join(Order)
|
||||
.options(
|
||||
joinedload(OrderItemException.order_item).joinedload(OrderItem.order)
|
||||
)
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(OrderItemException.vendor_id == vendor_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(OrderItemException.status == status)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
OrderItemException.original_gtin.ilike(search_pattern),
|
||||
OrderItemException.original_product_name.ilike(search_pattern),
|
||||
OrderItemException.original_sku.ilike(search_pattern),
|
||||
Order.order_number.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
|
||||
exceptions = (
|
||||
query.order_by(OrderItemException.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return exceptions, total
|
||||
|
||||
def get_exceptions_for_order(
|
||||
self,
|
||||
db: Session,
|
||||
order_id: int,
|
||||
) -> list[OrderItemException]:
|
||||
"""Get all exceptions for items in an order."""
|
||||
return (
|
||||
db.query(OrderItemException)
|
||||
.join(OrderItem)
|
||||
.filter(OrderItem.order_id == order_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Exception Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_exception_stats(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""Get exception counts by status."""
|
||||
query = db.query(
|
||||
OrderItemException.status,
|
||||
func.count(OrderItemException.id).label("count"),
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(OrderItemException.vendor_id == vendor_id)
|
||||
|
||||
results = query.group_by(OrderItemException.status).all()
|
||||
|
||||
stats = {
|
||||
"pending": 0,
|
||||
"resolved": 0,
|
||||
"ignored": 0,
|
||||
"total": 0,
|
||||
}
|
||||
|
||||
for status, count in results:
|
||||
if status in stats:
|
||||
stats[status] = count
|
||||
stats["total"] += count
|
||||
|
||||
orders_query = (
|
||||
db.query(func.count(func.distinct(OrderItem.order_id)))
|
||||
.join(OrderItemException)
|
||||
.filter(OrderItemException.status == "pending")
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
orders_query = orders_query.filter(
|
||||
OrderItemException.vendor_id == vendor_id
|
||||
)
|
||||
|
||||
stats["orders_with_exceptions"] = orders_query.scalar() or 0
|
||||
|
||||
return stats
|
||||
|
||||
# =========================================================================
|
||||
# Exception Resolution
|
||||
# =========================================================================
|
||||
|
||||
def resolve_exception(
|
||||
self,
|
||||
db: Session,
|
||||
exception_id: int,
|
||||
product_id: int,
|
||||
resolved_by: int,
|
||||
notes: str | None = None,
|
||||
vendor_id: int | None = None,
|
||||
) -> OrderItemException:
|
||||
"""Resolve an exception by assigning a product."""
|
||||
exception = self.get_exception_by_id(db, exception_id, vendor_id)
|
||||
|
||||
if exception.status == "resolved":
|
||||
raise ExceptionAlreadyResolvedException(exception_id)
|
||||
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
raise ProductNotFoundException(product_id)
|
||||
|
||||
if product.vendor_id != exception.vendor_id:
|
||||
raise InvalidProductForExceptionException(
|
||||
product_id, "Product belongs to a different vendor"
|
||||
)
|
||||
|
||||
if not product.is_active:
|
||||
raise InvalidProductForExceptionException(
|
||||
product_id, "Product is not active"
|
||||
)
|
||||
|
||||
exception.status = "resolved"
|
||||
exception.resolved_product_id = product_id
|
||||
exception.resolved_at = datetime.now(UTC)
|
||||
exception.resolved_by = resolved_by
|
||||
exception.resolution_notes = notes
|
||||
|
||||
order_item = exception.order_item
|
||||
order_item.product_id = product_id
|
||||
order_item.needs_product_match = False
|
||||
|
||||
if product.marketplace_product:
|
||||
order_item.product_name = product.marketplace_product.get_title("en")
|
||||
order_item.product_sku = product.vendor_sku or order_item.product_sku
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Resolved exception {exception_id} with product {product_id} "
|
||||
f"by user {resolved_by}"
|
||||
)
|
||||
|
||||
return exception
|
||||
|
||||
def ignore_exception(
|
||||
self,
|
||||
db: Session,
|
||||
exception_id: int,
|
||||
resolved_by: int,
|
||||
notes: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> OrderItemException:
|
||||
"""Mark an exception as ignored."""
|
||||
exception = self.get_exception_by_id(db, exception_id, vendor_id)
|
||||
|
||||
if exception.status == "resolved":
|
||||
raise ExceptionAlreadyResolvedException(exception_id)
|
||||
|
||||
exception.status = "ignored"
|
||||
exception.resolved_at = datetime.now(UTC)
|
||||
exception.resolved_by = resolved_by
|
||||
exception.resolution_notes = notes
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Ignored exception {exception_id} by user {resolved_by}: {notes}"
|
||||
)
|
||||
|
||||
return exception
|
||||
|
||||
# =========================================================================
|
||||
# Auto-Matching
|
||||
# =========================================================================
|
||||
|
||||
def auto_match_by_gtin(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
gtin: str,
|
||||
product_id: int,
|
||||
) -> list[OrderItemException]:
|
||||
"""Auto-resolve pending exceptions matching a GTIN."""
|
||||
if not gtin:
|
||||
return []
|
||||
|
||||
pending = (
|
||||
db.query(OrderItemException)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItemException.vendor_id == vendor_id,
|
||||
OrderItemException.original_gtin == gtin,
|
||||
OrderItemException.status == "pending",
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not pending:
|
||||
return []
|
||||
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
logger.warning(f"Product {product_id} not found for auto-match")
|
||||
return []
|
||||
|
||||
resolved = []
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for exception in pending:
|
||||
exception.status = "resolved"
|
||||
exception.resolved_product_id = product_id
|
||||
exception.resolved_at = now
|
||||
exception.resolution_notes = "Auto-matched during product import"
|
||||
|
||||
order_item = exception.order_item
|
||||
order_item.product_id = product_id
|
||||
order_item.needs_product_match = False
|
||||
if product.marketplace_product:
|
||||
order_item.product_name = product.marketplace_product.get_title("en")
|
||||
|
||||
resolved.append(exception)
|
||||
|
||||
if resolved:
|
||||
db.flush()
|
||||
logger.info(
|
||||
f"Auto-matched {len(resolved)} exceptions for GTIN {gtin} "
|
||||
f"with product {product_id}"
|
||||
)
|
||||
|
||||
return resolved
|
||||
|
||||
def auto_match_batch(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
gtin_to_product: dict[str, int],
|
||||
) -> int:
|
||||
"""Batch auto-match multiple GTINs after bulk import."""
|
||||
if not gtin_to_product:
|
||||
return 0
|
||||
|
||||
total_resolved = 0
|
||||
|
||||
for gtin, product_id in gtin_to_product.items():
|
||||
resolved = self.auto_match_by_gtin(db, vendor_id, gtin, product_id)
|
||||
total_resolved += len(resolved)
|
||||
|
||||
return total_resolved
|
||||
|
||||
# =========================================================================
|
||||
# Confirmation Checks
|
||||
# =========================================================================
|
||||
|
||||
def order_has_unresolved_exceptions(
|
||||
self,
|
||||
db: Session,
|
||||
order_id: int,
|
||||
) -> bool:
|
||||
"""Check if order has any unresolved exceptions."""
|
||||
count = (
|
||||
db.query(func.count(OrderItemException.id))
|
||||
.join(OrderItem)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItem.order_id == order_id,
|
||||
OrderItemException.status.in_(["pending", "ignored"]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
return count > 0
|
||||
|
||||
def get_unresolved_exception_count(
|
||||
self,
|
||||
db: Session,
|
||||
order_id: int,
|
||||
) -> int:
|
||||
"""Get count of unresolved exceptions for an order."""
|
||||
return (
|
||||
db.query(func.count(OrderItemException.id))
|
||||
.join(OrderItem)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItem.order_id == order_id,
|
||||
OrderItemException.status.in_(["pending", "ignored"]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# =========================================================================
|
||||
# Bulk Operations
|
||||
# =========================================================================
|
||||
|
||||
def bulk_resolve_by_gtin(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
gtin: str,
|
||||
product_id: int,
|
||||
resolved_by: int,
|
||||
notes: str | None = None,
|
||||
) -> int:
|
||||
"""Bulk resolve all pending exceptions for a GTIN."""
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
raise ProductNotFoundException(product_id)
|
||||
|
||||
if product.vendor_id != vendor_id:
|
||||
raise InvalidProductForExceptionException(
|
||||
product_id, "Product belongs to a different vendor"
|
||||
)
|
||||
|
||||
pending = (
|
||||
db.query(OrderItemException)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItemException.vendor_id == vendor_id,
|
||||
OrderItemException.original_gtin == gtin,
|
||||
OrderItemException.status == "pending",
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
resolution_notes = notes or f"Bulk resolved for GTIN {gtin}"
|
||||
|
||||
for exception in pending:
|
||||
exception.status = "resolved"
|
||||
exception.resolved_product_id = product_id
|
||||
exception.resolved_at = now
|
||||
exception.resolved_by = resolved_by
|
||||
exception.resolution_notes = resolution_notes
|
||||
|
||||
order_item = exception.order_item
|
||||
order_item.product_id = product_id
|
||||
order_item.needs_product_match = False
|
||||
if product.marketplace_product:
|
||||
order_item.product_name = product.marketplace_product.get_title("en")
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Bulk resolved {len(pending)} exceptions for GTIN {gtin} "
|
||||
f"with product {product_id} by user {resolved_by}"
|
||||
)
|
||||
|
||||
return len(pending)
|
||||
|
||||
|
||||
# Global service instance
|
||||
order_item_exception_service = OrderItemExceptionService()
|
||||
1325
app/modules/orders/services/order_service.py
Normal file
1325
app/modules/orders/services/order_service.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user