From de83875d0ad29162ef11ff4e6a21139fdb55cc51 Mon Sep 17 00:00:00 2001 From: Samir Boulahtit Date: Thu, 29 Jan 2026 21:28:56 +0100 Subject: [PATCH] refactor: migrate modules from re-exports to canonical implementations Move actual code implementations into module directories: - orders: 5 services, 4 models, order/invoice schemas - inventory: 3 services, 2 models, 30+ schemas - customers: 3 services, 2 models, customer schemas - messaging: 3 services, 2 models, message/notification schemas - monitoring: background_tasks_service - marketplace: 5+ services including letzshop submodule - dev_tools: code_quality_service, test_runner_service - billing: billing_service - contracts: definition.py Legacy files in app/services/, models/database/, models/schema/ now re-export from canonical module locations for backwards compatibility. Architecture validator passes with 0 errors. Co-Authored-By: Claude Opus 4.5 --- app/modules/analytics/routes/pages/vendor.py | 2 +- app/modules/billing/routes/api/admin.py | 3 +- app/modules/billing/routes/api/vendor.py | 3 +- app/modules/billing/services/__init__.py | 18 + .../billing/services/billing_service.py | 588 +++++++ app/modules/cms/routes/api/vendor.py | 2 +- app/modules/cms/routes/pages/vendor.py | 2 +- app/modules/contracts/definition.py | 24 + app/modules/customers/models/__init__.py | 15 +- app/modules/customers/models/customer.py | 93 + .../customers/models/password_reset_token.py | 91 + app/modules/customers/schemas/__init__.py | 49 +- app/modules/customers/schemas/customer.py | 340 ++++ app/modules/customers/services/__init__.py | 15 +- .../services/admin_customer_service.py | 242 +++ .../services/customer_address_service.py | 314 ++++ .../customers/services/customer_service.py | 659 +++++++ app/modules/dev_tools/services/__init__.py | 8 +- .../services/code_quality_service.py | 820 +++++++++ .../dev_tools/services/test_runner_service.py | 507 ++++++ app/modules/inventory/models/__init__.py | 6 +- app/modules/inventory/models/inventory.py | 61 + .../inventory/models/inventory_transaction.py | 170 ++ app/modules/inventory/schemas/__init__.py | 58 +- app/modules/inventory/schemas/inventory.py | 294 ++++ app/modules/inventory/services/__init__.py | 10 +- .../services/inventory_import_service.py | 250 +++ .../inventory/services/inventory_service.py | 949 ++++++++++ .../services/inventory_transaction_service.py | 431 +++++ app/modules/marketplace/services/__init__.py | 38 +- .../marketplace/services/letzshop/__init__.py | 53 + .../services/letzshop/client_service.py | 1015 +++++++++++ .../services/letzshop/credentials_service.py | 400 +++++ .../services/letzshop/order_service.py | 1136 ++++++++++++ .../services/letzshop/vendor_sync_service.py | 521 ++++++ .../services/letzshop_export_service.py | 338 ++++ .../marketplace_import_job_service.py | 334 ++++ .../services/marketplace_product_service.py | 1075 ++++++++++++ app/modules/messaging/models/__init__.py | 6 +- .../messaging/models/admin_notification.py | 54 + app/modules/messaging/models/message.py | 272 +++ app/modules/messaging/schemas/__init__.py | 101 +- app/modules/messaging/schemas/message.py | 312 ++++ app/modules/messaging/schemas/notification.py | 152 ++ app/modules/messaging/services/__init__.py | 24 +- .../services/admin_notification_service.py | 702 ++++++++ .../services/message_attachment_service.py | 225 +++ .../messaging/services/messaging_service.py | 684 ++++++++ app/modules/monitoring/services/__init__.py | 4 +- .../services/background_tasks_service.py | 194 +++ app/modules/orders/models/__init__.py | 11 +- app/modules/orders/models/invoice.py | 215 +++ app/modules/orders/models/order.py | 406 +++++ .../orders/models/order_item_exception.py | 117 ++ app/modules/orders/schemas/__init__.py | 125 +- app/modules/orders/schemas/invoice.py | 316 ++++ app/modules/orders/schemas/order.py | 584 +++++++ app/modules/orders/services/__init__.py | 12 +- .../orders/services/invoice_pdf_service.py | 150 ++ .../orders/services/invoice_service.py | 587 +++++++ .../services/order_inventory_service.py | 591 +++++++ .../services/order_item_exception_service.py | 466 +++++ app/modules/orders/services/order_service.py | 1325 ++++++++++++++ app/services/admin_customer_service.py | 253 +-- app/services/admin_notification_service.py | 728 +------- app/services/background_tasks_service.py | 207 +-- app/services/billing_service.py | 609 +------ app/services/code_quality_service.py | 845 +-------- app/services/customer_address_service.py | 323 +--- app/services/customer_service.py | 673 +------ app/services/inventory_import_service.py | 259 +-- app/services/inventory_service.py | 962 +--------- app/services/inventory_transaction_service.py | 442 +---- app/services/invoice_pdf_service.py | 175 +- app/services/invoice_service.py | 682 +------- app/services/letzshop/__init__.py | 34 +- app/services/letzshop_export_service.py | 349 +--- .../marketplace_import_job_service.py | 347 +--- app/services/marketplace_product_service.py | 1082 +----------- app/services/message_attachment_service.py | 236 +-- app/services/messaging_service.py | 693 +------- app/services/order_inventory_service.py | 740 +------- app/services/order_item_exception_service.py | 641 +------ app/services/order_service.py | 1541 +---------------- app/services/test_runner_service.py | 524 +----- models/database/customer.py | 101 +- models/database/inventory.py | 66 +- models/database/inventory_transaction.py | 177 +- models/database/invoice.py | 228 +-- models/database/message.py | 287 +-- models/database/order.py | 410 +---- models/database/order_item_exception.py | 124 +- models/database/password_reset_token.py | 94 +- models/schema/customer.py | 394 +---- models/schema/inventory.py | 377 +--- models/schema/invoice.py | 361 +--- models/schema/message.py | 381 +--- models/schema/notification.py | 193 +-- models/schema/order.py | 663 +------ 99 files changed, 19413 insertions(+), 15357 deletions(-) create mode 100644 app/modules/billing/services/billing_service.py create mode 100644 app/modules/contracts/definition.py create mode 100644 app/modules/customers/models/customer.py create mode 100644 app/modules/customers/models/password_reset_token.py create mode 100644 app/modules/customers/schemas/customer.py create mode 100644 app/modules/customers/services/admin_customer_service.py create mode 100644 app/modules/customers/services/customer_address_service.py create mode 100644 app/modules/customers/services/customer_service.py create mode 100644 app/modules/dev_tools/services/code_quality_service.py create mode 100644 app/modules/dev_tools/services/test_runner_service.py create mode 100644 app/modules/inventory/models/inventory.py create mode 100644 app/modules/inventory/models/inventory_transaction.py create mode 100644 app/modules/inventory/schemas/inventory.py create mode 100644 app/modules/inventory/services/inventory_import_service.py create mode 100644 app/modules/inventory/services/inventory_service.py create mode 100644 app/modules/inventory/services/inventory_transaction_service.py create mode 100644 app/modules/marketplace/services/letzshop/__init__.py create mode 100644 app/modules/marketplace/services/letzshop/client_service.py create mode 100644 app/modules/marketplace/services/letzshop/credentials_service.py create mode 100644 app/modules/marketplace/services/letzshop/order_service.py create mode 100644 app/modules/marketplace/services/letzshop/vendor_sync_service.py create mode 100644 app/modules/marketplace/services/letzshop_export_service.py create mode 100644 app/modules/marketplace/services/marketplace_import_job_service.py create mode 100644 app/modules/marketplace/services/marketplace_product_service.py create mode 100644 app/modules/messaging/models/admin_notification.py create mode 100644 app/modules/messaging/models/message.py create mode 100644 app/modules/messaging/schemas/message.py create mode 100644 app/modules/messaging/schemas/notification.py create mode 100644 app/modules/messaging/services/admin_notification_service.py create mode 100644 app/modules/messaging/services/message_attachment_service.py create mode 100644 app/modules/messaging/services/messaging_service.py create mode 100644 app/modules/monitoring/services/background_tasks_service.py create mode 100644 app/modules/orders/models/invoice.py create mode 100644 app/modules/orders/models/order.py create mode 100644 app/modules/orders/models/order_item_exception.py create mode 100644 app/modules/orders/schemas/invoice.py create mode 100644 app/modules/orders/schemas/order.py create mode 100644 app/modules/orders/services/invoice_pdf_service.py create mode 100644 app/modules/orders/services/invoice_service.py create mode 100644 app/modules/orders/services/order_inventory_service.py create mode 100644 app/modules/orders/services/order_item_exception_service.py create mode 100644 app/modules/orders/services/order_service.py diff --git a/app/modules/analytics/routes/pages/vendor.py b/app/modules/analytics/routes/pages/vendor.py index f7038628..c18410e9 100644 --- a/app/modules/analytics/routes/pages/vendor.py +++ b/app/modules/analytics/routes/pages/vendor.py @@ -12,7 +12,7 @@ from fastapi.responses import HTMLResponse from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_from_cookie_or_header, get_db -from app.services.platform_settings_service import platform_settings_service +from app.services.platform_settings_service import platform_settings_service # noqa: MOD-004 - shared platform service from app.templates_config import templates from models.database.user import User from models.database.vendor import Vendor diff --git a/app/modules/billing/routes/api/admin.py b/app/modules/billing/routes/api/admin.py index 9d18716f..a76fdc5b 100644 --- a/app/modules/billing/routes/api/admin.py +++ b/app/modules/billing/routes/api/admin.py @@ -17,8 +17,7 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_admin_api, require_module_access from app.core.database import get_db -from app.services.admin_subscription_service import admin_subscription_service -from app.services.subscription_service import subscription_service +from app.modules.billing.services import admin_subscription_service, subscription_service from models.database.user import User from models.schema.billing import ( BillingHistoryListResponse, diff --git a/app/modules/billing/routes/api/vendor.py b/app/modules/billing/routes/api/vendor.py index 0ab4405b..6c8d662c 100644 --- a/app/modules/billing/routes/api/vendor.py +++ b/app/modules/billing/routes/api/vendor.py @@ -19,8 +19,7 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_api, require_module_access from app.core.config import settings from app.core.database import get_db -from app.services.billing_service import billing_service -from app.services.subscription_service import subscription_service +from app.modules.billing.services import billing_service, subscription_service from models.database.user import User logger = logging.getLogger(__name__) diff --git a/app/modules/billing/services/__init__.py b/app/modules/billing/services/__init__.py index fbee9824..2ae7b2e0 100644 --- a/app/modules/billing/services/__init__.py +++ b/app/modules/billing/services/__init__.py @@ -17,6 +17,16 @@ from app.modules.billing.services.admin_subscription_service import ( AdminSubscriptionService, admin_subscription_service, ) +from app.modules.billing.services.billing_service import ( + BillingService, + billing_service, + BillingServiceError, + PaymentSystemNotConfiguredError, + TierNotFoundError, + StripePriceNotConfiguredError, + NoActiveSubscriptionError, + SubscriptionNotCancelledError, +) __all__ = [ "SubscriptionService", @@ -25,4 +35,12 @@ __all__ = [ "stripe_service", "AdminSubscriptionService", "admin_subscription_service", + "BillingService", + "billing_service", + "BillingServiceError", + "PaymentSystemNotConfiguredError", + "TierNotFoundError", + "StripePriceNotConfiguredError", + "NoActiveSubscriptionError", + "SubscriptionNotCancelledError", ] diff --git a/app/modules/billing/services/billing_service.py b/app/modules/billing/services/billing_service.py new file mode 100644 index 00000000..3dde95b2 --- /dev/null +++ b/app/modules/billing/services/billing_service.py @@ -0,0 +1,588 @@ +# app/modules/billing/services/billing_service.py +""" +Billing service for subscription and payment operations. + +Provides: +- Subscription status and usage queries +- Tier management +- Invoice history +- Add-on management +""" + +import logging +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.modules.billing.services.stripe_service import stripe_service +from app.modules.billing.services.subscription_service import subscription_service +from app.modules.billing.models import ( + AddOnProduct, + BillingHistory, + SubscriptionTier, + VendorAddOn, + VendorSubscription, +) +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class BillingServiceError(Exception): + """Base exception for billing service errors.""" + + pass + + +class PaymentSystemNotConfiguredError(BillingServiceError): + """Raised when Stripe is not configured.""" + + def __init__(self): + super().__init__("Payment system not configured") + + +class TierNotFoundError(BillingServiceError): + """Raised when a tier is not found.""" + + def __init__(self, tier_code: str): + self.tier_code = tier_code + super().__init__(f"Tier '{tier_code}' not found") + + +class StripePriceNotConfiguredError(BillingServiceError): + """Raised when Stripe price is not configured for a tier.""" + + def __init__(self, tier_code: str): + self.tier_code = tier_code + super().__init__(f"Stripe price not configured for tier '{tier_code}'") + + +class NoActiveSubscriptionError(BillingServiceError): + """Raised when no active subscription exists.""" + + def __init__(self): + super().__init__("No active subscription found") + + +class SubscriptionNotCancelledError(BillingServiceError): + """Raised when trying to reactivate a non-cancelled subscription.""" + + def __init__(self): + super().__init__("Subscription is not cancelled") + + +class BillingService: + """Service for billing operations.""" + + def get_subscription_with_tier( + self, db: Session, vendor_id: int + ) -> tuple[VendorSubscription, SubscriptionTier | None]: + """ + Get subscription and its tier info. + + Returns: + Tuple of (subscription, tier) where tier may be None + """ + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + + tier = ( + db.query(SubscriptionTier) + .filter(SubscriptionTier.code == subscription.tier) + .first() + ) + + return subscription, tier + + def get_available_tiers( + self, db: Session, current_tier: str + ) -> tuple[list[dict], dict[str, int]]: + """ + Get all available tiers with upgrade/downgrade flags. + + Returns: + Tuple of (tier_list, tier_order_map) + """ + tiers = ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.is_active == True, # noqa: E712 + SubscriptionTier.is_public == True, # noqa: E712 + ) + .order_by(SubscriptionTier.display_order) + .all() + ) + + tier_order = {t.code: t.display_order for t in tiers} + current_order = tier_order.get(current_tier, 0) + + tier_list = [] + for tier in tiers: + tier_list.append({ + "code": tier.code, + "name": tier.name, + "description": tier.description, + "price_monthly_cents": tier.price_monthly_cents, + "price_annual_cents": tier.price_annual_cents, + "orders_per_month": tier.orders_per_month, + "products_limit": tier.products_limit, + "team_members": tier.team_members, + "features": tier.features or [], + "is_current": tier.code == current_tier, + "can_upgrade": tier.display_order > current_order, + "can_downgrade": tier.display_order < current_order, + }) + + return tier_list, tier_order + + def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: + """ + Get a tier by its code. + + Raises: + TierNotFoundError: If tier doesn't exist + """ + tier = ( + db.query(SubscriptionTier) + .filter( + SubscriptionTier.code == tier_code, + SubscriptionTier.is_active == True, # noqa: E712 + ) + .first() + ) + + if not tier: + raise TierNotFoundError(tier_code) + + return tier + + def get_vendor(self, db: Session, vendor_id: int) -> Vendor: + """ + Get vendor by ID. + + Raises: + VendorNotFoundException from app.exceptions + """ + from app.exceptions import VendorNotFoundException + + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + if not vendor: + raise VendorNotFoundException(str(vendor_id), identifier_type="id") + + return vendor + + def create_checkout_session( + self, + db: Session, + vendor_id: int, + tier_code: str, + is_annual: bool, + success_url: str, + cancel_url: str, + ) -> dict: + """ + Create a Stripe checkout session. + + Returns: + Dict with checkout_url and session_id + + Raises: + PaymentSystemNotConfiguredError: If Stripe not configured + TierNotFoundError: If tier doesn't exist + StripePriceNotConfiguredError: If price not configured + """ + if not stripe_service.is_configured: + raise PaymentSystemNotConfiguredError() + + vendor = self.get_vendor(db, vendor_id) + tier = self.get_tier_by_code(db, tier_code) + + price_id = ( + tier.stripe_price_annual_id + if is_annual and tier.stripe_price_annual_id + else tier.stripe_price_monthly_id + ) + + if not price_id: + raise StripePriceNotConfiguredError(tier_code) + + # Check if this is a new subscription (for trial) + existing_sub = subscription_service.get_subscription(db, vendor_id) + trial_days = None + if not existing_sub or not existing_sub.stripe_subscription_id: + from app.core.config import settings + trial_days = settings.stripe_trial_days + + session = stripe_service.create_checkout_session( + db=db, + vendor=vendor, + price_id=price_id, + success_url=success_url, + cancel_url=cancel_url, + trial_days=trial_days, + ) + + # Update subscription with tier info + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + subscription.tier = tier_code + subscription.is_annual = is_annual + + return { + "checkout_url": session.url, + "session_id": session.id, + } + + def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict: + """ + Create a Stripe customer portal session. + + Returns: + Dict with portal_url + + Raises: + PaymentSystemNotConfiguredError: If Stripe not configured + NoActiveSubscriptionError: If no subscription with customer ID + """ + if not stripe_service.is_configured: + raise PaymentSystemNotConfiguredError() + + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_customer_id: + raise NoActiveSubscriptionError() + + session = stripe_service.create_portal_session( + customer_id=subscription.stripe_customer_id, + return_url=return_url, + ) + + return {"portal_url": session.url} + + def get_invoices( + self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20 + ) -> tuple[list[BillingHistory], int]: + """ + Get invoice history for a vendor. + + Returns: + Tuple of (invoices, total_count) + """ + query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id) + + total = query.count() + + invoices = ( + query.order_by(BillingHistory.invoice_date.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + return invoices, total + + def get_available_addons( + self, db: Session, category: str | None = None + ) -> list[AddOnProduct]: + """Get available add-on products.""" + query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712 + + if category: + query = query.filter(AddOnProduct.category == category) + + return query.order_by(AddOnProduct.display_order).all() + + def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]: + """Get vendor's purchased add-ons.""" + return ( + db.query(VendorAddOn) + .filter(VendorAddOn.vendor_id == vendor_id) + .all() + ) + + def cancel_subscription( + self, db: Session, vendor_id: int, reason: str | None, immediately: bool + ) -> dict: + """ + Cancel a subscription. + + Returns: + Dict with message and effective_date + + Raises: + NoActiveSubscriptionError: If no subscription to cancel + """ + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_subscription_id: + raise NoActiveSubscriptionError() + + if stripe_service.is_configured: + stripe_service.cancel_subscription( + subscription_id=subscription.stripe_subscription_id, + immediately=immediately, + cancellation_reason=reason, + ) + + subscription.cancelled_at = datetime.utcnow() + subscription.cancellation_reason = reason + + effective_date = ( + datetime.utcnow().isoformat() + if immediately + else subscription.period_end.isoformat() + if subscription.period_end + else datetime.utcnow().isoformat() + ) + + return { + "message": "Subscription cancelled successfully", + "effective_date": effective_date, + } + + def reactivate_subscription(self, db: Session, vendor_id: int) -> dict: + """ + Reactivate a cancelled subscription. + + Returns: + Dict with success message + + Raises: + NoActiveSubscriptionError: If no subscription + SubscriptionNotCancelledError: If not cancelled + """ + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_subscription_id: + raise NoActiveSubscriptionError() + + if not subscription.cancelled_at: + raise SubscriptionNotCancelledError() + + if stripe_service.is_configured: + stripe_service.reactivate_subscription(subscription.stripe_subscription_id) + + subscription.cancelled_at = None + subscription.cancellation_reason = None + + return {"message": "Subscription reactivated successfully"} + + def get_upcoming_invoice(self, db: Session, vendor_id: int) -> dict: + """ + Get upcoming invoice preview. + + Returns: + Dict with amount_due_cents, currency, next_payment_date, line_items + + Raises: + NoActiveSubscriptionError: If no subscription with customer ID + """ + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_customer_id: + raise NoActiveSubscriptionError() + + if not stripe_service.is_configured: + # Return empty preview if Stripe not configured + return { + "amount_due_cents": 0, + "currency": "EUR", + "next_payment_date": None, + "line_items": [], + } + + invoice = stripe_service.get_upcoming_invoice(subscription.stripe_customer_id) + + if not invoice: + return { + "amount_due_cents": 0, + "currency": "EUR", + "next_payment_date": None, + "line_items": [], + } + + line_items = [] + if invoice.lines and invoice.lines.data: + for line in invoice.lines.data: + line_items.append({ + "description": line.description or "", + "amount_cents": line.amount, + "quantity": line.quantity or 1, + }) + + return { + "amount_due_cents": invoice.amount_due, + "currency": invoice.currency.upper(), + "next_payment_date": datetime.fromtimestamp(invoice.next_payment_attempt).isoformat() + if invoice.next_payment_attempt + else None, + "line_items": line_items, + } + + def change_tier( + self, + db: Session, + vendor_id: int, + new_tier_code: str, + is_annual: bool, + ) -> dict: + """ + Change subscription tier (upgrade/downgrade). + + Returns: + Dict with message, new_tier, effective_immediately + + Raises: + TierNotFoundError: If tier doesn't exist + NoActiveSubscriptionError: If no subscription + StripePriceNotConfiguredError: If price not configured + """ + subscription = subscription_service.get_subscription(db, vendor_id) + + if not subscription or not subscription.stripe_subscription_id: + raise NoActiveSubscriptionError() + + tier = self.get_tier_by_code(db, new_tier_code) + + price_id = ( + tier.stripe_price_annual_id + if is_annual and tier.stripe_price_annual_id + else tier.stripe_price_monthly_id + ) + + if not price_id: + raise StripePriceNotConfiguredError(new_tier_code) + + # Update in Stripe + if stripe_service.is_configured: + stripe_service.update_subscription( + subscription_id=subscription.stripe_subscription_id, + new_price_id=price_id, + ) + + # Update local subscription + old_tier = subscription.tier + subscription.tier = new_tier_code + subscription.tier_id = tier.id + subscription.is_annual = is_annual + subscription.updated_at = datetime.utcnow() + + is_upgrade = self._is_upgrade(db, old_tier, new_tier_code) + + return { + "message": f"Subscription {'upgraded' if is_upgrade else 'changed'} to {tier.name}", + "new_tier": new_tier_code, + "effective_immediately": True, + } + + def _is_upgrade(self, db: Session, old_tier: str, new_tier: str) -> bool: + """Check if tier change is an upgrade.""" + old = db.query(SubscriptionTier).filter(SubscriptionTier.code == old_tier).first() + new = db.query(SubscriptionTier).filter(SubscriptionTier.code == new_tier).first() + + if not old or not new: + return False + + return new.display_order > old.display_order + + def purchase_addon( + self, + db: Session, + vendor_id: int, + addon_code: str, + domain_name: str | None, + quantity: int, + success_url: str, + cancel_url: str, + ) -> dict: + """ + Create checkout session for add-on purchase. + + Returns: + Dict with checkout_url and session_id + + Raises: + PaymentSystemNotConfiguredError: If Stripe not configured + AddonNotFoundError: If addon doesn't exist + """ + if not stripe_service.is_configured: + raise PaymentSystemNotConfiguredError() + + addon = ( + db.query(AddOnProduct) + .filter( + AddOnProduct.code == addon_code, + AddOnProduct.is_active == True, # noqa: E712 + ) + .first() + ) + + if not addon: + raise BillingServiceError(f"Add-on '{addon_code}' not found") + + if not addon.stripe_price_id: + raise BillingServiceError(f"Stripe price not configured for add-on '{addon_code}'") + + vendor = self.get_vendor(db, vendor_id) + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + + # Create checkout session for add-on + session = stripe_service.create_checkout_session( + db=db, + vendor=vendor, + price_id=addon.stripe_price_id, + success_url=success_url, + cancel_url=cancel_url, + quantity=quantity, + metadata={ + "addon_code": addon_code, + "domain_name": domain_name or "", + }, + ) + + return { + "checkout_url": session.url, + "session_id": session.id, + } + + def cancel_addon(self, db: Session, vendor_id: int, addon_id: int) -> dict: + """ + Cancel a purchased add-on. + + Returns: + Dict with message and addon_code + + Raises: + BillingServiceError: If addon not found or not owned by vendor + """ + vendor_addon = ( + db.query(VendorAddOn) + .filter( + VendorAddOn.id == addon_id, + VendorAddOn.vendor_id == vendor_id, + ) + .first() + ) + + if not vendor_addon: + raise BillingServiceError("Add-on not found") + + addon_code = vendor_addon.addon_product.code + + # Cancel in Stripe if applicable + if stripe_service.is_configured and vendor_addon.stripe_subscription_item_id: + try: + stripe_service.cancel_subscription_item(vendor_addon.stripe_subscription_item_id) + except Exception as e: + logger.warning(f"Failed to cancel addon in Stripe: {e}") + + # Mark as cancelled + vendor_addon.status = "cancelled" + vendor_addon.cancelled_at = datetime.utcnow() + + return { + "message": "Add-on cancelled successfully", + "addon_code": addon_code, + } + + +# Create service instance +billing_service = BillingService() diff --git a/app/modules/cms/routes/api/vendor.py b/app/modules/cms/routes/api/vendor.py index 96c293f2..084fe40d 100644 --- a/app/modules/cms/routes/api/vendor.py +++ b/app/modules/cms/routes/api/vendor.py @@ -25,7 +25,7 @@ from app.modules.cms.schemas import ( CMSUsageResponse, ) from app.modules.cms.services import content_page_service -from app.services.vendor_service import VendorService +from app.services.vendor_service import VendorService # noqa: MOD-004 - shared platform service from models.database.user import User vendor_service = VendorService() diff --git a/app/modules/cms/routes/pages/vendor.py b/app/modules/cms/routes/pages/vendor.py index 78da9c95..32e68e4f 100644 --- a/app/modules/cms/routes/pages/vendor.py +++ b/app/modules/cms/routes/pages/vendor.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import Session from app.api.deps import get_current_vendor_from_cookie_or_header, get_db from app.modules.cms.services import content_page_service -from app.services.platform_settings_service import platform_settings_service +from app.services.platform_settings_service import platform_settings_service # noqa: MOD-004 - shared platform service from app.templates_config import templates from models.database.user import User from models.database.vendor import Vendor diff --git a/app/modules/contracts/definition.py b/app/modules/contracts/definition.py new file mode 100644 index 00000000..b15cfbab --- /dev/null +++ b/app/modules/contracts/definition.py @@ -0,0 +1,24 @@ +# app/modules/contracts/definition.py +""" +Contracts module definition. + +Cross-module contracts and Protocol interfaces. +Infrastructure module - cannot be disabled. +""" + +from app.modules.base import ModuleDefinition + +contracts_module = ModuleDefinition( + code="contracts", + name="Module Contracts", + description="Cross-module contracts using Protocol pattern for type-safe inter-module communication.", + version="1.0.0", + is_core=True, + features=[ + "service_protocols", + "cross_module_interfaces", + ], + menu_items={}, # Infrastructure module - no UI +) + +__all__ = ["contracts_module"] diff --git a/app/modules/customers/models/__init__.py b/app/modules/customers/models/__init__.py index 16c3912e..d2242833 100644 --- a/app/modules/customers/models/__init__.py +++ b/app/modules/customers/models/__init__.py @@ -2,14 +2,23 @@ """ Customers module database models. -Re-exports customer-related models from their source locations. +This is the canonical location for customer models. Module models are +automatically discovered and registered with SQLAlchemy's Base.metadata +at startup. + +Usage: + from app.modules.customers.models import ( + Customer, + CustomerAddress, + PasswordResetToken, + ) """ -from models.database.customer import ( +from app.modules.customers.models.customer import ( Customer, CustomerAddress, ) -from models.database.password_reset_token import PasswordResetToken +from app.modules.customers.models.password_reset_token import PasswordResetToken __all__ = [ "Customer", diff --git a/app/modules/customers/models/customer.py b/app/modules/customers/models/customer.py new file mode 100644 index 00000000..1c97f71d --- /dev/null +++ b/app/modules/customers/models/customer.py @@ -0,0 +1,93 @@ +# app/modules/customers/models/customer.py +""" +Customer database models. + +Provides Customer and CustomerAddress models for vendor-scoped +customer management. +""" + +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + ForeignKey, + Integer, + Numeric, + String, +) +from sqlalchemy.orm import relationship + +from app.core.database import Base +from models.database.base import TimestampMixin + + +class Customer(Base, TimestampMixin): + """Customer model with vendor isolation.""" + + __tablename__ = "customers" + + id = Column(Integer, primary_key=True, index=True) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False) + email = Column( + String(255), nullable=False, index=True + ) # Unique within vendor scope + hashed_password = Column(String(255), nullable=False) + first_name = Column(String(100)) + last_name = Column(String(100)) + phone = Column(String(50)) + customer_number = Column( + String(100), nullable=False, index=True + ) # Vendor-specific ID + preferences = Column(JSON, default=dict) + marketing_consent = Column(Boolean, default=False) + last_order_date = Column(DateTime) + total_orders = Column(Integer, default=0) + total_spent = Column(Numeric(10, 2), default=0) + is_active = Column(Boolean, default=True, nullable=False) + + # Language preference (NULL = use vendor storefront_language default) + # Supported: en, fr, de, lb + preferred_language = Column(String(5), nullable=True) + + # Relationships + vendor = relationship("Vendor", back_populates="customers") + addresses = relationship("CustomerAddress", back_populates="customer") + orders = relationship("Order", back_populates="customer") + + def __repr__(self): + return f"" + + @property + def full_name(self): + if self.first_name and self.last_name: + return f"{self.first_name} {self.last_name}" + return self.email + + +class CustomerAddress(Base, TimestampMixin): + """Customer address model for shipping and billing addresses.""" + + __tablename__ = "customer_addresses" + + id = Column(Integer, primary_key=True, index=True) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False) + customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False) + address_type = Column(String(50), nullable=False) # 'billing', 'shipping' + first_name = Column(String(100), nullable=False) + last_name = Column(String(100), nullable=False) + company = Column(String(200)) + address_line_1 = Column(String(255), nullable=False) + address_line_2 = Column(String(255)) + city = Column(String(100), nullable=False) + postal_code = Column(String(20), nullable=False) + country_name = Column(String(100), nullable=False) + country_iso = Column(String(5), nullable=False) + is_default = Column(Boolean, default=False) + + # Relationships + vendor = relationship("Vendor") + customer = relationship("Customer", back_populates="addresses") + + def __repr__(self): + return f"" diff --git a/app/modules/customers/models/password_reset_token.py b/app/modules/customers/models/password_reset_token.py new file mode 100644 index 00000000..ea762a57 --- /dev/null +++ b/app/modules/customers/models/password_reset_token.py @@ -0,0 +1,91 @@ +# app/modules/customers/models/password_reset_token.py +""" +Password reset token model for customer accounts. + +Security features: +- Tokens are stored as SHA256 hashes, not plaintext +- Tokens expire after 1 hour +- Only one active token per customer (old tokens invalidated on new request) +""" + +import hashlib +import secrets +from datetime import datetime, timedelta + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import Session, relationship + +from app.core.database import Base + + +class PasswordResetToken(Base): + """Password reset token for customer accounts.""" + + __tablename__ = "password_reset_tokens" + + # Token expiry in hours + TOKEN_EXPIRY_HOURS = 1 + + id = Column(Integer, primary_key=True, index=True) + customer_id = Column( + Integer, ForeignKey("customers.id", ondelete="CASCADE"), nullable=False + ) + token_hash = Column(String(64), nullable=False, index=True) + expires_at = Column(DateTime, nullable=False) + used_at = Column(DateTime, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + # Relationships + customer = relationship("Customer") + + def __repr__(self): + return f"" + + @staticmethod + def hash_token(token: str) -> str: + """Hash a token using SHA256.""" + return hashlib.sha256(token.encode()).hexdigest() + + @classmethod + def create_for_customer(cls, db: Session, customer_id: int) -> str: + """Create a new password reset token for a customer. + + Invalidates any existing tokens for the customer. + Returns the plaintext token (to be sent via email). + """ + # Invalidate existing tokens for this customer + db.query(cls).filter( + cls.customer_id == customer_id, + cls.used_at.is_(None), + ).delete() + + # Generate new token + plaintext_token = secrets.token_urlsafe(32) + token_hash = cls.hash_token(plaintext_token) + + # Create token record + token = cls( + customer_id=customer_id, + token_hash=token_hash, + expires_at=datetime.utcnow() + timedelta(hours=cls.TOKEN_EXPIRY_HOURS), + ) + db.add(token) + db.flush() + + return plaintext_token + + @classmethod + def find_valid_token(cls, db: Session, plaintext_token: str) -> "PasswordResetToken | None": + """Find a valid (not expired, not used) token.""" + token_hash = cls.hash_token(plaintext_token) + + return db.query(cls).filter( + cls.token_hash == token_hash, + cls.expires_at > datetime.utcnow(), + cls.used_at.is_(None), + ).first() + + def mark_used(self, db: Session) -> None: + """Mark this token as used.""" + self.used_at = datetime.utcnow() + db.flush() diff --git a/app/modules/customers/schemas/__init__.py b/app/modules/customers/schemas/__init__.py index 3d265f9c..d27f5238 100644 --- a/app/modules/customers/schemas/__init__.py +++ b/app/modules/customers/schemas/__init__.py @@ -2,27 +2,68 @@ """ Customers module Pydantic schemas. -Re-exports customer-related schemas from their source locations. +This is the canonical location for customer schemas. + +Usage: + from app.modules.customers.schemas import ( + CustomerRegister, + CustomerUpdate, + CustomerResponse, + ) """ -from models.schema.customer import ( +from app.modules.customers.schemas.customer import ( + # Registration & Authentication CustomerRegister, CustomerUpdate, - CustomerLogin, + CustomerPasswordChange, + # Customer Response CustomerResponse, CustomerListResponse, + # Address CustomerAddressCreate, CustomerAddressUpdate, CustomerAddressResponse, + CustomerAddressListResponse, + # Preferences + CustomerPreferencesUpdate, + # Vendor Management + CustomerMessageResponse, + VendorCustomerListResponse, + CustomerDetailResponse, + CustomerOrderInfo, + CustomerOrdersResponse, + CustomerStatisticsResponse, + # Admin Management + AdminCustomerItem, + AdminCustomerListResponse, + AdminCustomerDetailResponse, ) __all__ = [ + # Registration & Authentication "CustomerRegister", "CustomerUpdate", - "CustomerLogin", + "CustomerPasswordChange", + # Customer Response "CustomerResponse", "CustomerListResponse", + # Address "CustomerAddressCreate", "CustomerAddressUpdate", "CustomerAddressResponse", + "CustomerAddressListResponse", + # Preferences + "CustomerPreferencesUpdate", + # Vendor Management + "CustomerMessageResponse", + "VendorCustomerListResponse", + "CustomerDetailResponse", + "CustomerOrderInfo", + "CustomerOrdersResponse", + "CustomerStatisticsResponse", + # Admin Management + "AdminCustomerItem", + "AdminCustomerListResponse", + "AdminCustomerDetailResponse", ] diff --git a/app/modules/customers/schemas/customer.py b/app/modules/customers/schemas/customer.py new file mode 100644 index 00000000..7f7d759e --- /dev/null +++ b/app/modules/customers/schemas/customer.py @@ -0,0 +1,340 @@ +# app/modules/customers/schemas/customer.py +""" +Pydantic schemas for customer-related operations. + +Provides schemas for: +- Customer registration and authentication +- Customer profile management +- Customer addresses +- Admin customer management +""" + +from datetime import datetime +from decimal import Decimal + +from pydantic import BaseModel, EmailStr, Field, field_validator + + +# ============================================================================ +# Customer Registration & Authentication +# ============================================================================ + + +class CustomerRegister(BaseModel): + """Schema for customer registration.""" + + email: EmailStr = Field(..., description="Customer email address") + password: str = Field( + ..., min_length=8, description="Password (minimum 8 characters)" + ) + first_name: str = Field(..., min_length=1, max_length=100) + last_name: str = Field(..., min_length=1, max_length=100) + phone: str | None = Field(None, max_length=50) + marketing_consent: bool = Field(default=False) + preferred_language: str | None = Field( + None, description="Preferred language (en, fr, de, lb)" + ) + + @field_validator("email") + @classmethod + def email_lowercase(cls, v: str) -> str: + """Convert email to lowercase.""" + return v.lower() + + @field_validator("password") + @classmethod + def password_strength(cls, v: str) -> str: + """Validate password strength.""" + if len(v) < 8: + raise ValueError("Password must be at least 8 characters") + if not any(char.isdigit() for char in v): + raise ValueError("Password must contain at least one digit") + if not any(char.isalpha() for char in v): + raise ValueError("Password must contain at least one letter") + return v + + +class CustomerUpdate(BaseModel): + """Schema for updating customer profile.""" + + email: EmailStr | None = None + first_name: str | None = Field(None, min_length=1, max_length=100) + last_name: str | None = Field(None, min_length=1, max_length=100) + phone: str | None = Field(None, max_length=50) + marketing_consent: bool | None = None + preferred_language: str | None = Field( + None, description="Preferred language (en, fr, de, lb)" + ) + + @field_validator("email") + @classmethod + def email_lowercase(cls, v: str | None) -> str | None: + """Convert email to lowercase.""" + return v.lower() if v else None + + +class CustomerPasswordChange(BaseModel): + """Schema for customer password change.""" + + current_password: str = Field(..., description="Current password") + new_password: str = Field( + ..., min_length=8, description="New password (minimum 8 characters)" + ) + confirm_password: str = Field(..., description="Confirm new password") + + @field_validator("new_password") + @classmethod + def password_strength(cls, v: str) -> str: + """Validate password strength.""" + if len(v) < 8: + raise ValueError("Password must be at least 8 characters") + if not any(char.isdigit() for char in v): + raise ValueError("Password must contain at least one digit") + if not any(char.isalpha() for char in v): + raise ValueError("Password must contain at least one letter") + return v + + +# ============================================================================ +# Customer Response +# ============================================================================ + + +class CustomerResponse(BaseModel): + """Schema for customer response (excludes password).""" + + id: int + vendor_id: int + email: str + first_name: str | None + last_name: str | None + phone: str | None + customer_number: str + marketing_consent: bool + preferred_language: str | None + last_order_date: datetime | None + total_orders: int + total_spent: Decimal + is_active: bool + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + @property + def full_name(self) -> str: + """Get customer full name.""" + if self.first_name and self.last_name: + return f"{self.first_name} {self.last_name}" + return self.email + + +class CustomerListResponse(BaseModel): + """Schema for paginated customer list.""" + + customers: list[CustomerResponse] + total: int + page: int + per_page: int + total_pages: int + + +# ============================================================================ +# Customer Address +# ============================================================================ + + +class CustomerAddressCreate(BaseModel): + """Schema for creating customer address.""" + + address_type: str = Field(..., pattern="^(billing|shipping)$") + first_name: str = Field(..., min_length=1, max_length=100) + last_name: str = Field(..., min_length=1, max_length=100) + company: str | None = Field(None, max_length=200) + address_line_1: str = Field(..., min_length=1, max_length=255) + address_line_2: str | None = Field(None, max_length=255) + city: str = Field(..., min_length=1, max_length=100) + postal_code: str = Field(..., min_length=1, max_length=20) + country_name: str = Field(..., min_length=2, max_length=100) + country_iso: str = Field(..., min_length=2, max_length=5) + is_default: bool = Field(default=False) + + +class CustomerAddressUpdate(BaseModel): + """Schema for updating customer address.""" + + address_type: str | None = Field(None, pattern="^(billing|shipping)$") + first_name: str | None = Field(None, min_length=1, max_length=100) + last_name: str | None = Field(None, min_length=1, max_length=100) + company: str | None = Field(None, max_length=200) + address_line_1: str | None = Field(None, min_length=1, max_length=255) + address_line_2: str | None = Field(None, max_length=255) + city: str | None = Field(None, min_length=1, max_length=100) + postal_code: str | None = Field(None, min_length=1, max_length=20) + country_name: str | None = Field(None, min_length=2, max_length=100) + country_iso: str | None = Field(None, min_length=2, max_length=5) + is_default: bool | None = None + + +class CustomerAddressResponse(BaseModel): + """Schema for customer address response.""" + + id: int + vendor_id: int + customer_id: int + address_type: str + first_name: str + last_name: str + company: str | None + address_line_1: str + address_line_2: str | None + city: str + postal_code: str + country_name: str + country_iso: str + is_default: bool + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class CustomerAddressListResponse(BaseModel): + """Schema for customer address list response.""" + + addresses: list[CustomerAddressResponse] + total: int + + +# ============================================================================ +# Customer Preferences +# ============================================================================ + + +class CustomerPreferencesUpdate(BaseModel): + """Schema for updating customer preferences.""" + + marketing_consent: bool | None = None + preferred_language: str | None = Field( + None, description="Preferred language (en, fr, de, lb)" + ) + currency: str | None = Field(None, max_length=3) + notification_preferences: dict[str, bool] | None = None + + +# ============================================================================ +# Vendor Customer Management Response Schemas +# ============================================================================ + + +class CustomerMessageResponse(BaseModel): + """Simple message response for customer operations.""" + + message: str + + +class VendorCustomerListResponse(BaseModel): + """Schema for vendor customer list with skip/limit pagination.""" + + customers: list[CustomerResponse] = [] + total: int = 0 + skip: int = 0 + limit: int = 100 + message: str | None = None + + +class CustomerDetailResponse(BaseModel): + """Detailed customer response for vendor management.""" + + id: int | None = None + vendor_id: int | None = None + email: str | None = None + first_name: str | None = None + last_name: str | None = None + phone: str | None = None + customer_number: str | None = None + marketing_consent: bool | None = None + preferred_language: str | None = None + last_order_date: datetime | None = None + total_orders: int | None = None + total_spent: Decimal | None = None + is_active: bool | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + message: str | None = None + + model_config = {"from_attributes": True} + + +class CustomerOrderInfo(BaseModel): + """Basic order info for customer order history.""" + + id: int + order_number: str + status: str + total: Decimal + created_at: datetime + + +class CustomerOrdersResponse(BaseModel): + """Response for customer order history.""" + + orders: list[CustomerOrderInfo] = [] + total: int = 0 + message: str | None = None + + +class CustomerStatisticsResponse(BaseModel): + """Response for customer statistics.""" + + total: int = 0 + active: int = 0 + inactive: int = 0 + with_orders: int = 0 + total_spent: float = 0.0 + total_orders: int = 0 + avg_order_value: float = 0.0 + + +# ============================================================================ +# Admin Customer Management Response Schemas +# ============================================================================ + + +class AdminCustomerItem(BaseModel): + """Admin customer list item with vendor info.""" + + id: int + vendor_id: int + email: str + first_name: str | None = None + last_name: str | None = None + phone: str | None = None + customer_number: str + marketing_consent: bool = False + preferred_language: str | None = None + last_order_date: datetime | None = None + total_orders: int = 0 + total_spent: float = 0.0 + is_active: bool = True + created_at: datetime + updated_at: datetime + vendor_name: str | None = None + vendor_code: str | None = None + + model_config = {"from_attributes": True} + + +class AdminCustomerListResponse(BaseModel): + """Admin paginated customer list with skip/limit.""" + + customers: list[AdminCustomerItem] = [] + total: int = 0 + skip: int = 0 + limit: int = 20 + + +class AdminCustomerDetailResponse(AdminCustomerItem): + """Detailed customer response for admin.""" + + pass diff --git a/app/modules/customers/services/__init__.py b/app/modules/customers/services/__init__.py index 762005df..5907596a 100644 --- a/app/modules/customers/services/__init__.py +++ b/app/modules/customers/services/__init__.py @@ -2,18 +2,25 @@ """ Customers module services. -Re-exports customer-related services from their source locations. +This is the canonical location for customer services. + +Usage: + from app.modules.customers.services import ( + customer_service, + admin_customer_service, + customer_address_service, + ) """ -from app.services.customer_service import ( +from app.modules.customers.services.customer_service import ( customer_service, CustomerService, ) -from app.services.admin_customer_service import ( +from app.modules.customers.services.admin_customer_service import ( admin_customer_service, AdminCustomerService, ) -from app.services.customer_address_service import ( +from app.modules.customers.services.customer_address_service import ( customer_address_service, CustomerAddressService, ) diff --git a/app/modules/customers/services/admin_customer_service.py b/app/modules/customers/services/admin_customer_service.py new file mode 100644 index 00000000..0b48f950 --- /dev/null +++ b/app/modules/customers/services/admin_customer_service.py @@ -0,0 +1,242 @@ +# app/modules/customers/services/admin_customer_service.py +""" +Admin customer management service. + +Handles customer operations for admin users across all vendors. +""" + +import logging +from typing import Any + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.exceptions.customer import CustomerNotFoundException +from app.modules.customers.models import Customer +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class AdminCustomerService: + """Service for admin-level customer management across vendors.""" + + def list_customers( + self, + db: Session, + vendor_id: int | None = None, + search: str | None = None, + is_active: bool | None = None, + skip: int = 0, + limit: int = 20, + ) -> tuple[list[dict[str, Any]], int]: + """ + Get paginated list of customers across all vendors. + + Args: + db: Database session + vendor_id: Optional vendor ID filter + search: Search by email, name, or customer number + is_active: Filter by active status + skip: Number of records to skip + limit: Maximum records to return + + Returns: + Tuple of (customers list, total count) + """ + # Build query + query = db.query(Customer).join(Vendor, Customer.vendor_id == Vendor.id) + + # Apply filters + if vendor_id: + query = query.filter(Customer.vendor_id == vendor_id) + + if search: + search_term = f"%{search}%" + query = query.filter( + (Customer.email.ilike(search_term)) + | (Customer.first_name.ilike(search_term)) + | (Customer.last_name.ilike(search_term)) + | (Customer.customer_number.ilike(search_term)) + ) + + if is_active is not None: + query = query.filter(Customer.is_active == is_active) + + # Get total count + total = query.count() + + # Get paginated results with vendor info + customers = ( + query.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code) + .order_by(Customer.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + # Format response + result = [] + for row in customers: + customer = row[0] + vendor_name = row[1] + vendor_code = row[2] + + customer_dict = { + "id": customer.id, + "vendor_id": customer.vendor_id, + "email": customer.email, + "first_name": customer.first_name, + "last_name": customer.last_name, + "phone": customer.phone, + "customer_number": customer.customer_number, + "marketing_consent": customer.marketing_consent, + "preferred_language": customer.preferred_language, + "last_order_date": customer.last_order_date, + "total_orders": customer.total_orders, + "total_spent": float(customer.total_spent) if customer.total_spent else 0, + "is_active": customer.is_active, + "created_at": customer.created_at, + "updated_at": customer.updated_at, + "vendor_name": vendor_name, + "vendor_code": vendor_code, + } + result.append(customer_dict) + + return result, total + + def get_customer_stats( + self, + db: Session, + vendor_id: int | None = None, + ) -> dict[str, Any]: + """ + Get customer statistics. + + Args: + db: Database session + vendor_id: Optional vendor ID filter + + Returns: + Dict with customer statistics + """ + query = db.query(Customer) + + if vendor_id: + query = query.filter(Customer.vendor_id == vendor_id) + + total = query.count() + active = query.filter(Customer.is_active == True).count() # noqa: E712 + inactive = query.filter(Customer.is_active == False).count() # noqa: E712 + with_orders = query.filter(Customer.total_orders > 0).count() + + # Total spent across all customers + total_spent_result = query.with_entities(func.sum(Customer.total_spent)).scalar() + total_spent = float(total_spent_result) if total_spent_result else 0 + + # Average order value + total_orders_result = query.with_entities(func.sum(Customer.total_orders)).scalar() + total_orders = int(total_orders_result) if total_orders_result else 0 + avg_order_value = total_spent / total_orders if total_orders > 0 else 0 + + return { + "total": total, + "active": active, + "inactive": inactive, + "with_orders": with_orders, + "total_spent": total_spent, + "total_orders": total_orders, + "avg_order_value": round(avg_order_value, 2), + } + + def get_customer( + self, + db: Session, + customer_id: int, + ) -> dict[str, Any]: + """ + Get customer details by ID. + + Args: + db: Database session + customer_id: Customer ID + + Returns: + Customer dict with vendor info + + Raises: + CustomerNotFoundException: If customer not found + """ + result = ( + db.query(Customer) + .join(Vendor, Customer.vendor_id == Vendor.id) + .add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code) + .filter(Customer.id == customer_id) + .first() + ) + + if not result: + raise CustomerNotFoundException(str(customer_id)) + + customer = result[0] + return { + "id": customer.id, + "vendor_id": customer.vendor_id, + "email": customer.email, + "first_name": customer.first_name, + "last_name": customer.last_name, + "phone": customer.phone, + "customer_number": customer.customer_number, + "marketing_consent": customer.marketing_consent, + "preferred_language": customer.preferred_language, + "last_order_date": customer.last_order_date, + "total_orders": customer.total_orders, + "total_spent": float(customer.total_spent) if customer.total_spent else 0, + "is_active": customer.is_active, + "created_at": customer.created_at, + "updated_at": customer.updated_at, + "vendor_name": result[1], + "vendor_code": result[2], + } + + def toggle_customer_status( + self, + db: Session, + customer_id: int, + admin_email: str, + ) -> dict[str, Any]: + """ + Toggle customer active status. + + Args: + db: Database session + customer_id: Customer ID + admin_email: Admin user email for logging + + Returns: + Dict with customer ID, new status, and message + + Raises: + CustomerNotFoundException: If customer not found + """ + customer = db.query(Customer).filter(Customer.id == customer_id).first() + + if not customer: + raise CustomerNotFoundException(str(customer_id)) + + customer.is_active = not customer.is_active + db.flush() + db.refresh(customer) + + status = "activated" if customer.is_active else "deactivated" + logger.info(f"Customer {customer.email} {status} by admin {admin_email}") + + return { + "id": customer.id, + "is_active": customer.is_active, + "message": f"Customer {status} successfully", + } + + +# Singleton instance +admin_customer_service = AdminCustomerService() diff --git a/app/modules/customers/services/customer_address_service.py b/app/modules/customers/services/customer_address_service.py new file mode 100644 index 00000000..c9c51e85 --- /dev/null +++ b/app/modules/customers/services/customer_address_service.py @@ -0,0 +1,314 @@ +# app/modules/customers/services/customer_address_service.py +""" +Customer Address Service + +Business logic for managing customer addresses with vendor isolation. +""" + +import logging + +from sqlalchemy.orm import Session + +from app.exceptions import ( + AddressLimitExceededException, + AddressNotFoundException, +) +from app.modules.customers.models import CustomerAddress +from app.modules.customers.schemas import CustomerAddressCreate, CustomerAddressUpdate + +logger = logging.getLogger(__name__) + + +class CustomerAddressService: + """Service for managing customer addresses with vendor isolation.""" + + MAX_ADDRESSES_PER_CUSTOMER = 10 + + def list_addresses( + self, db: Session, vendor_id: int, customer_id: int + ) -> list[CustomerAddress]: + """ + Get all addresses for a customer. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + + Returns: + List of customer addresses + """ + return ( + db.query(CustomerAddress) + .filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + ) + .order_by(CustomerAddress.is_default.desc(), CustomerAddress.created_at.desc()) + .all() + ) + + def get_address( + self, db: Session, vendor_id: int, customer_id: int, address_id: int + ) -> CustomerAddress: + """ + Get a specific address with ownership validation. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + + Returns: + Customer address + + Raises: + AddressNotFoundException: If address not found or doesn't belong to customer + """ + address = ( + db.query(CustomerAddress) + .filter( + CustomerAddress.id == address_id, + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + ) + .first() + ) + + if not address: + raise AddressNotFoundException(address_id) + + return address + + def get_default_address( + self, db: Session, vendor_id: int, customer_id: int, address_type: str + ) -> CustomerAddress | None: + """ + Get the default address for a specific type. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_type: 'shipping' or 'billing' + + Returns: + Default address or None if not set + """ + return ( + db.query(CustomerAddress) + .filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + CustomerAddress.address_type == address_type, + CustomerAddress.is_default == True, # noqa: E712 + ) + .first() + ) + + def create_address( + self, + db: Session, + vendor_id: int, + customer_id: int, + address_data: CustomerAddressCreate, + ) -> CustomerAddress: + """ + Create a new address for a customer. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_data: Address creation data + + Returns: + Created customer address + + Raises: + AddressLimitExceededException: If max addresses reached + """ + # Check address limit + current_count = ( + db.query(CustomerAddress) + .filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + ) + .count() + ) + + if current_count >= self.MAX_ADDRESSES_PER_CUSTOMER: + raise AddressLimitExceededException(self.MAX_ADDRESSES_PER_CUSTOMER) + + # If setting as default, clear other defaults of same type + if address_data.is_default: + self._clear_other_defaults( + db, vendor_id, customer_id, address_data.address_type + ) + + # Create the address + address = CustomerAddress( + vendor_id=vendor_id, + customer_id=customer_id, + address_type=address_data.address_type, + first_name=address_data.first_name, + last_name=address_data.last_name, + company=address_data.company, + address_line_1=address_data.address_line_1, + address_line_2=address_data.address_line_2, + city=address_data.city, + postal_code=address_data.postal_code, + country_name=address_data.country_name, + country_iso=address_data.country_iso, + is_default=address_data.is_default, + ) + + db.add(address) + db.flush() + + logger.info( + f"Created address {address.id} for customer {customer_id} " + f"(type={address_data.address_type}, default={address_data.is_default})" + ) + + return address + + def update_address( + self, + db: Session, + vendor_id: int, + customer_id: int, + address_id: int, + address_data: CustomerAddressUpdate, + ) -> CustomerAddress: + """ + Update an existing address. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + address_data: Address update data + + Returns: + Updated customer address + + Raises: + AddressNotFoundException: If address not found + """ + address = self.get_address(db, vendor_id, customer_id, address_id) + + # Update only provided fields + update_data = address_data.model_dump(exclude_unset=True) + + # Handle default flag - clear others if setting to default + if update_data.get("is_default") is True: + # Use updated type if provided, otherwise current type + address_type = update_data.get("address_type", address.address_type) + self._clear_other_defaults( + db, vendor_id, customer_id, address_type, exclude_id=address_id + ) + + for field, value in update_data.items(): + setattr(address, field, value) + + db.flush() + + logger.info(f"Updated address {address_id} for customer {customer_id}") + + return address + + def delete_address( + self, db: Session, vendor_id: int, customer_id: int, address_id: int + ) -> None: + """ + Delete an address. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + + Raises: + AddressNotFoundException: If address not found + """ + address = self.get_address(db, vendor_id, customer_id, address_id) + + db.delete(address) + db.flush() + + logger.info(f"Deleted address {address_id} for customer {customer_id}") + + def set_default( + self, db: Session, vendor_id: int, customer_id: int, address_id: int + ) -> CustomerAddress: + """ + Set an address as the default for its type. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_id: Address ID + + Returns: + Updated customer address + + Raises: + AddressNotFoundException: If address not found + """ + address = self.get_address(db, vendor_id, customer_id, address_id) + + # Clear other defaults of same type + self._clear_other_defaults( + db, vendor_id, customer_id, address.address_type, exclude_id=address_id + ) + + # Set this one as default + address.is_default = True + db.flush() + + logger.info( + f"Set address {address_id} as default {address.address_type} " + f"for customer {customer_id}" + ) + + return address + + def _clear_other_defaults( + self, + db: Session, + vendor_id: int, + customer_id: int, + address_type: str, + exclude_id: int | None = None, + ) -> None: + """ + Clear the default flag on other addresses of the same type. + + Args: + db: Database session + vendor_id: Vendor ID for isolation + customer_id: Customer ID + address_type: 'shipping' or 'billing' + exclude_id: Address ID to exclude from clearing + """ + query = db.query(CustomerAddress).filter( + CustomerAddress.vendor_id == vendor_id, + CustomerAddress.customer_id == customer_id, + CustomerAddress.address_type == address_type, + CustomerAddress.is_default == True, # noqa: E712 + ) + + if exclude_id: + query = query.filter(CustomerAddress.id != exclude_id) + + query.update({"is_default": False}, synchronize_session=False) + + +# Singleton instance +customer_address_service = CustomerAddressService() diff --git a/app/modules/customers/services/customer_service.py b/app/modules/customers/services/customer_service.py new file mode 100644 index 00000000..d12b566a --- /dev/null +++ b/app/modules/customers/services/customer_service.py @@ -0,0 +1,659 @@ +# app/modules/customers/services/customer_service.py +""" +Customer management service. + +Handles customer registration, authentication, and profile management +with complete vendor isolation. +""" + +import logging +from datetime import UTC, datetime, timedelta +from typing import Any + +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from app.exceptions.customer import ( + CustomerNotActiveException, + CustomerNotFoundException, + CustomerValidationException, + DuplicateCustomerEmailException, + InvalidCustomerCredentialsException, + InvalidPasswordResetTokenException, + PasswordTooShortException, +) +from app.exceptions.vendor import VendorNotActiveException, VendorNotFoundException +from app.services.auth_service import AuthService +from app.modules.customers.models import Customer, PasswordResetToken +from app.modules.customers.schemas import CustomerRegister, CustomerUpdate +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class CustomerService: + """Service for managing vendor-scoped customers.""" + + def __init__(self): + self.auth_service = AuthService() + + def register_customer( + self, db: Session, vendor_id: int, customer_data: CustomerRegister + ) -> Customer: + """ + Register a new customer for a specific vendor. + + Args: + db: Database session + vendor_id: Vendor ID + customer_data: Customer registration data + + Returns: + Customer: Created customer object + + Raises: + VendorNotFoundException: If vendor doesn't exist + VendorNotActiveException: If vendor is not active + DuplicateCustomerEmailException: If email already exists for this vendor + CustomerValidationException: If customer data is invalid + """ + # Verify vendor exists and is active + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + if not vendor: + raise VendorNotFoundException(str(vendor_id), identifier_type="id") + + if not vendor.is_active: + raise VendorNotActiveException(vendor.vendor_code) + + # Check if email already exists for this vendor + existing_customer = ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.email == customer_data.email.lower(), + ) + ) + .first() + ) + + if existing_customer: + raise DuplicateCustomerEmailException( + customer_data.email, vendor.vendor_code + ) + + # Generate unique customer number for this vendor + customer_number = self._generate_customer_number( + db, vendor_id, vendor.vendor_code + ) + + # Hash password + hashed_password = self.auth_service.hash_password(customer_data.password) + + # Create customer + customer = Customer( + vendor_id=vendor_id, + email=customer_data.email.lower(), + hashed_password=hashed_password, + first_name=customer_data.first_name, + last_name=customer_data.last_name, + phone=customer_data.phone, + customer_number=customer_number, + marketing_consent=( + customer_data.marketing_consent + if hasattr(customer_data, "marketing_consent") + else False + ), + is_active=True, + ) + + try: + db.add(customer) + db.flush() + db.refresh(customer) + + logger.info( + f"Customer registered successfully: {customer.email} " + f"(ID: {customer.id}, Number: {customer.customer_number}) " + f"for vendor {vendor.vendor_code}" + ) + + return customer + + except Exception as e: + logger.error(f"Error registering customer: {str(e)}") + raise CustomerValidationException( + message="Failed to register customer", details={"error": str(e)} + ) + + def login_customer( + self, db: Session, vendor_id: int, credentials + ) -> dict[str, Any]: + """ + Authenticate customer and generate JWT token. + + Args: + db: Database session + vendor_id: Vendor ID + credentials: Login credentials (UserLogin schema) + + Returns: + Dict containing customer and token data + + Raises: + VendorNotFoundException: If vendor doesn't exist + InvalidCustomerCredentialsException: If credentials are invalid + CustomerNotActiveException: If customer account is inactive + """ + # Verify vendor exists + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + if not vendor: + raise VendorNotFoundException(str(vendor_id), identifier_type="id") + + # Find customer by email (vendor-scoped) + customer = ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.email == credentials.email_or_username.lower(), + ) + ) + .first() + ) + + if not customer: + raise InvalidCustomerCredentialsException() + + # Verify password using auth_manager directly + if not self.auth_service.auth_manager.verify_password( + credentials.password, customer.hashed_password + ): + raise InvalidCustomerCredentialsException() + + # Check if customer is active + if not customer.is_active: + raise CustomerNotActiveException(customer.email) + + # Generate JWT token with customer context + from jose import jwt + + auth_manager = self.auth_service.auth_manager + expires_delta = timedelta(minutes=auth_manager.token_expire_minutes) + expire = datetime.now(UTC) + expires_delta + + payload = { + "sub": str(customer.id), + "email": customer.email, + "vendor_id": vendor_id, + "type": "customer", + "exp": expire, + "iat": datetime.now(UTC), + } + + token = jwt.encode( + payload, auth_manager.secret_key, algorithm=auth_manager.algorithm + ) + + token_data = { + "access_token": token, + "token_type": "bearer", + "expires_in": auth_manager.token_expire_minutes * 60, + } + + logger.info( + f"Customer login successful: {customer.email} " + f"for vendor {vendor.vendor_code}" + ) + + return {"customer": customer, "token_data": token_data} + + def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer: + """ + Get customer by ID with vendor isolation. + + Args: + db: Database session + vendor_id: Vendor ID + customer_id: Customer ID + + Returns: + Customer: Customer object + + Raises: + CustomerNotFoundException: If customer not found + """ + customer = ( + db.query(Customer) + .filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id)) + .first() + ) + + if not customer: + raise CustomerNotFoundException(str(customer_id)) + + return customer + + def get_customer_by_email( + self, db: Session, vendor_id: int, email: str + ) -> Customer | None: + """ + Get customer by email (vendor-scoped). + + Args: + db: Database session + vendor_id: Vendor ID + email: Customer email + + Returns: + Optional[Customer]: Customer object or None + """ + return ( + db.query(Customer) + .filter( + and_(Customer.vendor_id == vendor_id, Customer.email == email.lower()) + ) + .first() + ) + + def get_vendor_customers( + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 100, + search: str | None = None, + is_active: bool | None = None, + ) -> tuple[list[Customer], int]: + """ + Get all customers for a vendor with filtering and pagination. + + Args: + db: Database session + vendor_id: Vendor ID + skip: Pagination offset + limit: Pagination limit + search: Search in name/email + is_active: Filter by active status + + Returns: + Tuple of (customers, total_count) + """ + from sqlalchemy import or_ + + query = db.query(Customer).filter(Customer.vendor_id == vendor_id) + + if search: + search_pattern = f"%{search}%" + query = query.filter( + or_( + Customer.email.ilike(search_pattern), + Customer.first_name.ilike(search_pattern), + Customer.last_name.ilike(search_pattern), + Customer.customer_number.ilike(search_pattern), + ) + ) + + if is_active is not None: + query = query.filter(Customer.is_active == is_active) + + # Order by most recent first + query = query.order_by(Customer.created_at.desc()) + + total = query.count() + customers = query.offset(skip).limit(limit).all() + + return customers, total + + def get_customer_orders( + self, + db: Session, + vendor_id: int, + customer_id: int, + skip: int = 0, + limit: int = 50, + ) -> tuple[list, int]: + """ + Get orders for a specific customer. + + Args: + db: Database session + vendor_id: Vendor ID + customer_id: Customer ID + skip: Pagination offset + limit: Pagination limit + + Returns: + Tuple of (orders, total_count) + + Raises: + CustomerNotFoundException: If customer not found + """ + from models.database.order import Order + + # Verify customer belongs to vendor + self.get_customer(db, vendor_id, customer_id) + + # Get customer orders + query = ( + db.query(Order) + .filter( + Order.customer_id == customer_id, + Order.vendor_id == vendor_id, + ) + .order_by(Order.created_at.desc()) + ) + + total = query.count() + orders = query.offset(skip).limit(limit).all() + + return orders, total + + def get_customer_statistics( + self, db: Session, vendor_id: int, customer_id: int + ) -> dict: + """ + Get detailed statistics for a customer. + + Args: + db: Database session + vendor_id: Vendor ID + customer_id: Customer ID + + Returns: + Dict with customer statistics + """ + from sqlalchemy import func + + from models.database.order import Order + + customer = self.get_customer(db, vendor_id, customer_id) + + # Get order statistics + order_stats = ( + db.query( + func.count(Order.id).label("total_orders"), + func.sum(Order.total_cents).label("total_spent_cents"), + func.avg(Order.total_cents).label("avg_order_cents"), + func.max(Order.created_at).label("last_order_date"), + ) + .filter(Order.customer_id == customer_id) + .first() + ) + + total_orders = order_stats.total_orders or 0 + total_spent_cents = order_stats.total_spent_cents or 0 + avg_order_cents = order_stats.avg_order_cents or 0 + + return { + "customer_id": customer_id, + "total_orders": total_orders, + "total_spent": total_spent_cents / 100, # Convert to euros + "average_order_value": avg_order_cents / 100 if avg_order_cents else 0.0, + "last_order_date": order_stats.last_order_date, + "member_since": customer.created_at, + "is_active": customer.is_active, + } + + def toggle_customer_status( + self, db: Session, vendor_id: int, customer_id: int + ) -> Customer: + """ + Toggle customer active status. + + Args: + db: Database session + vendor_id: Vendor ID + customer_id: Customer ID + + Returns: + Customer: Updated customer + """ + customer = self.get_customer(db, vendor_id, customer_id) + customer.is_active = not customer.is_active + + db.flush() + db.refresh(customer) + + action = "activated" if customer.is_active else "deactivated" + logger.info(f"Customer {action}: {customer.email} (ID: {customer.id})") + + return customer + + def update_customer( + self, + db: Session, + vendor_id: int, + customer_id: int, + customer_data: CustomerUpdate, + ) -> Customer: + """ + Update customer profile. + + Args: + db: Database session + vendor_id: Vendor ID + customer_id: Customer ID + customer_data: Updated customer data + + Returns: + Customer: Updated customer object + + Raises: + CustomerNotFoundException: If customer not found + CustomerValidationException: If update data is invalid + """ + customer = self.get_customer(db, vendor_id, customer_id) + + # Update fields + update_data = customer_data.model_dump(exclude_unset=True) + + for field, value in update_data.items(): + if field == "email" and value: + # Check if new email already exists for this vendor + existing = ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.email == value.lower(), + Customer.id != customer_id, + ) + ) + .first() + ) + + if existing: + raise DuplicateCustomerEmailException(value, "vendor") + + setattr(customer, field, value.lower()) + elif hasattr(customer, field): + setattr(customer, field, value) + + try: + db.flush() + db.refresh(customer) + + logger.info(f"Customer updated: {customer.email} (ID: {customer.id})") + + return customer + + except Exception as e: + logger.error(f"Error updating customer: {str(e)}") + raise CustomerValidationException( + message="Failed to update customer", details={"error": str(e)} + ) + + def deactivate_customer( + self, db: Session, vendor_id: int, customer_id: int + ) -> Customer: + """ + Deactivate customer account. + + Args: + db: Database session + vendor_id: Vendor ID + customer_id: Customer ID + + Returns: + Customer: Deactivated customer object + + Raises: + CustomerNotFoundException: If customer not found + """ + customer = self.get_customer(db, vendor_id, customer_id) + customer.is_active = False + + db.flush() + db.refresh(customer) + + logger.info(f"Customer deactivated: {customer.email} (ID: {customer.id})") + + return customer + + def update_customer_stats( + self, db: Session, customer_id: int, order_total: float + ) -> None: + """ + Update customer statistics after order. + + Args: + db: Database session + customer_id: Customer ID + order_total: Order total amount + """ + customer = db.query(Customer).filter(Customer.id == customer_id).first() + + if customer: + customer.total_orders += 1 + customer.total_spent += order_total + customer.last_order_date = datetime.utcnow() + + logger.debug(f"Updated stats for customer {customer.email}") + + def _generate_customer_number( + self, db: Session, vendor_id: int, vendor_code: str + ) -> str: + """ + Generate unique customer number for vendor. + + Format: {VENDOR_CODE}-CUST-{SEQUENCE} + Example: VENDORA-CUST-00001 + + Args: + db: Database session + vendor_id: Vendor ID + vendor_code: Vendor code + + Returns: + str: Unique customer number + """ + # Get count of customers for this vendor + count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count() + + # Generate number with padding + sequence = str(count + 1).zfill(5) + customer_number = f"{vendor_code.upper()}-CUST-{sequence}" + + # Ensure uniqueness (in case of deletions) + while ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.customer_number == customer_number, + ) + ) + .first() + ): + count += 1 + sequence = str(count + 1).zfill(5) + customer_number = f"{vendor_code.upper()}-CUST-{sequence}" + + return customer_number + + def get_customer_for_password_reset( + self, db: Session, vendor_id: int, email: str + ) -> Customer | None: + """ + Get active customer by email for password reset. + + Args: + db: Database session + vendor_id: Vendor ID + email: Customer email + + Returns: + Customer if found and active, None otherwise + """ + return ( + db.query(Customer) + .filter( + Customer.vendor_id == vendor_id, + Customer.email == email.lower(), + Customer.is_active == True, # noqa: E712 + ) + .first() + ) + + def validate_and_reset_password( + self, + db: Session, + vendor_id: int, + reset_token: str, + new_password: str, + ) -> Customer: + """ + Validate reset token and update customer password. + + Args: + db: Database session + vendor_id: Vendor ID + reset_token: Password reset token from email + new_password: New password + + Returns: + Customer: Updated customer + + Raises: + PasswordTooShortException: If password too short + InvalidPasswordResetTokenException: If token invalid/expired + CustomerNotActiveException: If customer not active + """ + # Validate password length + if len(new_password) < 8: + raise PasswordTooShortException(min_length=8) + + # Find valid token + token_record = PasswordResetToken.find_valid_token(db, reset_token) + + if not token_record: + raise InvalidPasswordResetTokenException() + + # Get the customer and verify they belong to this vendor + customer = ( + db.query(Customer) + .filter(Customer.id == token_record.customer_id) + .first() + ) + + if not customer or customer.vendor_id != vendor_id: + raise InvalidPasswordResetTokenException() + + if not customer.is_active: + raise CustomerNotActiveException(customer.email) + + # Hash the new password and update customer + hashed_password = self.auth_service.hash_password(new_password) + customer.hashed_password = hashed_password + + # Mark token as used + token_record.mark_used(db) + + logger.info(f"Password reset completed for customer {customer.id}") + + return customer + + +# Singleton instance +customer_service = CustomerService() diff --git a/app/modules/dev_tools/services/__init__.py b/app/modules/dev_tools/services/__init__.py index 5046720e..95ef0a4b 100644 --- a/app/modules/dev_tools/services/__init__.py +++ b/app/modules/dev_tools/services/__init__.py @@ -2,16 +2,14 @@ """ Dev-Tools module services. -This module re-exports services from their current locations. -In future cleanup phases, the actual service implementations -may be moved here. +Provides code quality scanning and test running functionality. Services: - code_quality_service: Code quality scanning and violation management - test_runner_service: Test execution and results management """ -from app.services.code_quality_service import ( +from app.modules.dev_tools.services.code_quality_service import ( code_quality_service, CodeQualityService, VALIDATOR_ARCHITECTURE, @@ -21,7 +19,7 @@ from app.services.code_quality_service import ( VALIDATOR_SCRIPTS, VALIDATOR_NAMES, ) -from app.services.test_runner_service import ( +from app.modules.dev_tools.services.test_runner_service import ( test_runner_service, TestRunnerService, ) diff --git a/app/modules/dev_tools/services/code_quality_service.py b/app/modules/dev_tools/services/code_quality_service.py new file mode 100644 index 00000000..66c8468d --- /dev/null +++ b/app/modules/dev_tools/services/code_quality_service.py @@ -0,0 +1,820 @@ +""" +Code Quality Service +Business logic for managing code quality scans and violations +Supports multiple validator types: architecture, security, performance +""" + +import json +import logging +import subprocess +from datetime import datetime, UTC + +from sqlalchemy import desc, func +from sqlalchemy.orm import Session + +from app.exceptions import ( + ScanParseException, + ScanTimeoutException, + ViolationNotFoundException, +) +from app.modules.dev_tools.models import ( + ArchitectureScan, + ArchitectureViolation, + ViolationAssignment, + ViolationComment, +) + +logger = logging.getLogger(__name__) + +# Validator type constants +VALIDATOR_ARCHITECTURE = "architecture" +VALIDATOR_SECURITY = "security" +VALIDATOR_PERFORMANCE = "performance" + +VALID_VALIDATOR_TYPES = [VALIDATOR_ARCHITECTURE, VALIDATOR_SECURITY, VALIDATOR_PERFORMANCE] + +# Map validator types to their scripts +VALIDATOR_SCRIPTS = { + VALIDATOR_ARCHITECTURE: "scripts/validate_architecture.py", + VALIDATOR_SECURITY: "scripts/validate_security.py", + VALIDATOR_PERFORMANCE: "scripts/validate_performance.py", +} + +# Human-readable names +VALIDATOR_NAMES = { + VALIDATOR_ARCHITECTURE: "Architecture", + VALIDATOR_SECURITY: "Security", + VALIDATOR_PERFORMANCE: "Performance", +} + + +class CodeQualityService: + """Service for managing code quality scans and violations""" + + def run_scan( + self, + db: Session, + triggered_by: str = "manual", + validator_type: str = VALIDATOR_ARCHITECTURE, + ) -> ArchitectureScan: + """ + Run a code quality validator and store results in database + + Args: + db: Database session + triggered_by: Who/what triggered the scan ('manual', 'scheduled', 'ci/cd') + validator_type: Type of validator ('architecture', 'security', 'performance') + + Returns: + ArchitectureScan object with results + + Raises: + ValueError: If validator_type is invalid + ScanTimeoutException: If validator times out + ScanParseException: If validator output cannot be parsed + """ + if validator_type not in VALID_VALIDATOR_TYPES: + raise ValueError( + f"Invalid validator type: {validator_type}. " + f"Must be one of: {VALID_VALIDATOR_TYPES}" + ) + + script_path = VALIDATOR_SCRIPTS[validator_type] + validator_name = VALIDATOR_NAMES[validator_type] + + logger.info( + f"Starting {validator_name} scan (triggered by: {triggered_by})" + ) + + # Get git commit hash + git_commit = self._get_git_commit_hash() + + # Run validator with JSON output + start_time = datetime.now() + try: + result = subprocess.run( + ["python", script_path, "--json"], + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + except subprocess.TimeoutExpired: + logger.error(f"{validator_name} scan timed out after 5 minutes") + raise ScanTimeoutException(timeout_seconds=300) + + duration = (datetime.now() - start_time).total_seconds() + + # Parse JSON output (get only the JSON part, skip progress messages) + try: + # Find the JSON part in stdout + lines = result.stdout.strip().split("\n") + json_start = -1 + for i, line in enumerate(lines): + if line.strip().startswith("{"): + json_start = i + break + + if json_start == -1: + raise ValueError("No JSON output found") + + json_output = "\n".join(lines[json_start:]) + data = json.loads(json_output) + except (json.JSONDecodeError, ValueError) as e: + logger.error(f"Failed to parse {validator_name} validator output: {e}") + logger.error(f"Stdout: {result.stdout}") + logger.error(f"Stderr: {result.stderr}") + raise ScanParseException(reason=str(e)) + + # Create scan record + scan = ArchitectureScan( + timestamp=datetime.now(), + validator_type=validator_type, + total_files=data.get("files_checked", 0), + total_violations=data.get("total_violations", 0), + errors=data.get("errors", 0), + warnings=data.get("warnings", 0), + duration_seconds=duration, + triggered_by=triggered_by, + git_commit_hash=git_commit, + ) + db.add(scan) + db.flush() # Get scan.id + + # Create violation records + violations_data = data.get("violations", []) + logger.info(f"Creating {len(violations_data)} {validator_name} violation records") + + for v in violations_data: + violation = ArchitectureViolation( + scan_id=scan.id, + validator_type=validator_type, + rule_id=v["rule_id"], + rule_name=v["rule_name"], + severity=v["severity"], + file_path=v["file_path"], + line_number=v["line_number"], + message=v["message"], + context=v.get("context", ""), + suggestion=v.get("suggestion", ""), + status="open", + ) + db.add(violation) + + db.flush() + db.refresh(scan) + + logger.info( + f"{validator_name} scan completed: {scan.total_violations} violations found" + ) + return scan + + def run_all_scans( + self, db: Session, triggered_by: str = "manual" + ) -> list[ArchitectureScan]: + """ + Run all validators and return list of scans + + Args: + db: Database session + triggered_by: Who/what triggered the scan + + Returns: + List of ArchitectureScan objects (one per validator) + """ + results = [] + for validator_type in VALID_VALIDATOR_TYPES: + try: + scan = self.run_scan(db, triggered_by, validator_type) + results.append(scan) + except Exception as e: + logger.error(f"Failed to run {validator_type} scan: {e}") + # Continue with other validators even if one fails + return results + + def get_latest_scan( + self, db: Session, validator_type: str = None + ) -> ArchitectureScan | None: + """ + Get the most recent scan + + Args: + db: Database session + validator_type: Optional filter by validator type + + Returns: + Most recent ArchitectureScan or None + """ + query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)) + + if validator_type: + query = query.filter(ArchitectureScan.validator_type == validator_type) + + return query.first() + + def get_latest_scans_by_type(self, db: Session) -> dict[str, ArchitectureScan]: + """ + Get the most recent scan for each validator type + + Returns: + Dictionary mapping validator_type to its latest scan + """ + result = {} + for vtype in VALID_VALIDATOR_TYPES: + scan = self.get_latest_scan(db, validator_type=vtype) + if scan: + result[vtype] = scan + return result + + def get_scan_by_id(self, db: Session, scan_id: int) -> ArchitectureScan | None: + """Get scan by ID""" + return db.query(ArchitectureScan).filter(ArchitectureScan.id == scan_id).first() + + def create_pending_scan( + self, db: Session, validator_type: str, triggered_by: str + ) -> ArchitectureScan: + """ + Create a new scan record with pending status. + + Args: + db: Database session + validator_type: Type of validator (architecture, security, performance) + triggered_by: Who triggered the scan (e.g., "manual:username") + + Returns: + The created ArchitectureScan record with ID populated + """ + scan = ArchitectureScan( + timestamp=datetime.now(UTC), + validator_type=validator_type, + status="pending", + triggered_by=triggered_by, + ) + db.add(scan) + db.flush() # Get scan.id + return scan + + def get_running_scans(self, db: Session) -> list[ArchitectureScan]: + """ + Get all currently running scans (pending or running status). + + Returns: + List of scans with status 'pending' or 'running', newest first + """ + return ( + db.query(ArchitectureScan) + .filter(ArchitectureScan.status.in_(["pending", "running"])) + .order_by(ArchitectureScan.timestamp.desc()) + .all() + ) + + def get_scan_history( + self, db: Session, limit: int = 30, validator_type: str = None + ) -> list[ArchitectureScan]: + """ + Get scan history for trend graphs + + Args: + db: Database session + limit: Maximum number of scans to return + validator_type: Optional filter by validator type + + Returns: + List of ArchitectureScan objects, newest first + """ + query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)) + + if validator_type: + query = query.filter(ArchitectureScan.validator_type == validator_type) + + return query.limit(limit).all() + + def get_violations( + self, + db: Session, + scan_id: int = None, + validator_type: str = None, + severity: str = None, + status: str = None, + rule_id: str = None, + file_path: str = None, + limit: int = 100, + offset: int = 0, + ) -> tuple[list[ArchitectureViolation], int]: + """ + Get violations with filtering and pagination + + Args: + db: Database session + scan_id: Filter by scan ID (if None, use latest scan(s)) + validator_type: Filter by validator type + severity: Filter by severity ('error', 'warning') + status: Filter by status ('open', 'assigned', 'resolved', etc.) + rule_id: Filter by rule ID + file_path: Filter by file path (partial match) + limit: Page size + offset: Page offset + + Returns: + Tuple of (violations list, total count) + """ + # Build query + query = db.query(ArchitectureViolation) + + # If scan_id specified, filter by it + if scan_id is not None: + query = query.filter(ArchitectureViolation.scan_id == scan_id) + else: + # If no scan_id, get violations from latest scan(s) + if validator_type: + # Get latest scan for specific validator type + latest_scan = self.get_latest_scan(db, validator_type) + if not latest_scan: + return [], 0 + query = query.filter(ArchitectureViolation.scan_id == latest_scan.id) + else: + # Get violations from latest scans of all types + latest_scans = self.get_latest_scans_by_type(db) + if not latest_scans: + return [], 0 + scan_ids = [s.id for s in latest_scans.values()] + query = query.filter(ArchitectureViolation.scan_id.in_(scan_ids)) + + # Apply validator_type filter if specified (for scan_id queries) + if validator_type and scan_id is not None: + query = query.filter(ArchitectureViolation.validator_type == validator_type) + + # Apply other filters + if severity: + query = query.filter(ArchitectureViolation.severity == severity) + + if status: + query = query.filter(ArchitectureViolation.status == status) + + if rule_id: + query = query.filter(ArchitectureViolation.rule_id == rule_id) + + if file_path: + query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%")) + + # Get total count + total = query.count() + + # Get page of results + violations = ( + query.order_by( + ArchitectureViolation.severity.desc(), + ArchitectureViolation.validator_type, + ArchitectureViolation.file_path, + ) + .limit(limit) + .offset(offset) + .all() + ) + + return violations, total + + def get_violation_by_id( + self, db: Session, violation_id: int + ) -> ArchitectureViolation | None: + """Get single violation with details""" + return ( + db.query(ArchitectureViolation) + .filter(ArchitectureViolation.id == violation_id) + .first() + ) + + def assign_violation( + self, + db: Session, + violation_id: int, + user_id: int, + assigned_by: int, + due_date: datetime = None, + priority: str = "medium", + ) -> ViolationAssignment: + """ + Assign violation to a developer + + Args: + db: Database session + violation_id: Violation ID + user_id: User to assign to + assigned_by: User who is assigning + due_date: Due date (optional) + priority: Priority level ('low', 'medium', 'high', 'critical') + + Returns: + ViolationAssignment object + """ + # Update violation status + violation = self.get_violation_by_id(db, violation_id) + if violation: + violation.status = "assigned" + violation.assigned_to = user_id + + # Create assignment record + assignment = ViolationAssignment( + violation_id=violation_id, + user_id=user_id, + assigned_by=assigned_by, + due_date=due_date, + priority=priority, + ) + db.add(assignment) + db.flush() + + logger.info(f"Violation {violation_id} assigned to user {user_id}") + return assignment + + def resolve_violation( + self, db: Session, violation_id: int, resolved_by: int, resolution_note: str + ) -> ArchitectureViolation: + """ + Mark violation as resolved + + Args: + db: Database session + violation_id: Violation ID + resolved_by: User who resolved it + resolution_note: Note about resolution + + Returns: + Updated ArchitectureViolation object + """ + violation = self.get_violation_by_id(db, violation_id) + if not violation: + raise ViolationNotFoundException(violation_id) + + violation.status = "resolved" + violation.resolved_at = datetime.now() + violation.resolved_by = resolved_by + violation.resolution_note = resolution_note + + db.flush() + logger.info(f"Violation {violation_id} resolved by user {resolved_by}") + return violation + + def ignore_violation( + self, db: Session, violation_id: int, ignored_by: int, reason: str + ) -> ArchitectureViolation: + """ + Mark violation as ignored/won't fix + + Args: + db: Database session + violation_id: Violation ID + ignored_by: User who ignored it + reason: Reason for ignoring + + Returns: + Updated ArchitectureViolation object + """ + violation = self.get_violation_by_id(db, violation_id) + if not violation: + raise ViolationNotFoundException(violation_id) + + violation.status = "ignored" + violation.resolved_at = datetime.now() + violation.resolved_by = ignored_by + violation.resolution_note = f"Ignored: {reason}" + + db.flush() + logger.info(f"Violation {violation_id} ignored by user {ignored_by}") + return violation + + def add_comment( + self, db: Session, violation_id: int, user_id: int, comment: str + ) -> ViolationComment: + """ + Add comment to violation + + Args: + db: Database session + violation_id: Violation ID + user_id: User posting comment + comment: Comment text + + Returns: + ViolationComment object + """ + comment_obj = ViolationComment( + violation_id=violation_id, user_id=user_id, comment=comment + ) + db.add(comment_obj) + db.flush() + + logger.info(f"Comment added to violation {violation_id} by user {user_id}") + return comment_obj + + def get_dashboard_stats( + self, db: Session, validator_type: str = None + ) -> dict: + """ + Get statistics for dashboard + + Args: + db: Database session + validator_type: Optional filter by validator type. If None, returns combined stats. + + Returns: + Dictionary with various statistics including per-validator breakdown + """ + # Get latest scans by type + latest_scans = self.get_latest_scans_by_type(db) + + if not latest_scans: + return self._empty_dashboard_stats() + + # If specific validator type requested + if validator_type and validator_type in latest_scans: + scan = latest_scans[validator_type] + return self._get_stats_for_scan(db, scan, validator_type) + + # Combined stats across all validators + return self._get_combined_stats(db, latest_scans) + + def _empty_dashboard_stats(self) -> dict: + """Return empty dashboard stats structure""" + return { + "total_violations": 0, + "errors": 0, + "warnings": 0, + "info": 0, + "open": 0, + "assigned": 0, + "resolved": 0, + "ignored": 0, + "technical_debt_score": 100, + "trend": [], + "by_severity": {}, + "by_rule": {}, + "by_module": {}, + "top_files": [], + "last_scan": None, + "by_validator": {}, + } + + def _get_stats_for_scan( + self, db: Session, scan: ArchitectureScan, validator_type: str + ) -> dict: + """Get stats for a single scan/validator type""" + # Get violation counts by status + status_counts = ( + db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id)) + .filter(ArchitectureViolation.scan_id == scan.id) + .group_by(ArchitectureViolation.status) + .all() + ) + status_dict = {status: count for status, count in status_counts} + + # Get violations by severity + severity_counts = ( + db.query( + ArchitectureViolation.severity, func.count(ArchitectureViolation.id) + ) + .filter(ArchitectureViolation.scan_id == scan.id) + .group_by(ArchitectureViolation.severity) + .all() + ) + by_severity = {sev: count for sev, count in severity_counts} + + # Get violations by rule + rule_counts = ( + db.query( + ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id) + ) + .filter(ArchitectureViolation.scan_id == scan.id) + .group_by(ArchitectureViolation.rule_id) + .all() + ) + by_rule = { + rule: count + for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10] + } + + # Get top violating files + file_counts = ( + db.query( + ArchitectureViolation.file_path, + func.count(ArchitectureViolation.id).label("count"), + ) + .filter(ArchitectureViolation.scan_id == scan.id) + .group_by(ArchitectureViolation.file_path) + .order_by(desc("count")) + .limit(10) + .all() + ) + top_files = [{"file": file, "count": count} for file, count in file_counts] + + # Get violations by module + by_module = self._get_violations_by_module(db, scan.id) + + # Get trend for this validator type + trend_scans = self.get_scan_history(db, limit=7, validator_type=validator_type) + trend = [ + { + "timestamp": s.timestamp.isoformat(), + "violations": s.total_violations, + "errors": s.errors, + "warnings": s.warnings, + } + for s in reversed(trend_scans) + ] + + return { + "total_violations": scan.total_violations, + "errors": scan.errors, + "warnings": scan.warnings, + "info": by_severity.get("info", 0), + "open": status_dict.get("open", 0), + "assigned": status_dict.get("assigned", 0), + "resolved": status_dict.get("resolved", 0), + "ignored": status_dict.get("ignored", 0), + "technical_debt_score": self._calculate_score(scan.errors, scan.warnings), + "trend": trend, + "by_severity": by_severity, + "by_rule": by_rule, + "by_module": by_module, + "top_files": top_files, + "last_scan": scan.timestamp.isoformat(), + "validator_type": validator_type, + "by_validator": { + validator_type: { + "total_violations": scan.total_violations, + "errors": scan.errors, + "warnings": scan.warnings, + "last_scan": scan.timestamp.isoformat(), + } + }, + } + + def _get_combined_stats( + self, db: Session, latest_scans: dict[str, ArchitectureScan] + ) -> dict: + """Get combined stats across all validators""" + # Aggregate totals + total_violations = sum(s.total_violations for s in latest_scans.values()) + total_errors = sum(s.errors for s in latest_scans.values()) + total_warnings = sum(s.warnings for s in latest_scans.values()) + + # Get all scan IDs + scan_ids = [s.id for s in latest_scans.values()] + + # Get violation counts by status + status_counts = ( + db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id)) + .filter(ArchitectureViolation.scan_id.in_(scan_ids)) + .group_by(ArchitectureViolation.status) + .all() + ) + status_dict = {status: count for status, count in status_counts} + + # Get violations by severity + severity_counts = ( + db.query( + ArchitectureViolation.severity, func.count(ArchitectureViolation.id) + ) + .filter(ArchitectureViolation.scan_id.in_(scan_ids)) + .group_by(ArchitectureViolation.severity) + .all() + ) + by_severity = {sev: count for sev, count in severity_counts} + + # Get violations by rule (across all validators) + rule_counts = ( + db.query( + ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id) + ) + .filter(ArchitectureViolation.scan_id.in_(scan_ids)) + .group_by(ArchitectureViolation.rule_id) + .all() + ) + by_rule = { + rule: count + for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10] + } + + # Get top violating files + file_counts = ( + db.query( + ArchitectureViolation.file_path, + func.count(ArchitectureViolation.id).label("count"), + ) + .filter(ArchitectureViolation.scan_id.in_(scan_ids)) + .group_by(ArchitectureViolation.file_path) + .order_by(desc("count")) + .limit(10) + .all() + ) + top_files = [{"file": file, "count": count} for file, count in file_counts] + + # Get violations by module + by_module = {} + for scan_id in scan_ids: + module_counts = self._get_violations_by_module(db, scan_id) + for module, count in module_counts.items(): + by_module[module] = by_module.get(module, 0) + count + by_module = dict( + sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10] + ) + + # Per-validator breakdown + by_validator = {} + for vtype, scan in latest_scans.items(): + by_validator[vtype] = { + "total_violations": scan.total_violations, + "errors": scan.errors, + "warnings": scan.warnings, + "last_scan": scan.timestamp.isoformat(), + } + + # Get most recent scan timestamp + most_recent = max(latest_scans.values(), key=lambda s: s.timestamp) + + return { + "total_violations": total_violations, + "errors": total_errors, + "warnings": total_warnings, + "info": by_severity.get("info", 0), + "open": status_dict.get("open", 0), + "assigned": status_dict.get("assigned", 0), + "resolved": status_dict.get("resolved", 0), + "ignored": status_dict.get("ignored", 0), + "technical_debt_score": self._calculate_score(total_errors, total_warnings), + "trend": [], # Combined trend would need special handling + "by_severity": by_severity, + "by_rule": by_rule, + "by_module": by_module, + "top_files": top_files, + "last_scan": most_recent.timestamp.isoformat(), + "by_validator": by_validator, + } + + def _get_violations_by_module(self, db: Session, scan_id: int) -> dict[str, int]: + """Extract module from file paths and count violations""" + by_module = {} + violations = ( + db.query(ArchitectureViolation.file_path) + .filter(ArchitectureViolation.scan_id == scan_id) + .all() + ) + + for v in violations: + path_parts = v.file_path.split("/") + if len(path_parts) >= 2: + module = "/".join(path_parts[:2]) + else: + module = path_parts[0] + by_module[module] = by_module.get(module, 0) + 1 + + return dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]) + + def _calculate_score(self, errors: int, warnings: int) -> int: + """Calculate technical debt score (0-100)""" + score = 100 - (errors * 0.5 + warnings * 0.05) + return max(0, min(100, int(score))) + + def calculate_technical_debt_score( + self, db: Session, scan_id: int = None, validator_type: str = None + ) -> int: + """ + Calculate technical debt score (0-100) + + Formula: 100 - (errors * 0.5 + warnings * 0.05) + Capped at 0 minimum + + Args: + db: Database session + scan_id: Scan ID (if None, use latest) + validator_type: Filter by validator type + + Returns: + Score from 0-100 + """ + if scan_id is None: + latest_scan = self.get_latest_scan(db, validator_type) + if not latest_scan: + return 100 + scan_id = latest_scan.id + + scan = self.get_scan_by_id(db, scan_id) + if not scan: + return 100 + + return self._calculate_score(scan.errors, scan.warnings) + + def _get_git_commit_hash(self) -> str | None: + """Get current git commit hash""" + try: + result = subprocess.run( + ["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0: + return result.stdout.strip()[:40] + except Exception: + pass + return None + + +# Singleton instance +code_quality_service = CodeQualityService() diff --git a/app/modules/dev_tools/services/test_runner_service.py b/app/modules/dev_tools/services/test_runner_service.py new file mode 100644 index 00000000..ed3116c5 --- /dev/null +++ b/app/modules/dev_tools/services/test_runner_service.py @@ -0,0 +1,507 @@ +""" +Test Runner Service +Service for running pytest and storing results +""" + +import json +import logging +import re +import subprocess +import tempfile +from datetime import UTC, datetime +from pathlib import Path + +from sqlalchemy import desc, func +from sqlalchemy.orm import Session + +from app.modules.dev_tools.models import TestCollection, TestResult, TestRun + +logger = logging.getLogger(__name__) + + +class TestRunnerService: + """Service for managing pytest test runs""" + + def __init__(self): + self.project_root = Path(__file__).parent.parent.parent.parent.parent + + def create_test_run( + self, + db: Session, + test_path: str = "tests", + triggered_by: str = "manual", + ) -> TestRun: + """Create a test run record without executing tests""" + test_run = TestRun( + timestamp=datetime.now(UTC), + triggered_by=triggered_by, + test_path=test_path, + status="running", + git_commit_hash=self._get_git_commit(), + git_branch=self._get_git_branch(), + ) + db.add(test_run) + db.flush() + return test_run + + def run_tests( + self, + db: Session, + test_path: str = "tests", + triggered_by: str = "manual", + extra_args: list[str] | None = None, + ) -> TestRun: + """ + Run pytest synchronously and store results in database + + Args: + db: Database session + test_path: Path to tests (relative to project root) + triggered_by: Who triggered the run + extra_args: Additional pytest arguments + + Returns: + TestRun object with results + """ + test_run = self.create_test_run(db, test_path, triggered_by) + self._execute_tests(db, test_run, test_path, extra_args) + return test_run + + def _execute_tests( + self, + db: Session, + test_run: TestRun, + test_path: str, + extra_args: list[str] | None, + ) -> None: + """Execute pytest and update the test run record""" + try: + # Build pytest command with JSON output + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as f: + json_report_path = f.name + + pytest_args = [ + "python", + "-m", + "pytest", + test_path, + "--json-report", + f"--json-report-file={json_report_path}", + "-v", + "--tb=short", + ] + + if extra_args: + pytest_args.extend(extra_args) + + test_run.pytest_args = " ".join(pytest_args) + + # Run pytest + start_time = datetime.now(UTC) + result = subprocess.run( + pytest_args, + cwd=str(self.project_root), + capture_output=True, + text=True, + timeout=600, # 10 minute timeout + ) + end_time = datetime.now(UTC) + + test_run.duration_seconds = (end_time - start_time).total_seconds() + + # Parse JSON report + try: + with open(json_report_path) as f: + report = json.load(f) + + self._process_json_report(db, test_run, report) + except (FileNotFoundError, json.JSONDecodeError) as e: + # Fallback to parsing stdout if JSON report failed + logger.warning(f"JSON report unavailable ({e}), parsing stdout") + self._parse_pytest_output(test_run, result.stdout, result.stderr) + finally: + # Clean up temp file + try: + Path(json_report_path).unlink() + except Exception: + pass + + # Set final status + if test_run.failed > 0 or test_run.errors > 0: + test_run.status = "failed" + else: + test_run.status = "passed" + + except subprocess.TimeoutExpired: + test_run.status = "error" + logger.error("Pytest run timed out") + except Exception as e: + test_run.status = "error" + logger.error(f"Error running tests: {e}") + + def _process_json_report(self, db: Session, test_run: TestRun, report: dict): + """Process pytest-json-report output""" + summary = report.get("summary", {}) + + test_run.total_tests = summary.get("total", 0) + test_run.passed = summary.get("passed", 0) + test_run.failed = summary.get("failed", 0) + test_run.errors = summary.get("error", 0) + test_run.skipped = summary.get("skipped", 0) + test_run.xfailed = summary.get("xfailed", 0) + test_run.xpassed = summary.get("xpassed", 0) + + # Process individual test results + tests = report.get("tests", []) + for test in tests: + node_id = test.get("nodeid", "") + outcome = test.get("outcome", "unknown") + + # Parse node_id to get file, class, function + test_file, test_class, test_name = self._parse_node_id(node_id) + + # Get failure details + error_message = None + traceback = None + if outcome in ("failed", "error"): + call_info = test.get("call", {}) + if "longrepr" in call_info: + traceback = call_info["longrepr"] + # Extract error message from traceback + if isinstance(traceback, str): + lines = traceback.strip().split("\n") + if lines: + error_message = lines[-1][:500] # Last line, limited length + + test_result = TestResult( + run_id=test_run.id, + node_id=node_id, + test_name=test_name, + test_file=test_file, + test_class=test_class, + outcome=outcome, + duration_seconds=test.get("duration", 0.0), + error_message=error_message, + traceback=traceback, + markers=test.get("keywords", []), + ) + db.add(test_result) + + def _parse_node_id(self, node_id: str) -> tuple[str, str | None, str]: + """Parse pytest node_id into file, class, function""" + # Format: tests/unit/test_foo.py::TestClass::test_method + # or: tests/unit/test_foo.py::test_function + parts = node_id.split("::") + + test_file = parts[0] if parts else "" + test_class = None + test_name = parts[-1] if parts else "" + + if len(parts) == 3: + test_class = parts[1] + elif len(parts) == 2: + # Could be Class::method or file::function + if parts[1].startswith("Test"): + test_class = parts[1] + test_name = parts[1] + + # Handle parametrized tests + if "[" in test_name: + test_name = test_name.split("[")[0] + + return test_file, test_class, test_name + + def _parse_pytest_output(self, test_run: TestRun, stdout: str, stderr: str): + """Fallback parser for pytest text output""" + # Parse summary line like: "10 passed, 2 failed, 1 skipped" + summary_pattern = r"(\d+)\s+(passed|failed|error|skipped|xfailed|xpassed)" + + for match in re.finditer(summary_pattern, stdout): + count = int(match.group(1)) + status = match.group(2) + + if status == "passed": + test_run.passed = count + elif status == "failed": + test_run.failed = count + elif status == "error": + test_run.errors = count + elif status == "skipped": + test_run.skipped = count + elif status == "xfailed": + test_run.xfailed = count + elif status == "xpassed": + test_run.xpassed = count + + test_run.total_tests = ( + test_run.passed + + test_run.failed + + test_run.errors + + test_run.skipped + + test_run.xfailed + + test_run.xpassed + ) + + def _get_git_commit(self) -> str | None: + """Get current git commit hash""" + try: + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(self.project_root), + capture_output=True, + text=True, + timeout=5, + ) + return result.stdout.strip()[:40] if result.returncode == 0 else None + except: + return None + + def _get_git_branch(self) -> str | None: + """Get current git branch""" + try: + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + cwd=str(self.project_root), + capture_output=True, + text=True, + timeout=5, + ) + return result.stdout.strip() if result.returncode == 0 else None + except: + return None + + def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]: + """Get recent test run history""" + return db.query(TestRun).order_by(desc(TestRun.timestamp)).limit(limit).all() + + def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None: + """Get a specific test run with results""" + return db.query(TestRun).filter(TestRun.id == run_id).first() + + def get_failed_tests(self, db: Session, run_id: int) -> list[TestResult]: + """Get failed tests from a run""" + return ( + db.query(TestResult) + .filter( + TestResult.run_id == run_id, TestResult.outcome.in_(["failed", "error"]) + ) + .all() + ) + + def get_run_results( + self, db: Session, run_id: int, outcome: str | None = None + ) -> list[TestResult]: + """Get test results for a specific run, optionally filtered by outcome""" + query = db.query(TestResult).filter(TestResult.run_id == run_id) + + if outcome: + query = query.filter(TestResult.outcome == outcome) + + return query.all() + + def get_dashboard_stats(self, db: Session) -> dict: + """Get statistics for the testing dashboard""" + # Get latest run + latest_run = ( + db.query(TestRun) + .filter(TestRun.status != "running") + .order_by(desc(TestRun.timestamp)) + .first() + ) + + # Get test collection info (or calculate from latest run) + collection = ( + db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first() + ) + + # Get trend data (last 10 runs) + trend_runs = ( + db.query(TestRun) + .filter(TestRun.status != "running") + .order_by(desc(TestRun.timestamp)) + .limit(10) + .all() + ) + + # Calculate stats by category from latest run + by_category = {} + if latest_run: + results = ( + db.query(TestResult).filter(TestResult.run_id == latest_run.id).all() + ) + for result in results: + # Categorize by test path + if "unit" in result.test_file: + category = "Unit Tests" + elif "integration" in result.test_file: + category = "Integration Tests" + elif "performance" in result.test_file: + category = "Performance Tests" + elif "system" in result.test_file: + category = "System Tests" + else: + category = "Other" + + if category not in by_category: + by_category[category] = {"total": 0, "passed": 0, "failed": 0} + by_category[category]["total"] += 1 + if result.outcome == "passed": + by_category[category]["passed"] += 1 + elif result.outcome in ("failed", "error"): + by_category[category]["failed"] += 1 + + # Get top failing tests (across recent runs) + top_failing = ( + db.query( + TestResult.test_name, + TestResult.test_file, + func.count(TestResult.id).label("failure_count"), + ) + .filter(TestResult.outcome.in_(["failed", "error"])) + .group_by(TestResult.test_name, TestResult.test_file) + .order_by(desc("failure_count")) + .limit(10) + .all() + ) + + return { + # Current run stats + "total_tests": latest_run.total_tests if latest_run else 0, + "passed": latest_run.passed if latest_run else 0, + "failed": latest_run.failed if latest_run else 0, + "errors": latest_run.errors if latest_run else 0, + "skipped": latest_run.skipped if latest_run else 0, + "pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0, + "duration_seconds": round(latest_run.duration_seconds, 2) + if latest_run + else 0, + "coverage_percent": latest_run.coverage_percent if latest_run else None, + "last_run": latest_run.timestamp.isoformat() if latest_run else None, + "last_run_status": latest_run.status if latest_run else None, + # Collection stats + "total_test_files": collection.total_files if collection else 0, + "collected_tests": collection.total_tests if collection else 0, + "unit_tests": collection.unit_tests if collection else 0, + "integration_tests": collection.integration_tests if collection else 0, + "performance_tests": collection.performance_tests if collection else 0, + "system_tests": collection.system_tests if collection else 0, + "last_collected": collection.collected_at.isoformat() + if collection + else None, + # Trend data + "trend": [ + { + "timestamp": run.timestamp.isoformat(), + "total": run.total_tests, + "passed": run.passed, + "failed": run.failed, + "pass_rate": round(run.pass_rate, 1), + "duration": round(run.duration_seconds, 1), + } + for run in reversed(trend_runs) + ], + # By category + "by_category": by_category, + # Top failing tests + "top_failing": [ + { + "test_name": t.test_name, + "test_file": t.test_file, + "failure_count": t.failure_count, + } + for t in top_failing + ], + } + + def collect_tests(self, db: Session) -> TestCollection: + """Collect test information without running tests""" + collection = TestCollection( + collected_at=datetime.now(UTC), + ) + + try: + # Run pytest --collect-only with JSON report + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as f: + json_report_path = f.name + + result = subprocess.run( + [ + "python", + "-m", + "pytest", + "--collect-only", + "--json-report", + f"--json-report-file={json_report_path}", + "tests", + ], + cwd=str(self.project_root), + capture_output=True, + text=True, + timeout=120, + ) + + # Parse JSON report + json_path = Path(json_report_path) + if json_path.exists(): + with open(json_path) as f: + report = json.load(f) + + # Get total from summary + collection.total_tests = report.get("summary", {}).get("collected", 0) + + # Parse collectors to get test files and counts + test_files = {} + for collector in report.get("collectors", []): + for item in collector.get("result", []): + if item.get("type") == "Function": + node_id = item.get("nodeid", "") + if "::" in node_id: + file_path = node_id.split("::")[0] + if file_path not in test_files: + test_files[file_path] = 0 + test_files[file_path] += 1 + + # Count files and categorize + for file_path, count in test_files.items(): + collection.total_files += 1 + + if "/unit/" in file_path or file_path.startswith("tests/unit"): + collection.unit_tests += count + elif "/integration/" in file_path or file_path.startswith( + "tests/integration" + ): + collection.integration_tests += count + elif "/performance/" in file_path or file_path.startswith( + "tests/performance" + ): + collection.performance_tests += count + elif "/system/" in file_path or file_path.startswith( + "tests/system" + ): + collection.system_tests += count + + collection.test_files = [ + {"file": f, "count": c} + for f, c in sorted(test_files.items(), key=lambda x: -x[1]) + ] + + # Cleanup + json_path.unlink(missing_ok=True) + + logger.info( + f"Collected {collection.total_tests} tests from {collection.total_files} files" + ) + + except Exception as e: + logger.error(f"Error collecting tests: {e}", exc_info=True) + + db.add(collection) + return collection + + +# Singleton instance +test_runner_service = TestRunnerService() diff --git a/app/modules/inventory/models/__init__.py b/app/modules/inventory/models/__init__.py index df37b955..59777bdb 100644 --- a/app/modules/inventory/models/__init__.py +++ b/app/modules/inventory/models/__init__.py @@ -2,11 +2,11 @@ """ Inventory module database models. -Re-exports inventory-related models from their source locations. +This module contains the canonical implementations of inventory-related models. """ -from models.database.inventory import Inventory -from models.database.inventory_transaction import ( +from app.modules.inventory.models.inventory import Inventory +from app.modules.inventory.models.inventory_transaction import ( InventoryTransaction, TransactionType, ) diff --git a/app/modules/inventory/models/inventory.py b/app/modules/inventory/models/inventory.py new file mode 100644 index 00000000..1a413dd9 --- /dev/null +++ b/app/modules/inventory/models/inventory.py @@ -0,0 +1,61 @@ +# app/modules/inventory/models/inventory.py +""" +Inventory model for tracking stock at warehouse/bin locations. + +Each entry represents a quantity of a product at a specific bin location +within a warehouse. Products can be scattered across multiple bins. + +Example: + Warehouse: "strassen" + Bin: "SA-10-02" + Product: GTIN 4007817144145 + Quantity: 3 +""" + +from sqlalchemy import Column, ForeignKey, Index, Integer, String, UniqueConstraint +from sqlalchemy.orm import relationship + +from app.core.database import Base +from models.database.base import TimestampMixin + + +class Inventory(Base, TimestampMixin): + __tablename__ = "inventory" + + id = Column(Integer, primary_key=True, index=True) + product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) + + # Location: warehouse + bin + warehouse = Column(String, nullable=False, default="strassen", index=True) + bin_location = Column(String, nullable=False, index=True) # e.g., "SA-10-02" + + # Legacy field - kept for backward compatibility, will be removed + location = Column(String, index=True) + + quantity = Column(Integer, nullable=False, default=0) + reserved_quantity = Column(Integer, default=0) + + # Keep GTIN for reference/reporting (matches Product.gtin) + gtin = Column(String, index=True) + + # Relationships + product = relationship("Product", back_populates="inventory_entries") + vendor = relationship("Vendor") + + # Constraints + __table_args__ = ( + UniqueConstraint( + "product_id", "warehouse", "bin_location", name="uq_inventory_product_warehouse_bin" + ), + Index("idx_inventory_vendor_product", "vendor_id", "product_id"), + Index("idx_inventory_warehouse_bin", "warehouse", "bin_location"), + ) + + def __repr__(self): + return f"" + + @property + def available_quantity(self): + """Calculate available quantity (total - reserved).""" + return max(0, self.quantity - self.reserved_quantity) diff --git a/app/modules/inventory/models/inventory_transaction.py b/app/modules/inventory/models/inventory_transaction.py new file mode 100644 index 00000000..430f13a9 --- /dev/null +++ b/app/modules/inventory/models/inventory_transaction.py @@ -0,0 +1,170 @@ +# app/modules/inventory/models/inventory_transaction.py +""" +Inventory Transaction Model - Audit trail for all stock movements. + +This model tracks every change to inventory quantities, providing: +- Complete audit trail for compliance and debugging +- Order-linked transactions for traceability +- Support for different transaction types (reserve, fulfill, adjust, etc.) + +All stock movements should create a transaction record. +""" + +from datetime import UTC, datetime +from enum import Enum + +from sqlalchemy import ( + Column, + DateTime, + Enum as SQLEnum, + ForeignKey, + Index, + Integer, + String, + Text, +) +from sqlalchemy.orm import relationship + +from app.core.database import Base + + +class TransactionType(str, Enum): + """Types of inventory transactions.""" + + # Order-related + RESERVE = "reserve" # Stock reserved for order + FULFILL = "fulfill" # Reserved stock consumed (shipped) + RELEASE = "release" # Reserved stock released (cancelled) + + # Manual adjustments + ADJUST = "adjust" # Manual adjustment (+/-) + SET = "set" # Set to exact quantity + + # Imports + IMPORT = "import" # Initial import/sync + + # Returns + RETURN = "return" # Stock returned from customer + + +class InventoryTransaction(Base): + """ + Audit log for inventory movements. + + Every change to inventory quantity creates a transaction record, + enabling complete traceability of stock levels over time. + """ + + __tablename__ = "inventory_transactions" + + id = Column(Integer, primary_key=True, index=True) + + # Core references + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) + product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True) + inventory_id = Column( + Integer, ForeignKey("inventory.id"), nullable=True, index=True + ) + + # Transaction details + transaction_type = Column( + SQLEnum(TransactionType), nullable=False, index=True + ) + quantity_change = Column(Integer, nullable=False) # Positive = add, negative = remove + + # Quantities after transaction (snapshot) + quantity_after = Column(Integer, nullable=False) + reserved_after = Column(Integer, nullable=False, default=0) + + # Location context + location = Column(String, nullable=True) + warehouse = Column(String, nullable=True) + + # Order reference (for order-related transactions) + order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True) + order_number = Column(String, nullable=True) + + # Audit fields + reason = Column(Text, nullable=True) # Human-readable reason + created_by = Column(String, nullable=True) # User/system that created + created_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + nullable=False, + index=True, + ) + + # Relationships + vendor = relationship("Vendor") + product = relationship("Product") + inventory = relationship("Inventory") + order = relationship("Order") + + # Indexes for common queries + __table_args__ = ( + Index("idx_inv_tx_vendor_product", "vendor_id", "product_id"), + Index("idx_inv_tx_vendor_created", "vendor_id", "created_at"), + Index("idx_inv_tx_order", "order_id"), + Index("idx_inv_tx_type_created", "transaction_type", "created_at"), + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + @classmethod + def create_transaction( + cls, + vendor_id: int, + product_id: int, + transaction_type: TransactionType, + quantity_change: int, + quantity_after: int, + reserved_after: int = 0, + inventory_id: int | None = None, + location: str | None = None, + warehouse: str | None = None, + order_id: int | None = None, + order_number: str | None = None, + reason: str | None = None, + created_by: str | None = None, + ) -> "InventoryTransaction": + """ + Factory method to create a transaction record. + + Args: + vendor_id: Vendor ID + product_id: Product ID + transaction_type: Type of transaction + quantity_change: Change in quantity (positive = add, negative = remove) + quantity_after: Total quantity after this transaction + reserved_after: Reserved quantity after this transaction + inventory_id: Optional inventory record ID + location: Optional location + warehouse: Optional warehouse + order_id: Optional order ID (for order-related transactions) + order_number: Optional order number for display + reason: Optional human-readable reason + created_by: Optional user/system identifier + + Returns: + InventoryTransaction instance (not yet added to session) + """ + return cls( + vendor_id=vendor_id, + product_id=product_id, + inventory_id=inventory_id, + transaction_type=transaction_type, + quantity_change=quantity_change, + quantity_after=quantity_after, + reserved_after=reserved_after, + location=location, + warehouse=warehouse, + order_id=order_id, + order_number=order_number, + reason=reason, + created_by=created_by, + ) diff --git a/app/modules/inventory/schemas/__init__.py b/app/modules/inventory/schemas/__init__.py index 55e1b9aa..cc7c0855 100644 --- a/app/modules/inventory/schemas/__init__.py +++ b/app/modules/inventory/schemas/__init__.py @@ -2,27 +2,77 @@ """ Inventory module Pydantic schemas. -Re-exports inventory-related schemas from their source locations. +This module contains the canonical implementations of inventory-related schemas. """ -from models.schema.inventory import ( +from app.modules.inventory.schemas.inventory import ( + # Base schemas + InventoryBase, InventoryCreate, InventoryAdjust, + InventoryUpdate, InventoryReserve, + # Response schemas InventoryResponse, + InventoryLocationResponse, + ProductInventorySummary, InventoryListResponse, - InventoryTransactionResponse, + InventoryMessageResponse, + InventorySummaryResponse, + # Admin schemas + AdminInventoryCreate, + AdminInventoryAdjust, AdminInventoryItem, + AdminInventoryListResponse, + AdminInventoryStats, AdminLowStockItem, + AdminVendorWithInventory, + AdminVendorsWithInventoryResponse, + AdminInventoryLocationsResponse, + # Transaction schemas + InventoryTransactionResponse, + InventoryTransactionWithProduct, + InventoryTransactionListResponse, + ProductTransactionHistoryResponse, + OrderTransactionHistoryResponse, + # Admin transaction schemas + AdminInventoryTransactionItem, + AdminInventoryTransactionListResponse, + AdminTransactionStatsResponse, ) __all__ = [ + # Base schemas + "InventoryBase", "InventoryCreate", "InventoryAdjust", + "InventoryUpdate", "InventoryReserve", + # Response schemas "InventoryResponse", + "InventoryLocationResponse", + "ProductInventorySummary", "InventoryListResponse", - "InventoryTransactionResponse", + "InventoryMessageResponse", + "InventorySummaryResponse", + # Admin schemas + "AdminInventoryCreate", + "AdminInventoryAdjust", "AdminInventoryItem", + "AdminInventoryListResponse", + "AdminInventoryStats", "AdminLowStockItem", + "AdminVendorWithInventory", + "AdminVendorsWithInventoryResponse", + "AdminInventoryLocationsResponse", + # Transaction schemas + "InventoryTransactionResponse", + "InventoryTransactionWithProduct", + "InventoryTransactionListResponse", + "ProductTransactionHistoryResponse", + "OrderTransactionHistoryResponse", + # Admin transaction schemas + "AdminInventoryTransactionItem", + "AdminInventoryTransactionListResponse", + "AdminTransactionStatsResponse", ] diff --git a/app/modules/inventory/schemas/inventory.py b/app/modules/inventory/schemas/inventory.py new file mode 100644 index 00000000..dd46a848 --- /dev/null +++ b/app/modules/inventory/schemas/inventory.py @@ -0,0 +1,294 @@ +# app/modules/inventory/schemas/inventory.py +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + + +class InventoryBase(BaseModel): + product_id: int = Field(..., description="Product ID in vendor catalog") + location: str = Field(..., description="Storage location") + + +class InventoryCreate(InventoryBase): + """Set exact inventory quantity (replaces existing).""" + + quantity: int = Field(..., description="Exact inventory quantity", ge=0) + + +class InventoryAdjust(InventoryBase): + """Add or remove inventory quantity.""" + + quantity: int = Field( + ..., description="Quantity to add (positive) or remove (negative)" + ) + + +class InventoryUpdate(BaseModel): + """Update inventory fields.""" + + quantity: int | None = Field(None, ge=0) + reserved_quantity: int | None = Field(None, ge=0) + location: str | None = None + + +class InventoryReserve(BaseModel): + """Reserve inventory for orders.""" + + product_id: int + location: str + quantity: int = Field(..., gt=0) + + +class InventoryResponse(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + product_id: int + vendor_id: int + location: str + quantity: int + reserved_quantity: int + gtin: str | None + created_at: datetime + updated_at: datetime + + @property + def available_quantity(self): + return max(0, self.quantity - self.reserved_quantity) + + +class InventoryLocationResponse(BaseModel): + location: str + quantity: int + reserved_quantity: int + available_quantity: int + + +class ProductInventorySummary(BaseModel): + """Inventory summary for a product.""" + + product_id: int + vendor_id: int + product_sku: str | None + product_title: str + total_quantity: int + total_reserved: int + total_available: int + locations: list[InventoryLocationResponse] + + +class InventoryListResponse(BaseModel): + inventories: list[InventoryResponse] + total: int + skip: int + limit: int + + +class InventoryMessageResponse(BaseModel): + """Simple message response for inventory operations.""" + + message: str + + +class InventorySummaryResponse(BaseModel): + """Inventory summary response for marketplace product service.""" + + gtin: str + total_quantity: int + locations: list[InventoryLocationResponse] + + +# ============================================================================ +# Admin Inventory Schemas +# ============================================================================ + + +class AdminInventoryCreate(BaseModel): + """Admin version of inventory create - requires explicit vendor_id.""" + + vendor_id: int = Field(..., description="Target vendor ID") + product_id: int = Field(..., description="Product ID in vendor catalog") + location: str = Field(..., description="Storage location") + quantity: int = Field(..., description="Exact inventory quantity", ge=0) + + +class AdminInventoryAdjust(BaseModel): + """Admin version of inventory adjust - requires explicit vendor_id.""" + + vendor_id: int = Field(..., description="Target vendor ID") + product_id: int = Field(..., description="Product ID in vendor catalog") + location: str = Field(..., description="Storage location") + quantity: int = Field( + ..., description="Quantity to add (positive) or remove (negative)" + ) + reason: str | None = Field(None, description="Reason for adjustment") + + +class AdminInventoryItem(BaseModel): + """Inventory item with vendor info for admin list view.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + product_id: int + vendor_id: int + vendor_name: str | None = None + vendor_code: str | None = None + product_title: str | None = None + product_sku: str | None = None + location: str + quantity: int + reserved_quantity: int + available_quantity: int + gtin: str | None = None + created_at: datetime + updated_at: datetime + + +class AdminInventoryListResponse(BaseModel): + """Cross-vendor inventory list for admin.""" + + inventories: list[AdminInventoryItem] + total: int + skip: int + limit: int + vendor_filter: int | None = None + location_filter: str | None = None + + +class AdminInventoryStats(BaseModel): + """Inventory statistics for admin dashboard.""" + + total_entries: int + total_quantity: int + total_reserved: int + total_available: int + low_stock_count: int + vendors_with_inventory: int + unique_locations: int + + +class AdminLowStockItem(BaseModel): + """Low stock item for admin alerts.""" + + id: int + product_id: int + vendor_id: int + vendor_name: str | None = None + product_title: str | None = None + location: str + quantity: int + reserved_quantity: int + available_quantity: int + + +class AdminVendorWithInventory(BaseModel): + """Vendor with inventory entries.""" + + id: int + name: str + vendor_code: str + + +class AdminVendorsWithInventoryResponse(BaseModel): + """Response for vendors with inventory list.""" + + vendors: list[AdminVendorWithInventory] + + +class AdminInventoryLocationsResponse(BaseModel): + """Response for unique inventory locations.""" + + locations: list[str] + + +# ============================================================================ +# Inventory Transaction Schemas +# ============================================================================ + + +class InventoryTransactionResponse(BaseModel): + """Single inventory transaction record.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + vendor_id: int + product_id: int + inventory_id: int | None = None + transaction_type: str + quantity_change: int + quantity_after: int + reserved_after: int + location: str | None = None + warehouse: str | None = None + order_id: int | None = None + order_number: str | None = None + reason: str | None = None + created_by: str | None = None + created_at: datetime + + +class InventoryTransactionWithProduct(InventoryTransactionResponse): + """Transaction with product details for list views.""" + + product_title: str | None = None + product_sku: str | None = None + + +class InventoryTransactionListResponse(BaseModel): + """Paginated list of inventory transactions.""" + + transactions: list[InventoryTransactionWithProduct] + total: int + skip: int + limit: int + + +class ProductTransactionHistoryResponse(BaseModel): + """Transaction history for a specific product.""" + + product_id: int + product_title: str | None = None + product_sku: str | None = None + current_quantity: int + current_reserved: int + transactions: list[InventoryTransactionResponse] + total: int + + +class OrderTransactionHistoryResponse(BaseModel): + """Transaction history for a specific order.""" + + order_id: int + order_number: str + transactions: list[InventoryTransactionWithProduct] + + +# ============================================================================ +# Admin Inventory Transaction Schemas +# ============================================================================ + + +class AdminInventoryTransactionItem(InventoryTransactionWithProduct): + """Transaction with vendor details for admin views.""" + + vendor_name: str | None = None + vendor_code: str | None = None + + +class AdminInventoryTransactionListResponse(BaseModel): + """Paginated list of transactions for admin.""" + + transactions: list[AdminInventoryTransactionItem] + total: int + skip: int + limit: int + + +class AdminTransactionStatsResponse(BaseModel): + """Transaction statistics for admin dashboard.""" + + total_transactions: int + transactions_today: int + by_type: dict[str, int] diff --git a/app/modules/inventory/services/__init__.py b/app/modules/inventory/services/__init__.py index 0cd8c671..0217c1e2 100644 --- a/app/modules/inventory/services/__init__.py +++ b/app/modules/inventory/services/__init__.py @@ -2,20 +2,21 @@ """ Inventory module services. -Re-exports inventory-related services from their source locations. +This module contains the canonical implementations of inventory-related services. """ -from app.services.inventory_service import ( +from app.modules.inventory.services.inventory_service import ( inventory_service, InventoryService, ) -from app.services.inventory_transaction_service import ( +from app.modules.inventory.services.inventory_transaction_service import ( inventory_transaction_service, InventoryTransactionService, ) -from app.services.inventory_import_service import ( +from app.modules.inventory.services.inventory_import_service import ( inventory_import_service, InventoryImportService, + ImportResult, ) __all__ = [ @@ -25,4 +26,5 @@ __all__ = [ "InventoryTransactionService", "inventory_import_service", "InventoryImportService", + "ImportResult", ] diff --git a/app/modules/inventory/services/inventory_import_service.py b/app/modules/inventory/services/inventory_import_service.py new file mode 100644 index 00000000..30341324 --- /dev/null +++ b/app/modules/inventory/services/inventory_import_service.py @@ -0,0 +1,250 @@ +# app/modules/inventory/services/inventory_import_service.py +""" +Inventory import service for bulk importing stock from TSV/CSV files. + +Supports two formats: +1. One row per unit (quantity = count of rows): + BIN EAN PRODUCT + SA-10-02 0810050910101 Product Name + SA-10-02 0810050910101 Product Name (2nd unit) + +2. With explicit quantity column: + BIN EAN PRODUCT QUANTITY + SA-10-02 0810050910101 Product Name 12 + +Products are matched by GTIN/EAN to existing vendor products. +""" + +import csv +import io +import logging +from collections import defaultdict +from dataclasses import dataclass, field + +from sqlalchemy.orm import Session + +from app.modules.inventory.models.inventory import Inventory +from models.database.product import Product + +logger = logging.getLogger(__name__) + + +@dataclass +class ImportResult: + """Result of an inventory import operation.""" + + success: bool = True + total_rows: int = 0 + entries_created: int = 0 + entries_updated: int = 0 + quantity_imported: int = 0 + unmatched_gtins: list = field(default_factory=list) + errors: list = field(default_factory=list) + + +class InventoryImportService: + """Service for importing inventory from TSV/CSV files.""" + + def import_from_text( + self, + db: Session, + content: str, + vendor_id: int, + warehouse: str = "strassen", + delimiter: str = "\t", + clear_existing: bool = False, + ) -> ImportResult: + """ + Import inventory from TSV/CSV text content. + + Args: + db: Database session + content: TSV/CSV content as string + vendor_id: Vendor ID for inventory + warehouse: Warehouse name (default: "strassen") + delimiter: Column delimiter (default: tab) + clear_existing: If True, clear existing inventory before import + + Returns: + ImportResult with summary and errors + """ + result = ImportResult() + + try: + # Parse CSV/TSV + reader = csv.DictReader(io.StringIO(content), delimiter=delimiter) + + # Normalize headers (case-insensitive, strip whitespace) + if reader.fieldnames: + reader.fieldnames = [h.strip().upper() for h in reader.fieldnames] + + # Validate required columns + required = {"BIN", "EAN"} + if not reader.fieldnames or not required.issubset(set(reader.fieldnames)): + result.success = False + result.errors.append( + f"Missing required columns. Found: {reader.fieldnames}, Required: {required}" + ) + return result + + has_quantity = "QUANTITY" in reader.fieldnames + + # Group entries by (EAN, BIN) + # Key: (ean, bin) -> quantity + inventory_data: dict[tuple[str, str], int] = defaultdict(int) + product_names: dict[str, str] = {} # EAN -> product name (for logging) + + for row in reader: + result.total_rows += 1 + + ean = row.get("EAN", "").strip() + bin_loc = row.get("BIN", "").strip() + product_name = row.get("PRODUCT", "").strip() + + if not ean or not bin_loc: + result.errors.append(f"Row {result.total_rows}: Missing EAN or BIN") + continue + + # Get quantity + if has_quantity: + try: + qty = int(row.get("QUANTITY", "1").strip()) + except ValueError: + result.errors.append( + f"Row {result.total_rows}: Invalid quantity '{row.get('QUANTITY')}'" + ) + continue + else: + qty = 1 # Each row = 1 unit + + inventory_data[(ean, bin_loc)] += qty + if product_name: + product_names[ean] = product_name + + # Clear existing inventory if requested + if clear_existing: + db.query(Inventory).filter( + Inventory.vendor_id == vendor_id, + Inventory.warehouse == warehouse, + ).delete() + db.flush() + + # Build EAN to Product mapping for this vendor + products = ( + db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.gtin.isnot(None), + ) + .all() + ) + ean_to_product: dict[str, Product] = {p.gtin: p for p in products if p.gtin} + + # Track unmatched GTINs + unmatched: dict[str, int] = {} # EAN -> total quantity + + # Process inventory entries + for (ean, bin_loc), quantity in inventory_data.items(): + product = ean_to_product.get(ean) + + if not product: + # Track unmatched + if ean not in unmatched: + unmatched[ean] = 0 + unmatched[ean] += quantity + continue + + # Upsert inventory entry + existing = ( + db.query(Inventory) + .filter( + Inventory.product_id == product.id, + Inventory.warehouse == warehouse, + Inventory.bin_location == bin_loc, + ) + .first() + ) + + if existing: + existing.quantity = quantity + existing.gtin = ean + result.entries_updated += 1 + else: + inv = Inventory( + product_id=product.id, + vendor_id=vendor_id, + warehouse=warehouse, + bin_location=bin_loc, + location=bin_loc, # Legacy field + quantity=quantity, + gtin=ean, + ) + db.add(inv) + result.entries_created += 1 + + result.quantity_imported += quantity + + db.flush() + + # Format unmatched GTINs for result + for ean, qty in unmatched.items(): + product_name = product_names.get(ean, "Unknown") + result.unmatched_gtins.append( + {"gtin": ean, "quantity": qty, "product_name": product_name} + ) + + if result.unmatched_gtins: + logger.warning( + f"Import had {len(result.unmatched_gtins)} unmatched GTINs" + ) + + except Exception as e: + logger.exception("Inventory import failed") + result.success = False + result.errors.append(str(e)) + + return result + + def import_from_file( + self, + db: Session, + file_path: str, + vendor_id: int, + warehouse: str = "strassen", + clear_existing: bool = False, + ) -> ImportResult: + """ + Import inventory from a TSV/CSV file. + + Args: + db: Database session + file_path: Path to TSV/CSV file + vendor_id: Vendor ID for inventory + warehouse: Warehouse name + clear_existing: If True, clear existing inventory before import + + Returns: + ImportResult with summary and errors + """ + try: + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + except Exception as e: + return ImportResult(success=False, errors=[f"Failed to read file: {e}"]) + + # Detect delimiter + first_line = content.split("\n")[0] if content else "" + delimiter = "\t" if "\t" in first_line else "," + + return self.import_from_text( + db=db, + content=content, + vendor_id=vendor_id, + warehouse=warehouse, + delimiter=delimiter, + clear_existing=clear_existing, + ) + + +# Singleton instance +inventory_import_service = InventoryImportService() diff --git a/app/modules/inventory/services/inventory_service.py b/app/modules/inventory/services/inventory_service.py new file mode 100644 index 00000000..5a845cef --- /dev/null +++ b/app/modules/inventory/services/inventory_service.py @@ -0,0 +1,949 @@ +# app/modules/inventory/services/inventory_service.py +import logging +from datetime import UTC, datetime + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.exceptions import ( + InsufficientInventoryException, + InvalidQuantityException, + InventoryNotFoundException, + InventoryValidationException, + ProductNotFoundException, + ValidationException, + VendorNotFoundException, +) +from app.modules.inventory.models.inventory import Inventory +from app.modules.inventory.schemas.inventory import ( + AdminInventoryItem, + AdminInventoryListResponse, + AdminInventoryLocationsResponse, + AdminInventoryStats, + AdminLowStockItem, + AdminVendorsWithInventoryResponse, + AdminVendorWithInventory, + InventoryAdjust, + InventoryCreate, + InventoryLocationResponse, + InventoryReserve, + InventoryUpdate, + ProductInventorySummary, +) +from models.database.product import Product +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class InventoryService: + """Service for inventory operations with vendor isolation.""" + + def set_inventory( + self, db: Session, vendor_id: int, inventory_data: InventoryCreate + ) -> Inventory: + """ + Set exact inventory quantity for a product at a location (replaces existing). + + Args: + db: Database session + vendor_id: Vendor ID (from middleware) + inventory_data: Inventory data + + Returns: + Inventory object + """ + try: + # Validate product belongs to vendor + product = self._get_vendor_product(db, vendor_id, inventory_data.product_id) + + # Validate location + location = self._validate_location(inventory_data.location) + + # Validate quantity + self._validate_quantity(inventory_data.quantity, allow_zero=True) + + # Check if inventory entry exists + existing = self._get_inventory_entry( + db, inventory_data.product_id, location + ) + + if existing: + old_qty = existing.quantity + existing.quantity = inventory_data.quantity + existing.updated_at = datetime.now(UTC) + db.flush() + db.refresh(existing) + + logger.info( + f"Set inventory for product {inventory_data.product_id} at {location}: " + f"{old_qty} → {inventory_data.quantity}" + ) + return existing + # Create new inventory entry + new_inventory = Inventory( + product_id=inventory_data.product_id, + vendor_id=vendor_id, + warehouse="strassen", # Default warehouse + bin_location=location, # Use location as bin location + location=location, # Keep for backward compatibility + quantity=inventory_data.quantity, + gtin=product.marketplace_product.gtin, # Optional reference + ) + db.add(new_inventory) + db.flush() + db.refresh(new_inventory) + + logger.info( + f"Created inventory for product {inventory_data.product_id} at {location}: " + f"{inventory_data.quantity}" + ) + return new_inventory + + except ( + ProductNotFoundException, + InvalidQuantityException, + InventoryValidationException, + ): + db.rollback() + raise + except Exception as e: + db.rollback() + logger.error(f"Error setting inventory: {str(e)}") + raise ValidationException("Failed to set inventory") + + def adjust_inventory( + self, db: Session, vendor_id: int, inventory_data: InventoryAdjust + ) -> Inventory: + """ + Adjust inventory by adding or removing quantity. + Positive quantity = add, negative = remove. + + Args: + db: Database session + vendor_id: Vendor ID + inventory_data: Adjustment data + + Returns: + Updated Inventory object + """ + try: + # Validate product belongs to vendor + product = self._get_vendor_product(db, vendor_id, inventory_data.product_id) + + # Validate location + location = self._validate_location(inventory_data.location) + + # Check if inventory exists + existing = self._get_inventory_entry( + db, inventory_data.product_id, location + ) + + if not existing: + # Create new if adding, error if removing + if inventory_data.quantity < 0: + raise InventoryNotFoundException( + f"No inventory found for product {inventory_data.product_id} at {location}" + ) + + # Create with positive quantity + new_inventory = Inventory( + product_id=inventory_data.product_id, + vendor_id=vendor_id, + warehouse="strassen", # Default warehouse + bin_location=location, # Use location as bin location + location=location, # Keep for backward compatibility + quantity=inventory_data.quantity, + gtin=product.marketplace_product.gtin, + ) + db.add(new_inventory) + db.flush() + db.refresh(new_inventory) + + logger.info( + f"Created inventory for product {inventory_data.product_id} at {location}: " + f"+{inventory_data.quantity}" + ) + return new_inventory + + # Adjust existing inventory + old_qty = existing.quantity + new_qty = old_qty + inventory_data.quantity + + # Validate resulting quantity + if new_qty < 0: + raise InsufficientInventoryException( + f"Insufficient inventory. Available: {old_qty}, " + f"Requested removal: {abs(inventory_data.quantity)}" + ) + + existing.quantity = new_qty + existing.updated_at = datetime.now(UTC) + db.flush() + db.refresh(existing) + + logger.info( + f"Adjusted inventory for product {inventory_data.product_id} at {location}: " + f"{old_qty} {'+' if inventory_data.quantity >= 0 else ''}{inventory_data.quantity} = {new_qty}" + ) + return existing + + except ( + ProductNotFoundException, + InventoryNotFoundException, + InsufficientInventoryException, + InventoryValidationException, + ): + db.rollback() + raise + except Exception as e: + db.rollback() + logger.error(f"Error adjusting inventory: {str(e)}") + raise ValidationException("Failed to adjust inventory") + + def reserve_inventory( + self, db: Session, vendor_id: int, reserve_data: InventoryReserve + ) -> Inventory: + """ + Reserve inventory for an order (increases reserved_quantity). + + Args: + db: Database session + vendor_id: Vendor ID + reserve_data: Reservation data + + Returns: + Updated Inventory object + """ + try: + # Validate product + product = self._get_vendor_product(db, vendor_id, reserve_data.product_id) + + # Validate location and quantity + location = self._validate_location(reserve_data.location) + self._validate_quantity(reserve_data.quantity, allow_zero=False) + + # Get inventory entry + inventory = self._get_inventory_entry(db, reserve_data.product_id, location) + if not inventory: + raise InventoryNotFoundException( + f"No inventory found for product {reserve_data.product_id} at {location}" + ) + + # Check available quantity + available = inventory.quantity - inventory.reserved_quantity + if available < reserve_data.quantity: + raise InsufficientInventoryException( + f"Insufficient available inventory. Available: {available}, " + f"Requested: {reserve_data.quantity}" + ) + + # Reserve inventory + inventory.reserved_quantity += reserve_data.quantity + inventory.updated_at = datetime.now(UTC) + db.flush() + db.refresh(inventory) + + logger.info( + f"Reserved {reserve_data.quantity} units for product {reserve_data.product_id} " + f"at {location}" + ) + return inventory + + except ( + ProductNotFoundException, + InventoryNotFoundException, + InsufficientInventoryException, + InvalidQuantityException, + ): + db.rollback() + raise + except Exception as e: + db.rollback() + logger.error(f"Error reserving inventory: {str(e)}") + raise ValidationException("Failed to reserve inventory") + + def release_reservation( + self, db: Session, vendor_id: int, reserve_data: InventoryReserve + ) -> Inventory: + """ + Release reserved inventory (decreases reserved_quantity). + + Args: + db: Database session + vendor_id: Vendor ID + reserve_data: Reservation data + + Returns: + Updated Inventory object + """ + try: + # Validate product + product = self._get_vendor_product(db, vendor_id, reserve_data.product_id) + + location = self._validate_location(reserve_data.location) + self._validate_quantity(reserve_data.quantity, allow_zero=False) + + inventory = self._get_inventory_entry(db, reserve_data.product_id, location) + if not inventory: + raise InventoryNotFoundException( + f"No inventory found for product {reserve_data.product_id} at {location}" + ) + + # Validate reserved quantity + if inventory.reserved_quantity < reserve_data.quantity: + logger.warning( + f"Attempting to release more than reserved. Reserved: {inventory.reserved_quantity}, " + f"Requested: {reserve_data.quantity}" + ) + inventory.reserved_quantity = 0 + else: + inventory.reserved_quantity -= reserve_data.quantity + + inventory.updated_at = datetime.now(UTC) + db.flush() + db.refresh(inventory) + + logger.info( + f"Released {reserve_data.quantity} units for product {reserve_data.product_id} " + f"at {location}" + ) + return inventory + + except ( + ProductNotFoundException, + InventoryNotFoundException, + InvalidQuantityException, + ): + db.rollback() + raise + except Exception as e: + db.rollback() + logger.error(f"Error releasing reservation: {str(e)}") + raise ValidationException("Failed to release reservation") + + def fulfill_reservation( + self, db: Session, vendor_id: int, reserve_data: InventoryReserve + ) -> Inventory: + """ + Fulfill a reservation (decreases both quantity and reserved_quantity). + Use when order is shipped/completed. + + Args: + db: Database session + vendor_id: Vendor ID + reserve_data: Reservation data + + Returns: + Updated Inventory object + """ + try: + product = self._get_vendor_product(db, vendor_id, reserve_data.product_id) + location = self._validate_location(reserve_data.location) + self._validate_quantity(reserve_data.quantity, allow_zero=False) + + inventory = self._get_inventory_entry(db, reserve_data.product_id, location) + if not inventory: + raise InventoryNotFoundException( + f"No inventory found for product {reserve_data.product_id} at {location}" + ) + + # Validate quantities + if inventory.quantity < reserve_data.quantity: + raise InsufficientInventoryException( + f"Insufficient inventory. Available: {inventory.quantity}, " + f"Requested: {reserve_data.quantity}" + ) + + if inventory.reserved_quantity < reserve_data.quantity: + logger.warning( + f"Fulfilling more than reserved. Reserved: {inventory.reserved_quantity}, " + f"Fulfilling: {reserve_data.quantity}" + ) + + # Fulfill (remove from both quantity and reserved) + inventory.quantity -= reserve_data.quantity + inventory.reserved_quantity = max( + 0, inventory.reserved_quantity - reserve_data.quantity + ) + inventory.updated_at = datetime.now(UTC) + db.flush() + db.refresh(inventory) + + logger.info( + f"Fulfilled {reserve_data.quantity} units for product {reserve_data.product_id} " + f"at {location}" + ) + return inventory + + except ( + ProductNotFoundException, + InventoryNotFoundException, + InsufficientInventoryException, + InvalidQuantityException, + ): + db.rollback() + raise + except Exception as e: + db.rollback() + logger.error(f"Error fulfilling reservation: {str(e)}") + raise ValidationException("Failed to fulfill reservation") + + def get_product_inventory( + self, db: Session, vendor_id: int, product_id: int + ) -> ProductInventorySummary: + """ + Get inventory summary for a product across all locations. + + Args: + db: Database session + vendor_id: Vendor ID + product_id: Product ID + + Returns: + ProductInventorySummary + """ + try: + product = self._get_vendor_product(db, vendor_id, product_id) + + inventory_entries = ( + db.query(Inventory).filter(Inventory.product_id == product_id).all() + ) + + if not inventory_entries: + return ProductInventorySummary( + product_id=product_id, + vendor_id=vendor_id, + product_sku=product.vendor_sku, + product_title=product.marketplace_product.get_title() or "", + total_quantity=0, + total_reserved=0, + total_available=0, + locations=[], + ) + + total_qty = sum(inv.quantity for inv in inventory_entries) + total_reserved = sum(inv.reserved_quantity for inv in inventory_entries) + total_available = sum(inv.available_quantity for inv in inventory_entries) + + locations = [ + InventoryLocationResponse( + location=inv.location, + quantity=inv.quantity, + reserved_quantity=inv.reserved_quantity, + available_quantity=inv.available_quantity, + ) + for inv in inventory_entries + ] + + return ProductInventorySummary( + product_id=product_id, + vendor_id=vendor_id, + product_sku=product.vendor_sku, + product_title=product.marketplace_product.get_title() or "", + total_quantity=total_qty, + total_reserved=total_reserved, + total_available=total_available, + locations=locations, + ) + + except ProductNotFoundException: + raise + except Exception as e: + logger.error(f"Error getting product inventory: {str(e)}") + raise ValidationException("Failed to retrieve product inventory") + + def get_vendor_inventory( + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 100, + location: str | None = None, + low_stock_threshold: int | None = None, + ) -> list[Inventory]: + """ + Get all inventory for a vendor with filtering. + + Args: + db: Database session + vendor_id: Vendor ID + skip: Pagination offset + limit: Pagination limit + location: Filter by location + low_stock_threshold: Filter items below threshold + + Returns: + List of Inventory objects + """ + try: + query = db.query(Inventory).filter(Inventory.vendor_id == vendor_id) + + if location: + query = query.filter(Inventory.location.ilike(f"%{location}%")) + + if low_stock_threshold is not None: + query = query.filter(Inventory.quantity <= low_stock_threshold) + + return query.offset(skip).limit(limit).all() + + except Exception as e: + logger.error(f"Error getting vendor inventory: {str(e)}") + raise ValidationException("Failed to retrieve vendor inventory") + + def update_inventory( + self, + db: Session, + vendor_id: int, + inventory_id: int, + inventory_update: InventoryUpdate, + ) -> Inventory: + """Update inventory entry.""" + try: + inventory = self._get_inventory_by_id(db, inventory_id) + + # Verify ownership + if inventory.vendor_id != vendor_id: + raise InventoryNotFoundException(f"Inventory {inventory_id} not found") + + # Update fields + if inventory_update.quantity is not None: + self._validate_quantity(inventory_update.quantity, allow_zero=True) + inventory.quantity = inventory_update.quantity + + if inventory_update.reserved_quantity is not None: + self._validate_quantity( + inventory_update.reserved_quantity, allow_zero=True + ) + inventory.reserved_quantity = inventory_update.reserved_quantity + + if inventory_update.location: + inventory.location = self._validate_location(inventory_update.location) + + inventory.updated_at = datetime.now(UTC) + db.flush() + db.refresh(inventory) + + logger.info(f"Updated inventory {inventory_id}") + return inventory + + except ( + InventoryNotFoundException, + InvalidQuantityException, + InventoryValidationException, + ): + db.rollback() + raise + except Exception as e: + db.rollback() + logger.error(f"Error updating inventory: {str(e)}") + raise ValidationException("Failed to update inventory") + + def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool: + """Delete inventory entry.""" + try: + inventory = self._get_inventory_by_id(db, inventory_id) + + # Verify ownership + if inventory.vendor_id != vendor_id: + raise InventoryNotFoundException(f"Inventory {inventory_id} not found") + + db.delete(inventory) + db.flush() + + logger.info(f"Deleted inventory {inventory_id}") + return True + + except InventoryNotFoundException: + raise + except Exception as e: + db.rollback() + logger.error(f"Error deleting inventory: {str(e)}") + raise ValidationException("Failed to delete inventory") + + # ========================================================================= + # Admin Methods (cross-vendor operations) + # ========================================================================= + + def get_all_inventory_admin( + self, + db: Session, + skip: int = 0, + limit: int = 50, + vendor_id: int | None = None, + location: str | None = None, + low_stock: int | None = None, + search: str | None = None, + ) -> AdminInventoryListResponse: + """ + Get inventory across all vendors with filtering (admin only). + + Args: + db: Database session + skip: Pagination offset + limit: Pagination limit + vendor_id: Filter by vendor + location: Filter by location + low_stock: Filter items below threshold + search: Search by product title or SKU + + Returns: + AdminInventoryListResponse + """ + query = db.query(Inventory).join(Product).join(Vendor) + + # Apply filters + if vendor_id is not None: + query = query.filter(Inventory.vendor_id == vendor_id) + + if location: + query = query.filter(Inventory.location.ilike(f"%{location}%")) + + if low_stock is not None: + query = query.filter(Inventory.quantity <= low_stock) + + if search: + from models.database.marketplace_product import MarketplaceProduct + from models.database.marketplace_product_translation import ( + MarketplaceProductTranslation, + ) + + query = ( + query.join(MarketplaceProduct) + .outerjoin(MarketplaceProductTranslation) + .filter( + (MarketplaceProductTranslation.title.ilike(f"%{search}%")) + | (Product.vendor_sku.ilike(f"%{search}%")) + ) + ) + + # Get total count before pagination + total = query.count() + + # Apply pagination + inventories = query.offset(skip).limit(limit).all() + + # Build response with vendor/product info + items = [] + for inv in inventories: + product = inv.product + vendor = inv.vendor + title = None + if product and product.marketplace_product: + title = product.marketplace_product.get_title() + + items.append( + AdminInventoryItem( + id=inv.id, + product_id=inv.product_id, + vendor_id=inv.vendor_id, + vendor_name=vendor.name if vendor else None, + vendor_code=vendor.vendor_code if vendor else None, + product_title=title, + product_sku=product.vendor_sku if product else None, + location=inv.location, + quantity=inv.quantity, + reserved_quantity=inv.reserved_quantity, + available_quantity=inv.available_quantity, + gtin=inv.gtin, + created_at=inv.created_at, + updated_at=inv.updated_at, + ) + ) + + return AdminInventoryListResponse( + inventories=items, + total=total, + skip=skip, + limit=limit, + vendor_filter=vendor_id, + location_filter=location, + ) + + def get_inventory_stats_admin(self, db: Session) -> AdminInventoryStats: + """Get platform-wide inventory statistics (admin only).""" + # Total entries + total_entries = db.query(func.count(Inventory.id)).scalar() or 0 + + # Aggregate quantities + totals = db.query( + func.sum(Inventory.quantity).label("total_qty"), + func.sum(Inventory.reserved_quantity).label("total_reserved"), + ).first() + + total_quantity = totals.total_qty or 0 + total_reserved = totals.total_reserved or 0 + total_available = total_quantity - total_reserved + + # Low stock count (default threshold: 10) + low_stock_count = ( + db.query(func.count(Inventory.id)) + .filter(Inventory.quantity <= 10) + .scalar() + or 0 + ) + + # Vendors with inventory + vendors_with_inventory = ( + db.query(func.count(func.distinct(Inventory.vendor_id))).scalar() or 0 + ) + + # Unique locations + unique_locations = ( + db.query(func.count(func.distinct(Inventory.location))).scalar() or 0 + ) + + return AdminInventoryStats( + total_entries=total_entries, + total_quantity=total_quantity, + total_reserved=total_reserved, + total_available=total_available, + low_stock_count=low_stock_count, + vendors_with_inventory=vendors_with_inventory, + unique_locations=unique_locations, + ) + + def get_low_stock_items_admin( + self, + db: Session, + threshold: int = 10, + vendor_id: int | None = None, + limit: int = 50, + ) -> list[AdminLowStockItem]: + """Get items with low stock levels (admin only).""" + query = ( + db.query(Inventory) + .join(Product) + .join(Vendor) + .filter(Inventory.quantity <= threshold) + ) + + if vendor_id is not None: + query = query.filter(Inventory.vendor_id == vendor_id) + + # Order by quantity ascending (most critical first) + query = query.order_by(Inventory.quantity.asc()) + + inventories = query.limit(limit).all() + + items = [] + for inv in inventories: + product = inv.product + vendor = inv.vendor + title = None + if product and product.marketplace_product: + title = product.marketplace_product.get_title() + + items.append( + AdminLowStockItem( + id=inv.id, + product_id=inv.product_id, + vendor_id=inv.vendor_id, + vendor_name=vendor.name if vendor else None, + product_title=title, + location=inv.location, + quantity=inv.quantity, + reserved_quantity=inv.reserved_quantity, + available_quantity=inv.available_quantity, + ) + ) + + return items + + def get_vendors_with_inventory_admin( + self, db: Session + ) -> AdminVendorsWithInventoryResponse: + """Get list of vendors that have inventory entries (admin only).""" + # noqa: SVC-005 - Admin function, intentionally cross-vendor + # Use subquery to avoid DISTINCT on JSON columns (PostgreSQL can't compare JSON) + vendor_ids_subquery = ( + db.query(Inventory.vendor_id) + .distinct() + .subquery() + ) + vendors = ( + db.query(Vendor) + .filter(Vendor.id.in_(db.query(vendor_ids_subquery.c.vendor_id))) + .order_by(Vendor.name) + .all() + ) + + return AdminVendorsWithInventoryResponse( + vendors=[ + AdminVendorWithInventory( + id=v.id, name=v.name, vendor_code=v.vendor_code + ) + for v in vendors + ] + ) + + def get_inventory_locations_admin( + self, db: Session, vendor_id: int | None = None + ) -> AdminInventoryLocationsResponse: + """Get list of unique inventory locations (admin only).""" + query = db.query(func.distinct(Inventory.location)) + + if vendor_id is not None: + query = query.filter(Inventory.vendor_id == vendor_id) + + locations = [loc[0] for loc in query.all()] + + return AdminInventoryLocationsResponse(locations=sorted(locations)) + + def get_vendor_inventory_admin( + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 50, + location: str | None = None, + low_stock: int | None = None, + ) -> AdminInventoryListResponse: + """Get inventory for a specific vendor (admin only).""" + # Verify vendor exists + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + if not vendor: + raise VendorNotFoundException(f"Vendor {vendor_id} not found") + + # Use the existing method + inventories = self.get_vendor_inventory( + db=db, + vendor_id=vendor_id, + skip=skip, + limit=limit, + location=location, + low_stock_threshold=low_stock, + ) + + # Build response with product info + items = [] + for inv in inventories: + product = inv.product + title = None + if product and product.marketplace_product: + title = product.marketplace_product.get_title() + + items.append( + AdminInventoryItem( + id=inv.id, + product_id=inv.product_id, + vendor_id=inv.vendor_id, + vendor_name=vendor.name, + vendor_code=vendor.vendor_code, + product_title=title, + product_sku=product.vendor_sku if product else None, + location=inv.location, + quantity=inv.quantity, + reserved_quantity=inv.reserved_quantity, + available_quantity=inv.available_quantity, + gtin=inv.gtin, + created_at=inv.created_at, + updated_at=inv.updated_at, + ) + ) + + # Get total count for pagination + total_query = db.query(func.count(Inventory.id)).filter( + Inventory.vendor_id == vendor_id + ) + if location: + total_query = total_query.filter(Inventory.location.ilike(f"%{location}%")) + if low_stock is not None: + total_query = total_query.filter(Inventory.quantity <= low_stock) + total = total_query.scalar() or 0 + + return AdminInventoryListResponse( + inventories=items, + total=total, + skip=skip, + limit=limit, + vendor_filter=vendor_id, + location_filter=location, + ) + + def get_product_inventory_admin( + self, db: Session, product_id: int + ) -> ProductInventorySummary: + """Get inventory summary for a product (admin only - no vendor check).""" + product = db.query(Product).filter(Product.id == product_id).first() + if not product: + raise ProductNotFoundException(f"Product {product_id} not found") + + # Use existing method with the product's vendor_id + return self.get_product_inventory(db, product.vendor_id, product_id) + + def verify_vendor_exists(self, db: Session, vendor_id: int) -> Vendor: + """Verify vendor exists and return it.""" + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + if not vendor: + raise VendorNotFoundException(f"Vendor {vendor_id} not found") + return vendor + + def get_inventory_by_id_admin(self, db: Session, inventory_id: int) -> Inventory: + """Get inventory by ID (admin only - returns inventory with vendor_id).""" + inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first() + if not inventory: + raise InventoryNotFoundException(f"Inventory {inventory_id} not found") + return inventory + + # ========================================================================= + # Private helper methods + # ========================================================================= + + def _get_vendor_product( + self, db: Session, vendor_id: int, product_id: int + ) -> Product: + """Get product and verify it belongs to vendor.""" + product = ( + db.query(Product) + .filter(Product.id == product_id, Product.vendor_id == vendor_id) + .first() + ) + + if not product: + raise ProductNotFoundException( + f"Product {product_id} not found in your catalog" + ) + + return product + + def _get_inventory_entry( + self, db: Session, product_id: int, location: str + ) -> Inventory | None: + """Get inventory entry by product and location.""" + return ( + db.query(Inventory) + .filter(Inventory.product_id == product_id, Inventory.location == location) + .first() + ) + + def _get_inventory_by_id(self, db: Session, inventory_id: int) -> Inventory: + """Get inventory by ID or raise exception.""" + inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first() + if not inventory: + raise InventoryNotFoundException(f"Inventory {inventory_id} not found") + return inventory + + def _validate_location(self, location: str) -> str: + """Validate and normalize location.""" + if not location or not location.strip(): + raise InventoryValidationException("Location is required") + return location.strip().upper() + + def _validate_quantity(self, quantity: int, allow_zero: bool = True) -> None: + """Validate quantity value.""" + if quantity is None: + raise InvalidQuantityException("Quantity is required") + + if not isinstance(quantity, int): + raise InvalidQuantityException("Quantity must be an integer") + + if quantity < 0: + raise InvalidQuantityException("Quantity cannot be negative") + + if not allow_zero and quantity == 0: + raise InvalidQuantityException("Quantity must be positive") + + +# Create service instance +inventory_service = InventoryService() diff --git a/app/modules/inventory/services/inventory_transaction_service.py b/app/modules/inventory/services/inventory_transaction_service.py new file mode 100644 index 00000000..4543d279 --- /dev/null +++ b/app/modules/inventory/services/inventory_transaction_service.py @@ -0,0 +1,431 @@ +# app/modules/inventory/services/inventory_transaction_service.py +""" +Inventory Transaction Service. + +Provides query operations for inventory transaction history. +All transaction WRITES are handled by OrderInventoryService. +This service handles transaction READS for reporting and auditing. +""" + +import logging +from sqlalchemy import func +from sqlalchemy.orm import Session + +from app.exceptions import OrderNotFoundException, ProductNotFoundException +from app.modules.inventory.models.inventory import Inventory +from app.modules.inventory.models.inventory_transaction import InventoryTransaction +from models.database.order import Order +from models.database.product import Product + +logger = logging.getLogger(__name__) + + +class InventoryTransactionService: + """Service for querying inventory transaction history.""" + + def get_vendor_transactions( + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 50, + product_id: int | None = None, + transaction_type: str | None = None, + ) -> tuple[list[dict], int]: + """ + Get inventory transactions for a vendor with optional filters. + + Args: + db: Database session + vendor_id: Vendor ID + skip: Pagination offset + limit: Pagination limit + product_id: Optional product filter + transaction_type: Optional transaction type filter + + Returns: + Tuple of (transactions with product details, total count) + """ + # Build query + query = db.query(InventoryTransaction).filter( + InventoryTransaction.vendor_id == vendor_id + ) + + # Apply filters + if product_id: + query = query.filter(InventoryTransaction.product_id == product_id) + if transaction_type: + query = query.filter( + InventoryTransaction.transaction_type == transaction_type + ) + + # Get total count + total = query.count() + + # Get transactions with pagination (newest first) + transactions = ( + query.order_by(InventoryTransaction.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + # Build result with product details + result = [] + for tx in transactions: + product = db.query(Product).filter(Product.id == tx.product_id).first() + product_title = None + product_sku = None + if product: + product_sku = product.vendor_sku + if product.marketplace_product: + product_title = product.marketplace_product.get_title() + + result.append( + { + "id": tx.id, + "vendor_id": tx.vendor_id, + "product_id": tx.product_id, + "inventory_id": tx.inventory_id, + "transaction_type": ( + tx.transaction_type.value if tx.transaction_type else None + ), + "quantity_change": tx.quantity_change, + "quantity_after": tx.quantity_after, + "reserved_after": tx.reserved_after, + "location": tx.location, + "warehouse": tx.warehouse, + "order_id": tx.order_id, + "order_number": tx.order_number, + "reason": tx.reason, + "created_by": tx.created_by, + "created_at": tx.created_at, + "product_title": product_title, + "product_sku": product_sku, + } + ) + + return result, total + + def get_product_history( + self, + db: Session, + vendor_id: int, + product_id: int, + limit: int = 50, + ) -> dict: + """ + Get transaction history for a specific product. + + Args: + db: Database session + vendor_id: Vendor ID + product_id: Product ID + limit: Max transactions to return + + Returns: + Dict with product info, current inventory, and transactions + + Raises: + ProductNotFoundException: If product not found or doesn't belong to vendor + """ + # Get product details + product = ( + db.query(Product) + .filter(Product.id == product_id, Product.vendor_id == vendor_id) + .first() + ) + + if not product: + raise ProductNotFoundException( + f"Product {product_id} not found in vendor catalog" + ) + + product_title = None + product_sku = product.vendor_sku + if product.marketplace_product: + product_title = product.marketplace_product.get_title() + + # Get current inventory + inventory = ( + db.query(Inventory) + .filter(Inventory.product_id == product_id, Inventory.vendor_id == vendor_id) + .first() + ) + + current_quantity = inventory.quantity if inventory else 0 + current_reserved = inventory.reserved_quantity if inventory else 0 + + # Get transactions + transactions = ( + db.query(InventoryTransaction) + .filter( + InventoryTransaction.vendor_id == vendor_id, + InventoryTransaction.product_id == product_id, + ) + .order_by(InventoryTransaction.created_at.desc()) + .limit(limit) + .all() + ) + + total = ( + db.query(func.count(InventoryTransaction.id)) + .filter( + InventoryTransaction.vendor_id == vendor_id, + InventoryTransaction.product_id == product_id, + ) + .scalar() + or 0 + ) + + return { + "product_id": product_id, + "product_title": product_title, + "product_sku": product_sku, + "current_quantity": current_quantity, + "current_reserved": current_reserved, + "transactions": [ + { + "id": tx.id, + "vendor_id": tx.vendor_id, + "product_id": tx.product_id, + "inventory_id": tx.inventory_id, + "transaction_type": ( + tx.transaction_type.value if tx.transaction_type else None + ), + "quantity_change": tx.quantity_change, + "quantity_after": tx.quantity_after, + "reserved_after": tx.reserved_after, + "location": tx.location, + "warehouse": tx.warehouse, + "order_id": tx.order_id, + "order_number": tx.order_number, + "reason": tx.reason, + "created_by": tx.created_by, + "created_at": tx.created_at, + } + for tx in transactions + ], + "total": total, + } + + def get_order_history( + self, + db: Session, + vendor_id: int, + order_id: int, + ) -> dict: + """ + Get all inventory transactions for a specific order. + + Args: + db: Database session + vendor_id: Vendor ID + order_id: Order ID + + Returns: + Dict with order info and transactions + + Raises: + OrderNotFoundException: If order not found or doesn't belong to vendor + """ + # Verify order belongs to vendor + order = ( + db.query(Order) + .filter(Order.id == order_id, Order.vendor_id == vendor_id) + .first() + ) + + if not order: + raise OrderNotFoundException(f"Order {order_id} not found") + + # Get transactions for this order + transactions = ( + db.query(InventoryTransaction) + .filter(InventoryTransaction.order_id == order_id) + .order_by(InventoryTransaction.created_at.asc()) + .all() + ) + + # Build result with product details + result = [] + for tx in transactions: + product = db.query(Product).filter(Product.id == tx.product_id).first() + product_title = None + product_sku = None + if product: + product_sku = product.vendor_sku + if product.marketplace_product: + product_title = product.marketplace_product.get_title() + + result.append( + { + "id": tx.id, + "vendor_id": tx.vendor_id, + "product_id": tx.product_id, + "inventory_id": tx.inventory_id, + "transaction_type": ( + tx.transaction_type.value if tx.transaction_type else None + ), + "quantity_change": tx.quantity_change, + "quantity_after": tx.quantity_after, + "reserved_after": tx.reserved_after, + "location": tx.location, + "warehouse": tx.warehouse, + "order_id": tx.order_id, + "order_number": tx.order_number, + "reason": tx.reason, + "created_by": tx.created_by, + "created_at": tx.created_at, + "product_title": product_title, + "product_sku": product_sku, + } + ) + + return { + "order_id": order_id, + "order_number": order.order_number, + "transactions": result, + } + + + # ========================================================================= + # Admin Methods (cross-vendor operations) + # ========================================================================= + + def get_all_transactions_admin( + self, + db: Session, + skip: int = 0, + limit: int = 50, + vendor_id: int | None = None, + product_id: int | None = None, + transaction_type: str | None = None, + order_id: int | None = None, + ) -> tuple[list[dict], int]: + """ + Get inventory transactions across all vendors (admin only). + + Args: + db: Database session + skip: Pagination offset + limit: Pagination limit + vendor_id: Optional vendor filter + product_id: Optional product filter + transaction_type: Optional transaction type filter + order_id: Optional order filter + + Returns: + Tuple of (transactions with details, total count) + """ + from models.database.vendor import Vendor + + # Build query + query = db.query(InventoryTransaction) + + # Apply filters + if vendor_id: + query = query.filter(InventoryTransaction.vendor_id == vendor_id) + if product_id: + query = query.filter(InventoryTransaction.product_id == product_id) + if transaction_type: + query = query.filter( + InventoryTransaction.transaction_type == transaction_type + ) + if order_id: + query = query.filter(InventoryTransaction.order_id == order_id) + + # Get total count + total = query.count() + + # Get transactions with pagination (newest first) + transactions = ( + query.order_by(InventoryTransaction.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + # Build result with vendor and product details + result = [] + for tx in transactions: + vendor = db.query(Vendor).filter(Vendor.id == tx.vendor_id).first() + product = db.query(Product).filter(Product.id == tx.product_id).first() + + product_title = None + product_sku = None + if product: + product_sku = product.vendor_sku + if product.marketplace_product: + product_title = product.marketplace_product.get_title() + + result.append( + { + "id": tx.id, + "vendor_id": tx.vendor_id, + "vendor_name": vendor.name if vendor else None, + "vendor_code": vendor.vendor_code if vendor else None, + "product_id": tx.product_id, + "inventory_id": tx.inventory_id, + "transaction_type": ( + tx.transaction_type.value if tx.transaction_type else None + ), + "quantity_change": tx.quantity_change, + "quantity_after": tx.quantity_after, + "reserved_after": tx.reserved_after, + "location": tx.location, + "warehouse": tx.warehouse, + "order_id": tx.order_id, + "order_number": tx.order_number, + "reason": tx.reason, + "created_by": tx.created_by, + "created_at": tx.created_at, + "product_title": product_title, + "product_sku": product_sku, + } + ) + + return result, total + + def get_transaction_stats_admin(self, db: Session) -> dict: + """ + Get transaction statistics across the platform (admin only). + + Returns: + Dict with transaction counts by type + """ + from sqlalchemy import func as sql_func + + # Count by transaction type + type_counts = ( + db.query( + InventoryTransaction.transaction_type, + sql_func.count(InventoryTransaction.id).label("count"), + ) + .group_by(InventoryTransaction.transaction_type) + .all() + ) + + # Total transactions + total = db.query(sql_func.count(InventoryTransaction.id)).scalar() or 0 + + # Transactions today + from datetime import UTC, datetime, timedelta + + today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0) + today_count = ( + db.query(sql_func.count(InventoryTransaction.id)) + .filter(InventoryTransaction.created_at >= today_start) + .scalar() + or 0 + ) + + return { + "total_transactions": total, + "transactions_today": today_count, + "by_type": {tc.transaction_type.value: tc.count for tc in type_counts}, + } + + +# Create service instance +inventory_transaction_service = InventoryTransactionService() diff --git a/app/modules/marketplace/services/__init__.py b/app/modules/marketplace/services/__init__.py index 6eac1b3e..e609fc70 100644 --- a/app/modules/marketplace/services/__init__.py +++ b/app/modules/marketplace/services/__init__.py @@ -2,45 +2,35 @@ """ Marketplace module services. -Re-exports Letzshop and marketplace services from their current locations. -Services remain in app/services/ for now to avoid breaking existing imports. - -Usage: - from app.modules.marketplace.services import ( - letzshop_export_service, - marketplace_import_job_service, - marketplace_product_service, - ) - from app.modules.marketplace.services.letzshop import ( - LetzshopClient, - LetzshopCredentialsService, - LetzshopOrderService, - LetzshopVendorSyncService, - ) +This module contains the canonical implementations of marketplace-related services. """ -# Re-export from existing locations for convenience -from app.services.letzshop_export_service import ( +# Main marketplace services +from app.modules.marketplace.services.letzshop_export_service import ( LetzshopExportService, letzshop_export_service, ) -from app.services.marketplace_import_job_service import ( +from app.modules.marketplace.services.marketplace_import_job_service import ( MarketplaceImportJobService, marketplace_import_job_service, ) -from app.services.marketplace_product_service import ( +from app.modules.marketplace.services.marketplace_product_service import ( MarketplaceProductService, marketplace_product_service, ) -# Letzshop submodule re-exports -from app.services.letzshop import ( +# Letzshop submodule services +from app.modules.marketplace.services.letzshop import ( LetzshopClient, LetzshopClientError, ) -from app.services.letzshop.credentials_service import LetzshopCredentialsService -from app.services.letzshop.order_service import LetzshopOrderService -from app.services.letzshop.vendor_sync_service import LetzshopVendorSyncService +from app.modules.marketplace.services.letzshop.credentials_service import ( + LetzshopCredentialsService, +) +from app.modules.marketplace.services.letzshop.order_service import LetzshopOrderService +from app.modules.marketplace.services.letzshop.vendor_sync_service import ( + LetzshopVendorSyncService, +) __all__ = [ # Export service diff --git a/app/modules/marketplace/services/letzshop/__init__.py b/app/modules/marketplace/services/letzshop/__init__.py new file mode 100644 index 00000000..d22c257a --- /dev/null +++ b/app/modules/marketplace/services/letzshop/__init__.py @@ -0,0 +1,53 @@ +# app/modules/marketplace/services/letzshop/__init__.py +""" +Letzshop marketplace integration services. + +Provides: +- GraphQL client for API communication +- Credential management service +- Order import service +- Fulfillment sync service +- Vendor directory sync service +""" + +from .client_service import ( + LetzshopAPIError, + LetzshopAuthError, + LetzshopClient, + LetzshopClientError, + LetzshopConnectionError, +) +from .credentials_service import ( + CredentialsError, + CredentialsNotFoundError, + LetzshopCredentialsService, +) +from .order_service import ( + LetzshopOrderService, + OrderNotFoundError, + VendorNotFoundError, +) +from .vendor_sync_service import ( + LetzshopVendorSyncService, + get_vendor_sync_service, +) + +__all__ = [ + # Client + "LetzshopClient", + "LetzshopClientError", + "LetzshopAuthError", + "LetzshopAPIError", + "LetzshopConnectionError", + # Credentials + "LetzshopCredentialsService", + "CredentialsError", + "CredentialsNotFoundError", + # Order Service + "LetzshopOrderService", + "OrderNotFoundError", + "VendorNotFoundError", + # Vendor Sync Service + "LetzshopVendorSyncService", + "get_vendor_sync_service", +] diff --git a/app/modules/marketplace/services/letzshop/client_service.py b/app/modules/marketplace/services/letzshop/client_service.py new file mode 100644 index 00000000..d58957aa --- /dev/null +++ b/app/modules/marketplace/services/letzshop/client_service.py @@ -0,0 +1,1015 @@ +# app/services/letzshop/client_service.py +""" +GraphQL client for Letzshop marketplace API. + +Handles authentication, request formatting, and error handling +for all Letzshop API operations. +""" + +import logging +import time +from typing import Any, Callable + +import requests + +logger = logging.getLogger(__name__) + +# Default API endpoint +DEFAULT_ENDPOINT = "https://letzshop.lu/graphql" + + +class LetzshopClientError(Exception): + """Base exception for Letzshop client errors.""" + + def __init__(self, message: str, response_data: dict | None = None): + super().__init__(message) + self.message = message + self.response_data = response_data + + +class LetzshopAuthError(LetzshopClientError): + """Raised when authentication fails.""" + + +class LetzshopAPIError(LetzshopClientError): + """Raised when the API returns an error response.""" + + +class LetzshopConnectionError(LetzshopClientError): + """Raised when connection to the API fails.""" + + +# ============================================================================ +# GraphQL Queries +# ============================================================================ + +QUERY_SHIPMENTS_UNCONFIRMED = """ +query { + shipments(state: unconfirmed) { + nodes { + id + number + state + order { + id + number + email + total + completedAt + locale + shipAddress { + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country { + name { en fr de } + iso + } + } + billAddress { + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country { + name { en fr de } + iso + } + } + } + inventoryUnits { + id + state + variant { + id + sku + mpn + price + tradeId { + number + parser + } + product { + name { + en + fr + de + } + _brand { + ... on Brand { + name + } + } + } + } + } + tracking { + code + provider + } + data { + __typename + } + } + } +} +""" + +QUERY_SHIPMENTS_CONFIRMED = """ +query { + shipments(state: confirmed) { + nodes { + id + number + state + order { + id + number + email + total + completedAt + locale + shipAddress { + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country { + name { en fr de } + iso + } + } + billAddress { + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country { + name { en fr de } + iso + } + } + } + inventoryUnits { + id + state + variant { + id + sku + mpn + price + tradeId { + number + parser + } + product { + name { + en + fr + de + } + _brand { + ... on Brand { + name + } + } + } + } + } + tracking { + code + provider + } + data { + __typename + } + } + } +} +""" + +QUERY_SHIPMENT_BY_ID = """ +query GetShipment($id: ID!) { + node(id: $id) { + ... on Shipment { + id + number + state + order { + id + number + email + total + completedAt + locale + shipAddress { + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country { + name { en fr de } + iso + } + } + billAddress { + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country { + name { en fr de } + iso + } + } + } + inventoryUnits { + id + state + variant { + id + sku + mpn + price + tradeId { + number + parser + } + product { + name { + en + fr + de + } + _brand { + ... on Brand { + name + } + } + } + } + } + tracking { + code + provider + } + data { + __typename + } + } + } +} +""" + +# ============================================================================ +# Paginated Queries (for historical import) +# ============================================================================ + +# Note: Using string formatting for state since Letzshop has issues with enum variables +# Note: tracking field removed - causes 'demodulize' server error on some shipments +QUERY_SHIPMENTS_PAGINATED_TEMPLATE = """ +query GetShipmentsPaginated($first: Int!, $after: String) {{ + shipments(state: {state}, first: $first, after: $after) {{ + pageInfo {{ + hasNextPage + endCursor + }} + nodes {{ + id + number + state + order {{ + id + number + email + total + completedAt + locale + shipAddress {{ + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country {{ + iso + }} + }} + billAddress {{ + firstName + lastName + company + streetName + streetNumber + city + zipCode + phone + country {{ + iso + }} + }} + }} + inventoryUnits {{ + id + state + variant {{ + id + sku + mpn + price + tradeId {{ + number + parser + }} + product {{ + name {{ + en + fr + de + }} + }} + }} + }} + data {{ + __typename + }} + }} + }} +}} +""" + +# ============================================================================ +# GraphQL Queries - Vendor Directory (Public) +# ============================================================================ + +QUERY_VENDORS_PAGINATED = """ +query GetVendorsPaginated($first: Int!, $after: String) { + vendors(first: $first, after: $after) { + pageInfo { + hasNextPage + endCursor + } + totalCount + nodes { + id + slug + name + active + companyName + legalName + email + phone + fax + homepage + description { en fr de } + location { + street + number + city + zipcode + country { iso } + } + lat + lng + vendorCategories { name { en fr de } } + backgroundImage { url } + socialMediaLinks { url } + openingHours { en fr de } + representative + representativeTitle + } + } +} +""" + +QUERY_VENDOR_BY_SLUG = """ +query GetVendorBySlug($slug: String!) { + vendor(slug: $slug) { + id + slug + name + active + companyName + legalName + email + phone + fax + homepage + description { en fr de } + location { + street + number + city + zipcode + country { iso } + } + lat + lng + vendorCategories { name { en fr de } } + backgroundImage { url } + socialMediaLinks { url } + openingHours { en fr de } + representative + representativeTitle + } +} +""" + +# ============================================================================ +# GraphQL Mutations +# ============================================================================ + +MUTATION_CONFIRM_INVENTORY_UNITS = """ +mutation ConfirmInventoryUnits($input: ConfirmInventoryUnitsInput!) { + confirmInventoryUnits(input: $input) { + inventoryUnits { + id + state + } + errors { + id + code + message + } + } +} +""" + +MUTATION_REJECT_INVENTORY_UNITS = """ +mutation RejectInventoryUnits($input: RejectInventoryUnitsInput!) { + returnInventoryUnits(input: $input) { + inventoryUnits { + id + state + } + errors { + id + code + message + } + } +} +""" + +MUTATION_SET_SHIPMENT_TRACKING = """ +mutation SetShipmentTracking($input: SetShipmentTrackingInput!) { + setShipmentTracking(input: $input) { + shipment { + id + tracking { + code + provider + } + } + errors { + code + message + } + } +} +""" + + +class LetzshopClient: + """ + GraphQL client for Letzshop marketplace API. + + Usage: + client = LetzshopClient(api_key="your-api-key") + shipments = client.get_shipments(state="unconfirmed") + """ + + def __init__( + self, + api_key: str, + endpoint: str = DEFAULT_ENDPOINT, + timeout: int = 30, + ): + """ + Initialize the Letzshop client. + + Args: + api_key: The Letzshop API key (Bearer token). + endpoint: The GraphQL endpoint URL. + timeout: Request timeout in seconds. + """ + self.api_key = api_key + self.endpoint = endpoint + self.timeout = timeout + self._session: requests.Session | None = None + + @property + def session(self) -> requests.Session: + """Get or create a requests session.""" + if self._session is None: + self._session = requests.Session() + self._session.headers.update( + { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + ) + return self._session + + def close(self) -> None: + """Close the HTTP session.""" + if self._session is not None: + self._session.close() + self._session = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def _execute_public( + self, + query: str, + variables: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Execute a GraphQL query without authentication (for public queries). + + Args: + query: The GraphQL query string. + variables: Optional variables for the query. + + Returns: + The response data from the API. + + Raises: + LetzshopAPIError: If the API returns an error. + LetzshopConnectionError: If the request fails. + """ + payload = {"query": query} + if variables: + payload["variables"] = variables + + logger.debug(f"Executing public GraphQL request to {self.endpoint}") + + try: + # Use a simple request without Authorization header + response = requests.post( + self.endpoint, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=self.timeout, + ) + except requests.exceptions.Timeout as e: + raise LetzshopConnectionError(f"Request timed out: {e}") from e + except requests.exceptions.ConnectionError as e: + raise LetzshopConnectionError(f"Connection failed: {e}") from e + except requests.exceptions.RequestException as e: + raise LetzshopConnectionError(f"Request failed: {e}") from e + + # Handle HTTP-level errors + if response.status_code >= 500: + raise LetzshopAPIError( + f"Letzshop server error (HTTP {response.status_code})", + response_data={"status_code": response.status_code}, + ) + + # Parse JSON response + try: + data = response.json() + except ValueError as e: + raise LetzshopAPIError( + f"Invalid JSON response: {response.text[:200]}" + ) from e + + logger.debug(f"GraphQL response: {data}") + + # Handle GraphQL errors + if "errors" in data: + errors = data["errors"] + error_messages = [e.get("message", str(e)) for e in errors] + raise LetzshopAPIError( + f"GraphQL errors: {'; '.join(error_messages)}", + response_data=data, + ) + + return data.get("data", {}) + + def _execute( + self, + query: str, + variables: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Execute a GraphQL query or mutation. + + Args: + query: The GraphQL query or mutation string. + variables: Optional variables for the query. + + Returns: + The response data from the API. + + Raises: + LetzshopAuthError: If authentication fails. + LetzshopAPIError: If the API returns an error. + LetzshopConnectionError: If the request fails. + """ + payload = {"query": query} + if variables: + payload["variables"] = variables + + logger.debug(f"Executing GraphQL request to {self.endpoint}") + + try: + response = self.session.post( + self.endpoint, + json=payload, + timeout=self.timeout, + ) + except requests.exceptions.Timeout as e: + raise LetzshopConnectionError(f"Request timed out: {e}") from e + except requests.exceptions.ConnectionError as e: + raise LetzshopConnectionError(f"Connection failed: {e}") from e + except requests.exceptions.RequestException as e: + raise LetzshopConnectionError(f"Request failed: {e}") from e + + # Handle HTTP-level errors + if response.status_code == 401: + raise LetzshopAuthError( + "Authentication failed. Please check your API key.", + response_data={"status_code": 401}, + ) + if response.status_code == 403: + raise LetzshopAuthError( + "Access forbidden. Your API key may not have the required permissions.", + response_data={"status_code": 403}, + ) + if response.status_code >= 500: + raise LetzshopAPIError( + f"Letzshop server error (HTTP {response.status_code})", + response_data={"status_code": response.status_code}, + ) + + # Parse JSON response + try: + data = response.json() + except ValueError as e: + raise LetzshopAPIError( + f"Invalid JSON response: {response.text[:200]}" + ) from e + + logger.debug(f"GraphQL response: {data}") + + # Check for GraphQL errors + if "errors" in data and data["errors"]: + error_messages = [ + err.get("message", "Unknown error") for err in data["errors"] + ] + logger.warning(f"GraphQL errors received: {data['errors']}") + raise LetzshopAPIError( + f"GraphQL errors: {'; '.join(error_messages)}", + response_data=data, + ) + + return data.get("data", {}) + + # ======================================================================== + # Connection Testing + # ======================================================================== + + def test_connection(self) -> tuple[bool, float, str | None]: + """ + Test the connection to Letzshop API. + + Returns: + Tuple of (success, response_time_ms, error_message). + """ + test_query = """ + query TestConnection { + __typename + } + """ + + start_time = time.time() + + try: + self._execute(test_query) + elapsed_ms = (time.time() - start_time) * 1000 + return True, elapsed_ms, None + except LetzshopClientError as e: + elapsed_ms = (time.time() - start_time) * 1000 + return False, elapsed_ms, str(e) + + # ======================================================================== + # Shipment Queries + # ======================================================================== + + def get_shipments( + self, + state: str = "unconfirmed", + ) -> list[dict[str, Any]]: + """ + Get shipments from Letzshop. + + Args: + state: State filter ("unconfirmed" or "confirmed"). + + Returns: + List of shipment data dictionaries. + """ + # Use pre-built queries with inline state values + # (Letzshop's GraphQL has issues with enum variables) + if state == "confirmed": + query = QUERY_SHIPMENTS_CONFIRMED + else: + query = QUERY_SHIPMENTS_UNCONFIRMED + + logger.debug(f"Fetching shipments with state: {state}") + data = self._execute(query) + logger.debug(f"Shipments response data keys: {data.keys() if data else 'None'}") + shipments_data = data.get("shipments", {}) + nodes = shipments_data.get("nodes", []) + logger.info(f"Got {len(nodes)} {state} shipments from Letzshop API") + return nodes + + def get_unconfirmed_shipments(self) -> list[dict[str, Any]]: + """Get all unconfirmed shipments.""" + return self.get_shipments(state="unconfirmed") + + def get_shipment_by_id(self, shipment_id: str) -> dict[str, Any] | None: + """ + Get a single shipment by its ID. + + Args: + shipment_id: The Letzshop shipment ID. + + Returns: + Shipment data or None if not found. + """ + data = self._execute(QUERY_SHIPMENT_BY_ID, {"id": shipment_id}) + return data.get("node") + + def get_all_shipments_paginated( + self, + state: str = "confirmed", + page_size: int = 50, + max_pages: int | None = None, + progress_callback: Callable[[int, int], None] | None = None, + ) -> list[dict[str, Any]]: + """ + Fetch all shipments with pagination support. + + Args: + state: State filter ("unconfirmed" or "confirmed"). + page_size: Number of shipments per page (default 50). + max_pages: Maximum number of pages to fetch (None = all). + progress_callback: Optional callback(page, total_fetched) for progress updates. + + Returns: + List of all shipment data dictionaries. + """ + query = QUERY_SHIPMENTS_PAGINATED_TEMPLATE.format(state=state) + all_shipments = [] + cursor = None + page = 0 + + while True: + page += 1 + variables = {"first": page_size} + if cursor: + variables["after"] = cursor + + logger.info(f"Fetching {state} shipments page {page} (cursor: {cursor})") + + try: + data = self._execute(query, variables) + except LetzshopAPIError as e: + # Log error but return what we have so far + logger.error(f"Error fetching page {page}: {e}") + break + + shipments_data = data.get("shipments", {}) + nodes = shipments_data.get("nodes", []) + page_info = shipments_data.get("pageInfo", {}) + + all_shipments.extend(nodes) + + if progress_callback: + progress_callback(page, len(all_shipments)) + + logger.info(f"Page {page}: fetched {len(nodes)} shipments, total: {len(all_shipments)}") + + # Check if there are more pages + if not page_info.get("hasNextPage"): + logger.info(f"Reached last page. Total shipments: {len(all_shipments)}") + break + + cursor = page_info.get("endCursor") + + # Check max pages limit + if max_pages and page >= max_pages: + logger.info(f"Reached max pages limit ({max_pages}). Total shipments: {len(all_shipments)}") + break + + return all_shipments + + # ======================================================================== + # Fulfillment Mutations + # ======================================================================== + + def confirm_inventory_units( + self, + inventory_unit_ids: list[str], + ) -> dict[str, Any]: + """ + Confirm inventory units for fulfillment. + + Args: + inventory_unit_ids: List of inventory unit IDs to confirm. + + Returns: + Response data including confirmed units and any errors. + """ + variables = { + "input": { + "inventoryUnitIds": inventory_unit_ids, + } + } + + data = self._execute(MUTATION_CONFIRM_INVENTORY_UNITS, variables) + return data.get("confirmInventoryUnits", {}) + + def reject_inventory_units( + self, + inventory_unit_ids: list[str], + ) -> dict[str, Any]: + """ + Reject/return inventory units. + + Args: + inventory_unit_ids: List of inventory unit IDs to reject. + + Returns: + Response data including rejected units and any errors. + """ + variables = { + "input": { + "inventoryUnitIds": inventory_unit_ids, + } + } + + data = self._execute(MUTATION_REJECT_INVENTORY_UNITS, variables) + return data.get("returnInventoryUnits", {}) + + def set_shipment_tracking( + self, + shipment_id: str, + tracking_code: str, + tracking_provider: str, + ) -> dict[str, Any]: + """ + Set tracking information for a shipment. + + Args: + shipment_id: The Letzshop shipment ID. + tracking_code: The tracking number. + tracking_provider: The carrier code (e.g., "dhl", "ups"). + + Returns: + Response data including updated shipment and any errors. + """ + variables = { + "input": { + "shipmentId": shipment_id, + "tracking": { + "code": tracking_code, + "provider": tracking_provider, + }, + } + } + + data = self._execute(MUTATION_SET_SHIPMENT_TRACKING, variables) + return data.get("setShipmentTracking", {}) + + # ======================================================================== + # Vendor Directory Queries (Public - No Auth Required) + # ======================================================================== + + def get_all_vendors_paginated( + self, + page_size: int = 50, + max_pages: int | None = None, + progress_callback: Callable[[int, int, int], None] | None = None, + ) -> list[dict[str, Any]]: + """ + Fetch all vendors from Letzshop marketplace directory. + + This uses the public GraphQL API (no authentication required). + + Args: + page_size: Number of vendors per page (default 50). + max_pages: Maximum number of pages to fetch (None = all). + progress_callback: Optional callback(page, total_fetched, total_count) + for progress updates. + + Returns: + List of all vendor data dictionaries. + """ + all_vendors = [] + cursor = None + page = 0 + total_count = None + + while True: + page += 1 + variables = {"first": page_size} + if cursor: + variables["after"] = cursor + + logger.info(f"Fetching vendors page {page} (cursor: {cursor})") + + try: + # Use public endpoint (no authentication required) + data = self._execute_public(QUERY_VENDORS_PAGINATED, variables) + except LetzshopAPIError as e: + logger.error(f"Error fetching vendors page {page}: {e}") + break + + vendors_data = data.get("vendors", {}) + nodes = vendors_data.get("nodes", []) + page_info = vendors_data.get("pageInfo", {}) + + if total_count is None: + total_count = vendors_data.get("totalCount", 0) + logger.info(f"Total vendors in Letzshop: {total_count}") + + all_vendors.extend(nodes) + + if progress_callback: + progress_callback(page, len(all_vendors), total_count) + + logger.info( + f"Page {page}: fetched {len(nodes)} vendors, " + f"total: {len(all_vendors)}/{total_count}" + ) + + # Check if there are more pages + if not page_info.get("hasNextPage"): + logger.info(f"Reached last page. Total vendors: {len(all_vendors)}") + break + + cursor = page_info.get("endCursor") + + # Check max pages limit + if max_pages and page >= max_pages: + logger.info( + f"Reached max pages limit ({max_pages}). " + f"Total vendors: {len(all_vendors)}" + ) + break + + return all_vendors + + def get_vendor_by_slug(self, slug: str) -> dict[str, Any] | None: + """ + Get a single vendor by their URL slug. + + Args: + slug: The vendor's URL slug (e.g., "nicks-diecast-corner"). + + Returns: + Vendor data dictionary or None if not found. + """ + try: + # Use public endpoint (no authentication required) + data = self._execute_public(QUERY_VENDOR_BY_SLUG, {"slug": slug}) + return data.get("vendor") + except LetzshopAPIError as e: + logger.warning(f"Vendor not found with slug '{slug}': {e}") + return None diff --git a/app/modules/marketplace/services/letzshop/credentials_service.py b/app/modules/marketplace/services/letzshop/credentials_service.py new file mode 100644 index 00000000..1528cb07 --- /dev/null +++ b/app/modules/marketplace/services/letzshop/credentials_service.py @@ -0,0 +1,400 @@ +# app/services/letzshop/credentials_service.py +""" +Letzshop credentials management service. + +Handles secure storage and retrieval of per-vendor Letzshop API credentials. +""" + +import logging +from datetime import UTC, datetime + +from sqlalchemy.orm import Session + +from app.utils.encryption import decrypt_value, encrypt_value, mask_api_key +from models.database.letzshop import VendorLetzshopCredentials + +from .client_service import LetzshopClient + +logger = logging.getLogger(__name__) + +# Default Letzshop GraphQL endpoint +DEFAULT_ENDPOINT = "https://letzshop.lu/graphql" + + +class CredentialsError(Exception): + """Base exception for credentials errors.""" + + +class CredentialsNotFoundError(CredentialsError): + """Raised when credentials are not found for a vendor.""" + + +class LetzshopCredentialsService: + """ + Service for managing Letzshop API credentials. + + Provides secure storage and retrieval of encrypted API keys, + connection testing, and sync status updates. + """ + + def __init__(self, db: Session): + """ + Initialize the credentials service. + + Args: + db: SQLAlchemy database session. + """ + self.db = db + + # ======================================================================== + # CRUD Operations + # ======================================================================== + + def get_credentials(self, vendor_id: int) -> VendorLetzshopCredentials | None: + """ + Get Letzshop credentials for a vendor. + + Args: + vendor_id: The vendor ID. + + Returns: + VendorLetzshopCredentials or None if not found. + """ + return ( + self.db.query(VendorLetzshopCredentials) + .filter(VendorLetzshopCredentials.vendor_id == vendor_id) + .first() + ) + + def get_credentials_or_raise(self, vendor_id: int) -> VendorLetzshopCredentials: + """ + Get Letzshop credentials for a vendor or raise an exception. + + Args: + vendor_id: The vendor ID. + + Returns: + VendorLetzshopCredentials. + + Raises: + CredentialsNotFoundError: If credentials are not found. + """ + credentials = self.get_credentials(vendor_id) + if credentials is None: + raise CredentialsNotFoundError( + f"Letzshop credentials not found for vendor {vendor_id}" + ) + return credentials + + def create_credentials( + self, + vendor_id: int, + api_key: str, + api_endpoint: str | None = None, + auto_sync_enabled: bool = False, + sync_interval_minutes: int = 15, + ) -> VendorLetzshopCredentials: + """ + Create Letzshop credentials for a vendor. + + Args: + vendor_id: The vendor ID. + api_key: The Letzshop API key (will be encrypted). + api_endpoint: Custom API endpoint (optional). + auto_sync_enabled: Whether to enable automatic sync. + sync_interval_minutes: Sync interval in minutes. + + Returns: + Created VendorLetzshopCredentials. + """ + # Encrypt the API key + encrypted_key = encrypt_value(api_key) + + credentials = VendorLetzshopCredentials( + vendor_id=vendor_id, + api_key_encrypted=encrypted_key, + api_endpoint=api_endpoint or DEFAULT_ENDPOINT, + auto_sync_enabled=auto_sync_enabled, + sync_interval_minutes=sync_interval_minutes, + ) + + self.db.add(credentials) + self.db.flush() + + logger.info(f"Created Letzshop credentials for vendor {vendor_id}") + return credentials + + def update_credentials( + self, + vendor_id: int, + api_key: str | None = None, + api_endpoint: str | None = None, + auto_sync_enabled: bool | None = None, + sync_interval_minutes: int | None = None, + ) -> VendorLetzshopCredentials: + """ + Update Letzshop credentials for a vendor. + + Args: + vendor_id: The vendor ID. + api_key: New API key (optional, will be encrypted if provided). + api_endpoint: New API endpoint (optional). + auto_sync_enabled: New auto-sync setting (optional). + sync_interval_minutes: New sync interval (optional). + + Returns: + Updated VendorLetzshopCredentials. + + Raises: + CredentialsNotFoundError: If credentials are not found. + """ + credentials = self.get_credentials_or_raise(vendor_id) + + if api_key is not None: + credentials.api_key_encrypted = encrypt_value(api_key) + if api_endpoint is not None: + credentials.api_endpoint = api_endpoint + if auto_sync_enabled is not None: + credentials.auto_sync_enabled = auto_sync_enabled + if sync_interval_minutes is not None: + credentials.sync_interval_minutes = sync_interval_minutes + + self.db.flush() + + logger.info(f"Updated Letzshop credentials for vendor {vendor_id}") + return credentials + + def delete_credentials(self, vendor_id: int) -> bool: + """ + Delete Letzshop credentials for a vendor. + + Args: + vendor_id: The vendor ID. + + Returns: + True if deleted, False if not found. + """ + credentials = self.get_credentials(vendor_id) + if credentials is None: + return False + + self.db.delete(credentials) + self.db.flush() + + logger.info(f"Deleted Letzshop credentials for vendor {vendor_id}") + return True + + def upsert_credentials( + self, + vendor_id: int, + api_key: str, + api_endpoint: str | None = None, + auto_sync_enabled: bool = False, + sync_interval_minutes: int = 15, + ) -> VendorLetzshopCredentials: + """ + Create or update Letzshop credentials for a vendor. + + Args: + vendor_id: The vendor ID. + api_key: The Letzshop API key (will be encrypted). + api_endpoint: Custom API endpoint (optional). + auto_sync_enabled: Whether to enable automatic sync. + sync_interval_minutes: Sync interval in minutes. + + Returns: + Created or updated VendorLetzshopCredentials. + """ + existing = self.get_credentials(vendor_id) + + if existing: + return self.update_credentials( + vendor_id=vendor_id, + api_key=api_key, + api_endpoint=api_endpoint, + auto_sync_enabled=auto_sync_enabled, + sync_interval_minutes=sync_interval_minutes, + ) + + return self.create_credentials( + vendor_id=vendor_id, + api_key=api_key, + api_endpoint=api_endpoint, + auto_sync_enabled=auto_sync_enabled, + sync_interval_minutes=sync_interval_minutes, + ) + + # ======================================================================== + # Key Decryption and Client Creation + # ======================================================================== + + def get_decrypted_api_key(self, vendor_id: int) -> str: + """ + Get the decrypted API key for a vendor. + + Args: + vendor_id: The vendor ID. + + Returns: + Decrypted API key. + + Raises: + CredentialsNotFoundError: If credentials are not found. + """ + credentials = self.get_credentials_or_raise(vendor_id) + return decrypt_value(credentials.api_key_encrypted) + + def get_masked_api_key(self, vendor_id: int) -> str: + """ + Get a masked version of the API key for display. + + Args: + vendor_id: The vendor ID. + + Returns: + Masked API key (e.g., "sk-a***************"). + + Raises: + CredentialsNotFoundError: If credentials are not found. + """ + api_key = self.get_decrypted_api_key(vendor_id) + return mask_api_key(api_key) + + def create_client(self, vendor_id: int) -> LetzshopClient: + """ + Create a Letzshop client for a vendor. + + Args: + vendor_id: The vendor ID. + + Returns: + Configured LetzshopClient. + + Raises: + CredentialsNotFoundError: If credentials are not found. + """ + credentials = self.get_credentials_or_raise(vendor_id) + api_key = decrypt_value(credentials.api_key_encrypted) + + return LetzshopClient( + api_key=api_key, + endpoint=credentials.api_endpoint, + ) + + # ======================================================================== + # Connection Testing + # ======================================================================== + + def test_connection(self, vendor_id: int) -> tuple[bool, float | None, str | None]: + """ + Test the connection for a vendor's credentials. + + Args: + vendor_id: The vendor ID. + + Returns: + Tuple of (success, response_time_ms, error_message). + """ + try: + with self.create_client(vendor_id) as client: + return client.test_connection() + except CredentialsNotFoundError: + return False, None, "Letzshop credentials not configured" + except Exception as e: + logger.error(f"Connection test failed for vendor {vendor_id}: {e}") + return False, None, str(e) + + def test_api_key( + self, + api_key: str, + api_endpoint: str | None = None, + ) -> tuple[bool, float | None, str | None]: + """ + Test an API key without saving it. + + Args: + api_key: The API key to test. + api_endpoint: Optional custom endpoint. + + Returns: + Tuple of (success, response_time_ms, error_message). + """ + try: + with LetzshopClient( + api_key=api_key, + endpoint=api_endpoint or DEFAULT_ENDPOINT, + ) as client: + return client.test_connection() + except Exception as e: + logger.error(f"API key test failed: {e}") + return False, None, str(e) + + # ======================================================================== + # Sync Status Updates + # ======================================================================== + + def update_sync_status( + self, + vendor_id: int, + status: str, + error: str | None = None, + ) -> VendorLetzshopCredentials | None: + """ + Update the last sync status for a vendor. + + Args: + vendor_id: The vendor ID. + status: Sync status (success, failed, partial). + error: Error message if sync failed. + + Returns: + Updated credentials or None if not found. + """ + credentials = self.get_credentials(vendor_id) + if credentials is None: + return None + + credentials.last_sync_at = datetime.now(UTC) + credentials.last_sync_status = status + credentials.last_sync_error = error + + self.db.flush() + + return credentials + + # ======================================================================== + # Status Helpers + # ======================================================================== + + def is_configured(self, vendor_id: int) -> bool: + """Check if Letzshop is configured for a vendor.""" + return self.get_credentials(vendor_id) is not None + + def get_status(self, vendor_id: int) -> dict: + """ + Get the Letzshop integration status for a vendor. + + Args: + vendor_id: The vendor ID. + + Returns: + Status dictionary with configuration and sync info. + """ + credentials = self.get_credentials(vendor_id) + + if credentials is None: + return { + "is_configured": False, + "is_connected": False, + "last_sync_at": None, + "last_sync_status": None, + "auto_sync_enabled": False, + } + + return { + "is_configured": True, + "is_connected": credentials.last_sync_status == "success", + "last_sync_at": credentials.last_sync_at, + "last_sync_status": credentials.last_sync_status, + "auto_sync_enabled": credentials.auto_sync_enabled, + } diff --git a/app/modules/marketplace/services/letzshop/order_service.py b/app/modules/marketplace/services/letzshop/order_service.py new file mode 100644 index 00000000..07be04a2 --- /dev/null +++ b/app/modules/marketplace/services/letzshop/order_service.py @@ -0,0 +1,1136 @@ +# app/services/letzshop/order_service.py +""" +Letzshop order service for handling order-related database operations. + +This service handles Letzshop-specific order operations while using the +unified Order model. All Letzshop orders are stored in the `orders` table +with `channel='letzshop'`. +""" + +import logging +from datetime import UTC, datetime +from typing import Any, Callable + +from sqlalchemy import String, and_, func, or_ +from sqlalchemy.orm import Session + +from app.services.order_service import order_service as unified_order_service +from app.services.subscription_service import subscription_service +from models.database.letzshop import ( + LetzshopFulfillmentQueue, + LetzshopHistoricalImportJob, + LetzshopSyncLog, + VendorLetzshopCredentials, +) +from models.database.marketplace_import_job import MarketplaceImportJob +from models.database.order import Order, OrderItem +from models.database.product import Product +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +class VendorNotFoundError(Exception): + """Raised when a vendor is not found.""" + + +class OrderNotFoundError(Exception): + """Raised when an order is not found.""" + + +class LetzshopOrderService: + """Service for Letzshop order database operations using unified Order model.""" + + def __init__(self, db: Session): + self.db = db + + # ========================================================================= + # Vendor Operations + # ========================================================================= + + def get_vendor(self, vendor_id: int) -> Vendor | None: + """Get vendor by ID.""" + return self.db.query(Vendor).filter(Vendor.id == vendor_id).first() + + def get_vendor_or_raise(self, vendor_id: int) -> Vendor: + """Get vendor by ID or raise VendorNotFoundError.""" + vendor = self.get_vendor(vendor_id) + if vendor is None: + raise VendorNotFoundError(f"Vendor with ID {vendor_id} not found") + return vendor + + def list_vendors_with_letzshop_status( + self, + skip: int = 0, + limit: int = 100, + configured_only: bool = False, + ) -> tuple[list[dict[str, Any]], int]: + """ + List vendors with their Letzshop integration status. + + Returns a tuple of (vendor_overviews, total_count). + """ + query = self.db.query(Vendor).filter(Vendor.is_active == True) # noqa: E712 + + if configured_only: + query = query.join( + VendorLetzshopCredentials, + Vendor.id == VendorLetzshopCredentials.vendor_id, + ) + + total = query.count() + vendors = query.order_by(Vendor.name).offset(skip).limit(limit).all() + + vendor_overviews = [] + for vendor in vendors: + credentials = ( + self.db.query(VendorLetzshopCredentials) + .filter(VendorLetzshopCredentials.vendor_id == vendor.id) + .first() + ) + + # Count Letzshop orders from unified orders table + pending_orders = 0 + total_orders = 0 + if credentials: + pending_orders = ( + self.db.query(func.count(Order.id)) + .filter( + Order.vendor_id == vendor.id, + Order.channel == "letzshop", + Order.status == "pending", + ) + .scalar() + or 0 + ) + total_orders = ( + self.db.query(func.count(Order.id)) + .filter( + Order.vendor_id == vendor.id, + Order.channel == "letzshop", + ) + .scalar() + or 0 + ) + + vendor_overviews.append( + { + "vendor_id": vendor.id, + "vendor_name": vendor.name, + "vendor_code": vendor.vendor_code, + "is_configured": credentials is not None, + "auto_sync_enabled": credentials.auto_sync_enabled + if credentials + else False, + "last_sync_at": credentials.last_sync_at if credentials else None, + "last_sync_status": credentials.last_sync_status + if credentials + else None, + "pending_orders": pending_orders, + "total_orders": total_orders, + } + ) + + return vendor_overviews, total + + # ========================================================================= + # Order Operations (using unified Order model) + # ========================================================================= + + def get_order(self, vendor_id: int, order_id: int) -> Order | None: + """Get a Letzshop order by ID for a specific vendor.""" + return ( + self.db.query(Order) + .filter( + Order.id == order_id, + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + ) + .first() + ) + + def get_order_or_raise(self, vendor_id: int, order_id: int) -> Order: + """Get a Letzshop order or raise OrderNotFoundError.""" + order = self.get_order(vendor_id, order_id) + if order is None: + raise OrderNotFoundError(f"Order {order_id} not found") + return order + + def get_order_by_shipment_id( + self, vendor_id: int, shipment_id: str + ) -> Order | None: + """Get a Letzshop order by external shipment ID.""" + return ( + self.db.query(Order) + .filter( + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + Order.external_shipment_id == shipment_id, + ) + .first() + ) + + def get_order_by_id(self, order_id: int) -> Order | None: + """Get a Letzshop order by its database ID.""" + return ( + self.db.query(Order) + .filter( + Order.id == order_id, + Order.channel == "letzshop", + ) + .first() + ) + + def list_orders( + self, + vendor_id: int | None = None, + skip: int = 0, + limit: int = 50, + status: str | None = None, + has_declined_items: bool | None = None, + search: str | None = None, + ) -> tuple[list[Order], int]: + """ + List Letzshop orders for a vendor (or all vendors). + + Args: + vendor_id: Vendor ID to filter by. If None, returns all vendors. + skip: Number of records to skip. + limit: Maximum number of records to return. + status: Filter by order status (pending, processing, shipped, etc.) + has_declined_items: If True, only return orders with declined items. + search: Search by order number, customer name, or email. + + Returns a tuple of (orders, total_count). + """ + query = self.db.query(Order).filter( + Order.channel == "letzshop", + ) + + # Filter by vendor if specified + if vendor_id is not None: + query = query.filter(Order.vendor_id == vendor_id) + + if status: + query = query.filter(Order.status == status) + + if search: + search_term = f"%{search}%" + query = query.filter( + or_( + Order.order_number.ilike(search_term), + Order.external_order_number.ilike(search_term), + Order.customer_email.ilike(search_term), + Order.customer_first_name.ilike(search_term), + Order.customer_last_name.ilike(search_term), + ) + ) + + # Filter for orders with declined items + if has_declined_items is True: + # Subquery to find orders with declined items + declined_order_ids = ( + self.db.query(OrderItem.order_id) + .filter(OrderItem.item_state == "confirmed_unavailable") + .subquery() + ) + query = query.filter(Order.id.in_(declined_order_ids)) + + total = query.count() + orders = ( + query.order_by(Order.order_date.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + return orders, total + + def get_order_stats(self, vendor_id: int | None = None) -> dict[str, int]: + """ + Get order counts by status for Letzshop orders. + + Args: + vendor_id: Vendor ID to filter by. If None, returns stats for all vendors. + + Returns: + Dict with counts for each status. + """ + query = self.db.query( + Order.status, + func.count(Order.id).label("count"), + ).filter(Order.channel == "letzshop") + + if vendor_id is not None: + query = query.filter(Order.vendor_id == vendor_id) + + status_counts = query.group_by(Order.status).all() + + stats = { + "pending": 0, + "processing": 0, + "shipped": 0, + "delivered": 0, + "cancelled": 0, + "refunded": 0, + "total": 0, + } + for status, count in status_counts: + if status in stats: + stats[status] = count + stats["total"] += count + + # Count orders with declined items + declined_query = ( + self.db.query(func.count(func.distinct(OrderItem.order_id))) + .join(Order, OrderItem.order_id == Order.id) + .filter( + Order.channel == "letzshop", + OrderItem.item_state == "confirmed_unavailable", + ) + ) + if vendor_id is not None: + declined_query = declined_query.filter(Order.vendor_id == vendor_id) + + stats["has_declined_items"] = declined_query.scalar() or 0 + + return stats + + def create_order( + self, + vendor_id: int, + shipment_data: dict[str, Any], + ) -> Order: + """ + Create a new Letzshop order from shipment data. + + Uses the unified order service to create the order. + """ + return unified_order_service.create_letzshop_order( + db=self.db, + vendor_id=vendor_id, + shipment_data=shipment_data, + ) + + def update_order_from_shipment( + self, + order: Order, + shipment_data: dict[str, Any], + ) -> Order: + """Update an existing order from shipment data.""" + order_data = shipment_data.get("order", {}) + + # Map Letzshop state to status + letzshop_state = shipment_data.get("state", "unconfirmed") + state_mapping = { + "unconfirmed": "pending", + "confirmed": "processing", + "declined": "cancelled", + } + new_status = state_mapping.get(letzshop_state, "processing") + + # Update status if changed + if order.status != new_status: + order.status = new_status + now = datetime.now(UTC) + if new_status == "processing": + order.confirmed_at = now + elif new_status == "cancelled": + order.cancelled_at = now + + # Update external data + order.external_data = shipment_data + + # Update locale if not set + if not order.customer_locale and order_data.get("locale"): + order.customer_locale = order_data.get("locale") + + # Update order_date if not set + if not order.order_date: + completed_at_str = order_data.get("completedAt") + if completed_at_str: + try: + if completed_at_str.endswith("Z"): + completed_at_str = completed_at_str[:-1] + "+00:00" + order.order_date = datetime.fromisoformat(completed_at_str) + except (ValueError, TypeError): + pass + + # Update inventory unit states in order items + inventory_units_data = shipment_data.get("inventoryUnits", []) + if isinstance(inventory_units_data, dict): + inventory_units_data = inventory_units_data.get("nodes", []) + + for unit in inventory_units_data: + unit_id = unit.get("id") + unit_state = unit.get("state") + if unit_id and unit_state: + # Find and update the corresponding order item + item = ( + self.db.query(OrderItem) + .filter( + OrderItem.order_id == order.id, + OrderItem.external_item_id == unit_id, + ) + .first() + ) + if item: + item.item_state = unit_state + + order.updated_at = datetime.now(UTC) + return order + + def mark_order_confirmed(self, order: Order) -> Order: + """Mark an order as confirmed (processing).""" + order.confirmed_at = datetime.now(UTC) + order.status = "processing" + order.updated_at = datetime.now(UTC) + return order + + def mark_order_rejected(self, order: Order) -> Order: + """Mark an order as rejected (cancelled).""" + order.cancelled_at = datetime.now(UTC) + order.status = "cancelled" + order.updated_at = datetime.now(UTC) + return order + + def update_inventory_unit_state( + self, order: Order, item_id: str, state: str + ) -> Order: + """ + Update the state of a single order item. + + Args: + order: The order containing the item. + item_id: The external item ID (Letzshop inventory unit ID). + state: The new state (confirmed_available, confirmed_unavailable). + + Returns: + The updated order. + """ + # Find and update the item + item = ( + self.db.query(OrderItem) + .filter( + OrderItem.order_id == order.id, + OrderItem.external_item_id == item_id, + ) + .first() + ) + + if item: + item.item_state = state + item.updated_at = datetime.now(UTC) + + # Check if all items are now processed + all_items = ( + self.db.query(OrderItem) + .filter(OrderItem.order_id == order.id) + .all() + ) + + all_confirmed = all( + i.item_state in ("confirmed_available", "confirmed_unavailable", "returned") + for i in all_items + ) + + if all_confirmed: + has_available = any( + i.item_state == "confirmed_available" for i in all_items + ) + all_unavailable = all( + i.item_state == "confirmed_unavailable" for i in all_items + ) + + now = datetime.now(UTC) + if all_unavailable: + order.status = "cancelled" + order.cancelled_at = now + elif has_available: + order.status = "processing" + order.confirmed_at = now + + order.updated_at = now + + return order + + def set_order_tracking( + self, + order: Order, + tracking_number: str, + tracking_provider: str, + ) -> Order: + """Set tracking information for an order.""" + order.tracking_number = tracking_number + order.tracking_provider = tracking_provider + order.shipped_at = datetime.now(UTC) + order.status = "shipped" + order.updated_at = datetime.now(UTC) + return order + + def get_orders_without_tracking( + self, + vendor_id: int, + limit: int = 100, + ) -> list[Order]: + """Get orders that have been confirmed but don't have tracking info.""" + return ( + self.db.query(Order) + .filter( + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + Order.status == "processing", # Confirmed orders + Order.tracking_number.is_(None), + Order.external_shipment_id.isnot(None), # Has shipment ID + ) + .limit(limit) + .all() + ) + + def update_tracking_from_shipment_data( + self, + order: Order, + shipment_data: dict[str, Any], + ) -> bool: + """ + Update order tracking from Letzshop shipment data. + + Args: + order: The order to update. + shipment_data: Raw shipment data from Letzshop API. + + Returns: + True if tracking was updated, False otherwise. + """ + tracking_data = shipment_data.get("tracking") or {} + tracking_number = tracking_data.get("code") or tracking_data.get("number") + + if not tracking_number: + return False + + tracking_provider = tracking_data.get("provider") + # Handle carrier object format: tracking { carrier { name code } } + if not tracking_provider and tracking_data.get("carrier"): + carrier = tracking_data.get("carrier", {}) + tracking_provider = carrier.get("code") or carrier.get("name") + + order.tracking_number = tracking_number + order.tracking_provider = tracking_provider + order.updated_at = datetime.now(UTC) + + logger.info( + f"Updated tracking for order {order.order_number}: " + f"{tracking_provider} {tracking_number}" + ) + return True + + def get_order_items(self, order: Order) -> list[OrderItem]: + """Get all items for an order.""" + return ( + self.db.query(OrderItem) + .filter(OrderItem.order_id == order.id) + .all() + ) + + # ========================================================================= + # Sync Log Operations + # ========================================================================= + + def list_sync_logs( + self, + vendor_id: int, + skip: int = 0, + limit: int = 50, + ) -> tuple[list[LetzshopSyncLog], int]: + """List sync logs for a vendor.""" + query = self.db.query(LetzshopSyncLog).filter( + LetzshopSyncLog.vendor_id == vendor_id + ) + total = query.count() + logs = ( + query.order_by(LetzshopSyncLog.started_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return logs, total + + # ========================================================================= + # Fulfillment Queue Operations + # ========================================================================= + + def list_fulfillment_queue( + self, + vendor_id: int, + skip: int = 0, + limit: int = 50, + status: str | None = None, + ) -> tuple[list[LetzshopFulfillmentQueue], int]: + """List fulfillment queue items for a vendor.""" + query = self.db.query(LetzshopFulfillmentQueue).filter( + LetzshopFulfillmentQueue.vendor_id == vendor_id + ) + + if status: + query = query.filter(LetzshopFulfillmentQueue.status == status) + + total = query.count() + items = ( + query.order_by(LetzshopFulfillmentQueue.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return items, total + + def add_to_fulfillment_queue( + self, + vendor_id: int, + order_id: int, + operation: str, + payload: dict[str, Any], + ) -> LetzshopFulfillmentQueue: + """Add an operation to the fulfillment queue.""" + queue_item = LetzshopFulfillmentQueue( + vendor_id=vendor_id, + order_id=order_id, + operation=operation, + payload=payload, + status="pending", + ) + self.db.add(queue_item) + return queue_item + + # ========================================================================= + # Unified Jobs Operations + # ========================================================================= + + def list_letzshop_jobs( + self, + vendor_id: int | None = None, + job_type: str | None = None, + status: str | None = None, + skip: int = 0, + limit: int = 20, + ) -> tuple[list[dict[str, Any]], int]: + """ + List unified Letzshop-related jobs for a vendor or all vendors. + + Combines product imports, historical order imports, and order syncs. + If vendor_id is None, returns jobs across all vendors. + """ + jobs = [] + + # Fetch vendor info - for single vendor or build lookup for all vendors + if vendor_id: + vendor = self.get_vendor(vendor_id) + vendor_lookup = {vendor_id: (vendor.name if vendor else None, vendor.vendor_code if vendor else None)} + else: + # Build lookup for all vendors when showing all jobs + from models.database.vendor import Vendor + vendors = self.db.query(Vendor.id, Vendor.name, Vendor.vendor_code).all() + vendor_lookup = {v.id: (v.name, v.vendor_code) for v in vendors} + + # Historical order imports from letzshop_historical_import_jobs + if job_type in (None, "historical_import"): + hist_query = self.db.query(LetzshopHistoricalImportJob) + if vendor_id: + hist_query = hist_query.filter( + LetzshopHistoricalImportJob.vendor_id == vendor_id, + ) + if status: + hist_query = hist_query.filter( + LetzshopHistoricalImportJob.status == status + ) + + hist_jobs = hist_query.order_by( + LetzshopHistoricalImportJob.created_at.desc() + ).all() + + for job in hist_jobs: + v_name, v_code = vendor_lookup.get(job.vendor_id, (None, None)) + jobs.append( + { + "id": job.id, + "type": "historical_import", + "status": job.status, + "created_at": job.created_at, + "started_at": job.started_at, + "completed_at": job.completed_at, + "records_processed": job.orders_processed or 0, + "records_succeeded": (job.orders_imported or 0) + + (job.orders_updated or 0), + "records_failed": job.orders_skipped or 0, + "vendor_id": job.vendor_id, + "vendor_name": v_name, + "vendor_code": v_code, + "current_phase": job.current_phase, + "error_message": job.error_message, + } + ) + + # Product imports from marketplace_import_jobs + if job_type in (None, "import"): + import_query = self.db.query(MarketplaceImportJob).filter( + MarketplaceImportJob.marketplace == "Letzshop", + ) + if vendor_id: + import_query = import_query.filter( + MarketplaceImportJob.vendor_id == vendor_id, + ) + if status: + import_query = import_query.filter( + MarketplaceImportJob.status == status + ) + + import_jobs = import_query.order_by( + MarketplaceImportJob.created_at.desc() + ).all() + + for job in import_jobs: + v_name, v_code = vendor_lookup.get(job.vendor_id, (None, None)) + jobs.append( + { + "id": job.id, + "type": "import", + "status": job.status, + "created_at": job.created_at, + "started_at": job.started_at, + "completed_at": job.completed_at, + "records_processed": job.total_processed or 0, + "records_succeeded": (job.imported_count or 0) + + (job.updated_count or 0), + "records_failed": job.error_count or 0, + "vendor_id": job.vendor_id, + "vendor_name": v_name, + "vendor_code": v_code, + } + ) + + # Order syncs from letzshop_sync_logs + if job_type in (None, "order_sync"): + sync_query = self.db.query(LetzshopSyncLog).filter( + LetzshopSyncLog.operation_type == "order_import", + ) + if vendor_id: + sync_query = sync_query.filter(LetzshopSyncLog.vendor_id == vendor_id) + if status: + sync_query = sync_query.filter(LetzshopSyncLog.status == status) + + sync_logs = sync_query.order_by(LetzshopSyncLog.created_at.desc()).all() + + for log in sync_logs: + v_name, v_code = vendor_lookup.get(log.vendor_id, (None, None)) + jobs.append( + { + "id": log.id, + "type": "order_sync", + "status": log.status, + "created_at": log.created_at, + "started_at": log.started_at, + "completed_at": log.completed_at, + "records_processed": log.records_processed or 0, + "records_succeeded": log.records_succeeded or 0, + "records_failed": log.records_failed or 0, + "vendor_id": log.vendor_id, + "vendor_name": v_name, + "vendor_code": v_code, + "error_details": log.error_details, + } + ) + + # Product exports from letzshop_sync_logs + if job_type in (None, "export"): + export_query = self.db.query(LetzshopSyncLog).filter( + LetzshopSyncLog.operation_type == "product_export", + ) + if vendor_id: + export_query = export_query.filter(LetzshopSyncLog.vendor_id == vendor_id) + if status: + export_query = export_query.filter(LetzshopSyncLog.status == status) + + export_logs = export_query.order_by( + LetzshopSyncLog.created_at.desc() + ).all() + + for log in export_logs: + v_name, v_code = vendor_lookup.get(log.vendor_id, (None, None)) + jobs.append( + { + "id": log.id, + "type": "export", + "status": log.status, + "created_at": log.created_at, + "started_at": log.started_at, + "completed_at": log.completed_at, + "records_processed": log.records_processed or 0, + "records_succeeded": log.records_succeeded or 0, + "records_failed": log.records_failed or 0, + "vendor_id": log.vendor_id, + "vendor_name": v_name, + "vendor_code": v_code, + "error_details": log.error_details, + } + ) + + # Sort all jobs by created_at descending + jobs.sort(key=lambda x: x["created_at"], reverse=True) + + total = len(jobs) + jobs = jobs[skip : skip + limit] + + return jobs, total + + # ========================================================================= + # Historical Import Operations + # ========================================================================= + + def import_historical_shipments( + self, + vendor_id: int, + shipments: list[dict[str, Any]], + match_products: bool = True, + progress_callback: Callable[[int, int, int, int], None] | None = None, + ) -> dict[str, Any]: + """ + Import historical shipments into the unified orders table. + + Args: + vendor_id: Vendor ID to import for. + shipments: List of shipment data from Letzshop API. + match_products: Whether to match GTIN to local products. + progress_callback: Optional callback(processed, imported, updated, skipped) + + Returns: + Dict with import statistics. + """ + stats = { + "total": len(shipments), + "imported": 0, + "updated": 0, + "skipped": 0, + "errors": 0, + "limit_exceeded": 0, + "products_matched": 0, + "products_not_found": 0, + "eans_processed": set(), + "eans_matched": set(), + "eans_not_found": set(), + "error_messages": [], + } + + # Get subscription usage upfront for batch efficiency + usage = subscription_service.get_usage(self.db, vendor_id) + orders_remaining = usage.orders_remaining # None = unlimited + + for i, shipment in enumerate(shipments): + shipment_id = shipment.get("id") + if not shipment_id: + continue + + # Check if order already exists + existing_order = self.get_order_by_shipment_id(vendor_id, shipment_id) + + if existing_order: + # Check if we need to update + letzshop_state = shipment.get("state") + state_mapping = { + "unconfirmed": "pending", + "confirmed": "processing", + "declined": "cancelled", + } + expected_status = state_mapping.get(letzshop_state, "processing") + + needs_update = False + if existing_order.status != expected_status: + self.update_order_from_shipment(existing_order, shipment) + needs_update = True + + # Update order_date if missing + if not existing_order.order_date: + order_data = shipment.get("order", {}) + completed_at_str = order_data.get("completedAt") + if completed_at_str: + try: + if completed_at_str.endswith("Z"): + completed_at_str = completed_at_str[:-1] + "+00:00" + existing_order.order_date = datetime.fromisoformat( + completed_at_str + ) + needs_update = True + except (ValueError, TypeError): + pass + + if needs_update: + self.db.commit() # noqa: SVC-006 - background task needs incremental commits + stats["updated"] += 1 + else: + stats["skipped"] += 1 + else: + # Check tier limit before creating order + if orders_remaining is not None and orders_remaining <= 0: + stats["limit_exceeded"] += 1 + stats["error_messages"].append( + f"Shipment {shipment_id}: Order limit reached" + ) + continue + + # Create new order using unified service + try: + self.create_order(vendor_id, shipment) + self.db.commit() # noqa: SVC-006 - background task needs incremental commits + stats["imported"] += 1 + + # Decrement remaining count for batch efficiency + if orders_remaining is not None: + orders_remaining -= 1 + + except Exception as e: + self.db.rollback() # Rollback failed order + stats["errors"] += 1 + stats["error_messages"].append( + f"Shipment {shipment_id}: {str(e)}" + ) + logger.error(f"Error importing shipment {shipment_id}: {e}") + + # Process GTINs for matching + if match_products: + inventory_units = shipment.get("inventoryUnits", []) + for unit in inventory_units: + variant = unit.get("variant", {}) or {} + trade_id = variant.get("tradeId") or {} + gtin = trade_id.get("number") + + if gtin: + stats["eans_processed"].add(gtin) + + # Report progress + if progress_callback and ((i + 1) % 10 == 0 or i == len(shipments) - 1): + progress_callback( + i + 1, + stats["imported"], + stats["updated"], + stats["skipped"], + ) + + # Match GTINs to local products + if match_products and stats["eans_processed"]: + matched, not_found = self._match_gtins_to_products( + vendor_id, list(stats["eans_processed"]) + ) + stats["eans_matched"] = matched + stats["eans_not_found"] = not_found + stats["products_matched"] = len(matched) + stats["products_not_found"] = len(not_found) + + # Convert sets to lists for JSON serialization + stats["eans_processed"] = list(stats["eans_processed"]) + stats["eans_matched"] = list(stats["eans_matched"]) + stats["eans_not_found"] = list(stats["eans_not_found"]) + + return stats + + def _match_gtins_to_products( + self, + vendor_id: int, + gtins: list[str], + ) -> tuple[set[str], set[str]]: + """Match GTIN codes to local products.""" + if not gtins: + return set(), set() + + products = ( + self.db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.gtin.in_(gtins), + ) + .all() + ) + + matched_gtins = {p.gtin for p in products if p.gtin} + not_found_gtins = set(gtins) - matched_gtins + + logger.info( + f"GTIN matching: {len(matched_gtins)} matched, " + f"{len(not_found_gtins)} not found" + ) + return matched_gtins, not_found_gtins + + def get_products_by_gtins( + self, + vendor_id: int, + gtins: list[str], + ) -> dict[str, Product]: + """Get products by their GTIN codes.""" + if not gtins: + return {} + + products = ( + self.db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.gtin.in_(gtins), + ) + .all() + ) + + return {p.gtin: p for p in products if p.gtin} + + def get_historical_import_summary( + self, + vendor_id: int, + ) -> dict[str, Any]: + """Get summary of Letzshop orders for a vendor.""" + # Count orders by status + status_counts = ( + self.db.query( + Order.status, + func.count(Order.id).label("count"), + ) + .filter( + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + ) + .group_by(Order.status) + .all() + ) + + # Count orders by locale + locale_counts = ( + self.db.query( + Order.customer_locale, + func.count(Order.id).label("count"), + ) + .filter( + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + ) + .group_by(Order.customer_locale) + .all() + ) + + # Count orders by country + country_counts = ( + self.db.query( + Order.ship_country_iso, + func.count(Order.id).label("count"), + ) + .filter( + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + ) + .group_by(Order.ship_country_iso) + .all() + ) + + # Total orders + total_orders = ( + self.db.query(func.count(Order.id)) + .filter( + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + ) + .scalar() + or 0 + ) + + # Unique customers + unique_customers = ( + self.db.query(func.count(func.distinct(Order.customer_email))) + .filter( + Order.vendor_id == vendor_id, + Order.channel == "letzshop", + ) + .scalar() + or 0 + ) + + return { + "total_orders": total_orders, + "unique_customers": unique_customers, + "orders_by_status": {status: count for status, count in status_counts}, + "orders_by_locale": { + locale or "unknown": count for locale, count in locale_counts + }, + "orders_by_country": { + country or "unknown": count for country, count in country_counts + }, + } + + # ========================================================================= + # Historical Import Job Operations + # ========================================================================= + + def get_running_historical_import_job( + self, + vendor_id: int, + ) -> LetzshopHistoricalImportJob | None: + """Get any running historical import job for a vendor.""" + return ( + self.db.query(LetzshopHistoricalImportJob) + .filter( + LetzshopHistoricalImportJob.vendor_id == vendor_id, + LetzshopHistoricalImportJob.status.in_( + ["pending", "fetching", "processing"] + ), + ) + .first() + ) + + def create_historical_import_job( + self, + vendor_id: int, + user_id: int, + ) -> LetzshopHistoricalImportJob: + """Create a new historical import job.""" + job = LetzshopHistoricalImportJob( + vendor_id=vendor_id, + user_id=user_id, + status="pending", + ) + self.db.add(job) + self.db.commit() # noqa: SVC-006 - job must be visible immediately before background task starts + self.db.refresh(job) + return job + + def get_historical_import_job_by_id( + self, + vendor_id: int, + job_id: int, + ) -> LetzshopHistoricalImportJob | None: + """Get a historical import job by ID.""" + return ( + self.db.query(LetzshopHistoricalImportJob) + .filter( + LetzshopHistoricalImportJob.id == job_id, + LetzshopHistoricalImportJob.vendor_id == vendor_id, + ) + .first() + ) + + def update_job_celery_task_id( + self, + job_id: int, + celery_task_id: str, + ) -> bool: + """ + Update the Celery task ID for a historical import job. + + Args: + job_id: The job ID to update. + celery_task_id: The Celery task ID to set. + + Returns: + True if updated successfully, False if job not found. + """ + job = ( + self.db.query(LetzshopHistoricalImportJob) + .filter(LetzshopHistoricalImportJob.id == job_id) + .first() + ) + if job: + job.celery_task_id = celery_task_id + self.db.commit() # noqa: SVC-006 - Called from API endpoint + return True + return False diff --git a/app/modules/marketplace/services/letzshop/vendor_sync_service.py b/app/modules/marketplace/services/letzshop/vendor_sync_service.py new file mode 100644 index 00000000..7fa963bb --- /dev/null +++ b/app/modules/marketplace/services/letzshop/vendor_sync_service.py @@ -0,0 +1,521 @@ +# app/services/letzshop/vendor_sync_service.py +""" +Service for syncing Letzshop vendor directory to local cache. + +Fetches vendor data from Letzshop's public GraphQL API and stores it +in the letzshop_vendor_cache table for fast lookups during signup. +""" + +import logging +from datetime import UTC, datetime +from typing import Any, Callable + +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.orm import Session + +from app.services.letzshop.client_service import LetzshopClient +from models.database.letzshop import LetzshopVendorCache + +logger = logging.getLogger(__name__) + + +class LetzshopVendorSyncService: + """ + Service for syncing Letzshop vendor directory. + + Usage: + service = LetzshopVendorSyncService(db) + stats = service.sync_all_vendors() + """ + + def __init__(self, db: Session): + """Initialize the sync service.""" + self.db = db + + def sync_all_vendors( + self, + progress_callback: Callable[[int, int, int], None] | None = None, + max_pages: int | None = None, + ) -> dict[str, Any]: + """ + Sync all vendors from Letzshop to local cache. + + Args: + progress_callback: Optional callback(page, fetched, total) for progress. + + Returns: + Dictionary with sync statistics. + """ + stats = { + "started_at": datetime.now(UTC), + "total_fetched": 0, + "created": 0, + "updated": 0, + "errors": 0, + "error_details": [], + } + + logger.info("Starting Letzshop vendor directory sync...") + + # Create client (no API key needed for public vendor data) + client = LetzshopClient(api_key="") + + try: + # Fetch all vendors + vendors = client.get_all_vendors_paginated( + page_size=50, + max_pages=max_pages, + progress_callback=progress_callback, + ) + + stats["total_fetched"] = len(vendors) + logger.info(f"Fetched {len(vendors)} vendors from Letzshop") + + # Process each vendor + for vendor_data in vendors: + try: + result = self._upsert_vendor(vendor_data) + if result == "created": + stats["created"] += 1 + elif result == "updated": + stats["updated"] += 1 + except Exception as e: + stats["errors"] += 1 + error_info = { + "vendor_id": vendor_data.get("id"), + "slug": vendor_data.get("slug"), + "error": str(e), + } + stats["error_details"].append(error_info) + logger.error(f"Error processing vendor {vendor_data.get('slug')}: {e}") + + # Commit all changes + self.db.commit() + logger.info( + f"Sync complete: {stats['created']} created, " + f"{stats['updated']} updated, {stats['errors']} errors" + ) + + except Exception as e: + self.db.rollback() + logger.error(f"Vendor sync failed: {e}") + stats["error"] = str(e) + raise + + finally: + client.close() + + stats["completed_at"] = datetime.now(UTC) + stats["duration_seconds"] = ( + stats["completed_at"] - stats["started_at"] + ).total_seconds() + + return stats + + def _upsert_vendor(self, vendor_data: dict[str, Any]) -> str: + """ + Insert or update a vendor in the cache. + + Args: + vendor_data: Raw vendor data from Letzshop API. + + Returns: + "created" or "updated" indicating the operation performed. + """ + letzshop_id = vendor_data.get("id") + slug = vendor_data.get("slug") + + if not letzshop_id or not slug: + raise ValueError("Vendor missing required id or slug") + + # Parse the vendor data + parsed = self._parse_vendor_data(vendor_data) + + # Check if exists + existing = ( + self.db.query(LetzshopVendorCache) + .filter(LetzshopVendorCache.letzshop_id == letzshop_id) + .first() + ) + + if existing: + # Update existing record (preserve claimed status) + for key, value in parsed.items(): + if key not in ("claimed_by_vendor_id", "claimed_at"): + setattr(existing, key, value) + existing.last_synced_at = datetime.now(UTC) + return "updated" + else: + # Create new record + cache_entry = LetzshopVendorCache( + **parsed, + last_synced_at=datetime.now(UTC), + ) + self.db.add(cache_entry) + return "created" + + def _parse_vendor_data(self, data: dict[str, Any]) -> dict[str, Any]: + """ + Parse raw Letzshop vendor data into cache model fields. + + Args: + data: Raw vendor data from Letzshop API. + + Returns: + Dictionary of parsed fields for LetzshopVendorCache. + """ + # Extract location + location = data.get("location") or {} + country = location.get("country") or {} + + # Extract descriptions + description = data.get("description") or {} + + # Extract opening hours + opening_hours = data.get("openingHours") or {} + + # Extract categories (list of translated name objects) + categories = [] + for cat in data.get("vendorCategories") or []: + cat_name = cat.get("name") or {} + # Prefer English, fallback to French or German + name = cat_name.get("en") or cat_name.get("fr") or cat_name.get("de") + if name: + categories.append(name) + + # Extract social media URLs + social_links = [] + for link in data.get("socialMediaLinks") or []: + url = link.get("url") + if url: + social_links.append(url) + + # Extract background image + bg_image = data.get("backgroundImage") or {} + + return { + "letzshop_id": data.get("id"), + "slug": data.get("slug"), + "name": data.get("name"), + "company_name": data.get("companyName") or data.get("legalName"), + "is_active": data.get("active", True), + # Descriptions + "description_en": description.get("en"), + "description_fr": description.get("fr"), + "description_de": description.get("de"), + # Contact + "email": data.get("email"), + "phone": data.get("phone"), + "fax": data.get("fax"), + "website": data.get("homepage"), + # Location + "street": location.get("street"), + "street_number": location.get("number"), + "city": location.get("city"), + "zipcode": location.get("zipcode"), + "country_iso": country.get("iso", "LU"), + "latitude": str(data.get("lat")) if data.get("lat") else None, + "longitude": str(data.get("lng")) if data.get("lng") else None, + # Categories and media + "categories": categories, + "background_image_url": bg_image.get("url"), + "social_media_links": social_links, + # Opening hours + "opening_hours_en": opening_hours.get("en"), + "opening_hours_fr": opening_hours.get("fr"), + "opening_hours_de": opening_hours.get("de"), + # Representative + "representative_name": data.get("representative"), + "representative_title": data.get("representativeTitle"), + # Raw data for reference + "raw_data": data, + } + + def sync_single_vendor(self, slug: str) -> LetzshopVendorCache | None: + """ + Sync a single vendor by slug. + + Useful for on-demand refresh when a user looks up a vendor. + + Args: + slug: The vendor's URL slug. + + Returns: + The updated/created cache entry, or None if not found. + """ + client = LetzshopClient(api_key="") + + try: + vendor_data = client.get_vendor_by_slug(slug) + + if not vendor_data: + logger.warning(f"Vendor not found on Letzshop: {slug}") + return None + + result = self._upsert_vendor(vendor_data) + self.db.commit() + + logger.info(f"Single vendor sync: {slug} ({result})") + + return ( + self.db.query(LetzshopVendorCache) + .filter(LetzshopVendorCache.slug == slug) + .first() + ) + + finally: + client.close() + + def get_cached_vendor(self, slug: str) -> LetzshopVendorCache | None: + """ + Get a vendor from cache by slug. + + Args: + slug: The vendor's URL slug. + + Returns: + Cache entry or None if not found. + """ + return ( + self.db.query(LetzshopVendorCache) + .filter(LetzshopVendorCache.slug == slug.lower()) + .first() + ) + + def search_cached_vendors( + self, + search: str | None = None, + city: str | None = None, + category: str | None = None, + only_unclaimed: bool = False, + page: int = 1, + limit: int = 20, + ) -> tuple[list[LetzshopVendorCache], int]: + """ + Search cached vendors with filters. + + Args: + search: Search term for name. + city: Filter by city. + category: Filter by category. + only_unclaimed: Only return vendors not yet claimed. + page: Page number (1-indexed). + limit: Items per page. + + Returns: + Tuple of (vendors list, total count). + """ + query = self.db.query(LetzshopVendorCache).filter( + LetzshopVendorCache.is_active == True # noqa: E712 + ) + + if search: + search_term = f"%{search.lower()}%" + query = query.filter( + func.lower(LetzshopVendorCache.name).like(search_term) + ) + + if city: + query = query.filter( + func.lower(LetzshopVendorCache.city) == city.lower() + ) + + if category: + # Search in JSON array + query = query.filter( + LetzshopVendorCache.categories.contains([category]) + ) + + if only_unclaimed: + query = query.filter( + LetzshopVendorCache.claimed_by_vendor_id.is_(None) + ) + + # Get total count + total = query.count() + + # Apply pagination + offset = (page - 1) * limit + vendors = ( + query.order_by(LetzshopVendorCache.name) + .offset(offset) + .limit(limit) + .all() + ) + + return vendors, total + + def get_sync_stats(self) -> dict[str, Any]: + """ + Get statistics about the vendor cache. + + Returns: + Dictionary with cache statistics. + """ + total = self.db.query(LetzshopVendorCache).count() + active = ( + self.db.query(LetzshopVendorCache) + .filter(LetzshopVendorCache.is_active == True) # noqa: E712 + .count() + ) + claimed = ( + self.db.query(LetzshopVendorCache) + .filter(LetzshopVendorCache.claimed_by_vendor_id.isnot(None)) + .count() + ) + + # Get last sync time + last_synced = ( + self.db.query(func.max(LetzshopVendorCache.last_synced_at)).scalar() + ) + + # Get unique cities + cities = ( + self.db.query(LetzshopVendorCache.city) + .filter(LetzshopVendorCache.city.isnot(None)) + .distinct() + .count() + ) + + return { + "total_vendors": total, + "active_vendors": active, + "claimed_vendors": claimed, + "unclaimed_vendors": active - claimed, + "unique_cities": cities, + "last_synced_at": last_synced.isoformat() if last_synced else None, + } + + def mark_vendor_claimed( + self, + letzshop_slug: str, + vendor_id: int, + ) -> bool: + """ + Mark a Letzshop vendor as claimed by a platform vendor. + + Args: + letzshop_slug: The Letzshop vendor slug. + vendor_id: The platform vendor ID that claimed it. + + Returns: + True if successful, False if vendor not found. + """ + cache_entry = self.get_cached_vendor(letzshop_slug) + + if not cache_entry: + return False + + cache_entry.claimed_by_vendor_id = vendor_id + cache_entry.claimed_at = datetime.now(UTC) + self.db.commit() + + logger.info(f"Vendor {letzshop_slug} claimed by vendor_id={vendor_id}") + return True + + def create_vendor_from_cache( + self, + letzshop_slug: str, + company_id: int, + ) -> dict[str, Any]: + """ + Create a platform vendor from a cached Letzshop vendor. + + Args: + letzshop_slug: The Letzshop vendor slug. + company_id: The company ID to create the vendor under. + + Returns: + Dictionary with created vendor info. + + Raises: + ValueError: If vendor not found, already claimed, or company not found. + """ + import random + + from sqlalchemy import func + + from app.services.admin_service import admin_service + from models.database.company import Company + from models.database.vendor import Vendor + from models.schema.vendor import VendorCreate + + # Get cache entry + cache_entry = self.get_cached_vendor(letzshop_slug) + if not cache_entry: + raise ValueError(f"Letzshop vendor '{letzshop_slug}' not found in cache") + + if cache_entry.is_claimed: + raise ValueError( + f"Letzshop vendor '{cache_entry.name}' is already claimed " + f"by vendor ID {cache_entry.claimed_by_vendor_id}" + ) + + # Verify company exists + company = self.db.query(Company).filter(Company.id == company_id).first() + if not company: + raise ValueError(f"Company with ID {company_id} not found") + + # Generate vendor code from slug + vendor_code = letzshop_slug.upper().replace("-", "_")[:20] + + # Check if vendor code already exists + existing = ( + self.db.query(Vendor) + .filter(func.upper(Vendor.vendor_code) == vendor_code) + .first() + ) + if existing: + vendor_code = f"{vendor_code[:16]}_{random.randint(100, 999)}" + + # Generate subdomain from slug + subdomain = letzshop_slug.lower().replace("_", "-")[:30] + existing_subdomain = ( + self.db.query(Vendor) + .filter(func.lower(Vendor.subdomain) == subdomain) + .first() + ) + if existing_subdomain: + subdomain = f"{subdomain[:26]}-{random.randint(100, 999)}" + + # Create vendor data from cache + address = f"{cache_entry.street or ''} {cache_entry.street_number or ''}".strip() + vendor_data = VendorCreate( + name=cache_entry.name, + vendor_code=vendor_code, + subdomain=subdomain, + company_id=company_id, + email=cache_entry.email or company.email, + phone=cache_entry.phone, + description=cache_entry.description_en or cache_entry.description_fr or "", + city=cache_entry.city, + country=cache_entry.country_iso or "LU", + website=cache_entry.website, + address_line_1=address or None, + postal_code=cache_entry.zipcode, + ) + + # Create vendor + vendor = admin_service.create_vendor(self.db, vendor_data) + + # Mark the Letzshop vendor as claimed (commits internally) # noqa: SVC-006 + self.mark_vendor_claimed(letzshop_slug, vendor.id) + + logger.info( + f"Created vendor {vendor.vendor_code} from Letzshop vendor {letzshop_slug}" + ) + + return { + "id": vendor.id, + "vendor_code": vendor.vendor_code, + "name": vendor.name, + "subdomain": vendor.subdomain, + "company_id": vendor.company_id, + } + + +# Singleton-style function for easy access +def get_vendor_sync_service(db: Session) -> LetzshopVendorSyncService: + """Get a vendor sync service instance.""" + return LetzshopVendorSyncService(db) diff --git a/app/modules/marketplace/services/letzshop_export_service.py b/app/modules/marketplace/services/letzshop_export_service.py new file mode 100644 index 00000000..5e4dee29 --- /dev/null +++ b/app/modules/marketplace/services/letzshop_export_service.py @@ -0,0 +1,338 @@ +# app/services/letzshop_export_service.py +""" +Service for exporting products to Letzshop CSV format. + +Generates Google Shopping compatible CSV files for Letzshop marketplace. +""" + +import csv +import io +import logging +from datetime import UTC, datetime + +from sqlalchemy.orm import Session, joinedload + +from models.database.letzshop import LetzshopSyncLog +from models.database.marketplace_product import MarketplaceProduct +from models.database.product import Product + +logger = logging.getLogger(__name__) + +# Letzshop CSV columns in order +LETZSHOP_CSV_COLUMNS = [ + "id", + "title", + "description", + "link", + "image_link", + "additional_image_link", + "availability", + "price", + "sale_price", + "brand", + "gtin", + "mpn", + "google_product_category", + "product_type", + "condition", + "adult", + "multipack", + "is_bundle", + "age_group", + "color", + "gender", + "material", + "pattern", + "size", + "size_type", + "size_system", + "item_group_id", + "custom_label_0", + "custom_label_1", + "custom_label_2", + "custom_label_3", + "custom_label_4", + "identifier_exists", + "unit_pricing_measure", + "unit_pricing_base_measure", + "shipping", + "atalanda:tax_rate", + "atalanda:quantity", + "atalanda:boost_sort", + "atalanda:delivery_method", +] + + +class LetzshopExportService: + """Service for exporting products to Letzshop CSV format.""" + + def __init__(self, default_tax_rate: float = 17.0): + """ + Initialize the export service. + + Args: + default_tax_rate: Default VAT rate for Luxembourg (17%) + """ + self.default_tax_rate = default_tax_rate + + def export_vendor_products( + self, + db: Session, + vendor_id: int, + language: str = "en", + include_inactive: bool = False, + ) -> str: + """ + Export all products for a vendor in Letzshop CSV format. + + Args: + db: Database session + vendor_id: Vendor ID to export products for + language: Language for title/description (en, fr, de) + include_inactive: Whether to include inactive products + + Returns: + CSV string content + """ + # Query products for this vendor with their marketplace product data + query = ( + db.query(Product) + .filter(Product.vendor_id == vendor_id) + .options( + joinedload(Product.marketplace_product).joinedload( + MarketplaceProduct.translations + ) + ) + ) + + if not include_inactive: + query = query.filter(Product.is_active == True) + + products = query.all() + + logger.info( + f"Exporting {len(products)} products for vendor {vendor_id} in {language}" + ) + + return self._generate_csv(products, language) + + def export_marketplace_products( + self, + db: Session, + marketplace: str = "Letzshop", + language: str = "en", + limit: int | None = None, + ) -> str: + """ + Export marketplace products directly (admin use). + + Args: + db: Database session + marketplace: Filter by marketplace source + language: Language for title/description + limit: Optional limit on number of products + + Returns: + CSV string content + """ + query = ( + db.query(MarketplaceProduct) + .filter(MarketplaceProduct.is_active == True) + .options(joinedload(MarketplaceProduct.translations)) + ) + + if marketplace: + query = query.filter( + MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") + ) + + if limit: + query = query.limit(limit) + + products = query.all() + + logger.info( + f"Exporting {len(products)} marketplace products for {marketplace} in {language}" + ) + + return self._generate_csv_from_marketplace_products(products, language) + + def _generate_csv(self, products: list[Product], language: str) -> str: + """Generate CSV from vendor Product objects.""" + output = io.StringIO() + writer = csv.DictWriter( + output, + fieldnames=LETZSHOP_CSV_COLUMNS, + delimiter="\t", + quoting=csv.QUOTE_MINIMAL, + ) + writer.writeheader() + + for product in products: + if product.marketplace_product: + row = self._product_to_row(product, language) + writer.writerow(row) + + return output.getvalue() + + def _generate_csv_from_marketplace_products( + self, products: list[MarketplaceProduct], language: str + ) -> str: + """Generate CSV from MarketplaceProduct objects directly.""" + output = io.StringIO() + writer = csv.DictWriter( + output, + fieldnames=LETZSHOP_CSV_COLUMNS, + delimiter="\t", + quoting=csv.QUOTE_MINIMAL, + ) + writer.writeheader() + + for mp in products: + row = self._marketplace_product_to_row(mp, language) + writer.writerow(row) + + return output.getvalue() + + def _product_to_row(self, product: Product, language: str) -> dict: + """Convert a Product (with MarketplaceProduct) to a CSV row.""" + mp = product.marketplace_product + return self._marketplace_product_to_row( + mp, language, vendor_sku=product.vendor_sku + ) + + def _marketplace_product_to_row( + self, + mp: MarketplaceProduct, + language: str, + vendor_sku: str | None = None, + ) -> dict: + """Convert a MarketplaceProduct to a CSV row dict.""" + # Get localized title and description + title = mp.get_title(language) or "" + description = mp.get_description(language) or "" + + # Format price with currency + price = "" + if mp.price_numeric: + price = f"{mp.price_numeric:.2f} {mp.currency or 'EUR'}" + elif mp.price: + price = mp.price + + # Format sale price + sale_price = "" + if mp.sale_price_numeric: + sale_price = f"{mp.sale_price_numeric:.2f} {mp.currency or 'EUR'}" + elif mp.sale_price: + sale_price = mp.sale_price + + # Additional images - join with comma if multiple + additional_images = "" + if mp.additional_images: + additional_images = ",".join(mp.additional_images) + elif mp.additional_image_link: + additional_images = mp.additional_image_link + + # Determine identifier_exists + identifier_exists = mp.identifier_exists + if not identifier_exists: + identifier_exists = "yes" if (mp.gtin or mp.mpn) else "no" + + return { + "id": vendor_sku or mp.marketplace_product_id, + "title": title, + "description": description, + "link": mp.link or mp.source_url or "", + "image_link": mp.image_link or "", + "additional_image_link": additional_images, + "availability": mp.availability or "in stock", + "price": price, + "sale_price": sale_price, + "brand": mp.brand or "", + "gtin": mp.gtin or "", + "mpn": mp.mpn or "", + "google_product_category": mp.google_product_category or "", + "product_type": mp.product_type_raw or "", + "condition": mp.condition or "new", + "adult": mp.adult or "no", + "multipack": str(mp.multipack) if mp.multipack else "", + "is_bundle": mp.is_bundle or "no", + "age_group": mp.age_group or "", + "color": mp.color or "", + "gender": mp.gender or "", + "material": mp.material or "", + "pattern": mp.pattern or "", + "size": mp.size or "", + "size_type": mp.size_type or "", + "size_system": mp.size_system or "", + "item_group_id": mp.item_group_id or "", + "custom_label_0": mp.custom_label_0 or "", + "custom_label_1": mp.custom_label_1 or "", + "custom_label_2": mp.custom_label_2 or "", + "custom_label_3": mp.custom_label_3 or "", + "custom_label_4": mp.custom_label_4 or "", + "identifier_exists": identifier_exists, + "unit_pricing_measure": mp.unit_pricing_measure or "", + "unit_pricing_base_measure": mp.unit_pricing_base_measure or "", + "shipping": mp.shipping or "", + "atalanda:tax_rate": str(self.default_tax_rate), + "atalanda:quantity": "", # Would need inventory data + "atalanda:boost_sort": "", + "atalanda:delivery_method": "", + } + + def log_export( + self, + db: Session, + vendor_id: int, + started_at: datetime, + completed_at: datetime, + files_processed: int, + files_succeeded: int, + files_failed: int, + products_exported: int, + triggered_by: str, + error_details: dict | None = None, + ) -> LetzshopSyncLog: + """ + Log an export operation to the sync log. + + Args: + db: Database session + vendor_id: Vendor ID + started_at: When the export started + completed_at: When the export completed + files_processed: Number of language files to export (e.g., 3) + files_succeeded: Number of files successfully exported + files_failed: Number of files that failed + products_exported: Total products in the export + triggered_by: Who triggered the export (e.g., "admin:123") + error_details: Optional error details if any failures + + Returns: + Created LetzshopSyncLog entry + """ + sync_log = LetzshopSyncLog( + vendor_id=vendor_id, + operation_type="product_export", + direction="outbound", + status="completed" if files_failed == 0 else "partial", + records_processed=files_processed, + records_succeeded=files_succeeded, + records_failed=files_failed, + started_at=started_at, + completed_at=completed_at, + duration_seconds=int((completed_at - started_at).total_seconds()), + triggered_by=triggered_by, + error_details={ + "products_exported": products_exported, + **(error_details or {}), + } if products_exported or error_details else None, + ) + db.add(sync_log) + db.flush() + return sync_log + + +# Singleton instance +letzshop_export_service = LetzshopExportService() diff --git a/app/modules/marketplace/services/marketplace_import_job_service.py b/app/modules/marketplace/services/marketplace_import_job_service.py new file mode 100644 index 00000000..a36973c2 --- /dev/null +++ b/app/modules/marketplace/services/marketplace_import_job_service.py @@ -0,0 +1,334 @@ +# app/services/marketplace_import_job_service.py +import logging + +from sqlalchemy.orm import Session + +from app.exceptions import ( + ImportJobNotFoundException, + ImportJobNotOwnedException, + ValidationException, +) +from models.database.marketplace_import_job import ( + MarketplaceImportError, + MarketplaceImportJob, +) +from models.database.user import User +from models.database.vendor import Vendor +from models.schema.marketplace_import_job import ( + AdminMarketplaceImportJobResponse, + MarketplaceImportJobRequest, + MarketplaceImportJobResponse, +) + +logger = logging.getLogger(__name__) + + +class MarketplaceImportJobService: + """Service class for Marketplace operations.""" + + def create_import_job( + self, + db: Session, + request: MarketplaceImportJobRequest, + vendor: Vendor, # CHANGED: Vendor object from middleware + user: User, + ) -> MarketplaceImportJob: + """ + Create a new marketplace import job. + + Args: + db: Database session + request: Import request data + vendor: Vendor object (from middleware) + user: User creating the job + + Returns: + Created MarketplaceImportJob object + """ + try: + # Create marketplace import job record + import_job = MarketplaceImportJob( + status="pending", + source_url=request.source_url, + marketplace=request.marketplace, + language=request.language, + vendor_id=vendor.id, + user_id=user.id, + ) + + db.add(import_job) + db.flush() + db.refresh(import_job) + + logger.info( + f"Created marketplace import job {import_job.id}: " + f"{request.marketplace} -> {vendor.name} (code: {vendor.vendor_code}) " + f"by user {user.username}" + ) + + return import_job + + except Exception as e: + logger.error(f"Error creating import job: {str(e)}") + raise ValidationException("Failed to create import job") + + def get_import_job_by_id( + self, db: Session, job_id: int, user: User + ) -> MarketplaceImportJob: + """Get a marketplace import job by ID with access control.""" + try: + job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) + + if not job: + raise ImportJobNotFoundException(job_id) + + # Users can only see their own jobs, admins can see all + if user.role != "admin" and job.user_id != user.id: + raise ImportJobNotOwnedException(job_id, user.id) + + return job + + except (ImportJobNotFoundException, ImportJobNotOwnedException): + raise + except Exception as e: + logger.error(f"Error getting import job {job_id}: {str(e)}") + raise ValidationException("Failed to retrieve import job") + + def get_import_job_for_vendor( + self, db: Session, job_id: int, vendor_id: int + ) -> MarketplaceImportJob: + """ + Get a marketplace import job by ID with vendor access control. + + Validates that the job belongs to the specified vendor. + + Args: + db: Database session + job_id: Import job ID + vendor_id: Vendor ID from token (to verify ownership) + + Raises: + ImportJobNotFoundException: If job not found + UnauthorizedVendorAccessException: If job doesn't belong to vendor + """ + from app.exceptions import UnauthorizedVendorAccessException + + try: + job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) + + if not job: + raise ImportJobNotFoundException(job_id) + + # Verify job belongs to vendor (service layer validation) + if job.vendor_id != vendor_id: + raise UnauthorizedVendorAccessException( + vendor_code=str(vendor_id), + user_id=0, # Not user-specific, but vendor mismatch + ) + + return job + + except (ImportJobNotFoundException, UnauthorizedVendorAccessException): + raise + except Exception as e: + logger.error( + f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}" + ) + raise ValidationException("Failed to retrieve import job") + + def get_import_jobs( + self, + db: Session, + vendor: Vendor, # ADDED: Vendor filter + user: User, + marketplace: str | None = None, + skip: int = 0, + limit: int = 50, + ) -> list[MarketplaceImportJob]: + """Get marketplace import jobs for a specific vendor.""" + try: + query = db.query(MarketplaceImportJob).filter( + MarketplaceImportJob.vendor_id == vendor.id + ) + + # Users can only see their own jobs, admins can see all vendor jobs + if user.role != "admin": + query = query.filter(MarketplaceImportJob.user_id == user.id) + + # Apply marketplace filter + if marketplace: + query = query.filter( + MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") + ) + + # Order by creation date (newest first) and apply pagination + jobs = ( + query.order_by( + MarketplaceImportJob.created_at.desc(), + MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp + ) + .offset(skip) + .limit(limit) + .all() + ) + + return jobs + + except Exception as e: + logger.error(f"Error getting import jobs: {str(e)}") + raise ValidationException("Failed to retrieve import jobs") + + def convert_to_response_model( + self, job: MarketplaceImportJob + ) -> MarketplaceImportJobResponse: + """Convert database model to API response model.""" + return MarketplaceImportJobResponse( + job_id=job.id, + status=job.status, + marketplace=job.marketplace, + language=job.language, + vendor_id=job.vendor_id, + vendor_code=job.vendor.vendor_code if job.vendor else None, + vendor_name=job.vendor.name if job.vendor else None, + source_url=job.source_url, + imported=job.imported_count or 0, + updated=job.updated_count or 0, + total_processed=job.total_processed or 0, + error_count=job.error_count or 0, + error_message=job.error_message, + created_at=job.created_at, + started_at=job.started_at, + completed_at=job.completed_at, + ) + + def convert_to_admin_response_model( + self, job: MarketplaceImportJob + ) -> AdminMarketplaceImportJobResponse: + """Convert database model to admin API response model with extra fields.""" + return AdminMarketplaceImportJobResponse( + id=job.id, + job_id=job.id, + status=job.status, + marketplace=job.marketplace, + language=job.language, + vendor_id=job.vendor_id, + vendor_code=job.vendor.vendor_code if job.vendor else None, + vendor_name=job.vendor.name if job.vendor else None, + source_url=job.source_url, + imported=job.imported_count or 0, + updated=job.updated_count or 0, + total_processed=job.total_processed or 0, + error_count=job.error_count or 0, + error_message=job.error_message, + error_details=[], + created_at=job.created_at, + started_at=job.started_at, + completed_at=job.completed_at, + created_by_name=job.user.username if job.user else None, + ) + + def get_all_import_jobs_paginated( + self, + db: Session, + marketplace: str | None = None, + status: str | None = None, + page: int = 1, + limit: int = 100, + ) -> tuple[list[MarketplaceImportJob], int]: + """Get all marketplace import jobs with pagination (for admin).""" + try: + query = db.query(MarketplaceImportJob) + + if marketplace: + query = query.filter( + MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") + ) + if status: + query = query.filter(MarketplaceImportJob.status == status) + + total = query.count() + skip = (page - 1) * limit + jobs = ( + query.order_by( + MarketplaceImportJob.created_at.desc(), + MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp + ) + .offset(skip) + .limit(limit) + .all() + ) + + return jobs, total + + except Exception as e: + logger.error(f"Error getting all import jobs: {str(e)}") + raise ValidationException("Failed to retrieve import jobs") + + def get_import_job_by_id_admin( + self, db: Session, job_id: int + ) -> MarketplaceImportJob: + """Get a marketplace import job by ID (admin - no access control).""" + job = ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.id == job_id) + .first() + ) + if not job: + raise ImportJobNotFoundException(job_id) + return job + + def get_import_job_errors( + self, + db: Session, + job_id: int, + error_type: str | None = None, + page: int = 1, + limit: int = 50, + ) -> tuple[list[MarketplaceImportError], int]: + """ + Get import errors for a specific job with pagination. + + Args: + db: Database session + job_id: Import job ID + error_type: Optional filter by error type + page: Page number (1-indexed) + limit: Number of items per page + + Returns: + Tuple of (list of errors, total count) + """ + try: + query = db.query(MarketplaceImportError).filter( + MarketplaceImportError.import_job_id == job_id + ) + + if error_type: + query = query.filter(MarketplaceImportError.error_type == error_type) + + total = query.count() + + offset = (page - 1) * limit + errors = ( + query.order_by(MarketplaceImportError.row_number) + .offset(offset) + .limit(limit) + .all() + ) + + return errors, total + + except Exception as e: + logger.error(f"Error getting import job errors for job {job_id}: {str(e)}") + raise ValidationException("Failed to retrieve import errors") + + +marketplace_import_job_service = MarketplaceImportJobService() diff --git a/app/modules/marketplace/services/marketplace_product_service.py b/app/modules/marketplace/services/marketplace_product_service.py new file mode 100644 index 00000000..720bcc37 --- /dev/null +++ b/app/modules/marketplace/services/marketplace_product_service.py @@ -0,0 +1,1075 @@ +# app/services/marketplace_product_service.py +""" +MarketplaceProduct service for managing product operations and data processing. + +This module provides classes and functions for: +- MarketplaceProduct CRUD operations with validation +- Advanced product filtering and search +- Inventory information integration +- CSV export functionality + +Note: Title and description are now stored in MarketplaceProductTranslation table. +Use get_title(language) and get_description(language) methods on the model. +""" + +import csv +import logging +from collections.abc import Generator +from datetime import UTC, datetime +from io import StringIO + +from sqlalchemy import or_ +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, joinedload + +from app.exceptions import ( + InvalidMarketplaceProductDataException, + MarketplaceProductAlreadyExistsException, + MarketplaceProductNotFoundException, + MarketplaceProductValidationException, + ValidationException, +) +from app.utils.data_processing import GTINProcessor, PriceProcessor +from models.database.inventory import Inventory +from models.database.marketplace_product import MarketplaceProduct +from models.database.marketplace_product_translation import ( + MarketplaceProductTranslation, +) +from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse +from models.schema.marketplace_product import ( + MarketplaceProductCreate, + MarketplaceProductUpdate, +) + +logger = logging.getLogger(__name__) + + +class MarketplaceProductService: + """Service class for MarketplaceProduct operations following the application's service pattern.""" + + def __init__(self): + """Class constructor.""" + self.gtin_processor = GTINProcessor() + self.price_processor = PriceProcessor() + + def create_product( + self, + db: Session, + product_data: MarketplaceProductCreate, + title: str | None = None, + description: str | None = None, + language: str = "en", + ) -> MarketplaceProduct: + """Create a new product with validation. + + Args: + db: Database session + product_data: Product data from schema + title: Product title (stored in translations table) + description: Product description (stored in translations table) + language: Language code for translation (default: 'en') + + Returns: + Created MarketplaceProduct instance + """ + try: + # Process and validate GTIN if provided + if product_data.gtin: + normalized_gtin = self.gtin_processor.normalize(product_data.gtin) + if not normalized_gtin: + raise InvalidMarketplaceProductDataException( + "Invalid GTIN format", field="gtin" + ) + product_data.gtin = normalized_gtin + + # Process price if provided + if product_data.price: + try: + parsed_price, currency = self.price_processor.parse_price_currency( + product_data.price + ) + if parsed_price: + product_data.price = parsed_price + product_data.currency = currency + except ValueError as e: + # Convert ValueError to domain-specific exception + raise InvalidMarketplaceProductDataException(str(e), field="price") + + # Set default marketplace if not provided + if not product_data.marketplace: + product_data.marketplace = "Letzshop" + + # Validate required fields + if ( + not product_data.marketplace_product_id + or not product_data.marketplace_product_id.strip() + ): + raise MarketplaceProductValidationException( + "MarketplaceProduct ID is required", field="marketplace_product_id" + ) + + # Create the product (without title/description - those go in translations) + product_dict = product_data.model_dump() + # Remove any title/description if present in schema (for backwards compatibility) + product_dict.pop("title", None) + product_dict.pop("description", None) + + db_product = MarketplaceProduct(**product_dict) + db.add(db_product) + db.flush() # Get the ID + + # Create translation if title is provided + if title and title.strip(): + translation = MarketplaceProductTranslation( + marketplace_product_id=db_product.id, + language=language, + title=title.strip(), + description=description.strip() if description else None, + ) + db.add(translation) + + db.flush() + db.refresh(db_product) + + logger.info(f"Created product {db_product.marketplace_product_id}") + return db_product + + except ( + InvalidMarketplaceProductDataException, + MarketplaceProductValidationException, + ): + raise # Re-raise custom exceptions + except IntegrityError as e: + logger.error(f"Database integrity error: {str(e)}") + if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower(): + raise MarketplaceProductAlreadyExistsException( + product_data.marketplace_product_id + ) + raise MarketplaceProductValidationException( + "Data integrity constraint violation" + ) + except Exception as e: + logger.error(f"Error creating product: {str(e)}") + raise ValidationException("Failed to create product") + + def get_product_by_id( + self, db: Session, marketplace_product_id: str + ) -> MarketplaceProduct | None: + """Get a product by its ID.""" + try: + return ( + db.query(MarketplaceProduct) + .options(joinedload(MarketplaceProduct.translations)) + .filter( + MarketplaceProduct.marketplace_product_id == marketplace_product_id + ) + .first() + ) + except Exception as e: + logger.error(f"Error getting product {marketplace_product_id}: {str(e)}") + return None + + def get_product_by_id_or_raise( + self, db: Session, marketplace_product_id: str + ) -> MarketplaceProduct: + """ + Get a product by its ID or raise exception. + + Args: + db: Database session + marketplace_product_id: MarketplaceProduct ID to find + + Returns: + MarketplaceProduct object + + Raises: + MarketplaceProductNotFoundException: If product doesn't exist + """ + product = self.get_product_by_id(db, marketplace_product_id) + if not product: + raise MarketplaceProductNotFoundException(marketplace_product_id) + return product + + def get_products_with_filters( + self, + db: Session, + skip: int = 0, + limit: int = 100, + brand: str | None = None, + category: str | None = None, + availability: str | None = None, + marketplace: str | None = None, + vendor_name: str | None = None, + search: str | None = None, + language: str = "en", + ) -> tuple[list[MarketplaceProduct], int]: + """ + Get products with filtering and pagination. + + Args: + db: Database session + skip: Number of records to skip + limit: Maximum records to return + brand: Brand filter + category: Category filter + availability: Availability filter + marketplace: Marketplace filter + vendor_name: Vendor name filter + search: Search term (searches in translations too) + language: Language for search (default: 'en') + + Returns: + Tuple of (products_list, total_count) + """ + try: + query = db.query(MarketplaceProduct).options( + joinedload(MarketplaceProduct.translations) + ) + + # Apply filters + if brand: + query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%")) + if category: + query = query.filter( + MarketplaceProduct.google_product_category.ilike(f"%{category}%") + ) + if availability: + query = query.filter(MarketplaceProduct.availability == availability) + if marketplace: + query = query.filter( + MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") + ) + if vendor_name: + query = query.filter( + MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") + ) + if search: + # Search in marketplace, vendor_name, brand, and translations + search_term = f"%{search}%" + # Use subquery to get distinct IDs (PostgreSQL can't compare JSON for DISTINCT) + id_subquery = ( + db.query(MarketplaceProduct.id) + .outerjoin(MarketplaceProductTranslation) + .filter( + or_( + MarketplaceProduct.marketplace.ilike(search_term), + MarketplaceProduct.vendor_name.ilike(search_term), + MarketplaceProduct.brand.ilike(search_term), + MarketplaceProduct.gtin.ilike(search_term), + MarketplaceProduct.marketplace_product_id.ilike(search_term), + MarketplaceProductTranslation.title.ilike(search_term), + MarketplaceProductTranslation.description.ilike(search_term), + ) + ) + .distinct() + .subquery() + ) + query = query.filter(MarketplaceProduct.id.in_( + db.query(id_subquery.c.id) + )) + + total = query.count() + products = query.offset(skip).limit(limit).all() + + return products, total + + except Exception as e: + logger.error(f"Error getting products with filters: {str(e)}") + raise ValidationException("Failed to retrieve products") + + def update_product( + self, + db: Session, + marketplace_product_id: str, + product_update: MarketplaceProductUpdate, + title: str | None = None, + description: str | None = None, + language: str = "en", + ) -> MarketplaceProduct: + """Update product with validation. + + Args: + db: Database session + marketplace_product_id: ID of product to update + product_update: Product update data from schema + title: Updated title (stored in translations table) + description: Updated description (stored in translations table) + language: Language code for translation (default: 'en') + + Returns: + Updated MarketplaceProduct instance + """ + try: + product = self.get_product_by_id_or_raise(db, marketplace_product_id) + + # Update fields + update_data = product_update.model_dump(exclude_unset=True) + + # Remove title/description from update data (handled separately) + update_data.pop("title", None) + update_data.pop("description", None) + + # Validate GTIN if being updated + if "gtin" in update_data and update_data["gtin"]: + normalized_gtin = self.gtin_processor.normalize(update_data["gtin"]) + if not normalized_gtin: + raise InvalidMarketplaceProductDataException( + "Invalid GTIN format", field="gtin" + ) + update_data["gtin"] = normalized_gtin + + # Process price if being updated + if "price" in update_data and update_data["price"]: + try: + parsed_price, currency = self.price_processor.parse_price_currency( + update_data["price"] + ) + if parsed_price: + update_data["price"] = parsed_price + update_data["currency"] = currency + except ValueError as e: + # Convert ValueError to domain-specific exception + raise InvalidMarketplaceProductDataException(str(e), field="price") + + # Apply updates to product + for key, value in update_data.items(): + if hasattr(product, key): + setattr(product, key, value) + + product.updated_at = datetime.now(UTC) + + # Update or create translation if title/description provided + if title is not None or description is not None: + self._update_or_create_translation( + db, product, title, description, language + ) + + db.flush() + db.refresh(product) + + logger.info(f"Updated product {marketplace_product_id}") + return product + + except ( + MarketplaceProductNotFoundException, + InvalidMarketplaceProductDataException, + MarketplaceProductValidationException, + ): + raise # Re-raise custom exceptions + except Exception as e: + logger.error(f"Error updating product {marketplace_product_id}: {str(e)}") + raise ValidationException("Failed to update product") + + def _update_or_create_translation( + self, + db: Session, + product: MarketplaceProduct, + title: str | None, + description: str | None, + language: str, + ) -> None: + """Update existing translation or create new one.""" + existing = ( + db.query(MarketplaceProductTranslation) + .filter( + MarketplaceProductTranslation.marketplace_product_id == product.id, + MarketplaceProductTranslation.language == language, + ) + .first() + ) + + if existing: + if title is not None: + existing.title = title.strip() if title else existing.title + if description is not None: + existing.description = description.strip() if description else None + existing.updated_at = datetime.now(UTC) + else: + # Only create if we have a title + if title and title.strip(): + new_translation = MarketplaceProductTranslation( + marketplace_product_id=product.id, + language=language, + title=title.strip(), + description=description.strip() if description else None, + ) + db.add(new_translation) + + def delete_product(self, db: Session, marketplace_product_id: str) -> bool: + """ + Delete product and associated inventory. + + Args: + db: Database session + marketplace_product_id: MarketplaceProduct ID to delete + + Returns: + True if deletion successful + + Raises: + MarketplaceProductNotFoundException: If product doesn't exist + """ + try: + product = self.get_product_by_id_or_raise(db, marketplace_product_id) + + # Delete associated inventory entries if GTIN exists + if product.gtin: + db.query(Inventory).filter(Inventory.gtin == product.gtin).delete() + + # Translations will be cascade deleted + db.delete(product) + db.flush() + + logger.info(f"Deleted product {marketplace_product_id}") + return True + + except MarketplaceProductNotFoundException: + raise # Re-raise custom exceptions + except Exception as e: + logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}") + raise ValidationException("Failed to delete product") + + def get_inventory_info( + self, db: Session, gtin: str + ) -> InventorySummaryResponse | None: + """ + Get inventory information for a product by GTIN. + + Args: + db: Database session + gtin: GTIN to look up inventory for + + Returns: + InventorySummaryResponse if inventory found, None otherwise + """ + try: + # noqa: SVC-005 - Admin/internal function for inventory lookup by GTIN + inventory_entries = db.query(Inventory).filter(Inventory.gtin == gtin).all() + if not inventory_entries: + return None + + total_quantity = sum(entry.quantity for entry in inventory_entries) + locations = [ + InventoryLocationResponse( + location=entry.location, + quantity=entry.quantity, + reserved_quantity=entry.reserved_quantity or 0, + available_quantity=entry.quantity - (entry.reserved_quantity or 0), + ) + for entry in inventory_entries + ] + + return InventorySummaryResponse( + gtin=gtin, total_quantity=total_quantity, locations=locations + ) + + except Exception as e: + logger.error(f"Error getting inventory info for GTIN {gtin}: {str(e)}") + return None + + def generate_csv_export( + self, + db: Session, + marketplace: str | None = None, + vendor_name: str | None = None, + language: str = "en", + ) -> Generator[str, None, None]: + """ + Generate CSV export with streaming for memory efficiency and proper CSV escaping. + + Args: + db: Database session + marketplace: Optional marketplace filter + vendor_name: Optional vendor name filter + language: Language code for title/description (default: 'en') + + Yields: + CSV content as strings with proper escaping + """ + try: + # Create a StringIO buffer for CSV writing + output = StringIO() + writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL) + + # Write header row + headers = [ + "marketplace_product_id", + "title", + "description", + "link", + "image_link", + "availability", + "price", + "currency", + "brand", + "gtin", + "marketplace", + "vendor_name", + ] + writer.writerow(headers) + yield output.getvalue() + + # Clear buffer for reuse + output.seek(0) + output.truncate(0) + + batch_size = 1000 + offset = 0 + + while True: + query = db.query(MarketplaceProduct).options( + joinedload(MarketplaceProduct.translations) + ) + + # Apply marketplace filters + if marketplace: + query = query.filter( + MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") + ) + if vendor_name: + query = query.filter( + MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") + ) + + products = query.offset(offset).limit(batch_size).all() + if not products: + break + + for product in products: + # Get title and description from translations + title = product.get_title(language) or "" + description = product.get_description(language) or "" + + # Create CSV row with proper escaping + row_data = [ + product.marketplace_product_id or "", + title, + description, + product.link or "", + product.image_link or "", + product.availability or "", + product.price or "", + product.currency or "", + product.brand or "", + product.gtin or "", + product.marketplace or "", + product.vendor_name or "", + ] + + writer.writerow(row_data) + yield output.getvalue() + + # Clear buffer for next row + output.seek(0) + output.truncate(0) + + offset += batch_size + + except Exception as e: + logger.error(f"Error generating CSV export: {str(e)}") + raise ValidationException("Failed to generate CSV export") + + def product_exists(self, db: Session, marketplace_product_id: str) -> bool: + """Check if product exists by ID.""" + try: + return ( + db.query(MarketplaceProduct) + .filter( + MarketplaceProduct.marketplace_product_id == marketplace_product_id + ) + .first() + is not None + ) + except Exception as e: + logger.error(f"Error checking if product exists: {str(e)}") + return False + + # Private helper methods + def _validate_product_data(self, product_data: dict) -> None: + """Validate product data structure.""" + required_fields = ["marketplace_product_id"] + + for field in required_fields: + if field not in product_data or not product_data[field]: + raise MarketplaceProductValidationException( + f"{field} is required", field=field + ) + + def _normalize_product_data(self, product_data: dict) -> dict: + """Normalize and clean product data.""" + normalized = product_data.copy() + + # Trim whitespace from string fields + string_fields = [ + "marketplace_product_id", + "brand", + "marketplace", + "vendor_name", + ] + for field in string_fields: + if field in normalized and normalized[field]: + normalized[field] = normalized[field].strip() + + return normalized + + # ========================================================================= + # Admin-specific methods for marketplace product management + # ========================================================================= + + def get_admin_products( + self, + db: Session, + skip: int = 0, + limit: int = 50, + search: str | None = None, + marketplace: str | None = None, + vendor_name: str | None = None, + availability: str | None = None, + is_active: bool | None = None, + is_digital: bool | None = None, + language: str = "en", + ) -> tuple[list[dict], int]: + """ + Get marketplace products for admin with search and filtering. + + Returns: + Tuple of (products list as dicts, total count) + """ + query = db.query(MarketplaceProduct).options( + joinedload(MarketplaceProduct.translations) + ) + + if search: + search_term = f"%{search}%" + # Use subquery to get distinct IDs (PostgreSQL can't compare JSON for DISTINCT) + id_subquery = ( + db.query(MarketplaceProduct.id) + .outerjoin(MarketplaceProductTranslation) + .filter( + or_( + MarketplaceProductTranslation.title.ilike(search_term), + MarketplaceProduct.gtin.ilike(search_term), + MarketplaceProduct.sku.ilike(search_term), + MarketplaceProduct.brand.ilike(search_term), + MarketplaceProduct.mpn.ilike(search_term), + MarketplaceProduct.marketplace_product_id.ilike(search_term), + ) + ) + .distinct() + .subquery() + ) + query = query.filter(MarketplaceProduct.id.in_( + db.query(id_subquery.c.id) + )) + + if marketplace: + query = query.filter(MarketplaceProduct.marketplace == marketplace) + + if vendor_name: + query = query.filter( + MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") + ) + + if availability: + query = query.filter(MarketplaceProduct.availability == availability) + + if is_active is not None: + query = query.filter(MarketplaceProduct.is_active == is_active) + + if is_digital is not None: + query = query.filter(MarketplaceProduct.is_digital == is_digital) + + total = query.count() + + products = ( + query.order_by(MarketplaceProduct.updated_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + result = [] + for product in products: + title = product.get_title(language) + result.append(self._build_admin_product_item(product, title)) + + return result, total + + def get_admin_product_stats( + self, + db: Session, + marketplace: str | None = None, + vendor_name: str | None = None, + ) -> dict: + """Get product statistics for admin dashboard. + + Args: + db: Database session + marketplace: Optional filter by marketplace (e.g., 'Letzshop') + vendor_name: Optional filter by vendor name + """ + from sqlalchemy import func + + # Build base filter + base_filters = [] + if marketplace: + base_filters.append(MarketplaceProduct.marketplace == marketplace) + if vendor_name: + base_filters.append(MarketplaceProduct.vendor_name == vendor_name) + + base_query = db.query(func.count(MarketplaceProduct.id)) + if base_filters: + base_query = base_query.filter(*base_filters) + + total = base_query.scalar() or 0 + + active_query = db.query(func.count(MarketplaceProduct.id)).filter( + MarketplaceProduct.is_active == True # noqa: E712 + ) + if base_filters: + active_query = active_query.filter(*base_filters) + active = active_query.scalar() or 0 + inactive = total - active + + digital_query = db.query(func.count(MarketplaceProduct.id)).filter( + MarketplaceProduct.is_digital == True # noqa: E712 + ) + if base_filters: + digital_query = digital_query.filter(*base_filters) + digital = digital_query.scalar() or 0 + physical = total - digital + + marketplace_query = db.query( + MarketplaceProduct.marketplace, + func.count(MarketplaceProduct.id), + ) + if base_filters: + marketplace_query = marketplace_query.filter(*base_filters) + marketplace_counts = marketplace_query.group_by( + MarketplaceProduct.marketplace + ).all() + by_marketplace = {mp or "unknown": count for mp, count in marketplace_counts} + + return { + "total": total, + "active": active, + "inactive": inactive, + "digital": digital, + "physical": physical, + "by_marketplace": by_marketplace, + } + + def get_marketplaces_list(self, db: Session) -> list[str]: + """Get list of unique marketplaces in the product catalog.""" + marketplaces = ( + db.query(MarketplaceProduct.marketplace) + .distinct() + .filter(MarketplaceProduct.marketplace.isnot(None)) + .all() + ) + return [m[0] for m in marketplaces if m[0]] + + def get_source_vendors_list(self, db: Session) -> list[str]: + """Get list of unique vendor names in the product catalog.""" + vendors = ( + db.query(MarketplaceProduct.vendor_name) + .distinct() + .filter(MarketplaceProduct.vendor_name.isnot(None)) + .all() + ) + return [v[0] for v in vendors if v[0]] + + def get_admin_product_detail(self, db: Session, product_id: int) -> dict: + """Get detailed product information by database ID.""" + product = ( + db.query(MarketplaceProduct) + .options(joinedload(MarketplaceProduct.translations)) + .filter(MarketplaceProduct.id == product_id) + .first() + ) + + if not product: + raise MarketplaceProductNotFoundException( + f"Marketplace product with ID {product_id} not found" + ) + + translations = {} + for t in product.translations: + translations[t.language] = { + "title": t.title, + "description": t.description, + "short_description": t.short_description, + } + + return { + "id": product.id, + "marketplace_product_id": product.marketplace_product_id, + "gtin": product.gtin, + "mpn": product.mpn, + "sku": product.sku, + "brand": product.brand, + "marketplace": product.marketplace, + "vendor_name": product.vendor_name, + "source_url": product.source_url, + "price": product.price, + "price_numeric": product.price_numeric, + "sale_price": product.sale_price, + "sale_price_numeric": product.sale_price_numeric, + "currency": product.currency, + "availability": product.availability, + "condition": product.condition, + "image_link": product.image_link, + "additional_images": product.additional_images, + "is_active": product.is_active, + "is_digital": product.is_digital, + "product_type_enum": product.product_type_enum, + "platform": product.platform, + "google_product_category": product.google_product_category, + "category_path": product.category_path, + "color": product.color, + "size": product.size, + "weight": product.weight, + "weight_unit": product.weight_unit, + "translations": translations, + "created_at": product.created_at.isoformat() + if product.created_at + else None, + "updated_at": product.updated_at.isoformat() + if product.updated_at + else None, + } + + def copy_to_vendor_catalog( + self, + db: Session, + marketplace_product_ids: list[int], + vendor_id: int, + skip_existing: bool = True, + ) -> dict: + """ + Copy marketplace products to a vendor's catalog. + + Creates independent vendor products with ALL fields copied from the + marketplace product. Each vendor product is a standalone entity - no + field inheritance or fallback logic. The marketplace_product_id FK is + kept for "view original source" feature. + + Also copies ALL translations from the marketplace product. + + Returns: + Dict with copied, skipped, failed counts and details + """ + from models.database.product import Product + from models.database.product_translation import ProductTranslation + from models.database.vendor import Vendor + + vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() + if not vendor: + from app.exceptions import VendorNotFoundException + + raise VendorNotFoundException(str(vendor_id), identifier_type="id") + + marketplace_products = ( + db.query(MarketplaceProduct) + .options(joinedload(MarketplaceProduct.translations)) + .filter(MarketplaceProduct.id.in_(marketplace_product_ids)) + .all() + ) + + if not marketplace_products: + raise MarketplaceProductNotFoundException("No marketplace products found") + + # Check product limit from subscription + from app.services.subscription_service import subscription_service + from sqlalchemy import func + + current_products = ( + db.query(func.count(Product.id)) + .filter(Product.vendor_id == vendor_id) + .scalar() + or 0 + ) + + subscription = subscription_service.get_or_create_subscription(db, vendor_id) + products_limit = subscription.products_limit + remaining_capacity = ( + products_limit - current_products if products_limit is not None else None + ) + + copied = 0 + skipped = 0 + failed = 0 + limit_reached = False + details = [] + + for mp in marketplace_products: + # Check if we've hit the product limit + if remaining_capacity is not None and copied >= remaining_capacity: + limit_reached = True + details.append( + { + "id": mp.id, + "status": "skipped", + "reason": "Product limit reached", + } + ) + skipped += 1 + continue + try: + existing = ( + db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.marketplace_product_id == mp.id, + ) + .first() + ) + + if existing: + skipped += 1 + details.append( + { + "id": mp.id, + "status": "skipped", + "reason": "Already exists in catalog", + } + ) + continue + + # Create vendor product with ALL fields copied from marketplace + product = Product( + vendor_id=vendor_id, + marketplace_product_id=mp.id, + # === Vendor settings (defaults) === + is_active=True, + is_featured=False, + # === Product identifiers === + gtin=mp.gtin, + gtin_type=mp.gtin_type if hasattr(mp, "gtin_type") else None, + # === Pricing (copy from marketplace) === + price_cents=mp.price_cents, + sale_price_cents=mp.sale_price_cents, + currency=mp.currency or "EUR", + # === Product info === + brand=mp.brand, + condition=mp.condition, + availability=mp.availability, + # === Media === + primary_image_url=mp.image_link, + additional_images=mp.additional_images, + # === Digital product fields === + download_url=mp.download_url if hasattr(mp, "download_url") else None, + license_type=mp.license_type if hasattr(mp, "license_type") else None, + ) + + db.add(product) + db.flush() # Get product.id for translations + + # Copy ALL translations from marketplace product + translations_copied = 0 + for mpt in mp.translations: + product_translation = ProductTranslation( + product_id=product.id, + language=mpt.language, + title=mpt.title, + description=mpt.description, + short_description=mpt.short_description, + meta_title=mpt.meta_title, + meta_description=mpt.meta_description, + url_slug=mpt.url_slug, + ) + db.add(product_translation) + translations_copied += 1 + + copied += 1 + details.append({ + "id": mp.id, + "status": "copied", + "gtin": mp.gtin, + "translations_copied": translations_copied, + }) + + except Exception as e: + logger.error(f"Failed to copy product {mp.id}: {str(e)}") + failed += 1 + details.append({"id": mp.id, "status": "failed", "reason": str(e)}) + + db.flush() + + # Auto-match pending order item exceptions + # Collect GTINs and their product IDs from newly copied products + from app.services.order_item_exception_service import ( + order_item_exception_service, + ) + + gtin_to_product: dict[str, int] = {} + for detail in details: + if detail.get("status") == "copied" and detail.get("gtin"): + # Find the product we just created + product = ( + db.query(Product) + .filter( + Product.vendor_id == vendor_id, + Product.gtin == detail["gtin"], + ) + .first() + ) + if product: + gtin_to_product[detail["gtin"]] = product.id + + auto_matched = 0 + if gtin_to_product: + auto_matched = order_item_exception_service.auto_match_batch( + db, vendor_id, gtin_to_product + ) + if auto_matched: + logger.info( + f"Auto-matched {auto_matched} order item exceptions " + f"during product copy to vendor {vendor_id}" + ) + + logger.info( + f"Copied {copied} products to vendor {vendor.name} " + f"(skipped: {skipped}, failed: {failed}, auto_matched: {auto_matched})" + ) + + return { + "copied": copied, + "skipped": skipped, + "failed": failed, + "auto_matched": auto_matched, + "limit_reached": limit_reached, + "details": details if len(details) <= 100 else None, + } + + def _build_admin_product_item( + self, product: MarketplaceProduct, title: str | None + ) -> dict: + """Build a product list item dict for admin view.""" + return { + "id": product.id, + "marketplace_product_id": product.marketplace_product_id, + "title": title, + "brand": product.brand, + "gtin": product.gtin, + "sku": product.sku, + "marketplace": product.marketplace, + "vendor_name": product.vendor_name, + "price_numeric": product.price_numeric, + "currency": product.currency, + "availability": product.availability, + "image_link": product.image_link, + "is_active": product.is_active, + "is_digital": product.is_digital, + "product_type_enum": product.product_type_enum, + "created_at": product.created_at.isoformat() + if product.created_at + else None, + "updated_at": product.updated_at.isoformat() + if product.updated_at + else None, + } + + +# Create service instance +marketplace_product_service = MarketplaceProductService() diff --git a/app/modules/messaging/models/__init__.py b/app/modules/messaging/models/__init__.py index 93b7a816..e95a6174 100644 --- a/app/modules/messaging/models/__init__.py +++ b/app/modules/messaging/models/__init__.py @@ -2,10 +2,10 @@ """ Messaging module database models. -Re-exports messaging-related models from their source locations. +This module contains the canonical implementations of messaging-related models. """ -from models.database.message import ( +from app.modules.messaging.models.message import ( Conversation, ConversationParticipant, ConversationType, @@ -13,7 +13,7 @@ from models.database.message import ( MessageAttachment, ParticipantType, ) -from models.database.admin import AdminNotification +from app.modules.messaging.models.admin_notification import AdminNotification __all__ = [ "Conversation", diff --git a/app/modules/messaging/models/admin_notification.py b/app/modules/messaging/models/admin_notification.py new file mode 100644 index 00000000..77092f8f --- /dev/null +++ b/app/modules/messaging/models/admin_notification.py @@ -0,0 +1,54 @@ +# app/modules/messaging/models/admin_notification.py +""" +Admin notification database model. + +This model handles admin-specific notifications for system alerts and warnings. +""" + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + Integer, + JSON, + String, + Text, +) +from sqlalchemy.orm import relationship + +from app.core.database import Base +from models.database.base import TimestampMixin + + +class AdminNotification(Base, TimestampMixin): + """ + Admin-specific notifications for system alerts and warnings. + + Different from vendor/customer notifications - these are for platform + administrators to track system health and issues requiring attention. + """ + + __tablename__ = "admin_notifications" + + id = Column(Integer, primary_key=True, index=True) + type = Column( + String(50), nullable=False, index=True + ) # system_alert, vendor_issue, import_failure + priority = Column( + String(20), default="normal", index=True + ) # low, normal, high, critical + title = Column(String(200), nullable=False) + message = Column(Text, nullable=False) + is_read = Column(Boolean, default=False, index=True) + read_at = Column(DateTime, nullable=True) + read_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) + action_required = Column(Boolean, default=False, index=True) + action_url = Column(String(500)) # Link to relevant admin page + notification_metadata = Column(JSON) # Additional contextual data + + # Relationships + read_by = relationship("User", foreign_keys=[read_by_user_id]) + + def __repr__(self): + return f"" diff --git a/app/modules/messaging/models/message.py b/app/modules/messaging/models/message.py new file mode 100644 index 00000000..f3d16516 --- /dev/null +++ b/app/modules/messaging/models/message.py @@ -0,0 +1,272 @@ +# app/modules/messaging/models/message.py +""" +Messaging system database models. + +Supports three communication channels: +- Admin <-> Vendor +- Vendor <-> Customer +- Admin <-> Customer + +Multi-tenant isolation is enforced via vendor_id for conversations +involving customers. +""" + +import enum +from datetime import datetime + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Enum, + ForeignKey, + Index, + Integer, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import relationship + +from app.core.database import Base +from models.database.base import TimestampMixin + + +class ConversationType(str, enum.Enum): + """Defines the three supported conversation channels.""" + + ADMIN_VENDOR = "admin_vendor" + VENDOR_CUSTOMER = "vendor_customer" + ADMIN_CUSTOMER = "admin_customer" + + +class ParticipantType(str, enum.Enum): + """Type of participant in a conversation.""" + + ADMIN = "admin" # User with role="admin" + VENDOR = "vendor" # User with role="vendor" (via VendorUser) + CUSTOMER = "customer" # Customer model + + +def _enum_values(enum_class): + """Extract enum values for SQLAlchemy Enum column.""" + return [e.value for e in enum_class] + + +class Conversation(Base, TimestampMixin): + """ + Represents a threaded conversation between participants. + + Multi-tenancy: vendor_id is required for vendor_customer and admin_customer + conversations to ensure customer data isolation. + """ + + __tablename__ = "conversations" + + id = Column(Integer, primary_key=True, index=True) + + # Conversation type determines participant structure + conversation_type = Column( + Enum(ConversationType, values_callable=_enum_values), + nullable=False, + index=True, + ) + + # Subject line for the conversation thread + subject = Column(String(500), nullable=False) + + # For vendor_customer and admin_customer conversations + # Required for multi-tenant data isolation + vendor_id = Column( + Integer, + ForeignKey("vendors.id"), + nullable=True, + index=True, + ) + + # Status flags + is_closed = Column(Boolean, default=False, nullable=False) + closed_at = Column(DateTime, nullable=True) + closed_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True) + closed_by_id = Column(Integer, nullable=True) + + # Last activity tracking for sorting + last_message_at = Column(DateTime, nullable=True, index=True) + message_count = Column(Integer, default=0, nullable=False) + + # Relationships + vendor = relationship("Vendor", foreign_keys=[vendor_id]) + participants = relationship( + "ConversationParticipant", + back_populates="conversation", + cascade="all, delete-orphan", + ) + messages = relationship( + "Message", + back_populates="conversation", + cascade="all, delete-orphan", + order_by="Message.created_at", + ) + + # Indexes for common queries + __table_args__ = ( + Index("ix_conversations_type_vendor", "conversation_type", "vendor_id"), + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class ConversationParticipant(Base, TimestampMixin): + """ + Links participants (users or customers) to conversations. + + Polymorphic relationship: + - participant_type="admin" or "vendor": references users.id + - participant_type="customer": references customers.id + """ + + __tablename__ = "conversation_participants" + + id = Column(Integer, primary_key=True, index=True) + conversation_id = Column( + Integer, + ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # Polymorphic participant reference + participant_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False) + participant_id = Column(Integer, nullable=False, index=True) + + # For vendor participants, track which vendor they represent + vendor_id = Column( + Integer, + ForeignKey("vendors.id"), + nullable=True, + ) + + # Unread tracking per participant + unread_count = Column(Integer, default=0, nullable=False) + last_read_at = Column(DateTime, nullable=True) + + # Notification preferences for this conversation + email_notifications = Column(Boolean, default=True, nullable=False) + muted = Column(Boolean, default=False, nullable=False) + + # Relationships + conversation = relationship("Conversation", back_populates="participants") + vendor = relationship("Vendor", foreign_keys=[vendor_id]) + + __table_args__ = ( + UniqueConstraint( + "conversation_id", + "participant_type", + "participant_id", + name="uq_conversation_participant", + ), + Index( + "ix_participant_lookup", + "participant_type", + "participant_id", + ), + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class Message(Base, TimestampMixin): + """ + Individual message within a conversation thread. + + Sender polymorphism follows same pattern as participant. + """ + + __tablename__ = "messages" + + id = Column(Integer, primary_key=True, index=True) + conversation_id = Column( + Integer, + ForeignKey("conversations.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # Polymorphic sender reference + sender_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False) + sender_id = Column(Integer, nullable=False, index=True) + + # Message content + content = Column(Text, nullable=False) + + # System messages (e.g., "conversation closed") + is_system_message = Column(Boolean, default=False, nullable=False) + + # Soft delete for moderation + is_deleted = Column(Boolean, default=False, nullable=False) + deleted_at = Column(DateTime, nullable=True) + deleted_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True) + deleted_by_id = Column(Integer, nullable=True) + + # Relationships + conversation = relationship("Conversation", back_populates="messages") + attachments = relationship( + "MessageAttachment", + back_populates="message", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("ix_messages_conversation_created", "conversation_id", "created_at"), + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class MessageAttachment(Base, TimestampMixin): + """ + File attachments for messages. + + Files are stored in platform storage (local/S3) with references here. + """ + + __tablename__ = "message_attachments" + + id = Column(Integer, primary_key=True, index=True) + message_id = Column( + Integer, + ForeignKey("messages.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + # File metadata + filename = Column(String(255), nullable=False) + original_filename = Column(String(255), nullable=False) + file_path = Column(String(1000), nullable=False) # Storage path + file_size = Column(Integer, nullable=False) # Size in bytes + mime_type = Column(String(100), nullable=False) + + # For image attachments + is_image = Column(Boolean, default=False, nullable=False) + image_width = Column(Integer, nullable=True) + image_height = Column(Integer, nullable=True) + thumbnail_path = Column(String(1000), nullable=True) + + # Relationships + message = relationship("Message", back_populates="attachments") + + def __repr__(self) -> str: + return f"" diff --git a/app/modules/messaging/schemas/__init__.py b/app/modules/messaging/schemas/__init__.py index 05217d2f..6f30d2cd 100644 --- a/app/modules/messaging/schemas/__init__.py +++ b/app/modules/messaging/schemas/__init__.py @@ -2,27 +2,106 @@ """ Messaging module Pydantic schemas. -Re-exports messaging-related schemas from their source locations. +This module contains the canonical implementations of messaging-related schemas. """ -from models.schema.message import ( - ConversationResponse, - ConversationListResponse, - MessageResponse, - MessageCreate, +from app.modules.messaging.schemas.message import ( + # Attachment schemas AttachmentResponse, + # Message schemas + MessageCreate, + MessageResponse, + # Participant schemas + ParticipantInfo, + ParticipantResponse, + # Conversation schemas + ConversationCreate, + ConversationSummary, + ConversationDetailResponse, + ConversationListResponse, + ConversationResponse, + # Unread count + UnreadCountResponse, + # Notification preferences + NotificationPreferencesUpdate, + # Conversation actions + CloseConversationResponse, + ReopenConversationResponse, + MarkReadResponse, + # Recipient selection + RecipientOption, + RecipientListResponse, + # Admin schemas + AdminConversationSummary, + AdminConversationListResponse, + AdminMessageStats, ) -from models.schema.notification import ( + +from app.modules.messaging.schemas.notification import ( + # Response schemas + MessageResponse as NotificationMessageResponse, + UnreadCountResponse as NotificationUnreadCountResponse, + # Notification schemas NotificationResponse, NotificationListResponse, + # Settings schemas + NotificationSettingsResponse, + NotificationSettingsUpdate, + # Template schemas + NotificationTemplateResponse, + NotificationTemplateListResponse, + NotificationTemplateUpdate, + # Test notification + TestNotificationRequest, + # Alert statistics + AlertStatisticsResponse, ) __all__ = [ - "ConversationResponse", - "ConversationListResponse", - "MessageResponse", - "MessageCreate", + # Attachment schemas "AttachmentResponse", + # Message schemas + "MessageCreate", + "MessageResponse", + # Participant schemas + "ParticipantInfo", + "ParticipantResponse", + # Conversation schemas + "ConversationCreate", + "ConversationSummary", + "ConversationDetailResponse", + "ConversationListResponse", + "ConversationResponse", + # Unread count + "UnreadCountResponse", + # Notification preferences + "NotificationPreferencesUpdate", + # Conversation actions + "CloseConversationResponse", + "ReopenConversationResponse", + "MarkReadResponse", + # Recipient selection + "RecipientOption", + "RecipientListResponse", + # Admin schemas + "AdminConversationSummary", + "AdminConversationListResponse", + "AdminMessageStats", + # Notification response schemas + "NotificationMessageResponse", + "NotificationUnreadCountResponse", + # Notification schemas "NotificationResponse", "NotificationListResponse", + # Settings schemas + "NotificationSettingsResponse", + "NotificationSettingsUpdate", + # Template schemas + "NotificationTemplateResponse", + "NotificationTemplateListResponse", + "NotificationTemplateUpdate", + # Test notification + "TestNotificationRequest", + # Alert statistics + "AlertStatisticsResponse", ] diff --git a/app/modules/messaging/schemas/message.py b/app/modules/messaging/schemas/message.py new file mode 100644 index 00000000..de4269a6 --- /dev/null +++ b/app/modules/messaging/schemas/message.py @@ -0,0 +1,312 @@ +# app/modules/messaging/schemas/message.py +""" +Pydantic schemas for the messaging system. + +Supports three communication channels: +- Admin <-> Vendor +- Vendor <-> Customer +- Admin <-> Customer +""" + +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + +from app.modules.messaging.models.message import ConversationType, ParticipantType + + +# ============================================================================ +# Attachment Schemas +# ============================================================================ + + +class AttachmentResponse(BaseModel): + """Schema for message attachment in responses.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + filename: str + original_filename: str + file_size: int + mime_type: str + is_image: bool + image_width: int | None = None + image_height: int | None = None + download_url: str | None = None + thumbnail_url: str | None = None + + @property + def file_size_display(self) -> str: + """Human-readable file size.""" + if self.file_size < 1024: + return f"{self.file_size} B" + elif self.file_size < 1024 * 1024: + return f"{self.file_size / 1024:.1f} KB" + else: + return f"{self.file_size / 1024 / 1024:.1f} MB" + + +# ============================================================================ +# Message Schemas +# ============================================================================ + + +class MessageCreate(BaseModel): + """Schema for sending a new message.""" + + content: str = Field(..., min_length=1, max_length=10000) + + +class MessageResponse(BaseModel): + """Schema for a single message in responses.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + conversation_id: int + sender_type: ParticipantType + sender_id: int + content: str + is_system_message: bool + is_deleted: bool + created_at: datetime + + # Enriched sender info (populated by API) + sender_name: str | None = None + sender_email: str | None = None + + # Attachments + attachments: list[AttachmentResponse] = [] + + +# ============================================================================ +# Participant Schemas +# ============================================================================ + + +class ParticipantInfo(BaseModel): + """Schema for participant information.""" + + id: int + type: ParticipantType + name: str + email: str | None = None + avatar_url: str | None = None + + +class ParticipantResponse(BaseModel): + """Schema for conversation participant in responses.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + participant_type: ParticipantType + participant_id: int + unread_count: int + last_read_at: datetime | None + email_notifications: bool + muted: bool + + # Enriched info (populated by API) + participant_info: ParticipantInfo | None = None + + +# ============================================================================ +# Conversation Schemas +# ============================================================================ + + +class ConversationCreate(BaseModel): + """Schema for creating a new conversation.""" + + conversation_type: ConversationType + subject: str = Field(..., min_length=1, max_length=500) + recipient_type: ParticipantType + recipient_id: int + vendor_id: int | None = None + initial_message: str | None = Field(None, min_length=1, max_length=10000) + + +class ConversationSummary(BaseModel): + """Schema for conversation in list views.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + conversation_type: ConversationType + subject: str + vendor_id: int | None = None + is_closed: bool + closed_at: datetime | None + last_message_at: datetime | None + message_count: int + created_at: datetime + + # Unread count for current user (from participant) + unread_count: int = 0 + + # Other participant info (enriched by API) + other_participant: ParticipantInfo | None = None + + # Last message preview + last_message_preview: str | None = None + + +class ConversationDetailResponse(BaseModel): + """Schema for full conversation detail with messages.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + conversation_type: ConversationType + subject: str + vendor_id: int | None = None + is_closed: bool + closed_at: datetime | None + closed_by_type: ParticipantType | None = None + closed_by_id: int | None = None + last_message_at: datetime | None + message_count: int + created_at: datetime + updated_at: datetime + + # Participants with enriched info + participants: list[ParticipantResponse] = [] + + # Messages ordered by created_at + messages: list[MessageResponse] = [] + + # Current user's unread count + unread_count: int = 0 + + # Vendor info if applicable + vendor_name: str | None = None + + +class ConversationListResponse(BaseModel): + """Schema for paginated conversation list.""" + + conversations: list[ConversationSummary] + total: int + total_unread: int + skip: int + limit: int + + +# Backward compatibility alias +ConversationResponse = ConversationDetailResponse + + +# ============================================================================ +# Unread Count Schemas +# ============================================================================ + + +class UnreadCountResponse(BaseModel): + """Schema for unread message count (for header badge).""" + + total_unread: int + + +# ============================================================================ +# Notification Preferences Schemas +# ============================================================================ + + +class NotificationPreferencesUpdate(BaseModel): + """Schema for updating notification preferences.""" + + email_notifications: bool | None = None + muted: bool | None = None + + +# ============================================================================ +# Conversation Action Schemas +# ============================================================================ + + +class CloseConversationResponse(BaseModel): + """Response after closing a conversation.""" + + success: bool + message: str + conversation_id: int + + +class ReopenConversationResponse(BaseModel): + """Response after reopening a conversation.""" + + success: bool + message: str + conversation_id: int + + +class MarkReadResponse(BaseModel): + """Response after marking conversation as read.""" + + success: bool + conversation_id: int + unread_count: int + + +# ============================================================================ +# Recipient Selection Schemas (for compose modal) +# ============================================================================ + + +class RecipientOption(BaseModel): + """Schema for a selectable recipient in compose modal.""" + + id: int + type: ParticipantType + name: str + email: str | None = None + vendor_id: int | None = None # For vendor users + vendor_name: str | None = None + + +class RecipientListResponse(BaseModel): + """Schema for list of available recipients.""" + + recipients: list[RecipientOption] + total: int + + +# ============================================================================ +# Admin-specific Schemas +# ============================================================================ + + +class AdminConversationSummary(ConversationSummary): + """Extended conversation summary with vendor info for admin views.""" + + vendor_name: str | None = None + vendor_code: str | None = None + + +class AdminConversationListResponse(BaseModel): + """Schema for admin conversation list with vendor info.""" + + conversations: list[AdminConversationSummary] + total: int + total_unread: int + skip: int + limit: int + + +class AdminMessageStats(BaseModel): + """Messaging statistics for admin dashboard.""" + + total_conversations: int = 0 + open_conversations: int = 0 + closed_conversations: int = 0 + total_messages: int = 0 + + # By type + admin_vendor_conversations: int = 0 + vendor_customer_conversations: int = 0 + admin_customer_conversations: int = 0 + + # Unread + unread_admin: int = 0 diff --git a/app/modules/messaging/schemas/notification.py b/app/modules/messaging/schemas/notification.py new file mode 100644 index 00000000..6a56d5fa --- /dev/null +++ b/app/modules/messaging/schemas/notification.py @@ -0,0 +1,152 @@ +# app/modules/messaging/schemas/notification.py +""" +Notification Pydantic schemas for API validation and responses. + +This module provides schemas for: +- Vendor notifications (list, read, delete) +- Notification settings management +- Notification email templates +- Unread counts and statistics +""" + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field + +# ============================================================================ +# SHARED RESPONSE SCHEMAS +# ============================================================================ + + +class MessageResponse(BaseModel): + """Generic message response for simple operations.""" + + message: str + + +class UnreadCountResponse(BaseModel): + """Response for unread notification count.""" + + unread_count: int + message: str | None = None + + +# ============================================================================ +# NOTIFICATION SCHEMAS +# ============================================================================ + + +class NotificationResponse(BaseModel): + """Single notification response.""" + + id: int + type: str + title: str + message: str + is_read: bool + read_at: datetime | None = None + priority: str = "normal" + action_url: str | None = None + metadata: dict[str, Any] | None = None + created_at: datetime + + model_config = {"from_attributes": True} + + +class NotificationListResponse(BaseModel): + """Paginated list of notifications.""" + + notifications: list[NotificationResponse] = [] + total: int = 0 + unread_count: int = 0 + message: str | None = None + + +# ============================================================================ +# NOTIFICATION SETTINGS SCHEMAS +# ============================================================================ + + +class NotificationSettingsResponse(BaseModel): + """Notification preferences response.""" + + email_notifications: bool = True + in_app_notifications: bool = True + notification_types: dict[str, bool] = Field(default_factory=dict) + message: str | None = None + + +class NotificationSettingsUpdate(BaseModel): + """Request model for updating notification settings.""" + + email_notifications: bool | None = None + in_app_notifications: bool | None = None + notification_types: dict[str, bool] | None = None + + +# ============================================================================ +# NOTIFICATION TEMPLATE SCHEMAS +# ============================================================================ + + +class NotificationTemplateResponse(BaseModel): + """Single notification template response.""" + + id: int + name: str + type: str + subject: str + body_html: str | None = None + body_text: str | None = None + variables: list[str] = Field(default_factory=list) + is_active: bool = True + created_at: datetime + updated_at: datetime | None = None + + model_config = {"from_attributes": True} + + +class NotificationTemplateListResponse(BaseModel): + """List of notification templates.""" + + templates: list[NotificationTemplateResponse] = [] + message: str | None = None + + +class NotificationTemplateUpdate(BaseModel): + """Request model for updating notification template.""" + + subject: str | None = Field(None, max_length=200) + body_html: str | None = None + body_text: str | None = None + is_active: bool | None = None + + +# ============================================================================ +# TEST NOTIFICATION SCHEMA +# ============================================================================ + + +class TestNotificationRequest(BaseModel): + """Request model for sending test notification.""" + + template_id: int | None = Field(None, description="Template to use") + email: str | None = Field(None, description="Override recipient email") + notification_type: str = Field( + default="test", description="Type of notification to send" + ) + + +# ============================================================================ +# ADMIN ALERT STATISTICS SCHEMA +# ============================================================================ + + +class AlertStatisticsResponse(BaseModel): + """Response for alert statistics.""" + + total_alerts: int = 0 + active_alerts: int = 0 + critical_alerts: int = 0 + resolved_today: int = 0 diff --git a/app/modules/messaging/services/__init__.py b/app/modules/messaging/services/__init__.py index 10869b53..0d528774 100644 --- a/app/modules/messaging/services/__init__.py +++ b/app/modules/messaging/services/__init__.py @@ -2,24 +2,29 @@ """ Messaging module services. -Re-exports messaging-related services from their source locations. +This module contains the canonical implementations of messaging-related services. """ -from app.services.messaging_service import ( +from app.modules.messaging.services.messaging_service import ( messaging_service, MessagingService, ) -from app.services.message_attachment_service import ( +from app.modules.messaging.services.message_attachment_service import ( message_attachment_service, MessageAttachmentService, ) -from app.services.admin_notification_service import ( +from app.modules.messaging.services.admin_notification_service import ( admin_notification_service, AdminNotificationService, + platform_alert_service, + PlatformAlertService, + # Constants + NotificationType, + Priority, + AlertType, + Severity, ) -# Note: notification_service is a placeholder - not yet implemented - __all__ = [ "messaging_service", "MessagingService", @@ -27,4 +32,11 @@ __all__ = [ "MessageAttachmentService", "admin_notification_service", "AdminNotificationService", + "platform_alert_service", + "PlatformAlertService", + # Constants + "NotificationType", + "Priority", + "AlertType", + "Severity", ] diff --git a/app/modules/messaging/services/admin_notification_service.py b/app/modules/messaging/services/admin_notification_service.py new file mode 100644 index 00000000..8491432d --- /dev/null +++ b/app/modules/messaging/services/admin_notification_service.py @@ -0,0 +1,702 @@ +# app/modules/messaging/services/admin_notification_service.py +""" +Admin notification service. + +Provides functionality for: +- Creating and managing admin notifications +- Managing platform alerts +- Notification statistics and queries +""" + +import logging +from datetime import datetime, timedelta +from typing import Any + +from sqlalchemy import and_, case, func +from sqlalchemy.orm import Session + +from app.modules.messaging.models.admin_notification import AdminNotification +from models.database.admin import PlatformAlert +from models.schema.admin import AdminNotificationCreate, PlatformAlertCreate + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# NOTIFICATION TYPES +# ============================================================================ + +class NotificationType: + """Notification type constants.""" + + SYSTEM_ALERT = "system_alert" + IMPORT_FAILURE = "import_failure" + EXPORT_FAILURE = "export_failure" + ORDER_SYNC_FAILURE = "order_sync_failure" + VENDOR_ISSUE = "vendor_issue" + CUSTOMER_MESSAGE = "customer_message" + VENDOR_MESSAGE = "vendor_message" + SECURITY_ALERT = "security_alert" + PERFORMANCE_ALERT = "performance_alert" + ORDER_EXCEPTION = "order_exception" + CRITICAL_ERROR = "critical_error" + + +class Priority: + """Priority level constants.""" + + LOW = "low" + NORMAL = "normal" + HIGH = "high" + CRITICAL = "critical" + + +class AlertType: + """Platform alert type constants.""" + + SECURITY = "security" + PERFORMANCE = "performance" + CAPACITY = "capacity" + INTEGRATION = "integration" + DATABASE = "database" + SYSTEM = "system" + + +class Severity: + """Alert severity constants.""" + + INFO = "info" + WARNING = "warning" + ERROR = "error" + CRITICAL = "critical" + + +# ============================================================================ +# ADMIN NOTIFICATION SERVICE +# ============================================================================ + + +class AdminNotificationService: + """Service for managing admin notifications.""" + + def create_notification( + self, + db: Session, + notification_type: str, + title: str, + message: str, + priority: str = Priority.NORMAL, + action_required: bool = False, + action_url: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> AdminNotification: + """ + Create a new admin notification. + + Args: + db: Database session + notification_type: Type of notification + title: Notification title + message: Notification message + priority: Priority level (low, normal, high, critical) + action_required: Whether action is required + action_url: URL to relevant admin page + metadata: Additional contextual data + + Returns: + Created AdminNotification + """ + notification = AdminNotification( + type=notification_type, + title=title, + message=message, + priority=priority, + action_required=action_required, + action_url=action_url, + notification_metadata=metadata, + ) + db.add(notification) + db.flush() + + logger.info( + f"Created notification: {notification_type} - {title} (priority: {priority})" + ) + + return notification + + def create_from_schema( + self, + db: Session, + data: AdminNotificationCreate, + ) -> AdminNotification: + """Create notification from Pydantic schema.""" + return self.create_notification( + db=db, + notification_type=data.type, + title=data.title, + message=data.message, + priority=data.priority, + action_required=data.action_required, + action_url=data.action_url, + metadata=data.metadata, + ) + + def get_notifications( + self, + db: Session, + priority: str | None = None, + is_read: bool | None = None, + notification_type: str | None = None, + skip: int = 0, + limit: int = 50, + ) -> tuple[list[AdminNotification], int, int]: + """ + Get paginated admin notifications. + + Returns: + Tuple of (notifications, total_count, unread_count) + """ + query = db.query(AdminNotification) + + # Apply filters + if priority: + query = query.filter(AdminNotification.priority == priority) + + if is_read is not None: + query = query.filter(AdminNotification.is_read == is_read) + + if notification_type: + query = query.filter(AdminNotification.type == notification_type) + + # Get counts + total = query.count() + unread_count = ( + db.query(AdminNotification) + .filter(AdminNotification.is_read == False) # noqa: E712 + .count() + ) + + # Get paginated results ordered by priority and date + priority_order = case( + (AdminNotification.priority == "critical", 1), + (AdminNotification.priority == "high", 2), + (AdminNotification.priority == "normal", 3), + (AdminNotification.priority == "low", 4), + else_=5, + ) + + notifications = ( + query.order_by( + AdminNotification.is_read, # Unread first + priority_order, + AdminNotification.created_at.desc(), + ) + .offset(skip) + .limit(limit) + .all() + ) + + return notifications, total, unread_count + + def get_unread_count(self, db: Session) -> int: + """Get count of unread notifications.""" + return ( + db.query(AdminNotification) + .filter(AdminNotification.is_read == False) # noqa: E712 + .count() + ) + + def get_recent_notifications( + self, + db: Session, + limit: int = 5, + ) -> list[AdminNotification]: + """Get recent unread notifications for header dropdown.""" + priority_order = case( + (AdminNotification.priority == "critical", 1), + (AdminNotification.priority == "high", 2), + (AdminNotification.priority == "normal", 3), + (AdminNotification.priority == "low", 4), + else_=5, + ) + + return ( + db.query(AdminNotification) + .filter(AdminNotification.is_read == False) # noqa: E712 + .order_by(priority_order, AdminNotification.created_at.desc()) + .limit(limit) + .all() + ) + + def mark_as_read( + self, + db: Session, + notification_id: int, + user_id: int, + ) -> AdminNotification | None: + """Mark a notification as read.""" + notification = ( + db.query(AdminNotification) + .filter(AdminNotification.id == notification_id) + .first() + ) + + if notification and not notification.is_read: + notification.is_read = True + notification.read_at = datetime.utcnow() + notification.read_by_user_id = user_id + db.flush() + + return notification + + def mark_all_as_read( + self, + db: Session, + user_id: int, + ) -> int: + """Mark all unread notifications as read. Returns count of updated.""" + now = datetime.utcnow() + count = ( + db.query(AdminNotification) + .filter(AdminNotification.is_read == False) # noqa: E712 + .update( + { + AdminNotification.is_read: True, + AdminNotification.read_at: now, + AdminNotification.read_by_user_id: user_id, + } + ) + ) + db.flush() + return count + + def delete_old_notifications( + self, + db: Session, + days: int = 30, + ) -> int: + """Delete notifications older than specified days.""" + cutoff = datetime.utcnow() - timedelta(days=days) + count = ( + db.query(AdminNotification) + .filter( + and_( + AdminNotification.is_read == True, # noqa: E712 + AdminNotification.created_at < cutoff, + ) + ) + .delete() + ) + db.flush() + return count + + def delete_notification( + self, + db: Session, + notification_id: int, + ) -> bool: + """ + Delete a notification by ID. + + Returns: + True if notification was deleted, False if not found. + """ + notification = ( + db.query(AdminNotification) + .filter(AdminNotification.id == notification_id) + .first() + ) + + if notification: + db.delete(notification) + db.flush() + logger.info(f"Deleted notification {notification_id}") + return True + + return False + + # ========================================================================= + # CONVENIENCE METHODS FOR CREATING SPECIFIC NOTIFICATION TYPES + # ========================================================================= + + def notify_import_failure( + self, + db: Session, + vendor_name: str, + job_id: int, + error_message: str, + vendor_id: int | None = None, + ) -> AdminNotification: + """Create notification for import job failure.""" + return self.create_notification( + db=db, + notification_type=NotificationType.IMPORT_FAILURE, + title=f"Import Failed: {vendor_name}", + message=error_message, + priority=Priority.HIGH, + action_required=True, + action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs" + if vendor_id + else "/admin/marketplace", + metadata={"vendor_name": vendor_name, "job_id": job_id, "vendor_id": vendor_id}, + ) + + def notify_order_sync_failure( + self, + db: Session, + vendor_name: str, + error_message: str, + vendor_id: int | None = None, + ) -> AdminNotification: + """Create notification for order sync failure.""" + return self.create_notification( + db=db, + notification_type=NotificationType.ORDER_SYNC_FAILURE, + title=f"Order Sync Failed: {vendor_name}", + message=error_message, + priority=Priority.HIGH, + action_required=True, + action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs" + if vendor_id + else "/admin/marketplace/letzshop", + metadata={"vendor_name": vendor_name, "vendor_id": vendor_id}, + ) + + def notify_order_exception( + self, + db: Session, + vendor_name: str, + order_number: str, + exception_count: int, + vendor_id: int | None = None, + ) -> AdminNotification: + """Create notification for order item exceptions.""" + return self.create_notification( + db=db, + notification_type=NotificationType.ORDER_EXCEPTION, + title=f"Order Exception: {order_number}", + message=f"{exception_count} item(s) need attention for order {order_number} ({vendor_name})", + priority=Priority.NORMAL, + action_required=True, + action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=exceptions" + if vendor_id + else "/admin/marketplace/letzshop", + metadata={ + "vendor_name": vendor_name, + "order_number": order_number, + "exception_count": exception_count, + "vendor_id": vendor_id, + }, + ) + + def notify_critical_error( + self, + db: Session, + error_type: str, + error_message: str, + details: dict[str, Any] | None = None, + ) -> AdminNotification: + """Create notification for critical application errors.""" + return self.create_notification( + db=db, + notification_type=NotificationType.CRITICAL_ERROR, + title=f"Critical Error: {error_type}", + message=error_message, + priority=Priority.CRITICAL, + action_required=True, + action_url="/admin/logs", + metadata=details, + ) + + def notify_vendor_issue( + self, + db: Session, + vendor_name: str, + issue_type: str, + message: str, + vendor_id: int | None = None, + ) -> AdminNotification: + """Create notification for vendor-related issues.""" + return self.create_notification( + db=db, + notification_type=NotificationType.VENDOR_ISSUE, + title=f"Vendor Issue: {vendor_name}", + message=message, + priority=Priority.HIGH, + action_required=True, + action_url=f"/admin/vendors/{vendor_id}" if vendor_id else "/admin/vendors", + metadata={ + "vendor_name": vendor_name, + "issue_type": issue_type, + "vendor_id": vendor_id, + }, + ) + + def notify_security_alert( + self, + db: Session, + title: str, + message: str, + details: dict[str, Any] | None = None, + ) -> AdminNotification: + """Create notification for security-related alerts.""" + return self.create_notification( + db=db, + notification_type=NotificationType.SECURITY_ALERT, + title=title, + message=message, + priority=Priority.CRITICAL, + action_required=True, + action_url="/admin/audit", + metadata=details, + ) + + +# ============================================================================ +# PLATFORM ALERT SERVICE +# ============================================================================ + + +class PlatformAlertService: + """Service for managing platform-wide alerts.""" + + def create_alert( + self, + db: Session, + alert_type: str, + severity: str, + title: str, + description: str | None = None, + affected_vendors: list[int] | None = None, + affected_systems: list[str] | None = None, + auto_generated: bool = True, + ) -> PlatformAlert: + """Create a new platform alert.""" + now = datetime.utcnow() + + alert = PlatformAlert( + alert_type=alert_type, + severity=severity, + title=title, + description=description, + affected_vendors=affected_vendors, + affected_systems=affected_systems, + auto_generated=auto_generated, + first_occurred_at=now, + last_occurred_at=now, + ) + db.add(alert) + db.flush() + + logger.info(f"Created platform alert: {alert_type} - {title} ({severity})") + + return alert + + def create_from_schema( + self, + db: Session, + data: PlatformAlertCreate, + ) -> PlatformAlert: + """Create alert from Pydantic schema.""" + return self.create_alert( + db=db, + alert_type=data.alert_type, + severity=data.severity, + title=data.title, + description=data.description, + affected_vendors=data.affected_vendors, + affected_systems=data.affected_systems, + auto_generated=data.auto_generated, + ) + + def get_alerts( + self, + db: Session, + severity: str | None = None, + alert_type: str | None = None, + is_resolved: bool | None = None, + skip: int = 0, + limit: int = 50, + ) -> tuple[list[PlatformAlert], int, int, int]: + """ + Get paginated platform alerts. + + Returns: + Tuple of (alerts, total_count, active_count, critical_count) + """ + query = db.query(PlatformAlert) + + # Apply filters + if severity: + query = query.filter(PlatformAlert.severity == severity) + + if alert_type: + query = query.filter(PlatformAlert.alert_type == alert_type) + + if is_resolved is not None: + query = query.filter(PlatformAlert.is_resolved == is_resolved) + + # Get counts + total = query.count() + active_count = ( + db.query(PlatformAlert) + .filter(PlatformAlert.is_resolved == False) # noqa: E712 + .count() + ) + critical_count = ( + db.query(PlatformAlert) + .filter( + and_( + PlatformAlert.is_resolved == False, # noqa: E712 + PlatformAlert.severity == Severity.CRITICAL, + ) + ) + .count() + ) + + # Get paginated results + severity_order = case( + (PlatformAlert.severity == "critical", 1), + (PlatformAlert.severity == "error", 2), + (PlatformAlert.severity == "warning", 3), + (PlatformAlert.severity == "info", 4), + else_=5, + ) + + alerts = ( + query.order_by( + PlatformAlert.is_resolved, # Unresolved first + severity_order, + PlatformAlert.last_occurred_at.desc(), + ) + .offset(skip) + .limit(limit) + .all() + ) + + return alerts, total, active_count, critical_count + + def resolve_alert( + self, + db: Session, + alert_id: int, + user_id: int, + resolution_notes: str | None = None, + ) -> PlatformAlert | None: + """Resolve a platform alert.""" + alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first() + + if alert and not alert.is_resolved: + alert.is_resolved = True + alert.resolved_at = datetime.utcnow() + alert.resolved_by_user_id = user_id + alert.resolution_notes = resolution_notes + db.flush() + + logger.info(f"Resolved platform alert {alert_id}") + + return alert + + def get_statistics(self, db: Session) -> dict[str, int]: + """Get alert statistics.""" + today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) + + total = db.query(PlatformAlert).count() + active = ( + db.query(PlatformAlert) + .filter(PlatformAlert.is_resolved == False) # noqa: E712 + .count() + ) + critical = ( + db.query(PlatformAlert) + .filter( + and_( + PlatformAlert.is_resolved == False, # noqa: E712 + PlatformAlert.severity == Severity.CRITICAL, + ) + ) + .count() + ) + resolved_today = ( + db.query(PlatformAlert) + .filter( + and_( + PlatformAlert.is_resolved == True, # noqa: E712 + PlatformAlert.resolved_at >= today_start, + ) + ) + .count() + ) + + return { + "total_alerts": total, + "active_alerts": active, + "critical_alerts": critical, + "resolved_today": resolved_today, + } + + def increment_occurrence( + self, + db: Session, + alert_id: int, + ) -> PlatformAlert | None: + """Increment occurrence count for repeated alert.""" + alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first() + + if alert: + alert.occurrence_count += 1 + alert.last_occurred_at = datetime.utcnow() + db.flush() + + return alert + + def find_similar_active_alert( + self, + db: Session, + alert_type: str, + title: str, + ) -> PlatformAlert | None: + """Find an active alert with same type and title.""" + return ( + db.query(PlatformAlert) + .filter( + and_( + PlatformAlert.alert_type == alert_type, + PlatformAlert.title == title, + PlatformAlert.is_resolved == False, # noqa: E712 + ) + ) + .first() + ) + + def create_or_increment_alert( + self, + db: Session, + alert_type: str, + severity: str, + title: str, + description: str | None = None, + affected_vendors: list[int] | None = None, + affected_systems: list[str] | None = None, + ) -> PlatformAlert: + """Create alert or increment occurrence if similar exists.""" + existing = self.find_similar_active_alert(db, alert_type, title) + + if existing: + self.increment_occurrence(db, existing.id) + return existing + + return self.create_alert( + db=db, + alert_type=alert_type, + severity=severity, + title=title, + description=description, + affected_vendors=affected_vendors, + affected_systems=affected_systems, + ) + + +# Singleton instances +admin_notification_service = AdminNotificationService() +platform_alert_service = PlatformAlertService() diff --git a/app/modules/messaging/services/message_attachment_service.py b/app/modules/messaging/services/message_attachment_service.py new file mode 100644 index 00000000..0fd3e568 --- /dev/null +++ b/app/modules/messaging/services/message_attachment_service.py @@ -0,0 +1,225 @@ +# app/modules/messaging/services/message_attachment_service.py +""" +Attachment handling service for messaging system. + +Handles file upload, validation, storage, and retrieval. +""" + +import logging +import os +import uuid +from datetime import datetime +from pathlib import Path + +from fastapi import UploadFile +from sqlalchemy.orm import Session + +from app.services.admin_settings_service import admin_settings_service + +logger = logging.getLogger(__name__) + +# Allowed MIME types for attachments +ALLOWED_MIME_TYPES = { + # Images + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + # Documents + "application/pdf", + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + # Archives + "application/zip", + # Text + "text/plain", + "text/csv", +} + +IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"} + +# Default max file size in MB +DEFAULT_MAX_FILE_SIZE_MB = 10 + + +class MessageAttachmentService: + """Service for handling message attachments.""" + + def __init__(self, storage_base: str = "uploads/messages"): + self.storage_base = storage_base + + def get_max_file_size_bytes(self, db: Session) -> int: + """Get maximum file size from platform settings.""" + max_mb = admin_settings_service.get_setting_value( + db, + "message_attachment_max_size_mb", + default=DEFAULT_MAX_FILE_SIZE_MB, + ) + try: + max_mb = int(max_mb) + except (TypeError, ValueError): + max_mb = DEFAULT_MAX_FILE_SIZE_MB + return max_mb * 1024 * 1024 # Convert to bytes + + def validate_file_type(self, mime_type: str) -> bool: + """Check if file type is allowed.""" + return mime_type in ALLOWED_MIME_TYPES + + def is_image(self, mime_type: str) -> bool: + """Check if file is an image.""" + return mime_type in IMAGE_MIME_TYPES + + async def validate_and_store( + self, + db: Session, + file: UploadFile, + conversation_id: int, + ) -> dict: + """ + Validate and store an uploaded file. + + Returns dict with file metadata for MessageAttachment creation. + + Raises: + ValueError: If file type or size is invalid + """ + # Validate MIME type + content_type = file.content_type or "application/octet-stream" + if not self.validate_file_type(content_type): + raise ValueError( + f"File type '{content_type}' not allowed. " + "Allowed types: images (JPEG, PNG, GIF, WebP), " + "PDF, Office documents, ZIP, text files." + ) + + # Read file content + content = await file.read() + file_size = len(content) + + # Validate file size + max_size = self.get_max_file_size_bytes(db) + if file_size > max_size: + raise ValueError( + f"File size {file_size / 1024 / 1024:.1f}MB exceeds " + f"maximum allowed size of {max_size / 1024 / 1024:.1f}MB" + ) + + # Generate unique filename + original_filename = file.filename or "attachment" + ext = Path(original_filename).suffix.lower() + unique_filename = f"{uuid.uuid4().hex}{ext}" + + # Create storage path: uploads/messages/YYYY/MM/conversation_id/filename + now = datetime.utcnow() + relative_path = os.path.join( + self.storage_base, + str(now.year), + f"{now.month:02d}", + str(conversation_id), + ) + + # Ensure directory exists + os.makedirs(relative_path, exist_ok=True) + + # Full file path + file_path = os.path.join(relative_path, unique_filename) + + # Write file + with open(file_path, "wb") as f: + f.write(content) + + # Prepare metadata + is_image = self.is_image(content_type) + metadata = { + "filename": unique_filename, + "original_filename": original_filename, + "file_path": file_path, + "file_size": file_size, + "mime_type": content_type, + "is_image": is_image, + } + + # Generate thumbnail for images + if is_image: + thumbnail_data = self._create_thumbnail(content, file_path) + metadata.update(thumbnail_data) + + logger.info( + f"Stored attachment {unique_filename} for conversation {conversation_id} " + f"({file_size} bytes, type: {content_type})" + ) + + return metadata + + def _create_thumbnail(self, content: bytes, original_path: str) -> dict: + """Create thumbnail for image attachments.""" + try: + from PIL import Image + import io + + img = Image.open(io.BytesIO(content)) + width, height = img.size + + # Create thumbnail + img.thumbnail((200, 200)) + + thumb_path = original_path.replace(".", "_thumb.") + img.save(thumb_path) + + return { + "image_width": width, + "image_height": height, + "thumbnail_path": thumb_path, + } + except ImportError: + logger.warning("PIL not installed, skipping thumbnail generation") + return {} + except Exception as e: + logger.error(f"Failed to create thumbnail: {e}") + return {} + + def delete_attachment( + self, file_path: str, thumbnail_path: str | None = None + ) -> bool: + """Delete attachment files from storage.""" + try: + if os.path.exists(file_path): + os.remove(file_path) + logger.info(f"Deleted attachment file: {file_path}") + + if thumbnail_path and os.path.exists(thumbnail_path): + os.remove(thumbnail_path) + logger.info(f"Deleted thumbnail: {thumbnail_path}") + + return True + except Exception as e: + logger.error(f"Failed to delete attachment {file_path}: {e}") + return False + + def get_download_url(self, file_path: str) -> str: + """ + Get download URL for an attachment. + + For local storage, returns a relative path that can be served + by the static file handler or a dedicated download endpoint. + """ + # Convert local path to URL path + # Assumes files are served from /static/uploads or similar + return f"/static/{file_path}" + + def get_file_content(self, file_path: str) -> bytes | None: + """Read file content from storage.""" + try: + if os.path.exists(file_path): + with open(file_path, "rb") as f: + return f.read() + return None + except Exception as e: + logger.error(f"Failed to read file {file_path}: {e}") + return None + + +# Singleton instance +message_attachment_service = MessageAttachmentService() diff --git a/app/modules/messaging/services/messaging_service.py b/app/modules/messaging/services/messaging_service.py new file mode 100644 index 00000000..6e128468 --- /dev/null +++ b/app/modules/messaging/services/messaging_service.py @@ -0,0 +1,684 @@ +# app/modules/messaging/services/messaging_service.py +""" +Messaging service for conversation and message management. + +Provides functionality for: +- Creating conversations between different participant types +- Sending messages with attachments +- Managing read status and unread counts +- Conversation listing with filters +- Multi-tenant data isolation +""" + +import logging +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import and_, func, or_ +from sqlalchemy.orm import Session, joinedload + +from app.modules.messaging.models.message import ( + Conversation, + ConversationParticipant, + ConversationType, + Message, + MessageAttachment, + ParticipantType, +) +from models.database.customer import Customer +from models.database.user import User + +logger = logging.getLogger(__name__) + + +class MessagingService: + """Service for managing conversations and messages.""" + + # ========================================================================= + # CONVERSATION MANAGEMENT + # ========================================================================= + + def create_conversation( + self, + db: Session, + conversation_type: ConversationType, + subject: str, + initiator_type: ParticipantType, + initiator_id: int, + recipient_type: ParticipantType, + recipient_id: int, + vendor_id: int | None = None, + initial_message: str | None = None, + ) -> Conversation: + """ + Create a new conversation between two participants. + + Args: + db: Database session + conversation_type: Type of conversation channel + subject: Conversation subject line + initiator_type: Type of initiating participant + initiator_id: ID of initiating participant + recipient_type: Type of receiving participant + recipient_id: ID of receiving participant + vendor_id: Required for vendor_customer/admin_customer types + initial_message: Optional first message content + + Returns: + Created Conversation object + """ + # Validate vendor_id requirement + if conversation_type in [ + ConversationType.VENDOR_CUSTOMER, + ConversationType.ADMIN_CUSTOMER, + ]: + if not vendor_id: + raise ValueError( + f"vendor_id required for {conversation_type.value} conversations" + ) + + # Create conversation + conversation = Conversation( + conversation_type=conversation_type, + subject=subject, + vendor_id=vendor_id, + ) + db.add(conversation) + db.flush() + + # Add participants + initiator_vendor_id = ( + vendor_id if initiator_type == ParticipantType.VENDOR else None + ) + recipient_vendor_id = ( + vendor_id if recipient_type == ParticipantType.VENDOR else None + ) + + initiator = ConversationParticipant( + conversation_id=conversation.id, + participant_type=initiator_type, + participant_id=initiator_id, + vendor_id=initiator_vendor_id, + unread_count=0, # Initiator has read their own message + ) + recipient = ConversationParticipant( + conversation_id=conversation.id, + participant_type=recipient_type, + participant_id=recipient_id, + vendor_id=recipient_vendor_id, + unread_count=1 if initial_message else 0, + ) + + db.add(initiator) + db.add(recipient) + db.flush() + + # Add initial message if provided + if initial_message: + self.send_message( + db=db, + conversation_id=conversation.id, + sender_type=initiator_type, + sender_id=initiator_id, + content=initial_message, + _skip_unread_update=True, # Already set above + ) + + logger.info( + f"Created {conversation_type.value} conversation {conversation.id}: " + f"{initiator_type.value}:{initiator_id} -> {recipient_type.value}:{recipient_id}" + ) + + return conversation + + def get_conversation( + self, + db: Session, + conversation_id: int, + participant_type: ParticipantType, + participant_id: int, + ) -> Conversation | None: + """ + Get conversation if participant has access. + + Validates that the requester is a participant. + """ + conversation = ( + db.query(Conversation) + .options( + joinedload(Conversation.participants), + joinedload(Conversation.messages).joinedload(Message.attachments), + ) + .filter(Conversation.id == conversation_id) + .first() + ) + + if not conversation: + return None + + # Verify participant access + has_access = any( + p.participant_type == participant_type + and p.participant_id == participant_id + for p in conversation.participants + ) + + if not has_access: + logger.warning( + f"Access denied to conversation {conversation_id} for " + f"{participant_type.value}:{participant_id}" + ) + return None + + return conversation + + def list_conversations( + self, + db: Session, + participant_type: ParticipantType, + participant_id: int, + vendor_id: int | None = None, + conversation_type: ConversationType | None = None, + is_closed: bool | None = None, + skip: int = 0, + limit: int = 20, + ) -> tuple[list[Conversation], int, int]: + """ + List conversations for a participant with filters. + + Returns: + Tuple of (conversations, total_count, total_unread) + """ + # Base query: conversations where user is a participant + query = ( + db.query(Conversation) + .join(ConversationParticipant) + .filter( + and_( + ConversationParticipant.participant_type == participant_type, + ConversationParticipant.participant_id == participant_id, + ) + ) + ) + + # Multi-tenant filter for vendor users + if participant_type == ParticipantType.VENDOR and vendor_id: + query = query.filter(ConversationParticipant.vendor_id == vendor_id) + + # Customer vendor isolation + if participant_type == ParticipantType.CUSTOMER and vendor_id: + query = query.filter(Conversation.vendor_id == vendor_id) + + # Type filter + if conversation_type: + query = query.filter(Conversation.conversation_type == conversation_type) + + # Status filter + if is_closed is not None: + query = query.filter(Conversation.is_closed == is_closed) + + # Get total count + total = query.count() + + # Get total unread across all conversations + unread_query = db.query( + func.sum(ConversationParticipant.unread_count) + ).filter( + and_( + ConversationParticipant.participant_type == participant_type, + ConversationParticipant.participant_id == participant_id, + ) + ) + + if participant_type == ParticipantType.VENDOR and vendor_id: + unread_query = unread_query.filter( + ConversationParticipant.vendor_id == vendor_id + ) + + total_unread = unread_query.scalar() or 0 + + # Get paginated results, ordered by last activity + conversations = ( + query.options(joinedload(Conversation.participants)) + .order_by(Conversation.last_message_at.desc().nullslast()) + .offset(skip) + .limit(limit) + .all() + ) + + return conversations, total, total_unread + + def close_conversation( + self, + db: Session, + conversation_id: int, + closer_type: ParticipantType, + closer_id: int, + ) -> Conversation | None: + """Close a conversation thread.""" + conversation = self.get_conversation( + db, conversation_id, closer_type, closer_id + ) + + if not conversation: + return None + + conversation.is_closed = True + conversation.closed_at = datetime.now(UTC) + conversation.closed_by_type = closer_type + conversation.closed_by_id = closer_id + + # Add system message + self.send_message( + db=db, + conversation_id=conversation_id, + sender_type=closer_type, + sender_id=closer_id, + content="Conversation closed", + is_system_message=True, + ) + + db.flush() + return conversation + + def reopen_conversation( + self, + db: Session, + conversation_id: int, + opener_type: ParticipantType, + opener_id: int, + ) -> Conversation | None: + """Reopen a closed conversation.""" + conversation = self.get_conversation( + db, conversation_id, opener_type, opener_id + ) + + if not conversation: + return None + + conversation.is_closed = False + conversation.closed_at = None + conversation.closed_by_type = None + conversation.closed_by_id = None + + # Add system message + self.send_message( + db=db, + conversation_id=conversation_id, + sender_type=opener_type, + sender_id=opener_id, + content="Conversation reopened", + is_system_message=True, + ) + + db.flush() + return conversation + + # ========================================================================= + # MESSAGE MANAGEMENT + # ========================================================================= + + def send_message( + self, + db: Session, + conversation_id: int, + sender_type: ParticipantType, + sender_id: int, + content: str, + attachments: list[dict[str, Any]] | None = None, + is_system_message: bool = False, + _skip_unread_update: bool = False, + ) -> Message: + """ + Send a message in a conversation. + + Args: + db: Database session + conversation_id: Target conversation ID + sender_type: Type of sender + sender_id: ID of sender + content: Message text content + attachments: List of attachment dicts with file metadata + is_system_message: Whether this is a system-generated message + _skip_unread_update: Internal flag to skip unread increment + + Returns: + Created Message object + """ + # Create message + message = Message( + conversation_id=conversation_id, + sender_type=sender_type, + sender_id=sender_id, + content=content, + is_system_message=is_system_message, + ) + db.add(message) + db.flush() + + # Add attachments if any + if attachments: + for att_data in attachments: + attachment = MessageAttachment( + message_id=message.id, + filename=att_data["filename"], + original_filename=att_data["original_filename"], + file_path=att_data["file_path"], + file_size=att_data["file_size"], + mime_type=att_data["mime_type"], + is_image=att_data.get("is_image", False), + image_width=att_data.get("image_width"), + image_height=att_data.get("image_height"), + thumbnail_path=att_data.get("thumbnail_path"), + ) + db.add(attachment) + + # Update conversation metadata + conversation = ( + db.query(Conversation).filter(Conversation.id == conversation_id).first() + ) + + if conversation: + conversation.last_message_at = datetime.now(UTC) + conversation.message_count += 1 + + # Update unread counts for other participants + if not _skip_unread_update: + db.query(ConversationParticipant).filter( + and_( + ConversationParticipant.conversation_id == conversation_id, + or_( + ConversationParticipant.participant_type != sender_type, + ConversationParticipant.participant_id != sender_id, + ), + ) + ).update( + { + ConversationParticipant.unread_count: ConversationParticipant.unread_count + + 1 + } + ) + + db.flush() + + logger.info( + f"Message {message.id} sent in conversation {conversation_id} " + f"by {sender_type.value}:{sender_id}" + ) + + return message + + def delete_message( + self, + db: Session, + message_id: int, + deleter_type: ParticipantType, + deleter_id: int, + ) -> Message | None: + """Soft delete a message (for moderation).""" + message = db.query(Message).filter(Message.id == message_id).first() + + if not message: + return None + + # Verify deleter has access to conversation + conversation = self.get_conversation( + db, message.conversation_id, deleter_type, deleter_id + ) + if not conversation: + return None + + message.is_deleted = True + message.deleted_at = datetime.now(UTC) + message.deleted_by_type = deleter_type + message.deleted_by_id = deleter_id + + db.flush() + return message + + def mark_conversation_read( + self, + db: Session, + conversation_id: int, + reader_type: ParticipantType, + reader_id: int, + ) -> bool: + """Mark all messages in conversation as read for participant.""" + result = ( + db.query(ConversationParticipant) + .filter( + and_( + ConversationParticipant.conversation_id == conversation_id, + ConversationParticipant.participant_type == reader_type, + ConversationParticipant.participant_id == reader_id, + ) + ) + .update( + { + ConversationParticipant.unread_count: 0, + ConversationParticipant.last_read_at: datetime.now(UTC), + } + ) + ) + db.flush() + return result > 0 + + def get_unread_count( + self, + db: Session, + participant_type: ParticipantType, + participant_id: int, + vendor_id: int | None = None, + ) -> int: + """Get total unread message count for a participant.""" + query = db.query(func.sum(ConversationParticipant.unread_count)).filter( + and_( + ConversationParticipant.participant_type == participant_type, + ConversationParticipant.participant_id == participant_id, + ) + ) + + if vendor_id: + query = query.filter(ConversationParticipant.vendor_id == vendor_id) + + return query.scalar() or 0 + + # ========================================================================= + # PARTICIPANT HELPERS + # ========================================================================= + + def get_participant_info( + self, + db: Session, + participant_type: ParticipantType, + participant_id: int, + ) -> dict[str, Any] | None: + """Get display info for a participant (name, email, avatar).""" + if participant_type in [ParticipantType.ADMIN, ParticipantType.VENDOR]: + user = db.query(User).filter(User.id == participant_id).first() + if user: + return { + "id": user.id, + "type": participant_type.value, + "name": f"{user.first_name or ''} {user.last_name or ''}".strip() + or user.username, + "email": user.email, + "avatar_url": None, # Could add avatar support later + } + elif participant_type == ParticipantType.CUSTOMER: + customer = db.query(Customer).filter(Customer.id == participant_id).first() + if customer: + return { + "id": customer.id, + "type": participant_type.value, + "name": f"{customer.first_name or ''} {customer.last_name or ''}".strip() + or customer.email, + "email": customer.email, + "avatar_url": None, + } + return None + + def get_other_participant( + self, + conversation: Conversation, + my_type: ParticipantType, + my_id: int, + ) -> ConversationParticipant | None: + """Get the other participant in a conversation.""" + for p in conversation.participants: + if p.participant_type != my_type or p.participant_id != my_id: + return p + return None + + # ========================================================================= + # NOTIFICATION PREFERENCES + # ========================================================================= + + def update_notification_preferences( + self, + db: Session, + conversation_id: int, + participant_type: ParticipantType, + participant_id: int, + email_notifications: bool | None = None, + muted: bool | None = None, + ) -> bool: + """Update notification preferences for a participant in a conversation.""" + updates = {} + if email_notifications is not None: + updates[ConversationParticipant.email_notifications] = email_notifications + if muted is not None: + updates[ConversationParticipant.muted] = muted + + if not updates: + return False + + result = ( + db.query(ConversationParticipant) + .filter( + and_( + ConversationParticipant.conversation_id == conversation_id, + ConversationParticipant.participant_type == participant_type, + ConversationParticipant.participant_id == participant_id, + ) + ) + .update(updates) + ) + db.flush() + return result > 0 + + # ========================================================================= + # RECIPIENT QUERIES + # ========================================================================= + + def get_vendor_recipients( + self, + db: Session, + vendor_id: int | None = None, + search: str | None = None, + skip: int = 0, + limit: int = 50, + ) -> tuple[list[dict], int]: + """ + Get list of vendor users as potential recipients. + + Args: + db: Database session + vendor_id: Optional vendor ID filter + search: Search term for name/email + skip: Pagination offset + limit: Max results + + Returns: + Tuple of (recipients list, total count) + """ + from models.database.vendor import VendorUser + + query = ( + db.query(User, VendorUser) + .join(VendorUser, User.id == VendorUser.user_id) + .filter(User.is_active == True) # noqa: E712 + ) + + if vendor_id: + query = query.filter(VendorUser.vendor_id == vendor_id) + + if search: + search_pattern = f"%{search}%" + query = query.filter( + (User.username.ilike(search_pattern)) + | (User.email.ilike(search_pattern)) + | (User.first_name.ilike(search_pattern)) + | (User.last_name.ilike(search_pattern)) + ) + + total = query.count() + results = query.offset(skip).limit(limit).all() + + recipients = [] + for user, vendor_user in results: + name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username + recipients.append({ + "id": user.id, + "type": ParticipantType.VENDOR, + "name": name, + "email": user.email, + "vendor_id": vendor_user.vendor_id, + "vendor_name": vendor_user.vendor.name if vendor_user.vendor else None, + }) + + return recipients, total + + def get_customer_recipients( + self, + db: Session, + vendor_id: int | None = None, + search: str | None = None, + skip: int = 0, + limit: int = 50, + ) -> tuple[list[dict], int]: + """ + Get list of customers as potential recipients. + + Args: + db: Database session + vendor_id: Optional vendor ID filter (required for vendor users) + search: Search term for name/email + skip: Pagination offset + limit: Max results + + Returns: + Tuple of (recipients list, total count) + """ + query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712 + + if vendor_id: + query = query.filter(Customer.vendor_id == vendor_id) + + if search: + search_pattern = f"%{search}%" + query = query.filter( + (Customer.email.ilike(search_pattern)) + | (Customer.first_name.ilike(search_pattern)) + | (Customer.last_name.ilike(search_pattern)) + ) + + total = query.count() + results = query.offset(skip).limit(limit).all() + + recipients = [] + for customer in results: + name = f"{customer.first_name or ''} {customer.last_name or ''}".strip() + recipients.append({ + "id": customer.id, + "type": ParticipantType.CUSTOMER, + "name": name or customer.email, + "email": customer.email, + "vendor_id": customer.vendor_id, + }) + + return recipients, total + + +# Singleton instance +messaging_service = MessagingService() diff --git a/app/modules/monitoring/services/__init__.py b/app/modules/monitoring/services/__init__.py index 8cd42179..5134b2b9 100644 --- a/app/modules/monitoring/services/__init__.py +++ b/app/modules/monitoring/services/__init__.py @@ -2,10 +2,10 @@ """ Monitoring module services. -Re-exports monitoring-related services from their source locations. +This module contains the canonical implementations of monitoring-related services. """ -from app.services.background_tasks_service import ( +from app.modules.monitoring.services.background_tasks_service import ( background_tasks_service, BackgroundTasksService, ) diff --git a/app/modules/monitoring/services/background_tasks_service.py b/app/modules/monitoring/services/background_tasks_service.py new file mode 100644 index 00000000..525b8566 --- /dev/null +++ b/app/modules/monitoring/services/background_tasks_service.py @@ -0,0 +1,194 @@ +# app/modules/monitoring/services/background_tasks_service.py +""" +Background Tasks Service +Service for monitoring background tasks across the system +""" + +from datetime import UTC, datetime + +from sqlalchemy import case, desc, func +from sqlalchemy.orm import Session + +from models.database.architecture_scan import ArchitectureScan +from models.database.marketplace_import_job import MarketplaceImportJob +from models.database.test_run import TestRun + + +class BackgroundTasksService: + """Service for monitoring background tasks""" + + def get_import_jobs( + self, db: Session, status: str | None = None, limit: int = 50 + ) -> list[MarketplaceImportJob]: + """Get import jobs with optional status filter""" + query = db.query(MarketplaceImportJob) + if status: + query = query.filter(MarketplaceImportJob.status == status) + return query.order_by(desc(MarketplaceImportJob.created_at)).limit(limit).all() + + def get_test_runs( + self, db: Session, status: str | None = None, limit: int = 50 + ) -> list[TestRun]: + """Get test runs with optional status filter""" + query = db.query(TestRun) + if status: + query = query.filter(TestRun.status == status) + return query.order_by(desc(TestRun.timestamp)).limit(limit).all() + + def get_running_imports(self, db: Session) -> list[MarketplaceImportJob]: + """Get currently running import jobs""" + return ( + db.query(MarketplaceImportJob) + .filter(MarketplaceImportJob.status == "processing") + .all() + ) + + def get_running_test_runs(self, db: Session) -> list[TestRun]: + """Get currently running test runs""" + # noqa: SVC-005 - Platform-level, TestRuns not vendor-scoped + return db.query(TestRun).filter(TestRun.status == "running").all() + + def get_import_stats(self, db: Session) -> dict: + """Get import job statistics""" + today_start = datetime.now(UTC).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + + stats = db.query( + func.count(MarketplaceImportJob.id).label("total"), + func.sum( + case((MarketplaceImportJob.status == "processing", 1), else_=0) + ).label("running"), + func.sum( + case( + ( + MarketplaceImportJob.status.in_( + ["completed", "completed_with_errors"] + ), + 1, + ), + else_=0, + ) + ).label("completed"), + func.sum( + case((MarketplaceImportJob.status == "failed", 1), else_=0) + ).label("failed"), + ).first() + + today_count = ( + db.query(func.count(MarketplaceImportJob.id)) + .filter(MarketplaceImportJob.created_at >= today_start) + .scalar() + or 0 + ) + + return { + "total": stats.total or 0, + "running": stats.running or 0, + "completed": stats.completed or 0, + "failed": stats.failed or 0, + "today": today_count, + } + + def get_test_run_stats(self, db: Session) -> dict: + """Get test run statistics""" + today_start = datetime.now(UTC).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + + stats = db.query( + func.count(TestRun.id).label("total"), + func.sum(case((TestRun.status == "running", 1), else_=0)).label( + "running" + ), + func.sum(case((TestRun.status == "passed", 1), else_=0)).label( + "completed" + ), + func.sum( + case((TestRun.status.in_(["failed", "error"]), 1), else_=0) + ).label("failed"), + func.avg(TestRun.duration_seconds).label("avg_duration"), + ).first() + + today_count = ( + db.query(func.count(TestRun.id)) + .filter(TestRun.timestamp >= today_start) + .scalar() + or 0 + ) + + return { + "total": stats.total or 0, + "running": stats.running or 0, + "completed": stats.completed or 0, + "failed": stats.failed or 0, + "today": today_count, + "avg_duration": round(stats.avg_duration or 0, 1), + } + + def get_code_quality_scans( + self, db: Session, status: str | None = None, limit: int = 50 + ) -> list[ArchitectureScan]: + """Get code quality scans with optional status filter""" + query = db.query(ArchitectureScan) + if status: + query = query.filter(ArchitectureScan.status == status) + return query.order_by(desc(ArchitectureScan.timestamp)).limit(limit).all() + + def get_running_scans(self, db: Session) -> list[ArchitectureScan]: + """Get currently running code quality scans""" + return ( + db.query(ArchitectureScan) + .filter(ArchitectureScan.status.in_(["pending", "running"])) + .all() + ) + + def get_scan_stats(self, db: Session) -> dict: + """Get code quality scan statistics""" + today_start = datetime.now(UTC).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + + stats = db.query( + func.count(ArchitectureScan.id).label("total"), + func.sum( + case( + (ArchitectureScan.status.in_(["pending", "running"]), 1), else_=0 + ) + ).label("running"), + func.sum( + case( + ( + ArchitectureScan.status.in_( + ["completed", "completed_with_warnings"] + ), + 1, + ), + else_=0, + ) + ).label("completed"), + func.sum( + case((ArchitectureScan.status == "failed", 1), else_=0) + ).label("failed"), + func.avg(ArchitectureScan.duration_seconds).label("avg_duration"), + ).first() + + today_count = ( + db.query(func.count(ArchitectureScan.id)) + .filter(ArchitectureScan.timestamp >= today_start) + .scalar() + or 0 + ) + + return { + "total": stats.total or 0, + "running": stats.running or 0, + "completed": stats.completed or 0, + "failed": stats.failed or 0, + "today": today_count, + "avg_duration": round(stats.avg_duration or 0, 1), + } + + +# Singleton instance +background_tasks_service = BackgroundTasksService() diff --git a/app/modules/orders/models/__init__.py b/app/modules/orders/models/__init__.py index d57f3e28..93c0a710 100644 --- a/app/modules/orders/models/__init__.py +++ b/app/modules/orders/models/__init__.py @@ -2,15 +2,12 @@ """ Orders module database models. -Re-exports order-related models from their source locations. +This module contains the canonical implementations of order-related models. """ -from models.database.order import ( - Order, - OrderItem, -) -from models.database.order_item_exception import OrderItemException -from models.database.invoice import ( +from app.modules.orders.models.order import Order, OrderItem +from app.modules.orders.models.order_item_exception import OrderItemException +from app.modules.orders.models.invoice import ( Invoice, InvoiceStatus, VATRegime, diff --git a/app/modules/orders/models/invoice.py b/app/modules/orders/models/invoice.py new file mode 100644 index 00000000..f06eb0ff --- /dev/null +++ b/app/modules/orders/models/invoice.py @@ -0,0 +1,215 @@ +# app/modules/orders/models/invoice.py +""" +Invoice database models for the OMS. + +Provides models for: +- VendorInvoiceSettings: Per-vendor invoice configuration (company details, VAT, numbering) +- Invoice: Invoice records with snapshots of seller/buyer details +""" + +import enum + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, +) +from sqlalchemy.dialects.sqlite import JSON +from sqlalchemy.orm import relationship + +from app.core.database import Base +from models.database.base import TimestampMixin + + +class VendorInvoiceSettings(Base, TimestampMixin): + """ + Per-vendor invoice configuration. + + Stores company details, VAT number, invoice numbering preferences, + and payment information for invoice generation. + + One-to-one relationship with Vendor. + """ + + __tablename__ = "vendor_invoice_settings" + + id = Column(Integer, primary_key=True, index=True) + vendor_id = Column( + Integer, ForeignKey("vendors.id"), unique=True, nullable=False, index=True + ) + + # Legal company details for invoice header + company_name = Column(String(255), nullable=False) # Legal name for invoices + company_address = Column(String(255), nullable=True) # Street address + company_city = Column(String(100), nullable=True) + company_postal_code = Column(String(20), nullable=True) + company_country = Column(String(2), nullable=False, default="LU") # ISO country code + + # VAT information + vat_number = Column(String(50), nullable=True) # e.g., "LU12345678" + is_vat_registered = Column(Boolean, default=True, nullable=False) + + # OSS (One-Stop-Shop) for EU VAT + is_oss_registered = Column(Boolean, default=False, nullable=False) + oss_registration_country = Column(String(2), nullable=True) # ISO country code + + # Invoice numbering + invoice_prefix = Column(String(20), default="INV", nullable=False) + invoice_next_number = Column(Integer, default=1, nullable=False) + invoice_number_padding = Column(Integer, default=5, nullable=False) # e.g., INV00001 + + # Payment information + payment_terms = Column(Text, nullable=True) # e.g., "Payment due within 30 days" + bank_name = Column(String(255), nullable=True) + bank_iban = Column(String(50), nullable=True) + bank_bic = Column(String(20), nullable=True) + + # Invoice footer + footer_text = Column(Text, nullable=True) # Custom footer text + + # Default VAT rate for Luxembourg invoices (17% standard) + default_vat_rate = Column(Numeric(5, 2), default=17.00, nullable=False) + + # Relationships + vendor = relationship("Vendor", back_populates="invoice_settings") + + def __repr__(self): + return f"" + + def get_next_invoice_number(self) -> str: + """Generate the next invoice number and increment counter.""" + number = str(self.invoice_next_number).zfill(self.invoice_number_padding) + return f"{self.invoice_prefix}{number}" + + +class InvoiceStatus(str, enum.Enum): + """Invoice status enumeration.""" + + DRAFT = "draft" + ISSUED = "issued" + PAID = "paid" + CANCELLED = "cancelled" + + +class VATRegime(str, enum.Enum): + """VAT regime for invoice calculation.""" + + DOMESTIC = "domestic" # Same country as seller + OSS = "oss" # EU cross-border with OSS registration + REVERSE_CHARGE = "reverse_charge" # B2B with valid VAT number + ORIGIN = "origin" # Cross-border without OSS (use origin VAT) + EXEMPT = "exempt" # VAT exempt + + +class Invoice(Base, TimestampMixin): + """ + Invoice record with snapshots of seller/buyer details. + + Stores complete invoice data including snapshots of seller and buyer + details at time of creation for audit purposes. + """ + + __tablename__ = "invoices" + + id = Column(Integer, primary_key=True, index=True) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) + order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True) + + # Invoice identification + invoice_number = Column(String(50), nullable=False) + invoice_date = Column(DateTime(timezone=True), nullable=False) + + # Status + status = Column(String(20), default=InvoiceStatus.DRAFT.value, nullable=False) + + # Seller details snapshot (captured at invoice creation) + seller_details = Column(JSON, nullable=False) + # Structure: { + # "company_name": str, + # "address": str, + # "city": str, + # "postal_code": str, + # "country": str, + # "vat_number": str | None + # } + + # Buyer details snapshot (captured at invoice creation) + buyer_details = Column(JSON, nullable=False) + # Structure: { + # "name": str, + # "email": str, + # "address": str, + # "city": str, + # "postal_code": str, + # "country": str, + # "vat_number": str | None (for B2B) + # } + + # Line items snapshot + line_items = Column(JSON, nullable=False) + # Structure: [{ + # "description": str, + # "quantity": int, + # "unit_price_cents": int, + # "total_cents": int, + # "sku": str | None, + # "ean": str | None + # }] + + # VAT information + vat_regime = Column(String(20), default=VATRegime.DOMESTIC.value, nullable=False) + destination_country = Column(String(2), nullable=True) # For OSS invoices + vat_rate = Column(Numeric(5, 2), nullable=False) # e.g., 17.00 for 17% + vat_rate_label = Column(String(50), nullable=True) # e.g., "Luxembourg Standard VAT" + + # Amounts (stored in cents for precision) + currency = Column(String(3), default="EUR", nullable=False) + subtotal_cents = Column(Integer, nullable=False) # Before VAT + vat_amount_cents = Column(Integer, nullable=False) # VAT amount + total_cents = Column(Integer, nullable=False) # After VAT + + # Payment information + payment_terms = Column(Text, nullable=True) + bank_details = Column(JSON, nullable=True) # IBAN, BIC snapshot + footer_text = Column(Text, nullable=True) + + # PDF storage + pdf_generated_at = Column(DateTime(timezone=True), nullable=True) + pdf_path = Column(String(500), nullable=True) # Path to stored PDF + + # Notes + notes = Column(Text, nullable=True) # Internal notes + + # Relationships + vendor = relationship("Vendor", back_populates="invoices") + order = relationship("Order", back_populates="invoices") + + __table_args__ = ( + Index("idx_invoice_vendor_number", "vendor_id", "invoice_number", unique=True), + Index("idx_invoice_vendor_date", "vendor_id", "invoice_date"), + Index("idx_invoice_status", "vendor_id", "status"), + ) + + def __repr__(self): + return f"" + + @property + def subtotal(self) -> float: + """Get subtotal in EUR.""" + return self.subtotal_cents / 100 + + @property + def vat_amount(self) -> float: + """Get VAT amount in EUR.""" + return self.vat_amount_cents / 100 + + @property + def total(self) -> float: + """Get total in EUR.""" + return self.total_cents / 100 diff --git a/app/modules/orders/models/order.py b/app/modules/orders/models/order.py new file mode 100644 index 00000000..66354910 --- /dev/null +++ b/app/modules/orders/models/order.py @@ -0,0 +1,406 @@ +# app/modules/orders/models/order.py +""" +Unified Order model for all sales channels. + +Supports: +- Direct orders (from vendor's own storefront) +- Marketplace orders (Letzshop, etc.) + +Design principles: +- Customer/address data is snapshotted at order time (preserves history) +- customer_id FK links to Customer record (may be inactive for marketplace imports) +- channel field distinguishes order source +- external_* fields store marketplace-specific references + +Money values are stored as integer cents (e.g., €105.91 = 10591). +See docs/architecture/money-handling.md for details. +""" + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, +) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.modules.orders.models.order_item_exception import OrderItemException +from sqlalchemy.dialects.sqlite import JSON +from sqlalchemy.orm import relationship + +from app.core.database import Base +from app.utils.money import cents_to_euros, euros_to_cents +from models.database.base import TimestampMixin + + +class Order(Base, TimestampMixin): + """ + Unified order model for all sales channels. + + Stores orders from direct sales and marketplaces (Letzshop, etc.) + with snapshotted customer and address data. + + All monetary amounts are stored as integer cents for precision. + """ + + __tablename__ = "orders" + + id = Column(Integer, primary_key=True, index=True) + vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) + customer_id = Column( + Integer, ForeignKey("customers.id"), nullable=False, index=True + ) + order_number = Column(String(100), nullable=False, unique=True, index=True) + + # === Channel/Source === + channel = Column( + String(50), default="direct", nullable=False, index=True + ) # direct, letzshop + + # External references (for marketplace orders) + external_order_id = Column( + String(100), nullable=True, index=True + ) # Marketplace order ID + external_shipment_id = Column( + String(100), nullable=True, index=True + ) # Marketplace shipment ID + external_order_number = Column(String(100), nullable=True) # Marketplace order # + external_data = Column(JSON, nullable=True) # Raw marketplace data for debugging + + # === Status === + # pending: awaiting confirmation + # processing: confirmed, being prepared + # shipped: shipped with tracking + # delivered: delivered to customer + # cancelled: order cancelled/declined + # refunded: order refunded + status = Column(String(50), nullable=False, default="pending", index=True) + + # === Financials (stored as integer cents) === + subtotal_cents = Column(Integer, nullable=True) # May not be available from marketplace + tax_amount_cents = Column(Integer, nullable=True) + shipping_amount_cents = Column(Integer, nullable=True) + discount_amount_cents = Column(Integer, nullable=True) + total_amount_cents = Column(Integer, nullable=False) + currency = Column(String(10), default="EUR") + + # === VAT Information === + # VAT regime: domestic, oss, reverse_charge, origin, exempt + vat_regime = Column(String(20), nullable=True) + # VAT rate as percentage (e.g., 17.00 for 17%) + vat_rate = Column(Numeric(5, 2), nullable=True) + # Human-readable VAT label (e.g., "Luxembourg VAT 17%") + vat_rate_label = Column(String(100), nullable=True) + # Destination country for cross-border sales (ISO code) + vat_destination_country = Column(String(2), nullable=True) + + # === Customer Snapshot (preserved at order time) === + customer_first_name = Column(String(100), nullable=False) + customer_last_name = Column(String(100), nullable=False) + customer_email = Column(String(255), nullable=False) + customer_phone = Column(String(50), nullable=True) + customer_locale = Column(String(10), nullable=True) # en, fr, de, lb + + # === Shipping Address Snapshot === + ship_first_name = Column(String(100), nullable=False) + ship_last_name = Column(String(100), nullable=False) + ship_company = Column(String(200), nullable=True) + ship_address_line_1 = Column(String(255), nullable=False) + ship_address_line_2 = Column(String(255), nullable=True) + ship_city = Column(String(100), nullable=False) + ship_postal_code = Column(String(20), nullable=False) + ship_country_iso = Column(String(5), nullable=False) + + # === Billing Address Snapshot === + bill_first_name = Column(String(100), nullable=False) + bill_last_name = Column(String(100), nullable=False) + bill_company = Column(String(200), nullable=True) + bill_address_line_1 = Column(String(255), nullable=False) + bill_address_line_2 = Column(String(255), nullable=True) + bill_city = Column(String(100), nullable=False) + bill_postal_code = Column(String(20), nullable=False) + bill_country_iso = Column(String(5), nullable=False) + + # === Tracking === + shipping_method = Column(String(100), nullable=True) + tracking_number = Column(String(100), nullable=True) + tracking_provider = Column(String(100), nullable=True) + tracking_url = Column(String(500), nullable=True) # Full tracking URL + shipment_number = Column(String(100), nullable=True) # Carrier shipment number (e.g., H74683403433) + shipping_carrier = Column(String(50), nullable=True) # Carrier code (greco, colissimo, etc.) + + # === Notes === + customer_notes = Column(Text, nullable=True) + internal_notes = Column(Text, nullable=True) + + # === Timestamps === + order_date = Column( + DateTime(timezone=True), nullable=False + ) # When customer placed order + confirmed_at = Column(DateTime(timezone=True), nullable=True) + shipped_at = Column(DateTime(timezone=True), nullable=True) + delivered_at = Column(DateTime(timezone=True), nullable=True) + cancelled_at = Column(DateTime(timezone=True), nullable=True) + + # === Relationships === + vendor = relationship("Vendor") + customer = relationship("Customer", back_populates="orders") + items = relationship( + "OrderItem", back_populates="order", cascade="all, delete-orphan" + ) + invoices = relationship( + "Invoice", back_populates="order", cascade="all, delete-orphan" + ) + + # Composite indexes for common queries + __table_args__ = ( + Index("idx_order_vendor_status", "vendor_id", "status"), + Index("idx_order_vendor_channel", "vendor_id", "channel"), + Index("idx_order_vendor_date", "vendor_id", "order_date"), + ) + + def __repr__(self): + return f"" + + # === PRICE PROPERTIES (Euro convenience accessors) === + + @property + def subtotal(self) -> float | None: + """Get subtotal in euros.""" + if self.subtotal_cents is not None: + return cents_to_euros(self.subtotal_cents) + return None + + @subtotal.setter + def subtotal(self, value: float | None): + """Set subtotal from euros.""" + self.subtotal_cents = euros_to_cents(value) if value is not None else None + + @property + def tax_amount(self) -> float | None: + """Get tax amount in euros.""" + if self.tax_amount_cents is not None: + return cents_to_euros(self.tax_amount_cents) + return None + + @tax_amount.setter + def tax_amount(self, value: float | None): + """Set tax amount from euros.""" + self.tax_amount_cents = euros_to_cents(value) if value is not None else None + + @property + def shipping_amount(self) -> float | None: + """Get shipping amount in euros.""" + if self.shipping_amount_cents is not None: + return cents_to_euros(self.shipping_amount_cents) + return None + + @shipping_amount.setter + def shipping_amount(self, value: float | None): + """Set shipping amount from euros.""" + self.shipping_amount_cents = euros_to_cents(value) if value is not None else None + + @property + def discount_amount(self) -> float | None: + """Get discount amount in euros.""" + if self.discount_amount_cents is not None: + return cents_to_euros(self.discount_amount_cents) + return None + + @discount_amount.setter + def discount_amount(self, value: float | None): + """Set discount amount from euros.""" + self.discount_amount_cents = euros_to_cents(value) if value is not None else None + + @property + def total_amount(self) -> float: + """Get total amount in euros.""" + return cents_to_euros(self.total_amount_cents) + + @total_amount.setter + def total_amount(self, value: float): + """Set total amount from euros.""" + self.total_amount_cents = euros_to_cents(value) + + # === NAME PROPERTIES === + + @property + def customer_full_name(self) -> str: + """Customer full name from snapshot.""" + return f"{self.customer_first_name} {self.customer_last_name}".strip() + + @property + def ship_full_name(self) -> str: + """Shipping address full name.""" + return f"{self.ship_first_name} {self.ship_last_name}".strip() + + @property + def bill_full_name(self) -> str: + """Billing address full name.""" + return f"{self.bill_first_name} {self.bill_last_name}".strip() + + @property + def is_marketplace_order(self) -> bool: + """Check if this is a marketplace order.""" + return self.channel != "direct" + + @property + def is_fully_shipped(self) -> bool: + """Check if all items are fully shipped.""" + if not self.items: + return False + return all(item.is_fully_shipped for item in self.items) + + @property + def is_partially_shipped(self) -> bool: + """Check if some items are shipped but not all.""" + if not self.items: + return False + has_shipped = any(item.shipped_quantity > 0 for item in self.items) + all_shipped = all(item.is_fully_shipped for item in self.items) + return has_shipped and not all_shipped + + @property + def shipped_item_count(self) -> int: + """Count of fully shipped items.""" + return sum(1 for item in self.items if item.is_fully_shipped) + + @property + def total_shipped_units(self) -> int: + """Total quantity shipped across all items.""" + return sum(item.shipped_quantity for item in self.items) + + @property + def total_ordered_units(self) -> int: + """Total quantity ordered across all items.""" + return sum(item.quantity for item in self.items) + + +class OrderItem(Base, TimestampMixin): + """ + Individual items in an order. + + Stores product snapshot at time of order plus external references + for marketplace items. + + All monetary amounts are stored as integer cents for precision. + """ + + __tablename__ = "order_items" + + id = Column(Integer, primary_key=True, index=True) + order_id = Column(Integer, ForeignKey("orders.id"), nullable=False, index=True) + product_id = Column(Integer, ForeignKey("products.id"), nullable=False) + + # === Product Snapshot (preserved at order time) === + product_name = Column(String(255), nullable=False) + product_sku = Column(String(100), nullable=True) + gtin = Column(String(50), nullable=True) # EAN/UPC/ISBN etc. + gtin_type = Column(String(20), nullable=True) # ean13, upc, isbn, etc. + + # === Pricing (stored as integer cents) === + quantity = Column(Integer, nullable=False) + unit_price_cents = Column(Integer, nullable=False) + total_price_cents = Column(Integer, nullable=False) + + # === External References (for marketplace items) === + external_item_id = Column(String(100), nullable=True) # e.g., Letzshop inventory unit ID + external_variant_id = Column(String(100), nullable=True) # e.g., Letzshop variant ID + + # === Item State (for marketplace confirmation flow) === + # confirmed_available: item confirmed and available + # confirmed_unavailable: item confirmed but not available (declined) + item_state = Column(String(50), nullable=True) + + # === Inventory Tracking === + inventory_reserved = Column(Boolean, default=False) + inventory_fulfilled = Column(Boolean, default=False) + + # === Shipment Tracking === + shipped_quantity = Column(Integer, default=0, nullable=False) # Units shipped so far + + # === Exception Tracking === + # True if product was not found by GTIN during import (linked to placeholder) + needs_product_match = Column(Boolean, default=False, index=True) + + # === Relationships === + order = relationship("Order", back_populates="items") + product = relationship("Product") + exception = relationship( + "OrderItemException", + back_populates="order_item", + uselist=False, + cascade="all, delete-orphan", + ) + + def __repr__(self): + return f"" + + # === PRICE PROPERTIES (Euro convenience accessors) === + + @property + def unit_price(self) -> float: + """Get unit price in euros.""" + return cents_to_euros(self.unit_price_cents) + + @unit_price.setter + def unit_price(self, value: float): + """Set unit price from euros.""" + self.unit_price_cents = euros_to_cents(value) + + @property + def total_price(self) -> float: + """Get total price in euros.""" + return cents_to_euros(self.total_price_cents) + + @total_price.setter + def total_price(self, value: float): + """Set total price from euros.""" + self.total_price_cents = euros_to_cents(value) + + # === STATUS PROPERTIES === + + @property + def is_confirmed(self) -> bool: + """Check if item has been confirmed (available or unavailable).""" + return self.item_state in ("confirmed_available", "confirmed_unavailable") + + @property + def is_available(self) -> bool: + """Check if item is confirmed as available.""" + return self.item_state == "confirmed_available" + + @property + def is_declined(self) -> bool: + """Check if item was declined (unavailable).""" + return self.item_state == "confirmed_unavailable" + + @property + def has_unresolved_exception(self) -> bool: + """Check if item has an unresolved exception blocking confirmation.""" + if not self.exception: + return False + return self.exception.blocks_confirmation + + # === SHIPMENT PROPERTIES === + + @property + def remaining_quantity(self) -> int: + """Quantity not yet shipped.""" + return max(0, self.quantity - self.shipped_quantity) + + @property + def is_fully_shipped(self) -> bool: + """Check if all units have been shipped.""" + return self.shipped_quantity >= self.quantity + + @property + def is_partially_shipped(self) -> bool: + """Check if some but not all units have been shipped.""" + return 0 < self.shipped_quantity < self.quantity diff --git a/app/modules/orders/models/order_item_exception.py b/app/modules/orders/models/order_item_exception.py new file mode 100644 index 00000000..fa9548f6 --- /dev/null +++ b/app/modules/orders/models/order_item_exception.py @@ -0,0 +1,117 @@ +# app/modules/orders/models/order_item_exception.py +""" +Order Item Exception model for tracking unmatched products during marketplace imports. + +When a marketplace order contains a GTIN that doesn't match any product in the +vendor's catalog, the order is still imported but the item is linked to a +placeholder product and an exception is recorded here for resolution. +""" + +from sqlalchemy import ( + Column, + DateTime, + ForeignKey, + Index, + Integer, + String, + Text, +) +from sqlalchemy.orm import relationship + +from app.core.database import Base +from models.database.base import TimestampMixin + + +class OrderItemException(Base, TimestampMixin): + """ + Tracks unmatched order items requiring admin/vendor resolution. + + When a marketplace order is imported and a product cannot be found by GTIN, + the order item is linked to a placeholder product and this exception record + is created. The order cannot be confirmed until all exceptions are resolved. + """ + + __tablename__ = "order_item_exceptions" + + id = Column(Integer, primary_key=True, index=True) + + # Link to the order item (one-to-one) + order_item_id = Column( + Integer, + ForeignKey("order_items.id", ondelete="CASCADE"), + nullable=False, + unique=True, + ) + + # Vendor ID for efficient querying (denormalized from order) + vendor_id = Column( + Integer, ForeignKey("vendors.id"), nullable=False, index=True + ) + + # Original data from marketplace (preserved for matching) + original_gtin = Column(String(50), nullable=True, index=True) + original_product_name = Column(String(500), nullable=True) + original_sku = Column(String(100), nullable=True) + + # Exception classification + # product_not_found: GTIN not in vendor catalog + # gtin_mismatch: GTIN format issue + # duplicate_gtin: Multiple products with same GTIN + exception_type = Column( + String(50), nullable=False, default="product_not_found" + ) + + # Resolution status + # pending: Awaiting resolution + # resolved: Product has been assigned + # ignored: Marked as ignored (still blocks confirmation) + status = Column(String(50), nullable=False, default="pending", index=True) + + # Resolution details (populated when resolved) + resolved_product_id = Column( + Integer, ForeignKey("products.id"), nullable=True + ) + resolved_at = Column(DateTime(timezone=True), nullable=True) + resolved_by = Column(Integer, ForeignKey("users.id"), nullable=True) + resolution_notes = Column(Text, nullable=True) + + # Relationships + order_item = relationship("OrderItem", back_populates="exception") + vendor = relationship("Vendor") + resolved_product = relationship("Product") + resolver = relationship("User") + + # Composite indexes for common queries + __table_args__ = ( + Index("idx_exception_vendor_status", "vendor_id", "status"), + Index("idx_exception_gtin", "vendor_id", "original_gtin"), + ) + + def __repr__(self): + return ( + f"" + ) + + @property + def is_pending(self) -> bool: + """Check if exception is pending resolution.""" + return self.status == "pending" + + @property + def is_resolved(self) -> bool: + """Check if exception has been resolved.""" + return self.status == "resolved" + + @property + def is_ignored(self) -> bool: + """Check if exception has been ignored.""" + return self.status == "ignored" + + @property + def blocks_confirmation(self) -> bool: + """Check if this exception blocks order confirmation.""" + # Both pending and ignored exceptions block confirmation + return self.status in ("pending", "ignored") diff --git a/app/modules/orders/schemas/__init__.py b/app/modules/orders/schemas/__init__.py index 9f9c1749..8b9119f2 100644 --- a/app/modules/orders/schemas/__init__.py +++ b/app/modules/orders/schemas/__init__.py @@ -2,34 +2,133 @@ """ Orders module Pydantic schemas. -Re-exports order-related schemas from their source locations. +This module contains the canonical implementations of order-related schemas. """ -from models.schema.order import ( - OrderCreate, - OrderItemCreate, - OrderResponse, - OrderItemResponse, - OrderListResponse, +from app.modules.orders.schemas.order import ( + # Address schemas AddressSnapshot, + AddressSnapshotResponse, + # Order item schemas + OrderItemCreate, + OrderItemExceptionBrief, + OrderItemResponse, + # Customer schemas CustomerSnapshot, + CustomerSnapshotResponse, + # Order CRUD schemas + OrderCreate, + OrderUpdate, + OrderTrackingUpdate, + OrderItemStateUpdate, + # Order response schemas + OrderResponse, + OrderDetailResponse, + OrderListResponse, + OrderListItem, + # Admin schemas + AdminOrderItem, + AdminOrderListResponse, + AdminOrderStats, + AdminOrderStatusUpdate, + AdminVendorWithOrders, + AdminVendorsWithOrdersResponse, + # Letzshop schemas + LetzshopOrderImport, + LetzshopShippingInfo, + LetzshopOrderConfirmItem, + LetzshopOrderConfirmRequest, + # Shipping schemas + MarkAsShippedRequest, + ShippingLabelInfo, ) -from models.schema.invoice import ( + +from app.modules.orders.schemas.invoice import ( + # Invoice settings schemas + VendorInvoiceSettingsCreate, + VendorInvoiceSettingsUpdate, + VendorInvoiceSettingsResponse, + # Line item schemas + InvoiceLineItem, + InvoiceLineItemResponse, + # Address schemas + InvoiceSellerDetails, + InvoiceBuyerDetails, + # Invoice CRUD schemas + InvoiceCreate, + InvoiceManualCreate, InvoiceResponse, + InvoiceListResponse, + InvoiceStatusUpdate, + # Pagination + InvoiceListPaginatedResponse, + # PDF + InvoicePDFGeneratedResponse, + InvoiceStatsResponse, + # Backward compatibility InvoiceSettingsCreate, InvoiceSettingsUpdate, InvoiceSettingsResponse, ) __all__ = [ - "OrderCreate", - "OrderItemCreate", - "OrderResponse", - "OrderItemResponse", - "OrderListResponse", + # Address schemas "AddressSnapshot", + "AddressSnapshotResponse", + # Order item schemas + "OrderItemCreate", + "OrderItemExceptionBrief", + "OrderItemResponse", + # Customer schemas "CustomerSnapshot", + "CustomerSnapshotResponse", + # Order CRUD schemas + "OrderCreate", + "OrderUpdate", + "OrderTrackingUpdate", + "OrderItemStateUpdate", + # Order response schemas + "OrderResponse", + "OrderDetailResponse", + "OrderListResponse", + "OrderListItem", + # Admin schemas + "AdminOrderItem", + "AdminOrderListResponse", + "AdminOrderStats", + "AdminOrderStatusUpdate", + "AdminVendorWithOrders", + "AdminVendorsWithOrdersResponse", + # Letzshop schemas + "LetzshopOrderImport", + "LetzshopShippingInfo", + "LetzshopOrderConfirmItem", + "LetzshopOrderConfirmRequest", + # Shipping schemas + "MarkAsShippedRequest", + "ShippingLabelInfo", + # Invoice settings schemas + "VendorInvoiceSettingsCreate", + "VendorInvoiceSettingsUpdate", + "VendorInvoiceSettingsResponse", + # Line item schemas + "InvoiceLineItem", + "InvoiceLineItemResponse", + # Invoice address schemas + "InvoiceSellerDetails", + "InvoiceBuyerDetails", + # Invoice CRUD schemas + "InvoiceCreate", + "InvoiceManualCreate", "InvoiceResponse", + "InvoiceListResponse", + "InvoiceStatusUpdate", + # Pagination + "InvoiceListPaginatedResponse", + # PDF + "InvoicePDFGeneratedResponse", + "InvoiceStatsResponse", + # Backward compatibility "InvoiceSettingsCreate", "InvoiceSettingsUpdate", "InvoiceSettingsResponse", diff --git a/app/modules/orders/schemas/invoice.py b/app/modules/orders/schemas/invoice.py new file mode 100644 index 00000000..37291df4 --- /dev/null +++ b/app/modules/orders/schemas/invoice.py @@ -0,0 +1,316 @@ +# app/modules/orders/schemas/invoice.py +""" +Pydantic schemas for invoice operations. + +Supports invoice settings management and invoice generation. +""" + +from datetime import datetime +from decimal import Decimal + +from pydantic import BaseModel, ConfigDict, Field + +# ============================================================================ +# Invoice Settings Schemas +# ============================================================================ + + +class VendorInvoiceSettingsCreate(BaseModel): + """Schema for creating vendor invoice settings.""" + + company_name: str = Field(..., min_length=1, max_length=255) + company_address: str | None = Field(None, max_length=255) + company_city: str | None = Field(None, max_length=100) + company_postal_code: str | None = Field(None, max_length=20) + company_country: str = Field(default="LU", min_length=2, max_length=2) + + vat_number: str | None = Field(None, max_length=50) + is_vat_registered: bool = True + + is_oss_registered: bool = False + oss_registration_country: str | None = Field(None, min_length=2, max_length=2) + + invoice_prefix: str = Field(default="INV", max_length=20) + invoice_number_padding: int = Field(default=5, ge=1, le=10) + + payment_terms: str | None = None + bank_name: str | None = Field(None, max_length=255) + bank_iban: str | None = Field(None, max_length=50) + bank_bic: str | None = Field(None, max_length=20) + + footer_text: str | None = None + default_vat_rate: Decimal = Field(default=Decimal("17.00"), ge=0, le=100) + + +class VendorInvoiceSettingsUpdate(BaseModel): + """Schema for updating vendor invoice settings.""" + + company_name: str | None = Field(None, min_length=1, max_length=255) + company_address: str | None = Field(None, max_length=255) + company_city: str | None = Field(None, max_length=100) + company_postal_code: str | None = Field(None, max_length=20) + company_country: str | None = Field(None, min_length=2, max_length=2) + + vat_number: str | None = None + is_vat_registered: bool | None = None + + is_oss_registered: bool | None = None + oss_registration_country: str | None = None + + invoice_prefix: str | None = Field(None, max_length=20) + invoice_number_padding: int | None = Field(None, ge=1, le=10) + + payment_terms: str | None = None + bank_name: str | None = Field(None, max_length=255) + bank_iban: str | None = Field(None, max_length=50) + bank_bic: str | None = Field(None, max_length=20) + + footer_text: str | None = None + default_vat_rate: Decimal | None = Field(None, ge=0, le=100) + + +class VendorInvoiceSettingsResponse(BaseModel): + """Schema for vendor invoice settings response.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + vendor_id: int + + company_name: str + company_address: str | None + company_city: str | None + company_postal_code: str | None + company_country: str + + vat_number: str | None + is_vat_registered: bool + + is_oss_registered: bool + oss_registration_country: str | None + + invoice_prefix: str + invoice_next_number: int + invoice_number_padding: int + + payment_terms: str | None + bank_name: str | None + bank_iban: str | None + bank_bic: str | None + + footer_text: str | None + default_vat_rate: Decimal + + created_at: datetime + updated_at: datetime + + +# ============================================================================ +# Invoice Line Item Schemas +# ============================================================================ + + +class InvoiceLineItem(BaseModel): + """Schema for invoice line item.""" + + description: str + quantity: int = Field(..., ge=1) + unit_price_cents: int + total_cents: int + sku: str | None = None + ean: str | None = None + + +class InvoiceLineItemResponse(BaseModel): + """Schema for invoice line item in response.""" + + description: str + quantity: int + unit_price_cents: int + total_cents: int + sku: str | None = None + ean: str | None = None + + @property + def unit_price(self) -> float: + return self.unit_price_cents / 100 + + @property + def total(self) -> float: + return self.total_cents / 100 + + +# ============================================================================ +# Invoice Address Schemas +# ============================================================================ + + +class InvoiceSellerDetails(BaseModel): + """Seller details for invoice.""" + + company_name: str + address: str | None = None + city: str | None = None + postal_code: str | None = None + country: str + vat_number: str | None = None + + +class InvoiceBuyerDetails(BaseModel): + """Buyer details for invoice.""" + + name: str + email: str | None = None + address: str | None = None + city: str | None = None + postal_code: str | None = None + country: str + vat_number: str | None = None # For B2B + + +# ============================================================================ +# Invoice Schemas +# ============================================================================ + + +class InvoiceCreate(BaseModel): + """Schema for creating an invoice from an order.""" + + order_id: int + notes: str | None = None + + +class InvoiceManualCreate(BaseModel): + """Schema for creating a manual invoice (without order).""" + + buyer_details: InvoiceBuyerDetails + line_items: list[InvoiceLineItem] + notes: str | None = None + payment_terms: str | None = None + + +class InvoiceResponse(BaseModel): + """Schema for invoice response.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + vendor_id: int + order_id: int | None + + invoice_number: str + invoice_date: datetime + status: str + + seller_details: dict + buyer_details: dict + line_items: list[dict] + + vat_regime: str + destination_country: str | None + vat_rate: Decimal + vat_rate_label: str | None + + currency: str + subtotal_cents: int + vat_amount_cents: int + total_cents: int + + payment_terms: str | None + bank_details: dict | None + footer_text: str | None + + pdf_generated_at: datetime | None + pdf_path: str | None + + notes: str | None + + created_at: datetime + updated_at: datetime + + @property + def subtotal(self) -> float: + return self.subtotal_cents / 100 + + @property + def vat_amount(self) -> float: + return self.vat_amount_cents / 100 + + @property + def total(self) -> float: + return self.total_cents / 100 + + +class InvoiceListResponse(BaseModel): + """Schema for invoice list response (summary).""" + + model_config = ConfigDict(from_attributes=True) + + id: int + invoice_number: str + invoice_date: datetime + status: str + currency: str + total_cents: int + order_id: int | None + + # Buyer name for display + buyer_name: str | None = None + + @property + def total(self) -> float: + return self.total_cents / 100 + + +class InvoiceStatusUpdate(BaseModel): + """Schema for updating invoice status.""" + + status: str = Field(..., pattern="^(draft|issued|paid|cancelled)$") + + +# ============================================================================ +# Paginated Response +# ============================================================================ + + +class InvoiceListPaginatedResponse(BaseModel): + """Paginated invoice list response.""" + + items: list[InvoiceListResponse] + total: int + page: int + per_page: int + pages: int + + +# ============================================================================ +# PDF Response +# ============================================================================ + + +class InvoicePDFGeneratedResponse(BaseModel): + """Response for PDF generation.""" + + pdf_path: str + message: str = "PDF generated successfully" + + +class InvoiceStatsResponse(BaseModel): + """Invoice statistics response.""" + + total_invoices: int + total_revenue_cents: int + draft_count: int + issued_count: int + paid_count: int + cancelled_count: int + + @property + def total_revenue(self) -> float: + return self.total_revenue_cents / 100 + + +# Backward compatibility re-exports +InvoiceSettingsCreate = VendorInvoiceSettingsCreate +InvoiceSettingsUpdate = VendorInvoiceSettingsUpdate +InvoiceSettingsResponse = VendorInvoiceSettingsResponse diff --git a/app/modules/orders/schemas/order.py b/app/modules/orders/schemas/order.py new file mode 100644 index 00000000..96372bff --- /dev/null +++ b/app/modules/orders/schemas/order.py @@ -0,0 +1,584 @@ +# app/modules/orders/schemas/order.py +""" +Pydantic schemas for unified order operations. + +Supports both direct orders and marketplace orders (Letzshop, etc.) +with snapshotted customer and address data. +""" + +from datetime import datetime + +from pydantic import BaseModel, ConfigDict, Field + +# ============================================================================ +# Address Snapshot Schemas +# ============================================================================ + + +class AddressSnapshot(BaseModel): + """Address snapshot for order creation.""" + + first_name: str = Field(..., min_length=1, max_length=100) + last_name: str = Field(..., min_length=1, max_length=100) + company: str | None = Field(None, max_length=200) + address_line_1: str = Field(..., min_length=1, max_length=255) + address_line_2: str | None = Field(None, max_length=255) + city: str = Field(..., min_length=1, max_length=100) + postal_code: str = Field(..., min_length=1, max_length=20) + country_iso: str = Field(..., min_length=2, max_length=5) + + +class AddressSnapshotResponse(BaseModel): + """Address snapshot in order response.""" + + first_name: str + last_name: str + company: str | None + address_line_1: str + address_line_2: str | None + city: str + postal_code: str + country_iso: str + + @property + def full_name(self) -> str: + return f"{self.first_name} {self.last_name}".strip() + + +# ============================================================================ +# Order Item Schemas +# ============================================================================ + + +class OrderItemCreate(BaseModel): + """Schema for creating an order item.""" + + product_id: int + quantity: int = Field(..., ge=1) + + +class OrderItemExceptionBrief(BaseModel): + """Brief exception info for embedding in order item responses.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + original_gtin: str | None + original_product_name: str | None + exception_type: str + status: str + resolved_product_id: int | None + + +class OrderItemResponse(BaseModel): + """Schema for order item response.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + order_id: int + product_id: int + product_name: str + product_sku: str | None + gtin: str | None + gtin_type: str | None + quantity: int + unit_price: float + total_price: float + + # External references (for marketplace items) + external_item_id: str | None = None + external_variant_id: str | None = None + + # Item state (for marketplace confirmation flow) + item_state: str | None = None + + # Inventory tracking + inventory_reserved: bool + inventory_fulfilled: bool + + # Exception tracking + needs_product_match: bool = False + exception: OrderItemExceptionBrief | None = None + + created_at: datetime + updated_at: datetime + + @property + def is_confirmed(self) -> bool: + """Check if item has been confirmed (available or unavailable).""" + return self.item_state in ("confirmed_available", "confirmed_unavailable") + + @property + def is_available(self) -> bool: + """Check if item is confirmed as available.""" + return self.item_state == "confirmed_available" + + @property + def is_declined(self) -> bool: + """Check if item was declined (unavailable).""" + return self.item_state == "confirmed_unavailable" + + @property + def has_unresolved_exception(self) -> bool: + """Check if item has an unresolved exception blocking confirmation.""" + if not self.exception: + return False + return self.exception.status in ("pending", "ignored") + + +# ============================================================================ +# Customer Snapshot Schemas +# ============================================================================ + + +class CustomerSnapshot(BaseModel): + """Customer snapshot for order creation.""" + + first_name: str = Field(..., min_length=1, max_length=100) + last_name: str = Field(..., min_length=1, max_length=100) + email: str = Field(..., max_length=255) + phone: str | None = Field(None, max_length=50) + locale: str | None = Field(None, max_length=10) + + +class CustomerSnapshotResponse(BaseModel): + """Customer snapshot in order response.""" + + first_name: str + last_name: str + email: str + phone: str | None + locale: str | None + + @property + def full_name(self) -> str: + return f"{self.first_name} {self.last_name}".strip() + + +# ============================================================================ +# Order Create/Update Schemas +# ============================================================================ + + +class OrderCreate(BaseModel): + """Schema for creating an order (direct channel).""" + + customer_id: int | None = None # Optional for guest checkout + items: list[OrderItemCreate] = Field(..., min_length=1) + + # Customer info snapshot + customer: CustomerSnapshot + + # Addresses (snapshots) + shipping_address: AddressSnapshot + billing_address: AddressSnapshot | None = None # Use shipping if not provided + + # Optional fields + shipping_method: str | None = None + customer_notes: str | None = Field(None, max_length=1000) + + # Cart/session info + session_id: str | None = None + + +class OrderUpdate(BaseModel): + """Schema for updating order status.""" + + status: str | None = Field( + None, pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$" + ) + tracking_number: str | None = None + tracking_provider: str | None = None + internal_notes: str | None = None + + +class OrderTrackingUpdate(BaseModel): + """Schema for setting tracking information.""" + + tracking_number: str = Field(..., min_length=1, max_length=100) + tracking_provider: str = Field(..., min_length=1, max_length=100) + + +class OrderItemStateUpdate(BaseModel): + """Schema for updating item state (marketplace confirmation).""" + + item_id: int + state: str = Field(..., pattern="^(confirmed_available|confirmed_unavailable)$") + + +# ============================================================================ +# Order Response Schemas +# ============================================================================ + + +class OrderResponse(BaseModel): + """Schema for order response.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + vendor_id: int + customer_id: int + order_number: str + + # Channel/Source + channel: str + external_order_id: str | None = None + external_shipment_id: str | None = None + external_order_number: str | None = None + + # Status + status: str + + # Financial + subtotal: float | None + tax_amount: float | None + shipping_amount: float | None + discount_amount: float | None + total_amount: float + currency: str + + # VAT information + vat_regime: str | None = None + vat_rate: float | None = None + vat_rate_label: str | None = None + vat_destination_country: str | None = None + + # Customer snapshot + customer_first_name: str + customer_last_name: str + customer_email: str + customer_phone: str | None + customer_locale: str | None + + # Shipping address snapshot + ship_first_name: str + ship_last_name: str + ship_company: str | None + ship_address_line_1: str + ship_address_line_2: str | None + ship_city: str + ship_postal_code: str + ship_country_iso: str + + # Billing address snapshot + bill_first_name: str + bill_last_name: str + bill_company: str | None + bill_address_line_1: str + bill_address_line_2: str | None + bill_city: str + bill_postal_code: str + bill_country_iso: str + + # Tracking + shipping_method: str | None + tracking_number: str | None + tracking_provider: str | None + tracking_url: str | None = None + shipment_number: str | None = None + shipping_carrier: str | None = None + + # Notes + customer_notes: str | None + internal_notes: str | None + + # Timestamps + order_date: datetime + confirmed_at: datetime | None + shipped_at: datetime | None + delivered_at: datetime | None + cancelled_at: datetime | None + created_at: datetime + updated_at: datetime + + @property + def customer_full_name(self) -> str: + return f"{self.customer_first_name} {self.customer_last_name}".strip() + + @property + def ship_full_name(self) -> str: + return f"{self.ship_first_name} {self.ship_last_name}".strip() + + @property + def is_marketplace_order(self) -> bool: + return self.channel != "direct" + + +class OrderDetailResponse(OrderResponse): + """Schema for detailed order response with items.""" + + items: list[OrderItemResponse] = [] + + # Vendor info (enriched by API) + vendor_name: str | None = None + vendor_code: str | None = None + + +class OrderListResponse(BaseModel): + """Schema for paginated order list.""" + + orders: list[OrderResponse] + total: int + skip: int + limit: int + + +# ============================================================================ +# Order List Item (Simplified for list views) +# ============================================================================ + + +class OrderListItem(BaseModel): + """Simplified order item for list views.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + vendor_id: int + order_number: str + channel: str + status: str + + # External references + external_order_number: str | None = None + + # Customer + customer_full_name: str + customer_email: str + + # Financial + total_amount: float + currency: str + + # Shipping + ship_country_iso: str + + # Tracking + tracking_number: str | None + tracking_provider: str | None + tracking_url: str | None = None + shipment_number: str | None = None + shipping_carrier: str | None = None + + # Item count + item_count: int = 0 + + # Timestamps + order_date: datetime + confirmed_at: datetime | None + shipped_at: datetime | None + + +# ============================================================================ +# Admin Order Schemas +# ============================================================================ + + +class AdminOrderItem(BaseModel): + """Order item with vendor info for admin list view.""" + + model_config = ConfigDict(from_attributes=True) + + id: int + vendor_id: int + vendor_name: str | None = None + vendor_code: str | None = None + customer_id: int + order_number: str + channel: str + status: str + + # External references + external_order_number: str | None = None + external_shipment_id: str | None = None + + # Customer snapshot + customer_full_name: str + customer_email: str + + # Financial + subtotal: float | None + tax_amount: float | None + shipping_amount: float | None + discount_amount: float | None + total_amount: float + currency: str + + # VAT information + vat_regime: str | None = None + vat_rate: float | None = None + vat_rate_label: str | None = None + vat_destination_country: str | None = None + + # Shipping + ship_country_iso: str + tracking_number: str | None + tracking_provider: str | None + tracking_url: str | None = None + shipment_number: str | None = None + shipping_carrier: str | None = None + + # Item count + item_count: int = 0 + + # Timestamps + order_date: datetime + confirmed_at: datetime | None + shipped_at: datetime | None + delivered_at: datetime | None + cancelled_at: datetime | None + created_at: datetime + updated_at: datetime + + +class AdminOrderListResponse(BaseModel): + """Cross-vendor order list for admin.""" + + orders: list[AdminOrderItem] + total: int + skip: int + limit: int + + +class AdminOrderStats(BaseModel): + """Order statistics for admin dashboard.""" + + total_orders: int = 0 + pending_orders: int = 0 + processing_orders: int = 0 + shipped_orders: int = 0 + delivered_orders: int = 0 + cancelled_orders: int = 0 + refunded_orders: int = 0 + total_revenue: float = 0.0 + + # By channel + direct_orders: int = 0 + letzshop_orders: int = 0 + + # Vendors + vendors_with_orders: int = 0 + + +class AdminOrderStatusUpdate(BaseModel): + """Admin version of status update with reason.""" + + status: str = Field( + ..., pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$" + ) + tracking_number: str | None = None + tracking_provider: str | None = None + reason: str | None = Field(None, description="Reason for status change") + + +class AdminVendorWithOrders(BaseModel): + """Vendor with order count.""" + + id: int + name: str + vendor_code: str + order_count: int = 0 + + +class AdminVendorsWithOrdersResponse(BaseModel): + """Response for vendors with orders list.""" + + vendors: list[AdminVendorWithOrders] + + +# ============================================================================ +# Letzshop-specific Schemas +# ============================================================================ + + +class LetzshopOrderImport(BaseModel): + """Schema for importing a Letzshop order from shipment data.""" + + shipment_id: str + order_id: str + order_number: str + order_date: datetime + + # Customer + customer_email: str + customer_locale: str | None = None + + # Shipping address + ship_first_name: str + ship_last_name: str + ship_company: str | None = None + ship_address_line_1: str + ship_address_line_2: str | None = None + ship_city: str + ship_postal_code: str + ship_country_iso: str + + # Billing address + bill_first_name: str + bill_last_name: str + bill_company: str | None = None + bill_address_line_1: str + bill_address_line_2: str | None = None + bill_city: str + bill_postal_code: str + bill_country_iso: str + + # Totals + total_amount: float + currency: str = "EUR" + + # State + letzshop_state: str # unconfirmed, confirmed, declined + + # Items + inventory_units: list[dict] + + # Raw data + raw_data: dict | None = None + + +class LetzshopShippingInfo(BaseModel): + """Shipping info retrieved from Letzshop.""" + + tracking_number: str + tracking_provider: str + shipment_id: str + + +class LetzshopOrderConfirmItem(BaseModel): + """Schema for confirming/declining a single item.""" + + item_id: int + external_item_id: str + action: str = Field(..., pattern="^(confirm|decline)$") + + +class LetzshopOrderConfirmRequest(BaseModel): + """Schema for confirming/declining order items.""" + + items: list[LetzshopOrderConfirmItem] + + +# ============================================================================ +# Mark as Shipped Schemas +# ============================================================================ + + +class MarkAsShippedRequest(BaseModel): + """Schema for marking an order as shipped with tracking info.""" + + tracking_number: str | None = Field(None, max_length=100) + tracking_url: str | None = Field(None, max_length=500) + shipping_carrier: str | None = Field(None, max_length=50) + + +class ShippingLabelInfo(BaseModel): + """Shipping label information for an order.""" + + shipment_number: str | None = None + shipping_carrier: str | None = None + label_url: str | None = None + tracking_number: str | None = None + tracking_url: str | None = None diff --git a/app/modules/orders/services/__init__.py b/app/modules/orders/services/__init__.py index 9cfbb55a..abe285ba 100644 --- a/app/modules/orders/services/__init__.py +++ b/app/modules/orders/services/__init__.py @@ -2,26 +2,26 @@ """ Orders module services. -Re-exports order-related services from their source locations. +This module contains the canonical implementations of order-related services. """ -from app.services.order_service import ( +from app.modules.orders.services.order_service import ( order_service, OrderService, ) -from app.services.order_inventory_service import ( +from app.modules.orders.services.order_inventory_service import ( order_inventory_service, OrderInventoryService, ) -from app.services.order_item_exception_service import ( +from app.modules.orders.services.order_item_exception_service import ( order_item_exception_service, OrderItemExceptionService, ) -from app.services.invoice_service import ( +from app.modules.orders.services.invoice_service import ( invoice_service, InvoiceService, ) -from app.services.invoice_pdf_service import ( +from app.modules.orders.services.invoice_pdf_service import ( invoice_pdf_service, InvoicePDFService, ) diff --git a/app/modules/orders/services/invoice_pdf_service.py b/app/modules/orders/services/invoice_pdf_service.py new file mode 100644 index 00000000..f11cf41d --- /dev/null +++ b/app/modules/orders/services/invoice_pdf_service.py @@ -0,0 +1,150 @@ +# app/modules/orders/services/invoice_pdf_service.py +""" +Invoice PDF generation service using WeasyPrint. + +Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint. +Stores generated PDFs in the configured storage location. +""" + +import logging +from datetime import UTC, datetime +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader +from sqlalchemy.orm import Session + +from app.modules.orders.models.invoice import Invoice + +logger = logging.getLogger(__name__) + +# Template directory +TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices" + +# PDF storage directory (relative to project root) +PDF_STORAGE_DIR = Path("storage") / "invoices" + + +class InvoicePDFService: + """Service for generating invoice PDFs.""" + + def __init__(self): + """Initialize the PDF service with Jinja2 environment.""" + self.env = Environment( + loader=FileSystemLoader(str(TEMPLATE_DIR)), + autoescape=True, + ) + + def _ensure_storage_dir(self, vendor_id: int) -> Path: + """Ensure the storage directory exists for a vendor.""" + storage_path = PDF_STORAGE_DIR / str(vendor_id) + storage_path.mkdir(parents=True, exist_ok=True) + return storage_path + + def _get_pdf_filename(self, invoice: Invoice) -> str: + """Generate PDF filename for an invoice.""" + safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-") + return f"{safe_number}.pdf" + + def generate_pdf( + self, + db: Session, + invoice: Invoice, + force_regenerate: bool = False, + ) -> str: + """ + Generate PDF for an invoice. + + Args: + db: Database session + invoice: Invoice to generate PDF for + force_regenerate: If True, regenerate even if PDF already exists + + Returns: + Path to the generated PDF file + """ + # Check if PDF already exists + if invoice.pdf_path and not force_regenerate: + if Path(invoice.pdf_path).exists(): + logger.debug(f"PDF already exists for invoice {invoice.invoice_number}") + return invoice.pdf_path + + # Ensure storage directory exists + storage_dir = self._ensure_storage_dir(invoice.vendor_id) + pdf_filename = self._get_pdf_filename(invoice) + pdf_path = storage_dir / pdf_filename + + # Render HTML template + html_content = self._render_html(invoice) + + # Generate PDF using WeasyPrint + try: + from weasyprint import HTML + + html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR)) + html_doc.write_pdf(str(pdf_path)) + + logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}") + except ImportError: + logger.error("WeasyPrint not installed. Install with: pip install weasyprint") + raise RuntimeError("WeasyPrint not installed") + except Exception as e: + logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}") + raise + + # Update invoice record with PDF path and timestamp + invoice.pdf_path = str(pdf_path) + invoice.pdf_generated_at = datetime.now(UTC) + db.flush() + + return str(pdf_path) + + def _render_html(self, invoice: Invoice) -> str: + """Render the invoice HTML template.""" + template = self.env.get_template("invoice.html") + + context = { + "invoice": invoice, + "seller": invoice.seller_details, + "buyer": invoice.buyer_details, + "line_items": invoice.line_items, + "bank_details": invoice.bank_details, + "payment_terms": invoice.payment_terms, + "footer_text": invoice.footer_text, + "now": datetime.now(UTC), + } + + return template.render(**context) + + def get_pdf_path(self, invoice: Invoice) -> str | None: + """Get the PDF path for an invoice if it exists.""" + if invoice.pdf_path and Path(invoice.pdf_path).exists(): + return invoice.pdf_path + return None + + def delete_pdf(self, invoice: Invoice, db: Session) -> bool: + """Delete the PDF file for an invoice.""" + if not invoice.pdf_path: + return False + + pdf_path = Path(invoice.pdf_path) + if pdf_path.exists(): + try: + pdf_path.unlink() + logger.info(f"Deleted PDF for invoice {invoice.invoice_number}") + except Exception as e: + logger.error(f"Failed to delete PDF {pdf_path}: {e}") + return False + + invoice.pdf_path = None + invoice.pdf_generated_at = None + db.flush() + + return True + + def regenerate_pdf(self, db: Session, invoice: Invoice) -> str: + """Force regenerate PDF for an invoice.""" + return self.generate_pdf(db, invoice, force_regenerate=True) + + +# Singleton instance +invoice_pdf_service = InvoicePDFService() diff --git a/app/modules/orders/services/invoice_service.py b/app/modules/orders/services/invoice_service.py new file mode 100644 index 00000000..2730a14d --- /dev/null +++ b/app/modules/orders/services/invoice_service.py @@ -0,0 +1,587 @@ +# app/modules/orders/services/invoice_service.py +""" +Invoice service for generating and managing invoices. + +Handles: +- Vendor invoice settings management +- Invoice generation from orders +- VAT calculation (Luxembourg, EU, B2B reverse charge) +- Invoice number sequencing +- PDF generation (via separate module) +""" + +import logging +from datetime import UTC, datetime +from decimal import Decimal +from typing import Any + +from sqlalchemy import and_, func +from sqlalchemy.orm import Session + +from app.exceptions import ValidationException +from app.exceptions.invoice import ( + InvoiceNotFoundException, + InvoiceSettingsNotFoundException, + OrderNotFoundException, +) +from app.modules.orders.models.invoice import ( + Invoice, + InvoiceStatus, + VATRegime, + VendorInvoiceSettings, +) +from app.modules.orders.models.order import Order +from app.modules.orders.schemas.invoice import ( + VendorInvoiceSettingsCreate, + VendorInvoiceSettingsUpdate, +) +from models.database.vendor import Vendor + +logger = logging.getLogger(__name__) + + +# EU VAT rates by country code (2024 standard rates) +EU_VAT_RATES: dict[str, Decimal] = { + "AT": Decimal("20.00"), + "BE": Decimal("21.00"), + "BG": Decimal("20.00"), + "HR": Decimal("25.00"), + "CY": Decimal("19.00"), + "CZ": Decimal("21.00"), + "DK": Decimal("25.00"), + "EE": Decimal("22.00"), + "FI": Decimal("24.00"), + "FR": Decimal("20.00"), + "DE": Decimal("19.00"), + "GR": Decimal("24.00"), + "HU": Decimal("27.00"), + "IE": Decimal("23.00"), + "IT": Decimal("22.00"), + "LV": Decimal("21.00"), + "LT": Decimal("21.00"), + "LU": Decimal("17.00"), + "MT": Decimal("18.00"), + "NL": Decimal("21.00"), + "PL": Decimal("23.00"), + "PT": Decimal("23.00"), + "RO": Decimal("19.00"), + "SK": Decimal("20.00"), + "SI": Decimal("22.00"), + "ES": Decimal("21.00"), + "SE": Decimal("25.00"), +} + +LU_VAT_RATES = { + "standard": Decimal("17.00"), + "intermediate": Decimal("14.00"), + "reduced": Decimal("8.00"), + "super_reduced": Decimal("3.00"), +} + + +class InvoiceService: + """Service for invoice operations.""" + + # ========================================================================= + # VAT Calculation + # ========================================================================= + + def get_vat_rate_for_country(self, country_iso: str) -> Decimal: + """Get standard VAT rate for EU country.""" + return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00")) + + def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str: + """Get human-readable VAT rate label.""" + country_names = { + "AT": "Austria", "BE": "Belgium", "BG": "Bulgaria", "HR": "Croatia", + "CY": "Cyprus", "CZ": "Czech Republic", "DK": "Denmark", "EE": "Estonia", + "FI": "Finland", "FR": "France", "DE": "Germany", "GR": "Greece", + "HU": "Hungary", "IE": "Ireland", "IT": "Italy", "LV": "Latvia", + "LT": "Lithuania", "LU": "Luxembourg", "MT": "Malta", "NL": "Netherlands", + "PL": "Poland", "PT": "Portugal", "RO": "Romania", "SK": "Slovakia", + "SI": "Slovenia", "ES": "Spain", "SE": "Sweden", + } + country_name = country_names.get(country_iso.upper(), country_iso) + return f"{country_name} VAT {vat_rate}%" + + def determine_vat_regime( + self, + seller_country: str, + buyer_country: str, + buyer_vat_number: str | None, + seller_oss_registered: bool, + ) -> tuple[VATRegime, Decimal, str | None]: + """Determine VAT regime and rate for invoice.""" + seller_country = seller_country.upper() + buyer_country = buyer_country.upper() + + if seller_country == buyer_country: + vat_rate = self.get_vat_rate_for_country(seller_country) + return VATRegime.DOMESTIC, vat_rate, None + + if buyer_country in EU_VAT_RATES: + if buyer_vat_number: + return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country + + if seller_oss_registered: + vat_rate = self.get_vat_rate_for_country(buyer_country) + return VATRegime.OSS, vat_rate, buyer_country + else: + vat_rate = self.get_vat_rate_for_country(seller_country) + return VATRegime.ORIGIN, vat_rate, buyer_country + + return VATRegime.EXEMPT, Decimal("0.00"), buyer_country + + # ========================================================================= + # Invoice Settings Management + # ========================================================================= + + def get_settings( + self, db: Session, vendor_id: int + ) -> VendorInvoiceSettings | None: + """Get vendor invoice settings.""" + return ( + db.query(VendorInvoiceSettings) + .filter(VendorInvoiceSettings.vendor_id == vendor_id) + .first() + ) + + def get_settings_or_raise( + self, db: Session, vendor_id: int + ) -> VendorInvoiceSettings: + """Get vendor invoice settings or raise exception.""" + settings = self.get_settings(db, vendor_id) + if not settings: + raise InvoiceSettingsNotFoundException(vendor_id) + return settings + + def create_settings( + self, + db: Session, + vendor_id: int, + data: VendorInvoiceSettingsCreate, + ) -> VendorInvoiceSettings: + """Create vendor invoice settings.""" + existing = self.get_settings(db, vendor_id) + if existing: + raise ValidationException( + "Invoice settings already exist for this vendor" + ) + + settings = VendorInvoiceSettings( + vendor_id=vendor_id, + **data.model_dump(), + ) + db.add(settings) + db.flush() + db.refresh(settings) + + logger.info(f"Created invoice settings for vendor {vendor_id}") + return settings + + def update_settings( + self, + db: Session, + vendor_id: int, + data: VendorInvoiceSettingsUpdate, + ) -> VendorInvoiceSettings: + """Update vendor invoice settings.""" + settings = self.get_settings_or_raise(db, vendor_id) + + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(settings, key, value) + + settings.updated_at = datetime.now(UTC) + db.flush() + db.refresh(settings) + + logger.info(f"Updated invoice settings for vendor {vendor_id}") + return settings + + def create_settings_from_vendor( + self, + db: Session, + vendor: Vendor, + ) -> VendorInvoiceSettings: + """Create invoice settings from vendor/company info.""" + company = vendor.company + + settings = VendorInvoiceSettings( + vendor_id=vendor.id, + company_name=company.legal_name if company else vendor.name, + company_address=vendor.effective_business_address, + company_city=None, + company_postal_code=None, + company_country="LU", + vat_number=vendor.effective_tax_number, + is_vat_registered=bool(vendor.effective_tax_number), + ) + db.add(settings) + db.flush() + db.refresh(settings) + + logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}") + return settings + + # ========================================================================= + # Invoice Number Generation + # ========================================================================= + + def _get_next_invoice_number( + self, db: Session, settings: VendorInvoiceSettings + ) -> str: + """Generate next invoice number and increment counter.""" + number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding) + invoice_number = f"{settings.invoice_prefix}{number}" + + settings.invoice_next_number += 1 + db.flush() + + return invoice_number + + # ========================================================================= + # Invoice Creation + # ========================================================================= + + def create_invoice_from_order( + self, + db: Session, + vendor_id: int, + order_id: int, + notes: str | None = None, + ) -> Invoice: + """Create an invoice from an order.""" + settings = self.get_settings_or_raise(db, vendor_id) + + order = ( + db.query(Order) + .filter(and_(Order.id == order_id, Order.vendor_id == vendor_id)) + .first() + ) + if not order: + raise OrderNotFoundException(f"Order {order_id} not found") + + existing = ( + db.query(Invoice) + .filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id)) + .first() + ) + if existing: + raise ValidationException(f"Invoice already exists for order {order_id}") + + buyer_country = order.bill_country_iso + vat_regime, vat_rate, destination_country = self.determine_vat_regime( + seller_country=settings.company_country, + buyer_country=buyer_country, + buyer_vat_number=None, + seller_oss_registered=settings.is_oss_registered, + ) + + seller_details = { + "company_name": settings.company_name, + "address": settings.company_address, + "city": settings.company_city, + "postal_code": settings.company_postal_code, + "country": settings.company_country, + "vat_number": settings.vat_number, + } + + buyer_details = { + "name": f"{order.bill_first_name} {order.bill_last_name}".strip(), + "email": order.customer_email, + "address": order.bill_address_line_1, + "city": order.bill_city, + "postal_code": order.bill_postal_code, + "country": order.bill_country_iso, + "vat_number": None, + } + if order.bill_company: + buyer_details["company"] = order.bill_company + + line_items = [] + for item in order.items: + line_items.append({ + "description": item.product_name, + "quantity": item.quantity, + "unit_price_cents": item.unit_price_cents, + "total_cents": item.total_price_cents, + "sku": item.product_sku, + "ean": item.gtin, + }) + + subtotal_cents = sum(item["total_cents"] for item in line_items) + + if vat_rate > 0: + vat_amount_cents = int(subtotal_cents * float(vat_rate) / 100) + else: + vat_amount_cents = 0 + + total_cents = subtotal_cents + vat_amount_cents + + vat_rate_label = None + if vat_rate > 0: + if destination_country: + vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate) + else: + vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate) + + invoice_number = self._get_next_invoice_number(db, settings) + + invoice = Invoice( + vendor_id=vendor_id, + order_id=order_id, + invoice_number=invoice_number, + invoice_date=datetime.now(UTC), + status=InvoiceStatus.DRAFT.value, + seller_details=seller_details, + buyer_details=buyer_details, + line_items=line_items, + vat_regime=vat_regime.value, + destination_country=destination_country, + vat_rate=vat_rate, + vat_rate_label=vat_rate_label, + currency=order.currency, + subtotal_cents=subtotal_cents, + vat_amount_cents=vat_amount_cents, + total_cents=total_cents, + payment_terms=settings.payment_terms, + bank_details={ + "bank_name": settings.bank_name, + "iban": settings.bank_iban, + "bic": settings.bank_bic, + } if settings.bank_iban else None, + footer_text=settings.footer_text, + notes=notes, + ) + + db.add(invoice) + db.flush() + db.refresh(invoice) + + logger.info( + f"Created invoice {invoice_number} for order {order_id} " + f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})" + ) + + return invoice + + # ========================================================================= + # Invoice Retrieval + # ========================================================================= + + def get_invoice( + self, db: Session, vendor_id: int, invoice_id: int + ) -> Invoice | None: + """Get invoice by ID.""" + return ( + db.query(Invoice) + .filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id)) + .first() + ) + + def get_invoice_or_raise( + self, db: Session, vendor_id: int, invoice_id: int + ) -> Invoice: + """Get invoice by ID or raise exception.""" + invoice = self.get_invoice(db, vendor_id, invoice_id) + if not invoice: + raise InvoiceNotFoundException(invoice_id) + return invoice + + def get_invoice_by_number( + self, db: Session, vendor_id: int, invoice_number: str + ) -> Invoice | None: + """Get invoice by invoice number.""" + return ( + db.query(Invoice) + .filter( + and_( + Invoice.invoice_number == invoice_number, + Invoice.vendor_id == vendor_id, + ) + ) + .first() + ) + + def get_invoice_by_order_id( + self, db: Session, vendor_id: int, order_id: int + ) -> Invoice | None: + """Get invoice by order ID.""" + return ( + db.query(Invoice) + .filter( + and_( + Invoice.order_id == order_id, + Invoice.vendor_id == vendor_id, + ) + ) + .first() + ) + + def list_invoices( + self, + db: Session, + vendor_id: int, + status: str | None = None, + page: int = 1, + per_page: int = 20, + ) -> tuple[list[Invoice], int]: + """List invoices for vendor with pagination.""" + query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id) + + if status: + query = query.filter(Invoice.status == status) + + total = query.count() + + invoices = ( + query.order_by(Invoice.invoice_date.desc()) + .offset((page - 1) * per_page) + .limit(per_page) + .all() + ) + + return invoices, total + + # ========================================================================= + # Invoice Status Management + # ========================================================================= + + def update_status( + self, + db: Session, + vendor_id: int, + invoice_id: int, + new_status: str, + ) -> Invoice: + """Update invoice status.""" + invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id) + + valid_statuses = [s.value for s in InvoiceStatus] + if new_status not in valid_statuses: + raise ValidationException(f"Invalid status: {new_status}") + + if invoice.status == InvoiceStatus.CANCELLED.value: + raise ValidationException("Cannot change status of cancelled invoice") + + invoice.status = new_status + invoice.updated_at = datetime.now(UTC) + db.flush() + db.refresh(invoice) + + logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}") + return invoice + + def mark_as_issued( + self, db: Session, vendor_id: int, invoice_id: int + ) -> Invoice: + """Mark invoice as issued.""" + return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value) + + def mark_as_paid( + self, db: Session, vendor_id: int, invoice_id: int + ) -> Invoice: + """Mark invoice as paid.""" + return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value) + + def cancel_invoice( + self, db: Session, vendor_id: int, invoice_id: int + ) -> Invoice: + """Cancel invoice.""" + return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value) + + # ========================================================================= + # Statistics + # ========================================================================= + + def get_invoice_stats( + self, db: Session, vendor_id: int + ) -> dict[str, Any]: + """Get invoice statistics for vendor.""" + total_count = ( + db.query(func.count(Invoice.id)) + .filter(Invoice.vendor_id == vendor_id) + .scalar() + or 0 + ) + + total_revenue = ( + db.query(func.sum(Invoice.total_cents)) + .filter( + and_( + Invoice.vendor_id == vendor_id, + Invoice.status.in_([ + InvoiceStatus.ISSUED.value, + InvoiceStatus.PAID.value, + ]), + ) + ) + .scalar() + or 0 + ) + + draft_count = ( + db.query(func.count(Invoice.id)) + .filter( + and_( + Invoice.vendor_id == vendor_id, + Invoice.status == InvoiceStatus.DRAFT.value, + ) + ) + .scalar() + or 0 + ) + + paid_count = ( + db.query(func.count(Invoice.id)) + .filter( + and_( + Invoice.vendor_id == vendor_id, + Invoice.status == InvoiceStatus.PAID.value, + ) + ) + .scalar() + or 0 + ) + + return { + "total_invoices": total_count, + "total_revenue_cents": total_revenue, + "total_revenue": total_revenue / 100 if total_revenue else 0, + "draft_count": draft_count, + "paid_count": paid_count, + } + + # ========================================================================= + # PDF Generation + # ========================================================================= + + def generate_pdf( + self, + db: Session, + vendor_id: int, + invoice_id: int, + force_regenerate: bool = False, + ) -> str: + """Generate PDF for an invoice.""" + from app.modules.orders.services.invoice_pdf_service import invoice_pdf_service + + invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id) + return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate) + + def get_pdf_path( + self, + db: Session, + vendor_id: int, + invoice_id: int, + ) -> str | None: + """Get PDF path for an invoice if it exists.""" + from app.modules.orders.services.invoice_pdf_service import invoice_pdf_service + + invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id) + return invoice_pdf_service.get_pdf_path(invoice) + + +# Singleton instance +invoice_service = InvoiceService() diff --git a/app/modules/orders/services/order_inventory_service.py b/app/modules/orders/services/order_inventory_service.py new file mode 100644 index 00000000..2256582e --- /dev/null +++ b/app/modules/orders/services/order_inventory_service.py @@ -0,0 +1,591 @@ +# app/modules/orders/services/order_inventory_service.py +""" +Order-Inventory Integration Service. + +This service orchestrates inventory operations for orders: +- Reserve inventory when orders are confirmed +- Fulfill (deduct) inventory when orders are shipped +- Release reservations when orders are cancelled + +All operations are logged to the inventory_transactions table for audit trail. +""" + +import logging +from sqlalchemy.orm import Session + +from app.exceptions import ( + InsufficientInventoryException, + InventoryNotFoundException, + OrderNotFoundException, + ValidationException, +) +from app.modules.inventory.models.inventory import Inventory +from app.modules.inventory.models.inventory_transaction import ( + InventoryTransaction, + TransactionType, +) +from app.modules.inventory.schemas.inventory import InventoryReserve +from app.modules.inventory.services.inventory_service import inventory_service +from app.modules.orders.models.order import Order, OrderItem + +logger = logging.getLogger(__name__) + +# Default location for inventory operations +DEFAULT_LOCATION = "DEFAULT" + + +class OrderInventoryService: + """ + Orchestrate order and inventory operations together. + """ + + def get_order_with_items( + self, db: Session, vendor_id: int, order_id: int + ) -> Order: + """Get order with items or raise OrderNotFoundException.""" + order = ( + db.query(Order) + .filter(Order.id == order_id, Order.vendor_id == vendor_id) + .first() + ) + if not order: + raise OrderNotFoundException(f"Order {order_id} not found") + return order + + def _find_inventory_location( + self, db: Session, product_id: int, vendor_id: int + ) -> str | None: + """ + Find the location with available inventory for a product. + """ + inventory = ( + db.query(Inventory) + .filter( + Inventory.product_id == product_id, + Inventory.vendor_id == vendor_id, + Inventory.quantity > Inventory.reserved_quantity, + ) + .first() + ) + return inventory.location if inventory else None + + def _is_placeholder_product(self, order_item: OrderItem) -> bool: + """Check if the order item uses a placeholder product.""" + if not order_item.product: + return True + return order_item.product.gtin == "0000000000000" + + def _log_transaction( + self, + db: Session, + vendor_id: int, + product_id: int, + inventory: Inventory, + transaction_type: TransactionType, + quantity_change: int, + order: Order, + reason: str | None = None, + ) -> InventoryTransaction: + """Create an inventory transaction record for audit trail.""" + transaction = InventoryTransaction.create_transaction( + vendor_id=vendor_id, + product_id=product_id, + inventory_id=inventory.id if inventory else None, + transaction_type=transaction_type, + quantity_change=quantity_change, + quantity_after=inventory.quantity if inventory else 0, + reserved_after=inventory.reserved_quantity if inventory else 0, + location=inventory.location if inventory else None, + warehouse=inventory.warehouse if inventory else None, + order_id=order.id, + order_number=order.order_number, + reason=reason, + created_by="system", + ) + db.add(transaction) + return transaction + + def reserve_for_order( + self, + db: Session, + vendor_id: int, + order_id: int, + skip_missing: bool = True, + ) -> dict: + """Reserve inventory for all items in an order.""" + order = self.get_order_with_items(db, vendor_id, order_id) + + reserved_count = 0 + skipped_items = [] + + for item in order.items: + if self._is_placeholder_product(item): + skipped_items.append({ + "item_id": item.id, + "reason": "placeholder_product", + }) + continue + + location = self._find_inventory_location(db, item.product_id, vendor_id) + + if not location: + if skip_missing: + skipped_items.append({ + "item_id": item.id, + "product_id": item.product_id, + "reason": "no_inventory", + }) + continue + else: + raise InventoryNotFoundException( + f"No inventory found for product {item.product_id}" + ) + + try: + reserve_data = InventoryReserve( + product_id=item.product_id, + location=location, + quantity=item.quantity, + ) + updated_inventory = inventory_service.reserve_inventory( + db, vendor_id, reserve_data + ) + reserved_count += 1 + + self._log_transaction( + db=db, + vendor_id=vendor_id, + product_id=item.product_id, + inventory=updated_inventory, + transaction_type=TransactionType.RESERVE, + quantity_change=0, + order=order, + reason=f"Reserved for order {order.order_number}", + ) + + logger.info( + f"Reserved {item.quantity} units of product {item.product_id} " + f"for order {order.order_number}" + ) + except InsufficientInventoryException: + if skip_missing: + skipped_items.append({ + "item_id": item.id, + "product_id": item.product_id, + "reason": "insufficient_inventory", + }) + else: + raise + + logger.info( + f"Order {order.order_number}: reserved {reserved_count} items, " + f"skipped {len(skipped_items)}" + ) + + return { + "order_id": order_id, + "order_number": order.order_number, + "reserved_count": reserved_count, + "skipped_items": skipped_items, + } + + def fulfill_order( + self, + db: Session, + vendor_id: int, + order_id: int, + skip_missing: bool = True, + ) -> dict: + """Fulfill (deduct) inventory when an order is shipped.""" + order = self.get_order_with_items(db, vendor_id, order_id) + + fulfilled_count = 0 + skipped_items = [] + + for item in order.items: + if item.is_fully_shipped: + continue + + if self._is_placeholder_product(item): + skipped_items.append({ + "item_id": item.id, + "reason": "placeholder_product", + }) + continue + + quantity_to_fulfill = item.remaining_quantity + + location = self._find_inventory_location(db, item.product_id, vendor_id) + + if not location: + inventory = ( + db.query(Inventory) + .filter( + Inventory.product_id == item.product_id, + Inventory.vendor_id == vendor_id, + ) + .first() + ) + if inventory: + location = inventory.location + + if not location: + if skip_missing: + skipped_items.append({ + "item_id": item.id, + "product_id": item.product_id, + "reason": "no_inventory", + }) + continue + else: + raise InventoryNotFoundException( + f"No inventory found for product {item.product_id}" + ) + + try: + reserve_data = InventoryReserve( + product_id=item.product_id, + location=location, + quantity=quantity_to_fulfill, + ) + updated_inventory = inventory_service.fulfill_reservation( + db, vendor_id, reserve_data + ) + fulfilled_count += 1 + + item.shipped_quantity = item.quantity + item.inventory_fulfilled = True + + self._log_transaction( + db=db, + vendor_id=vendor_id, + product_id=item.product_id, + inventory=updated_inventory, + transaction_type=TransactionType.FULFILL, + quantity_change=-quantity_to_fulfill, + order=order, + reason=f"Fulfilled for order {order.order_number}", + ) + + logger.info( + f"Fulfilled {quantity_to_fulfill} units of product {item.product_id} " + f"for order {order.order_number}" + ) + except (InsufficientInventoryException, InventoryNotFoundException) as e: + if skip_missing: + skipped_items.append({ + "item_id": item.id, + "product_id": item.product_id, + "reason": str(e), + }) + else: + raise + + logger.info( + f"Order {order.order_number}: fulfilled {fulfilled_count} items, " + f"skipped {len(skipped_items)}" + ) + + return { + "order_id": order_id, + "order_number": order.order_number, + "fulfilled_count": fulfilled_count, + "skipped_items": skipped_items, + } + + def fulfill_item( + self, + db: Session, + vendor_id: int, + order_id: int, + item_id: int, + quantity: int | None = None, + skip_missing: bool = True, + ) -> dict: + """Fulfill (deduct) inventory for a specific order item.""" + order = self.get_order_with_items(db, vendor_id, order_id) + + item = None + for order_item in order.items: + if order_item.id == item_id: + item = order_item + break + + if not item: + raise ValidationException(f"Item {item_id} not found in order {order_id}") + + if item.is_fully_shipped: + return { + "order_id": order_id, + "item_id": item_id, + "fulfilled_quantity": 0, + "message": "Item already fully shipped", + } + + quantity_to_fulfill = quantity or item.remaining_quantity + + if quantity_to_fulfill > item.remaining_quantity: + raise ValidationException( + f"Cannot ship {quantity_to_fulfill} units - only {item.remaining_quantity} remaining" + ) + + if quantity_to_fulfill <= 0: + return { + "order_id": order_id, + "item_id": item_id, + "fulfilled_quantity": 0, + "message": "Nothing to fulfill", + } + + if self._is_placeholder_product(item): + return { + "order_id": order_id, + "item_id": item_id, + "fulfilled_quantity": 0, + "message": "Placeholder product - skipped", + } + + location = self._find_inventory_location(db, item.product_id, vendor_id) + + if not location: + inventory = ( + db.query(Inventory) + .filter( + Inventory.product_id == item.product_id, + Inventory.vendor_id == vendor_id, + ) + .first() + ) + if inventory: + location = inventory.location + + if not location: + if skip_missing: + return { + "order_id": order_id, + "item_id": item_id, + "fulfilled_quantity": 0, + "message": "No inventory found", + } + else: + raise InventoryNotFoundException( + f"No inventory found for product {item.product_id}" + ) + + try: + reserve_data = InventoryReserve( + product_id=item.product_id, + location=location, + quantity=quantity_to_fulfill, + ) + updated_inventory = inventory_service.fulfill_reservation( + db, vendor_id, reserve_data + ) + + item.shipped_quantity += quantity_to_fulfill + + if item.is_fully_shipped: + item.inventory_fulfilled = True + + self._log_transaction( + db=db, + vendor_id=vendor_id, + product_id=item.product_id, + inventory=updated_inventory, + transaction_type=TransactionType.FULFILL, + quantity_change=-quantity_to_fulfill, + order=order, + reason=f"Partial shipment for order {order.order_number}", + ) + + logger.info( + f"Fulfilled {quantity_to_fulfill} of {item.quantity} units " + f"for item {item_id} in order {order.order_number}" + ) + + return { + "order_id": order_id, + "item_id": item_id, + "fulfilled_quantity": quantity_to_fulfill, + "shipped_quantity": item.shipped_quantity, + "remaining_quantity": item.remaining_quantity, + "is_fully_shipped": item.is_fully_shipped, + } + + except (InsufficientInventoryException, InventoryNotFoundException) as e: + if skip_missing: + return { + "order_id": order_id, + "item_id": item_id, + "fulfilled_quantity": 0, + "message": str(e), + } + else: + raise + + def release_order_reservation( + self, + db: Session, + vendor_id: int, + order_id: int, + skip_missing: bool = True, + ) -> dict: + """Release reserved inventory when an order is cancelled.""" + order = self.get_order_with_items(db, vendor_id, order_id) + + released_count = 0 + skipped_items = [] + + for item in order.items: + if self._is_placeholder_product(item): + skipped_items.append({ + "item_id": item.id, + "reason": "placeholder_product", + }) + continue + + inventory = ( + db.query(Inventory) + .filter( + Inventory.product_id == item.product_id, + Inventory.vendor_id == vendor_id, + ) + .first() + ) + + if not inventory: + if skip_missing: + skipped_items.append({ + "item_id": item.id, + "product_id": item.product_id, + "reason": "no_inventory", + }) + continue + else: + raise InventoryNotFoundException( + f"No inventory found for product {item.product_id}" + ) + + try: + reserve_data = InventoryReserve( + product_id=item.product_id, + location=inventory.location, + quantity=item.quantity, + ) + updated_inventory = inventory_service.release_reservation( + db, vendor_id, reserve_data + ) + released_count += 1 + + self._log_transaction( + db=db, + vendor_id=vendor_id, + product_id=item.product_id, + inventory=updated_inventory, + transaction_type=TransactionType.RELEASE, + quantity_change=0, + order=order, + reason=f"Released for cancelled order {order.order_number}", + ) + + logger.info( + f"Released {item.quantity} units of product {item.product_id} " + f"for cancelled order {order.order_number}" + ) + except Exception as e: + if skip_missing: + skipped_items.append({ + "item_id": item.id, + "product_id": item.product_id, + "reason": str(e), + }) + else: + raise + + logger.info( + f"Order {order.order_number}: released {released_count} items, " + f"skipped {len(skipped_items)}" + ) + + return { + "order_id": order_id, + "order_number": order.order_number, + "released_count": released_count, + "skipped_items": skipped_items, + } + + def handle_status_change( + self, + db: Session, + vendor_id: int, + order_id: int, + old_status: str | None, + new_status: str, + ) -> dict | None: + """Handle inventory operations based on order status changes.""" + if old_status == new_status: + return None + + result = None + + if new_status == "processing": + result = self.reserve_for_order(db, vendor_id, order_id, skip_missing=True) + logger.info(f"Order {order_id} confirmed: inventory reserved") + + elif new_status == "shipped": + result = self.fulfill_order(db, vendor_id, order_id, skip_missing=True) + logger.info(f"Order {order_id} shipped: inventory fulfilled") + + elif new_status == "partially_shipped": + logger.info( + f"Order {order_id} partially shipped: use fulfill_item for item-level fulfillment" + ) + result = {"order_id": order_id, "status": "partially_shipped"} + + elif new_status == "cancelled": + if old_status and old_status not in ("cancelled", "refunded"): + result = self.release_order_reservation( + db, vendor_id, order_id, skip_missing=True + ) + logger.info(f"Order {order_id} cancelled: reservations released") + + return result + + def get_shipment_status( + self, + db: Session, + vendor_id: int, + order_id: int, + ) -> dict: + """Get detailed shipment status for an order.""" + order = self.get_order_with_items(db, vendor_id, order_id) + + items = [] + for item in order.items: + items.append({ + "item_id": item.id, + "product_id": item.product_id, + "product_name": item.product_name, + "quantity": item.quantity, + "shipped_quantity": item.shipped_quantity, + "remaining_quantity": item.remaining_quantity, + "is_fully_shipped": item.is_fully_shipped, + "is_partially_shipped": item.is_partially_shipped, + }) + + return { + "order_id": order_id, + "order_number": order.order_number, + "order_status": order.status, + "is_fully_shipped": order.is_fully_shipped, + "is_partially_shipped": order.is_partially_shipped, + "shipped_item_count": order.shipped_item_count, + "total_item_count": len(order.items), + "total_shipped_units": order.total_shipped_units, + "total_ordered_units": order.total_ordered_units, + "items": items, + } + + +# Create service instance +order_inventory_service = OrderInventoryService() diff --git a/app/modules/orders/services/order_item_exception_service.py b/app/modules/orders/services/order_item_exception_service.py new file mode 100644 index 00000000..ce99fb60 --- /dev/null +++ b/app/modules/orders/services/order_item_exception_service.py @@ -0,0 +1,466 @@ +# app/modules/orders/services/order_item_exception_service.py +""" +Service for managing order item exceptions (unmatched products). + +This service handles: +- Creating exceptions when products are not found during order import +- Resolving exceptions by assigning products +- Auto-matching when new products are imported +- Querying and statistics for exceptions +""" + +import logging +from datetime import UTC, datetime + +from sqlalchemy import and_, func, or_ +from sqlalchemy.orm import Session, joinedload + +from app.exceptions import ( + ExceptionAlreadyResolvedException, + InvalidProductForExceptionException, + OrderItemExceptionNotFoundException, + ProductNotFoundException, +) +from app.modules.orders.models.order import Order, OrderItem +from app.modules.orders.models.order_item_exception import OrderItemException +from models.database.product import Product + +logger = logging.getLogger(__name__) + + +class OrderItemExceptionService: + """Service for order item exception CRUD and resolution workflow.""" + + # ========================================================================= + # Exception Creation + # ========================================================================= + + def create_exception( + self, + db: Session, + order_item: OrderItem, + vendor_id: int, + original_gtin: str | None, + original_product_name: str | None, + original_sku: str | None, + exception_type: str = "product_not_found", + ) -> OrderItemException: + """Create an exception record for an unmatched order item.""" + exception = OrderItemException( + order_item_id=order_item.id, + vendor_id=vendor_id, + original_gtin=original_gtin, + original_product_name=original_product_name, + original_sku=original_sku, + exception_type=exception_type, + status="pending", + ) + db.add(exception) + db.flush() + + logger.info( + f"Created order item exception {exception.id} for order item " + f"{order_item.id}, GTIN: {original_gtin}" + ) + + return exception + + # ========================================================================= + # Exception Retrieval + # ========================================================================= + + def get_exception_by_id( + self, + db: Session, + exception_id: int, + vendor_id: int | None = None, + ) -> OrderItemException: + """Get an exception by ID, optionally filtered by vendor.""" + query = db.query(OrderItemException).filter( + OrderItemException.id == exception_id + ) + + if vendor_id is not None: + query = query.filter(OrderItemException.vendor_id == vendor_id) + + exception = query.first() + + if not exception: + raise OrderItemExceptionNotFoundException(exception_id) + + return exception + + def get_pending_exceptions( + self, + db: Session, + vendor_id: int | None = None, + status: str | None = None, + search: str | None = None, + skip: int = 0, + limit: int = 50, + ) -> tuple[list[OrderItemException], int]: + """Get exceptions with pagination and filtering.""" + query = ( + db.query(OrderItemException) + .join(OrderItem) + .join(Order) + .options( + joinedload(OrderItemException.order_item).joinedload(OrderItem.order) + ) + ) + + if vendor_id is not None: + query = query.filter(OrderItemException.vendor_id == vendor_id) + + if status: + query = query.filter(OrderItemException.status == status) + + if search: + search_pattern = f"%{search}%" + query = query.filter( + or_( + OrderItemException.original_gtin.ilike(search_pattern), + OrderItemException.original_product_name.ilike(search_pattern), + OrderItemException.original_sku.ilike(search_pattern), + Order.order_number.ilike(search_pattern), + ) + ) + + total = query.count() + + exceptions = ( + query.order_by(OrderItemException.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + return exceptions, total + + def get_exceptions_for_order( + self, + db: Session, + order_id: int, + ) -> list[OrderItemException]: + """Get all exceptions for items in an order.""" + return ( + db.query(OrderItemException) + .join(OrderItem) + .filter(OrderItem.order_id == order_id) + .all() + ) + + # ========================================================================= + # Exception Statistics + # ========================================================================= + + def get_exception_stats( + self, + db: Session, + vendor_id: int | None = None, + ) -> dict[str, int]: + """Get exception counts by status.""" + query = db.query( + OrderItemException.status, + func.count(OrderItemException.id).label("count"), + ) + + if vendor_id is not None: + query = query.filter(OrderItemException.vendor_id == vendor_id) + + results = query.group_by(OrderItemException.status).all() + + stats = { + "pending": 0, + "resolved": 0, + "ignored": 0, + "total": 0, + } + + for status, count in results: + if status in stats: + stats[status] = count + stats["total"] += count + + orders_query = ( + db.query(func.count(func.distinct(OrderItem.order_id))) + .join(OrderItemException) + .filter(OrderItemException.status == "pending") + ) + + if vendor_id is not None: + orders_query = orders_query.filter( + OrderItemException.vendor_id == vendor_id + ) + + stats["orders_with_exceptions"] = orders_query.scalar() or 0 + + return stats + + # ========================================================================= + # Exception Resolution + # ========================================================================= + + def resolve_exception( + self, + db: Session, + exception_id: int, + product_id: int, + resolved_by: int, + notes: str | None = None, + vendor_id: int | None = None, + ) -> OrderItemException: + """Resolve an exception by assigning a product.""" + exception = self.get_exception_by_id(db, exception_id, vendor_id) + + if exception.status == "resolved": + raise ExceptionAlreadyResolvedException(exception_id) + + product = db.query(Product).filter(Product.id == product_id).first() + if not product: + raise ProductNotFoundException(product_id) + + if product.vendor_id != exception.vendor_id: + raise InvalidProductForExceptionException( + product_id, "Product belongs to a different vendor" + ) + + if not product.is_active: + raise InvalidProductForExceptionException( + product_id, "Product is not active" + ) + + exception.status = "resolved" + exception.resolved_product_id = product_id + exception.resolved_at = datetime.now(UTC) + exception.resolved_by = resolved_by + exception.resolution_notes = notes + + order_item = exception.order_item + order_item.product_id = product_id + order_item.needs_product_match = False + + if product.marketplace_product: + order_item.product_name = product.marketplace_product.get_title("en") + order_item.product_sku = product.vendor_sku or order_item.product_sku + + db.flush() + + logger.info( + f"Resolved exception {exception_id} with product {product_id} " + f"by user {resolved_by}" + ) + + return exception + + def ignore_exception( + self, + db: Session, + exception_id: int, + resolved_by: int, + notes: str, + vendor_id: int | None = None, + ) -> OrderItemException: + """Mark an exception as ignored.""" + exception = self.get_exception_by_id(db, exception_id, vendor_id) + + if exception.status == "resolved": + raise ExceptionAlreadyResolvedException(exception_id) + + exception.status = "ignored" + exception.resolved_at = datetime.now(UTC) + exception.resolved_by = resolved_by + exception.resolution_notes = notes + + db.flush() + + logger.info( + f"Ignored exception {exception_id} by user {resolved_by}: {notes}" + ) + + return exception + + # ========================================================================= + # Auto-Matching + # ========================================================================= + + def auto_match_by_gtin( + self, + db: Session, + vendor_id: int, + gtin: str, + product_id: int, + ) -> list[OrderItemException]: + """Auto-resolve pending exceptions matching a GTIN.""" + if not gtin: + return [] + + pending = ( + db.query(OrderItemException) + .filter( + and_( + OrderItemException.vendor_id == vendor_id, + OrderItemException.original_gtin == gtin, + OrderItemException.status == "pending", + ) + ) + .all() + ) + + if not pending: + return [] + + product = db.query(Product).filter(Product.id == product_id).first() + if not product: + logger.warning(f"Product {product_id} not found for auto-match") + return [] + + resolved = [] + now = datetime.now(UTC) + + for exception in pending: + exception.status = "resolved" + exception.resolved_product_id = product_id + exception.resolved_at = now + exception.resolution_notes = "Auto-matched during product import" + + order_item = exception.order_item + order_item.product_id = product_id + order_item.needs_product_match = False + if product.marketplace_product: + order_item.product_name = product.marketplace_product.get_title("en") + + resolved.append(exception) + + if resolved: + db.flush() + logger.info( + f"Auto-matched {len(resolved)} exceptions for GTIN {gtin} " + f"with product {product_id}" + ) + + return resolved + + def auto_match_batch( + self, + db: Session, + vendor_id: int, + gtin_to_product: dict[str, int], + ) -> int: + """Batch auto-match multiple GTINs after bulk import.""" + if not gtin_to_product: + return 0 + + total_resolved = 0 + + for gtin, product_id in gtin_to_product.items(): + resolved = self.auto_match_by_gtin(db, vendor_id, gtin, product_id) + total_resolved += len(resolved) + + return total_resolved + + # ========================================================================= + # Confirmation Checks + # ========================================================================= + + def order_has_unresolved_exceptions( + self, + db: Session, + order_id: int, + ) -> bool: + """Check if order has any unresolved exceptions.""" + count = ( + db.query(func.count(OrderItemException.id)) + .join(OrderItem) + .filter( + and_( + OrderItem.order_id == order_id, + OrderItemException.status.in_(["pending", "ignored"]), + ) + ) + .scalar() + ) + + return count > 0 + + def get_unresolved_exception_count( + self, + db: Session, + order_id: int, + ) -> int: + """Get count of unresolved exceptions for an order.""" + return ( + db.query(func.count(OrderItemException.id)) + .join(OrderItem) + .filter( + and_( + OrderItem.order_id == order_id, + OrderItemException.status.in_(["pending", "ignored"]), + ) + ) + .scalar() + ) or 0 + + # ========================================================================= + # Bulk Operations + # ========================================================================= + + def bulk_resolve_by_gtin( + self, + db: Session, + vendor_id: int, + gtin: str, + product_id: int, + resolved_by: int, + notes: str | None = None, + ) -> int: + """Bulk resolve all pending exceptions for a GTIN.""" + product = db.query(Product).filter(Product.id == product_id).first() + if not product: + raise ProductNotFoundException(product_id) + + if product.vendor_id != vendor_id: + raise InvalidProductForExceptionException( + product_id, "Product belongs to a different vendor" + ) + + pending = ( + db.query(OrderItemException) + .filter( + and_( + OrderItemException.vendor_id == vendor_id, + OrderItemException.original_gtin == gtin, + OrderItemException.status == "pending", + ) + ) + .all() + ) + + now = datetime.now(UTC) + resolution_notes = notes or f"Bulk resolved for GTIN {gtin}" + + for exception in pending: + exception.status = "resolved" + exception.resolved_product_id = product_id + exception.resolved_at = now + exception.resolved_by = resolved_by + exception.resolution_notes = resolution_notes + + order_item = exception.order_item + order_item.product_id = product_id + order_item.needs_product_match = False + if product.marketplace_product: + order_item.product_name = product.marketplace_product.get_title("en") + + db.flush() + + logger.info( + f"Bulk resolved {len(pending)} exceptions for GTIN {gtin} " + f"with product {product_id} by user {resolved_by}" + ) + + return len(pending) + + +# Global service instance +order_item_exception_service = OrderItemExceptionService() diff --git a/app/modules/orders/services/order_service.py b/app/modules/orders/services/order_service.py new file mode 100644 index 00000000..838abbb9 --- /dev/null +++ b/app/modules/orders/services/order_service.py @@ -0,0 +1,1325 @@ +# app/modules/orders/services/order_service.py +""" +Unified order service for all sales channels. + +This service handles: +- Order creation (direct and marketplace) +- Order status management +- Order retrieval and filtering +- Customer creation for marketplace imports +- Order item management + +All orders use snapshotted customer and address data. + +All monetary calculations use integer cents internally for precision. +See docs/architecture/money-handling.md for details. +""" + +import logging +import random +import string +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import and_, func, or_ +from sqlalchemy.orm import Session + +from app.exceptions import ( + CustomerNotFoundException, + InsufficientInventoryException, + OrderNotFoundException, + ValidationException, +) +from app.modules.customers.models.customer import Customer +from app.modules.orders.models.order import Order, OrderItem +from app.modules.orders.schemas.order import ( + AddressSnapshot, + CustomerSnapshot, + OrderCreate, + OrderItemCreate, + OrderUpdate, +) +from app.services.subscription_service import ( + subscription_service, + TierLimitExceededException, +) +from app.utils.money import Money, cents_to_euros, euros_to_cents +from app.utils.vat import ( + VATResult, + calculate_vat_amount, + determine_vat_regime, +) +from models.database.marketplace_product import MarketplaceProduct +from models.database.marketplace_product_translation import MarketplaceProductTranslation +from models.database.product import Product +from models.database.vendor import Vendor + +# Placeholder product constants +PLACEHOLDER_GTIN = "0000000000000" +PLACEHOLDER_MARKETPLACE_ID = "PLACEHOLDER" + +logger = logging.getLogger(__name__) + + +class OrderService: + """Unified service for order operations across all channels.""" + + # ========================================================================= + # Order Number Generation + # ========================================================================= + + def _generate_order_number(self, db: Session, vendor_id: int) -> str: + """ + Generate unique order number. + + Format: ORD-{VENDOR_ID}-{TIMESTAMP}-{RANDOM} + Example: ORD-1-20250110-A1B2C3 + """ + timestamp = datetime.now(UTC).strftime("%Y%m%d") + random_suffix = "".join( + random.choices(string.ascii_uppercase + string.digits, k=6) + ) + order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" + + # Ensure uniqueness + while db.query(Order).filter(Order.order_number == order_number).first(): + random_suffix = "".join( + random.choices(string.ascii_uppercase + string.digits, k=6) + ) + order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" + + return order_number + + # ========================================================================= + # Tax Calculation + # ========================================================================= + + def _calculate_tax_for_order( + self, + db: Session, + vendor_id: int, + subtotal_cents: int, + billing_country_iso: str, + buyer_vat_number: str | None = None, + ) -> VATResult: + """ + Calculate tax amount for an order based on billing destination. + + Uses the shared VAT utility to determine the correct VAT regime + and rate, consistent with invoice VAT calculation. + """ + from app.modules.orders.models.invoice import VendorInvoiceSettings + + # Get vendor invoice settings for seller country and OSS status + settings = ( + db.query(VendorInvoiceSettings) + .filter(VendorInvoiceSettings.vendor_id == vendor_id) + .first() + ) + + # Default to Luxembourg if no settings exist + seller_country = settings.company_country if settings else "LU" + seller_oss_registered = settings.is_oss_registered if settings else False + + # Determine VAT regime using shared utility + return determine_vat_regime( + seller_country=seller_country, + buyer_country=billing_country_iso or "LU", + buyer_vat_number=buyer_vat_number, + seller_oss_registered=seller_oss_registered, + ) + + # ========================================================================= + # Placeholder Product Management + # ========================================================================= + + def _get_or_create_placeholder_product( + self, + db: Session, + vendor_id: int, + ) -> Product: + """ + Get or create the vendor's placeholder product for unmatched items. + """ + # Check for existing placeholder product for this vendor + placeholder = ( + db.query(Product) + .filter( + and_( + Product.vendor_id == vendor_id, + Product.gtin == PLACEHOLDER_GTIN, + ) + ) + .first() + ) + + if placeholder: + return placeholder + + # Get or create placeholder marketplace product (shared) + mp = ( + db.query(MarketplaceProduct) + .filter( + MarketplaceProduct.marketplace_product_id == PLACEHOLDER_MARKETPLACE_ID + ) + .first() + ) + + if not mp: + mp = MarketplaceProduct( + marketplace_product_id=PLACEHOLDER_MARKETPLACE_ID, + marketplace="internal", + vendor_name="system", + product_type_enum="physical", + is_active=False, + ) + db.add(mp) + db.flush() + + # Add translation for placeholder + translation = MarketplaceProductTranslation( + marketplace_product_id=mp.id, + language="en", + title="Unmatched Product (Pending Resolution)", + description=( + "This is a placeholder for products not found during order import. " + "Please resolve the exception to assign the correct product." + ), + ) + db.add(translation) + db.flush() + + logger.info(f"Created placeholder MarketplaceProduct {mp.id}") + + # Create vendor-specific placeholder product + placeholder = Product( + vendor_id=vendor_id, + marketplace_product_id=mp.id, + gtin=PLACEHOLDER_GTIN, + gtin_type="placeholder", + is_active=False, + ) + db.add(placeholder) + db.flush() + + logger.info(f"Created placeholder product {placeholder.id} for vendor {vendor_id}") + + return placeholder + + # ========================================================================= + # Customer Management + # ========================================================================= + + def find_or_create_customer( + self, + db: Session, + vendor_id: int, + email: str, + first_name: str, + last_name: str, + phone: str | None = None, + is_active: bool = False, + ) -> Customer: + """ + Find existing customer by email or create new one. + """ + # Look for existing customer by email within vendor scope + customer = ( + db.query(Customer) + .filter( + and_( + Customer.vendor_id == vendor_id, + Customer.email == email, + ) + ) + .first() + ) + + if customer: + return customer + + # Generate a unique customer number + timestamp = datetime.now(UTC).strftime("%Y%m%d%H%M%S") + random_suffix = "".join(random.choices(string.digits, k=4)) + customer_number = f"CUST-{vendor_id}-{timestamp}-{random_suffix}" + + # Create new customer + customer = Customer( + vendor_id=vendor_id, + email=email, + first_name=first_name, + last_name=last_name, + phone=phone, + customer_number=customer_number, + hashed_password="", + is_active=is_active, + ) + db.add(customer) + db.flush() + + logger.info( + f"Created {'active' if is_active else 'inactive'} customer " + f"{customer.id} for vendor {vendor_id}: {email}" + ) + + return customer + + # ========================================================================= + # Order Creation + # ========================================================================= + + def create_order( + self, + db: Session, + vendor_id: int, + order_data: OrderCreate, + ) -> Order: + """ + Create a new direct order. + """ + # Check tier limit before creating order + subscription_service.check_order_limit(db, vendor_id) + + try: + # Get or create customer + if order_data.customer_id: + customer = ( + db.query(Customer) + .filter( + and_( + Customer.id == order_data.customer_id, + Customer.vendor_id == vendor_id, + ) + ) + .first() + ) + if not customer: + raise CustomerNotFoundException(str(order_data.customer_id)) + else: + # Create customer from snapshot + customer = self.find_or_create_customer( + db=db, + vendor_id=vendor_id, + email=order_data.customer.email, + first_name=order_data.customer.first_name, + last_name=order_data.customer.last_name, + phone=order_data.customer.phone, + is_active=True, + ) + + # Calculate order totals and validate products + subtotal_cents = 0 + order_items_data = [] + + for item_data in order_data.items: + product = ( + db.query(Product) + .filter( + and_( + Product.id == item_data.product_id, + Product.vendor_id == vendor_id, + Product.is_active == True, + ) + ) + .first() + ) + + if not product: + raise ValidationException( + f"Product {item_data.product_id} not found" + ) + + # Check inventory + if product.available_inventory < item_data.quantity: + raise InsufficientInventoryException( + product_id=product.id, + requested=item_data.quantity, + available=product.available_inventory, + ) + + # Get price in cents + unit_price_cents = ( + product.sale_price_cents + or product.price_cents + ) + if not unit_price_cents: + raise ValidationException(f"Product {product.id} has no price") + + # Calculate line total in cents + line_total_cents = Money.calculate_line_total( + unit_price_cents, item_data.quantity + ) + subtotal_cents += line_total_cents + + order_items_data.append( + { + "product_id": product.id, + "product_name": product.marketplace_product.get_title("en") + if product.marketplace_product + else str(product.id), + "product_sku": product.vendor_sku, + "gtin": product.gtin, + "gtin_type": product.gtin_type, + "quantity": item_data.quantity, + "unit_price_cents": unit_price_cents, + "total_price_cents": line_total_cents, + } + ) + + # Use billing address or shipping address for VAT + billing = order_data.billing_address or order_data.shipping_address + + # Calculate VAT using vendor settings + vat_result = self._calculate_tax_for_order( + db=db, + vendor_id=vendor_id, + subtotal_cents=subtotal_cents, + billing_country_iso=billing.country_iso, + buyer_vat_number=getattr(billing, 'vat_number', None), + ) + + # Calculate amounts in cents + tax_amount_cents = calculate_vat_amount(subtotal_cents, vat_result.rate) + shipping_amount_cents = 599 if subtotal_cents < 5000 else 0 + discount_amount_cents = 0 + total_amount_cents = Money.calculate_order_total( + subtotal_cents, tax_amount_cents, shipping_amount_cents, discount_amount_cents + ) + + # Generate order number + order_number = self._generate_order_number(db, vendor_id) + + # Create order with snapshots + order = Order( + vendor_id=vendor_id, + customer_id=customer.id, + order_number=order_number, + channel="direct", + status="pending", + # Financials + subtotal_cents=subtotal_cents, + tax_amount_cents=tax_amount_cents, + shipping_amount_cents=shipping_amount_cents, + discount_amount_cents=discount_amount_cents, + total_amount_cents=total_amount_cents, + currency="EUR", + # VAT information + vat_regime=vat_result.regime.value, + vat_rate=vat_result.rate, + vat_rate_label=vat_result.label, + vat_destination_country=vat_result.destination_country, + # Customer snapshot + customer_first_name=order_data.customer.first_name, + customer_last_name=order_data.customer.last_name, + customer_email=order_data.customer.email, + customer_phone=order_data.customer.phone, + customer_locale=order_data.customer.locale, + # Shipping address snapshot + ship_first_name=order_data.shipping_address.first_name, + ship_last_name=order_data.shipping_address.last_name, + ship_company=order_data.shipping_address.company, + ship_address_line_1=order_data.shipping_address.address_line_1, + ship_address_line_2=order_data.shipping_address.address_line_2, + ship_city=order_data.shipping_address.city, + ship_postal_code=order_data.shipping_address.postal_code, + ship_country_iso=order_data.shipping_address.country_iso, + # Billing address snapshot + bill_first_name=billing.first_name, + bill_last_name=billing.last_name, + bill_company=billing.company, + bill_address_line_1=billing.address_line_1, + bill_address_line_2=billing.address_line_2, + bill_city=billing.city, + bill_postal_code=billing.postal_code, + bill_country_iso=billing.country_iso, + # Other + shipping_method=order_data.shipping_method, + customer_notes=order_data.customer_notes, + order_date=datetime.now(UTC), + ) + + db.add(order) + db.flush() + + # Create order items + for item_data in order_items_data: + order_item = OrderItem(order_id=order.id, **item_data) + db.add(order_item) + + db.flush() + db.refresh(order) + + # Increment order count for subscription tracking + subscription_service.increment_order_count(db, vendor_id) + + logger.info( + f"Order {order.order_number} created for vendor {vendor_id}, " + f"total: EUR {cents_to_euros(total_amount_cents):.2f}" + ) + + return order + + except ( + ValidationException, + InsufficientInventoryException, + CustomerNotFoundException, + TierLimitExceededException, + ): + raise + except Exception as e: + logger.error(f"Error creating order: {str(e)}") + raise ValidationException(f"Failed to create order: {str(e)}") + + def create_letzshop_order( + self, + db: Session, + vendor_id: int, + shipment_data: dict[str, Any], + skip_limit_check: bool = False, + ) -> Order: + """ + Create an order from Letzshop shipment data. + """ + from app.modules.orders.services.order_item_exception_service import ( + order_item_exception_service, + ) + + # Check tier limit before creating order + if not skip_limit_check: + can_create, message = subscription_service.can_create_order(db, vendor_id) + if not can_create: + raise TierLimitExceededException( + message=message or "Order limit exceeded", + limit_type="orders", + current=0, + limit=0, + ) + + order_data = shipment_data.get("order", {}) + + # Generate order number using Letzshop order number + letzshop_order_number = order_data.get("number", "") + order_number = f"LS-{vendor_id}-{letzshop_order_number}" + + # Check if order already exists + existing = ( + db.query(Order) + .filter(Order.order_number == order_number) + .first() + ) + if existing: + updated = False + + # Update tracking if available and not already set + tracking_data = shipment_data.get("tracking") or {} + new_tracking = tracking_data.get("code") or tracking_data.get("number") + if new_tracking and not existing.tracking_number: + existing.tracking_number = new_tracking + tracking_provider = tracking_data.get("provider") + if not tracking_provider and tracking_data.get("carrier"): + carrier = tracking_data.get("carrier", {}) + tracking_provider = carrier.get("code") or carrier.get("name") + existing.tracking_provider = tracking_provider + updated = True + logger.info( + f"Updated tracking for order {order_number}: " + f"{tracking_provider} {new_tracking}" + ) + + # Update shipment number if not already set + shipment_number = shipment_data.get("number") + if shipment_number and not existing.shipment_number: + existing.shipment_number = shipment_number + updated = True + + # Update carrier if not already set + shipment_data_obj = shipment_data.get("data") or {} + if shipment_data_obj and not existing.shipping_carrier: + carrier_name = shipment_data_obj.get("__typename", "").lower() + if carrier_name: + existing.shipping_carrier = carrier_name + updated = True + + if updated: + existing.updated_at = datetime.now(UTC) + + return existing + + # Parse inventory units + inventory_units = shipment_data.get("inventoryUnits", []) + if isinstance(inventory_units, dict): + inventory_units = inventory_units.get("nodes", []) + + # Collect all GTINs and check for items without GTINs + gtins = set() + has_items_without_gtin = False + for unit in inventory_units: + variant = unit.get("variant", {}) or {} + trade_id = variant.get("tradeId", {}) or {} + gtin = trade_id.get("number") + if gtin: + gtins.add(gtin) + else: + has_items_without_gtin = True + + # Batch query all products by GTIN + products_by_gtin: dict[str, Product] = {} + if gtins: + products = ( + db.query(Product) + .filter( + and_( + Product.vendor_id == vendor_id, + Product.gtin.in_(gtins), + ) + ) + .all() + ) + products_by_gtin = {p.gtin: p for p in products if p.gtin} + + # Identify missing GTINs + missing_gtins = gtins - set(products_by_gtin.keys()) + placeholder = None + if missing_gtins or has_items_without_gtin: + placeholder = self._get_or_create_placeholder_product(db, vendor_id) + if missing_gtins: + logger.warning( + f"Order {order_number}: {len(missing_gtins)} product(s) not found. " + f"GTINs: {missing_gtins}. Using placeholder and creating exceptions." + ) + if has_items_without_gtin: + logger.warning( + f"Order {order_number}: Some items have no GTIN. " + f"Using placeholder and creating exceptions." + ) + + # Parse address data + ship_address = order_data.get("shipAddress", {}) or {} + bill_address = order_data.get("billAddress", {}) or {} + ship_country = ship_address.get("country", {}) or {} + bill_country = bill_address.get("country", {}) or {} + + # Extract customer info + customer_email = order_data.get("email", "") + ship_first_name = ship_address.get("firstName", "") or "" + ship_last_name = ship_address.get("lastName", "") or "" + + # Parse order date + order_date = datetime.now(UTC) + completed_at_str = order_data.get("completedAt") + if completed_at_str: + try: + if completed_at_str.endswith("Z"): + completed_at_str = completed_at_str[:-1] + "+00:00" + order_date = datetime.fromisoformat(completed_at_str) + except (ValueError, TypeError): + pass + + # Parse total amount + total_str = order_data.get("total", "0") + try: + total_euros = float(str(total_str).split()[0]) + total_amount_cents = euros_to_cents(total_euros) + except (ValueError, IndexError): + total_amount_cents = 0 + + # Map Letzshop state to status + letzshop_state = shipment_data.get("state", "unconfirmed") + status_mapping = { + "unconfirmed": "pending", + "confirmed": "processing", + "declined": "cancelled", + } + status = status_mapping.get(letzshop_state, "pending") + + # Parse tracking info if available + tracking_data = shipment_data.get("tracking") or {} + tracking_number = tracking_data.get("code") or tracking_data.get("number") + tracking_provider = tracking_data.get("provider") + if not tracking_provider and tracking_data.get("carrier"): + carrier = tracking_data.get("carrier", {}) + tracking_provider = carrier.get("code") or carrier.get("name") + + # Parse shipment number and carrier + shipment_number = shipment_data.get("number") + shipping_carrier = None + shipment_data_obj = shipment_data.get("data") or {} + if shipment_data_obj: + shipping_carrier = shipment_data_obj.get("__typename", "").lower() + + # Find or create customer (inactive) + customer = self.find_or_create_customer( + db=db, + vendor_id=vendor_id, + email=customer_email, + first_name=ship_first_name, + last_name=ship_last_name, + is_active=False, + ) + + # Create order + order = Order( + vendor_id=vendor_id, + customer_id=customer.id, + order_number=order_number, + channel="letzshop", + external_order_id=order_data.get("id"), + external_shipment_id=shipment_data.get("id"), + external_order_number=letzshop_order_number, + external_data=shipment_data, + status=status, + total_amount_cents=total_amount_cents, + currency="EUR", + customer_first_name=ship_first_name, + customer_last_name=ship_last_name, + customer_email=customer_email, + customer_locale=order_data.get("locale"), + ship_first_name=ship_first_name, + ship_last_name=ship_last_name, + ship_company=ship_address.get("company"), + ship_address_line_1=ship_address.get("streetName", "") or "", + ship_address_line_2=ship_address.get("streetNumber"), + ship_city=ship_address.get("city", "") or "", + ship_postal_code=ship_address.get("postalCode", "") or "", + ship_country_iso=ship_country.get("iso", "") or "", + bill_first_name=bill_address.get("firstName", "") or ship_first_name, + bill_last_name=bill_address.get("lastName", "") or ship_last_name, + bill_company=bill_address.get("company"), + bill_address_line_1=bill_address.get("streetName", "") or "", + bill_address_line_2=bill_address.get("streetNumber"), + bill_city=bill_address.get("city", "") or "", + bill_postal_code=bill_address.get("postalCode", "") or "", + bill_country_iso=bill_country.get("iso", "") or "", + order_date=order_date, + confirmed_at=datetime.now(UTC) if status == "processing" else None, + cancelled_at=datetime.now(UTC) if status == "cancelled" else None, + tracking_number=tracking_number, + tracking_provider=tracking_provider, + shipment_number=shipment_number, + shipping_carrier=shipping_carrier, + ) + + db.add(order) + db.flush() + + # Create order items from inventory units + exceptions_created = 0 + for unit in inventory_units: + variant = unit.get("variant", {}) or {} + product_info = variant.get("product", {}) or {} + trade_id = variant.get("tradeId", {}) or {} + product_name_dict = product_info.get("name", {}) or {} + + gtin = trade_id.get("number") + gtin_type = trade_id.get("parser") + + # Get product from map, or use placeholder if not found + product = products_by_gtin.get(gtin) + needs_product_match = False + + if not product: + product = placeholder + needs_product_match = True + + # Get product name from marketplace data + product_name = ( + product_name_dict.get("en") + or product_name_dict.get("fr") + or product_name_dict.get("de") + or str(product_name_dict) + ) + + # Get price + unit_price_cents = 0 + price_str = variant.get("price", "0") + try: + price_euros = float(str(price_str).split()[0]) + unit_price_cents = euros_to_cents(price_euros) + except (ValueError, IndexError): + pass + + item_state = unit.get("state") + + order_item = OrderItem( + order_id=order.id, + product_id=product.id if product else None, + product_name=product_name, + product_sku=variant.get("sku"), + gtin=gtin, + gtin_type=gtin_type, + quantity=1, + unit_price_cents=unit_price_cents, + total_price_cents=unit_price_cents, + external_item_id=unit.get("id"), + external_variant_id=variant.get("id"), + item_state=item_state, + needs_product_match=needs_product_match, + ) + db.add(order_item) + db.flush() + + # Create exception record for unmatched items + if needs_product_match: + order_item_exception_service.create_exception( + db=db, + order_item=order_item, + vendor_id=vendor_id, + original_gtin=gtin, + original_product_name=product_name, + original_sku=variant.get("sku"), + exception_type="product_not_found", + ) + exceptions_created += 1 + + db.flush() + db.refresh(order) + + if exceptions_created > 0: + logger.info( + f"Created {exceptions_created} order item exception(s) for " + f"order {order.order_number}" + ) + + # Increment order count for subscription tracking + subscription_service.increment_order_count(db, vendor_id) + + logger.info( + f"Letzshop order {order.order_number} created for vendor {vendor_id}, " + f"status: {status}, items: {len(inventory_units)}" + ) + + return order + + # ========================================================================= + # Order Retrieval + # ========================================================================= + + def get_order(self, db: Session, vendor_id: int, order_id: int) -> Order: + """Get order by ID within vendor scope.""" + order = ( + db.query(Order) + .filter(and_(Order.id == order_id, Order.vendor_id == vendor_id)) + .first() + ) + + if not order: + raise OrderNotFoundException(str(order_id)) + + return order + + def get_order_by_external_shipment_id( + self, + db: Session, + vendor_id: int, + shipment_id: str, + ) -> Order | None: + """Get order by external shipment ID (for Letzshop).""" + return ( + db.query(Order) + .filter( + and_( + Order.vendor_id == vendor_id, + Order.external_shipment_id == shipment_id, + ) + ) + .first() + ) + + def get_vendor_orders( + self, + db: Session, + vendor_id: int, + skip: int = 0, + limit: int = 50, + status: str | None = None, + channel: str | None = None, + search: str | None = None, + customer_id: int | None = None, + ) -> tuple[list[Order], int]: + """Get orders for vendor with filtering.""" + query = db.query(Order).filter(Order.vendor_id == vendor_id) + + if status: + query = query.filter(Order.status == status) + + if channel: + query = query.filter(Order.channel == channel) + + if customer_id: + query = query.filter(Order.customer_id == customer_id) + + if search: + search_term = f"%{search}%" + query = query.filter( + or_( + Order.order_number.ilike(search_term), + Order.external_order_number.ilike(search_term), + Order.customer_email.ilike(search_term), + Order.customer_first_name.ilike(search_term), + Order.customer_last_name.ilike(search_term), + ) + ) + + # Order by most recent first + query = query.order_by(Order.order_date.desc()) + + total = query.count() + orders = query.offset(skip).limit(limit).all() + + return orders, total + + def get_customer_orders( + self, + db: Session, + vendor_id: int, + customer_id: int, + skip: int = 0, + limit: int = 50, + ) -> tuple[list[Order], int]: + """Get orders for a specific customer.""" + return self.get_vendor_orders( + db=db, + vendor_id=vendor_id, + skip=skip, + limit=limit, + customer_id=customer_id, + ) + + def get_order_stats(self, db: Session, vendor_id: int) -> dict[str, int]: + """Get order counts by status for a vendor.""" + status_counts = ( + db.query(Order.status, func.count(Order.id).label("count")) + .filter(Order.vendor_id == vendor_id) + .group_by(Order.status) + .all() + ) + + stats = { + "pending": 0, + "processing": 0, + "partially_shipped": 0, + "shipped": 0, + "delivered": 0, + "cancelled": 0, + "refunded": 0, + "total": 0, + } + + for status, count in status_counts: + if status in stats: + stats[status] = count + stats["total"] += count + + # Also count by channel + channel_counts = ( + db.query(Order.channel, func.count(Order.id).label("count")) + .filter(Order.vendor_id == vendor_id) + .group_by(Order.channel) + .all() + ) + + for channel, count in channel_counts: + stats[f"{channel}_orders"] = count + + return stats + + # ========================================================================= + # Order Updates + # ========================================================================= + + def update_order_status( + self, + db: Session, + vendor_id: int, + order_id: int, + order_update: OrderUpdate, + ) -> Order: + """Update order status and tracking information.""" + from app.modules.orders.services.order_inventory_service import ( + order_inventory_service, + ) + + order = self.get_order(db, vendor_id, order_id) + + now = datetime.now(UTC) + old_status = order.status + + if order_update.status: + order.status = order_update.status + + # Update timestamps based on status + if order_update.status == "processing" and not order.confirmed_at: + order.confirmed_at = now + elif order_update.status == "shipped" and not order.shipped_at: + order.shipped_at = now + elif order_update.status == "delivered" and not order.delivered_at: + order.delivered_at = now + elif order_update.status == "cancelled" and not order.cancelled_at: + order.cancelled_at = now + + # Handle inventory operations based on status change + try: + inventory_result = order_inventory_service.handle_status_change( + db=db, + vendor_id=vendor_id, + order_id=order_id, + old_status=old_status, + new_status=order_update.status, + ) + if inventory_result: + logger.info( + f"Order {order.order_number} inventory update: " + f"{inventory_result.get('reserved_count', 0)} reserved, " + f"{inventory_result.get('fulfilled_count', 0)} fulfilled, " + f"{inventory_result.get('released_count', 0)} released" + ) + except Exception as e: + logger.warning( + f"Order {order.order_number} inventory operation failed: {e}" + ) + + if order_update.tracking_number: + order.tracking_number = order_update.tracking_number + + if order_update.tracking_provider: + order.tracking_provider = order_update.tracking_provider + + if order_update.internal_notes: + order.internal_notes = order_update.internal_notes + + order.updated_at = now + db.flush() + db.refresh(order) + + logger.info(f"Order {order.order_number} updated: status={order.status}") + + return order + + def set_order_tracking( + self, + db: Session, + vendor_id: int, + order_id: int, + tracking_number: str, + tracking_provider: str, + ) -> Order: + """Set tracking information and mark as shipped.""" + order = self.get_order(db, vendor_id, order_id) + + now = datetime.now(UTC) + order.tracking_number = tracking_number + order.tracking_provider = tracking_provider + order.status = "shipped" + order.shipped_at = now + order.updated_at = now + + db.flush() + db.refresh(order) + + logger.info( + f"Order {order.order_number} shipped: " + f"{tracking_provider} {tracking_number}" + ) + + return order + + def update_item_state( + self, + db: Session, + vendor_id: int, + order_id: int, + item_id: int, + state: str, + ) -> OrderItem: + """Update the state of an order item (for marketplace confirmation).""" + order = self.get_order(db, vendor_id, order_id) + + item = ( + db.query(OrderItem) + .filter( + and_( + OrderItem.id == item_id, + OrderItem.order_id == order.id, + ) + ) + .first() + ) + + if not item: + raise ValidationException(f"Order item {item_id} not found") + + item.item_state = state + item.updated_at = datetime.now(UTC) + + # Check if all items are processed + all_items = db.query(OrderItem).filter(OrderItem.order_id == order.id).all() + all_confirmed = all( + i.item_state in ("confirmed_available", "confirmed_unavailable") + for i in all_items + ) + + if all_confirmed: + has_available = any( + i.item_state == "confirmed_available" for i in all_items + ) + all_unavailable = all( + i.item_state == "confirmed_unavailable" for i in all_items + ) + + now = datetime.now(UTC) + if all_unavailable: + order.status = "cancelled" + order.cancelled_at = now + elif has_available: + order.status = "processing" + order.confirmed_at = now + + order.updated_at = now + + db.flush() + db.refresh(item) + + return item + + # ========================================================================= + # Admin Methods (cross-vendor) + # ========================================================================= + + def get_all_orders_admin( + self, + db: Session, + skip: int = 0, + limit: int = 50, + vendor_id: int | None = None, + status: str | None = None, + channel: str | None = None, + search: str | None = None, + ) -> tuple[list[dict], int]: + """Get orders across all vendors for admin.""" + query = db.query(Order).join(Vendor) + + if vendor_id: + query = query.filter(Order.vendor_id == vendor_id) + + if status: + query = query.filter(Order.status == status) + + if channel: + query = query.filter(Order.channel == channel) + + if search: + search_term = f"%{search}%" + query = query.filter( + or_( + Order.order_number.ilike(search_term), + Order.external_order_number.ilike(search_term), + Order.customer_email.ilike(search_term), + Order.customer_first_name.ilike(search_term), + Order.customer_last_name.ilike(search_term), + ) + ) + + query = query.order_by(Order.order_date.desc()) + + total = query.count() + orders = query.offset(skip).limit(limit).all() + + result = [] + for order in orders: + item_count = len(order.items) if order.items else 0 + + result.append( + { + "id": order.id, + "vendor_id": order.vendor_id, + "vendor_name": order.vendor.name if order.vendor else None, + "vendor_code": order.vendor.vendor_code if order.vendor else None, + "customer_id": order.customer_id, + "customer_full_name": order.customer_full_name, + "customer_email": order.customer_email, + "order_number": order.order_number, + "channel": order.channel, + "status": order.status, + "external_order_number": order.external_order_number, + "external_shipment_id": order.external_shipment_id, + "subtotal": order.subtotal, + "tax_amount": order.tax_amount, + "shipping_amount": order.shipping_amount, + "discount_amount": order.discount_amount, + "total_amount": order.total_amount, + "currency": order.currency, + "ship_country_iso": order.ship_country_iso, + "tracking_number": order.tracking_number, + "tracking_provider": order.tracking_provider, + "item_count": item_count, + "order_date": order.order_date, + "confirmed_at": order.confirmed_at, + "shipped_at": order.shipped_at, + "delivered_at": order.delivered_at, + "cancelled_at": order.cancelled_at, + "created_at": order.created_at, + "updated_at": order.updated_at, + } + ) + + return result, total + + def get_order_stats_admin(self, db: Session) -> dict: + """Get platform-wide order statistics.""" + # Get status counts + status_counts = ( + db.query(Order.status, func.count(Order.id)) + .group_by(Order.status) + .all() + ) + + stats = { + "total_orders": 0, + "pending_orders": 0, + "processing_orders": 0, + "partially_shipped_orders": 0, + "shipped_orders": 0, + "delivered_orders": 0, + "cancelled_orders": 0, + "refunded_orders": 0, + "total_revenue": 0.0, + "direct_orders": 0, + "letzshop_orders": 0, + "vendors_with_orders": 0, + } + + for status, count in status_counts: + stats["total_orders"] += count + key = f"{status}_orders" + if key in stats: + stats[key] = count + + # Get channel counts + channel_counts = ( + db.query(Order.channel, func.count(Order.id)) + .group_by(Order.channel) + .all() + ) + + for channel, count in channel_counts: + key = f"{channel}_orders" + if key in stats: + stats[key] = count + + # Get total revenue + revenue_cents = ( + db.query(func.sum(Order.total_amount_cents)) + .filter(Order.status == "delivered") + .scalar() + ) + stats["total_revenue"] = cents_to_euros(revenue_cents) if revenue_cents else 0.0 + + # Count vendors with orders + vendors_count = ( + db.query(func.count(func.distinct(Order.vendor_id))).scalar() or 0 + ) + stats["vendors_with_orders"] = vendors_count + + return stats + + def get_order_by_id_admin(self, db: Session, order_id: int) -> Order: + """Get order by ID without vendor scope (admin only).""" + order = db.query(Order).filter(Order.id == order_id).first() + + if not order: + raise OrderNotFoundException(str(order_id)) + + return order + + def get_vendors_with_orders_admin(self, db: Session) -> list[dict]: + """Get list of vendors that have orders (admin only).""" + results = ( + db.query( + Vendor.id, + Vendor.name, + Vendor.vendor_code, + func.count(Order.id).label("order_count"), + ) + .join(Order, Order.vendor_id == Vendor.id) + .group_by(Vendor.id, Vendor.name, Vendor.vendor_code) + .order_by(func.count(Order.id).desc()) + .all() + ) + + return [ + { + "id": row.id, + "name": row.name, + "vendor_code": row.vendor_code, + "order_count": row.order_count, + } + for row in results + ] + + def mark_as_shipped_admin( + self, + db: Session, + order_id: int, + tracking_number: str | None = None, + tracking_url: str | None = None, + shipping_carrier: str | None = None, + ) -> Order: + """Mark an order as shipped with optional tracking info (admin only).""" + order = db.query(Order).filter(Order.id == order_id).first() + + if not order: + raise OrderNotFoundException(str(order_id)) + + order.status = "shipped" + order.shipped_at = datetime.now(UTC) + order.updated_at = datetime.now(UTC) + + if tracking_number: + order.tracking_number = tracking_number + if tracking_url: + order.tracking_url = tracking_url + if shipping_carrier: + order.shipping_carrier = shipping_carrier + + logger.info( + f"Order {order.order_number} marked as shipped. " + f"Tracking: {tracking_number or 'N/A'}, Carrier: {shipping_carrier or 'N/A'}" + ) + + return order + + def get_shipping_label_info_admin( + self, + db: Session, + order_id: int, + ) -> dict[str, Any]: + """Get shipping label information for an order (admin only).""" + from app.services.admin_settings_service import admin_settings_service # noqa: MOD-004 + + order = db.query(Order).filter(Order.id == order_id).first() + + if not order: + raise OrderNotFoundException(str(order_id)) + + label_url = None + carrier = order.shipping_carrier + + # Generate label URL based on carrier + if order.shipment_number and carrier: + setting_key = f"carrier_{carrier}_label_url" + prefix = admin_settings_service.get_setting_value(db, setting_key) + + if prefix: + label_url = prefix + order.shipment_number + + return { + "shipment_number": order.shipment_number, + "shipping_carrier": carrier, + "label_url": label_url, + "tracking_number": order.tracking_number, + "tracking_url": order.tracking_url, + } + + +# Create service instance +order_service = OrderService() diff --git a/app/services/admin_customer_service.py b/app/services/admin_customer_service.py index 31912e06..b0ed0474 100644 --- a/app/services/admin_customer_service.py +++ b/app/services/admin_customer_service.py @@ -1,242 +1,23 @@ # app/services/admin_customer_service.py """ -Admin customer management service. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Handles customer operations for admin users across all vendors. +The canonical implementation is now in: + app/modules/customers/services/admin_customer_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.customers.services import admin_customer_service """ -import logging -from typing import Any +from app.modules.customers.services.admin_customer_service import ( + admin_customer_service, + AdminCustomerService, +) -from sqlalchemy import func -from sqlalchemy.orm import Session - -from app.exceptions.customer import CustomerNotFoundException -from models.database.customer import Customer -from models.database.vendor import Vendor - -logger = logging.getLogger(__name__) - - -class AdminCustomerService: - """Service for admin-level customer management across vendors.""" - - def list_customers( - self, - db: Session, - vendor_id: int | None = None, - search: str | None = None, - is_active: bool | None = None, - skip: int = 0, - limit: int = 20, - ) -> tuple[list[dict[str, Any]], int]: - """ - Get paginated list of customers across all vendors. - - Args: - db: Database session - vendor_id: Optional vendor ID filter - search: Search by email, name, or customer number - is_active: Filter by active status - skip: Number of records to skip - limit: Maximum records to return - - Returns: - Tuple of (customers list, total count) - """ - # Build query - query = db.query(Customer).join(Vendor, Customer.vendor_id == Vendor.id) - - # Apply filters - if vendor_id: - query = query.filter(Customer.vendor_id == vendor_id) - - if search: - search_term = f"%{search}%" - query = query.filter( - (Customer.email.ilike(search_term)) - | (Customer.first_name.ilike(search_term)) - | (Customer.last_name.ilike(search_term)) - | (Customer.customer_number.ilike(search_term)) - ) - - if is_active is not None: - query = query.filter(Customer.is_active == is_active) - - # Get total count - total = query.count() - - # Get paginated results with vendor info - customers = ( - query.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code) - .order_by(Customer.created_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) - - # Format response - result = [] - for row in customers: - customer = row[0] - vendor_name = row[1] - vendor_code = row[2] - - customer_dict = { - "id": customer.id, - "vendor_id": customer.vendor_id, - "email": customer.email, - "first_name": customer.first_name, - "last_name": customer.last_name, - "phone": customer.phone, - "customer_number": customer.customer_number, - "marketing_consent": customer.marketing_consent, - "preferred_language": customer.preferred_language, - "last_order_date": customer.last_order_date, - "total_orders": customer.total_orders, - "total_spent": float(customer.total_spent) if customer.total_spent else 0, - "is_active": customer.is_active, - "created_at": customer.created_at, - "updated_at": customer.updated_at, - "vendor_name": vendor_name, - "vendor_code": vendor_code, - } - result.append(customer_dict) - - return result, total - - def get_customer_stats( - self, - db: Session, - vendor_id: int | None = None, - ) -> dict[str, Any]: - """ - Get customer statistics. - - Args: - db: Database session - vendor_id: Optional vendor ID filter - - Returns: - Dict with customer statistics - """ - query = db.query(Customer) - - if vendor_id: - query = query.filter(Customer.vendor_id == vendor_id) - - total = query.count() - active = query.filter(Customer.is_active == True).count() # noqa: E712 - inactive = query.filter(Customer.is_active == False).count() # noqa: E712 - with_orders = query.filter(Customer.total_orders > 0).count() - - # Total spent across all customers - total_spent_result = query.with_entities(func.sum(Customer.total_spent)).scalar() - total_spent = float(total_spent_result) if total_spent_result else 0 - - # Average order value - total_orders_result = query.with_entities(func.sum(Customer.total_orders)).scalar() - total_orders = int(total_orders_result) if total_orders_result else 0 - avg_order_value = total_spent / total_orders if total_orders > 0 else 0 - - return { - "total": total, - "active": active, - "inactive": inactive, - "with_orders": with_orders, - "total_spent": total_spent, - "total_orders": total_orders, - "avg_order_value": round(avg_order_value, 2), - } - - def get_customer( - self, - db: Session, - customer_id: int, - ) -> dict[str, Any]: - """ - Get customer details by ID. - - Args: - db: Database session - customer_id: Customer ID - - Returns: - Customer dict with vendor info - - Raises: - CustomerNotFoundException: If customer not found - """ - result = ( - db.query(Customer) - .join(Vendor, Customer.vendor_id == Vendor.id) - .add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code) - .filter(Customer.id == customer_id) - .first() - ) - - if not result: - raise CustomerNotFoundException(str(customer_id)) - - customer = result[0] - return { - "id": customer.id, - "vendor_id": customer.vendor_id, - "email": customer.email, - "first_name": customer.first_name, - "last_name": customer.last_name, - "phone": customer.phone, - "customer_number": customer.customer_number, - "marketing_consent": customer.marketing_consent, - "preferred_language": customer.preferred_language, - "last_order_date": customer.last_order_date, - "total_orders": customer.total_orders, - "total_spent": float(customer.total_spent) if customer.total_spent else 0, - "is_active": customer.is_active, - "created_at": customer.created_at, - "updated_at": customer.updated_at, - "vendor_name": result[1], - "vendor_code": result[2], - } - - def toggle_customer_status( - self, - db: Session, - customer_id: int, - admin_email: str, - ) -> dict[str, Any]: - """ - Toggle customer active status. - - Args: - db: Database session - customer_id: Customer ID - admin_email: Admin user email for logging - - Returns: - Dict with customer ID, new status, and message - - Raises: - CustomerNotFoundException: If customer not found - """ - customer = db.query(Customer).filter(Customer.id == customer_id).first() - - if not customer: - raise CustomerNotFoundException(str(customer_id)) - - customer.is_active = not customer.is_active - db.flush() - db.refresh(customer) - - status = "activated" if customer.is_active else "deactivated" - logger.info(f"Customer {customer.email} {status} by admin {admin_email}") - - return { - "id": customer.id, - "is_active": customer.is_active, - "message": f"Customer {status} successfully", - } - - -# Singleton instance -admin_customer_service = AdminCustomerService() +__all__ = [ + "admin_customer_service", + "AdminCustomerService", +] diff --git a/app/services/admin_notification_service.py b/app/services/admin_notification_service.py index 5daf9ee1..4c84a183 100644 --- a/app/services/admin_notification_service.py +++ b/app/services/admin_notification_service.py @@ -1,701 +1,37 @@ # app/services/admin_notification_service.py """ -Admin notification service. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Provides functionality for: -- Creating and managing admin notifications -- Managing platform alerts -- Notification statistics and queries +The canonical implementation is now in: + app/modules/messaging/services/admin_notification_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.messaging.services import admin_notification_service """ -import logging -from datetime import datetime, timedelta -from typing import Any - -from sqlalchemy import and_, case, func, or_ -from sqlalchemy.orm import Session - -from models.database.admin import AdminNotification, PlatformAlert -from models.schema.admin import AdminNotificationCreate, PlatformAlertCreate - -logger = logging.getLogger(__name__) - - -# ============================================================================ -# NOTIFICATION TYPES -# ============================================================================ - -class NotificationType: - """Notification type constants.""" - - SYSTEM_ALERT = "system_alert" - IMPORT_FAILURE = "import_failure" - EXPORT_FAILURE = "export_failure" - ORDER_SYNC_FAILURE = "order_sync_failure" - VENDOR_ISSUE = "vendor_issue" - CUSTOMER_MESSAGE = "customer_message" - VENDOR_MESSAGE = "vendor_message" - SECURITY_ALERT = "security_alert" - PERFORMANCE_ALERT = "performance_alert" - ORDER_EXCEPTION = "order_exception" - CRITICAL_ERROR = "critical_error" - - -class Priority: - """Priority level constants.""" - - LOW = "low" - NORMAL = "normal" - HIGH = "high" - CRITICAL = "critical" - - -class AlertType: - """Platform alert type constants.""" - - SECURITY = "security" - PERFORMANCE = "performance" - CAPACITY = "capacity" - INTEGRATION = "integration" - DATABASE = "database" - SYSTEM = "system" - - -class Severity: - """Alert severity constants.""" - - INFO = "info" - WARNING = "warning" - ERROR = "error" - CRITICAL = "critical" - - -# ============================================================================ -# ADMIN NOTIFICATION SERVICE -# ============================================================================ - - -class AdminNotificationService: - """Service for managing admin notifications.""" - - def create_notification( - self, - db: Session, - notification_type: str, - title: str, - message: str, - priority: str = Priority.NORMAL, - action_required: bool = False, - action_url: str | None = None, - metadata: dict[str, Any] | None = None, - ) -> AdminNotification: - """ - Create a new admin notification. - - Args: - db: Database session - notification_type: Type of notification - title: Notification title - message: Notification message - priority: Priority level (low, normal, high, critical) - action_required: Whether action is required - action_url: URL to relevant admin page - metadata: Additional contextual data - - Returns: - Created AdminNotification - """ - notification = AdminNotification( - type=notification_type, - title=title, - message=message, - priority=priority, - action_required=action_required, - action_url=action_url, - notification_metadata=metadata, - ) - db.add(notification) - db.flush() - - logger.info( - f"Created notification: {notification_type} - {title} (priority: {priority})" - ) - - return notification - - def create_from_schema( - self, - db: Session, - data: AdminNotificationCreate, - ) -> AdminNotification: - """Create notification from Pydantic schema.""" - return self.create_notification( - db=db, - notification_type=data.type, - title=data.title, - message=data.message, - priority=data.priority, - action_required=data.action_required, - action_url=data.action_url, - metadata=data.metadata, - ) - - def get_notifications( - self, - db: Session, - priority: str | None = None, - is_read: bool | None = None, - notification_type: str | None = None, - skip: int = 0, - limit: int = 50, - ) -> tuple[list[AdminNotification], int, int]: - """ - Get paginated admin notifications. - - Returns: - Tuple of (notifications, total_count, unread_count) - """ - query = db.query(AdminNotification) - - # Apply filters - if priority: - query = query.filter(AdminNotification.priority == priority) - - if is_read is not None: - query = query.filter(AdminNotification.is_read == is_read) - - if notification_type: - query = query.filter(AdminNotification.type == notification_type) - - # Get counts - total = query.count() - unread_count = ( - db.query(AdminNotification) - .filter(AdminNotification.is_read == False) # noqa: E712 - .count() - ) - - # Get paginated results ordered by priority and date - priority_order = case( - (AdminNotification.priority == "critical", 1), - (AdminNotification.priority == "high", 2), - (AdminNotification.priority == "normal", 3), - (AdminNotification.priority == "low", 4), - else_=5, - ) - - notifications = ( - query.order_by( - AdminNotification.is_read, # Unread first - priority_order, - AdminNotification.created_at.desc(), - ) - .offset(skip) - .limit(limit) - .all() - ) - - return notifications, total, unread_count - - def get_unread_count(self, db: Session) -> int: - """Get count of unread notifications.""" - return ( - db.query(AdminNotification) - .filter(AdminNotification.is_read == False) # noqa: E712 - .count() - ) - - def get_recent_notifications( - self, - db: Session, - limit: int = 5, - ) -> list[AdminNotification]: - """Get recent unread notifications for header dropdown.""" - priority_order = case( - (AdminNotification.priority == "critical", 1), - (AdminNotification.priority == "high", 2), - (AdminNotification.priority == "normal", 3), - (AdminNotification.priority == "low", 4), - else_=5, - ) - - return ( - db.query(AdminNotification) - .filter(AdminNotification.is_read == False) # noqa: E712 - .order_by(priority_order, AdminNotification.created_at.desc()) - .limit(limit) - .all() - ) - - def mark_as_read( - self, - db: Session, - notification_id: int, - user_id: int, - ) -> AdminNotification | None: - """Mark a notification as read.""" - notification = ( - db.query(AdminNotification) - .filter(AdminNotification.id == notification_id) - .first() - ) - - if notification and not notification.is_read: - notification.is_read = True - notification.read_at = datetime.utcnow() - notification.read_by_user_id = user_id - db.flush() - - return notification - - def mark_all_as_read( - self, - db: Session, - user_id: int, - ) -> int: - """Mark all unread notifications as read. Returns count of updated.""" - now = datetime.utcnow() - count = ( - db.query(AdminNotification) - .filter(AdminNotification.is_read == False) # noqa: E712 - .update( - { - AdminNotification.is_read: True, - AdminNotification.read_at: now, - AdminNotification.read_by_user_id: user_id, - } - ) - ) - db.flush() - return count - - def delete_old_notifications( - self, - db: Session, - days: int = 30, - ) -> int: - """Delete notifications older than specified days.""" - cutoff = datetime.utcnow() - timedelta(days=days) - count = ( - db.query(AdminNotification) - .filter( - and_( - AdminNotification.is_read == True, # noqa: E712 - AdminNotification.created_at < cutoff, - ) - ) - .delete() - ) - db.flush() - return count - - def delete_notification( - self, - db: Session, - notification_id: int, - ) -> bool: - """ - Delete a notification by ID. - - Returns: - True if notification was deleted, False if not found. - """ - notification = ( - db.query(AdminNotification) - .filter(AdminNotification.id == notification_id) - .first() - ) - - if notification: - db.delete(notification) - db.flush() - logger.info(f"Deleted notification {notification_id}") - return True - - return False - - # ========================================================================= - # CONVENIENCE METHODS FOR CREATING SPECIFIC NOTIFICATION TYPES - # ========================================================================= - - def notify_import_failure( - self, - db: Session, - vendor_name: str, - job_id: int, - error_message: str, - vendor_id: int | None = None, - ) -> AdminNotification: - """Create notification for import job failure.""" - return self.create_notification( - db=db, - notification_type=NotificationType.IMPORT_FAILURE, - title=f"Import Failed: {vendor_name}", - message=error_message, - priority=Priority.HIGH, - action_required=True, - action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs" - if vendor_id - else "/admin/marketplace", - metadata={"vendor_name": vendor_name, "job_id": job_id, "vendor_id": vendor_id}, - ) - - def notify_order_sync_failure( - self, - db: Session, - vendor_name: str, - error_message: str, - vendor_id: int | None = None, - ) -> AdminNotification: - """Create notification for order sync failure.""" - return self.create_notification( - db=db, - notification_type=NotificationType.ORDER_SYNC_FAILURE, - title=f"Order Sync Failed: {vendor_name}", - message=error_message, - priority=Priority.HIGH, - action_required=True, - action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs" - if vendor_id - else "/admin/marketplace/letzshop", - metadata={"vendor_name": vendor_name, "vendor_id": vendor_id}, - ) - - def notify_order_exception( - self, - db: Session, - vendor_name: str, - order_number: str, - exception_count: int, - vendor_id: int | None = None, - ) -> AdminNotification: - """Create notification for order item exceptions.""" - return self.create_notification( - db=db, - notification_type=NotificationType.ORDER_EXCEPTION, - title=f"Order Exception: {order_number}", - message=f"{exception_count} item(s) need attention for order {order_number} ({vendor_name})", - priority=Priority.NORMAL, - action_required=True, - action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=exceptions" - if vendor_id - else "/admin/marketplace/letzshop", - metadata={ - "vendor_name": vendor_name, - "order_number": order_number, - "exception_count": exception_count, - "vendor_id": vendor_id, - }, - ) - - def notify_critical_error( - self, - db: Session, - error_type: str, - error_message: str, - details: dict[str, Any] | None = None, - ) -> AdminNotification: - """Create notification for critical application errors.""" - return self.create_notification( - db=db, - notification_type=NotificationType.CRITICAL_ERROR, - title=f"Critical Error: {error_type}", - message=error_message, - priority=Priority.CRITICAL, - action_required=True, - action_url="/admin/logs", - metadata=details, - ) - - def notify_vendor_issue( - self, - db: Session, - vendor_name: str, - issue_type: str, - message: str, - vendor_id: int | None = None, - ) -> AdminNotification: - """Create notification for vendor-related issues.""" - return self.create_notification( - db=db, - notification_type=NotificationType.VENDOR_ISSUE, - title=f"Vendor Issue: {vendor_name}", - message=message, - priority=Priority.HIGH, - action_required=True, - action_url=f"/admin/vendors/{vendor_id}" if vendor_id else "/admin/vendors", - metadata={ - "vendor_name": vendor_name, - "issue_type": issue_type, - "vendor_id": vendor_id, - }, - ) - - def notify_security_alert( - self, - db: Session, - title: str, - message: str, - details: dict[str, Any] | None = None, - ) -> AdminNotification: - """Create notification for security-related alerts.""" - return self.create_notification( - db=db, - notification_type=NotificationType.SECURITY_ALERT, - title=title, - message=message, - priority=Priority.CRITICAL, - action_required=True, - action_url="/admin/audit", - metadata=details, - ) - - -# ============================================================================ -# PLATFORM ALERT SERVICE -# ============================================================================ - - -class PlatformAlertService: - """Service for managing platform-wide alerts.""" - - def create_alert( - self, - db: Session, - alert_type: str, - severity: str, - title: str, - description: str | None = None, - affected_vendors: list[int] | None = None, - affected_systems: list[str] | None = None, - auto_generated: bool = True, - ) -> PlatformAlert: - """Create a new platform alert.""" - now = datetime.utcnow() - - alert = PlatformAlert( - alert_type=alert_type, - severity=severity, - title=title, - description=description, - affected_vendors=affected_vendors, - affected_systems=affected_systems, - auto_generated=auto_generated, - first_occurred_at=now, - last_occurred_at=now, - ) - db.add(alert) - db.flush() - - logger.info(f"Created platform alert: {alert_type} - {title} ({severity})") - - return alert - - def create_from_schema( - self, - db: Session, - data: PlatformAlertCreate, - ) -> PlatformAlert: - """Create alert from Pydantic schema.""" - return self.create_alert( - db=db, - alert_type=data.alert_type, - severity=data.severity, - title=data.title, - description=data.description, - affected_vendors=data.affected_vendors, - affected_systems=data.affected_systems, - auto_generated=data.auto_generated, - ) - - def get_alerts( - self, - db: Session, - severity: str | None = None, - alert_type: str | None = None, - is_resolved: bool | None = None, - skip: int = 0, - limit: int = 50, - ) -> tuple[list[PlatformAlert], int, int, int]: - """ - Get paginated platform alerts. - - Returns: - Tuple of (alerts, total_count, active_count, critical_count) - """ - query = db.query(PlatformAlert) - - # Apply filters - if severity: - query = query.filter(PlatformAlert.severity == severity) - - if alert_type: - query = query.filter(PlatformAlert.alert_type == alert_type) - - if is_resolved is not None: - query = query.filter(PlatformAlert.is_resolved == is_resolved) - - # Get counts - total = query.count() - active_count = ( - db.query(PlatformAlert) - .filter(PlatformAlert.is_resolved == False) # noqa: E712 - .count() - ) - critical_count = ( - db.query(PlatformAlert) - .filter( - and_( - PlatformAlert.is_resolved == False, # noqa: E712 - PlatformAlert.severity == Severity.CRITICAL, - ) - ) - .count() - ) - - # Get paginated results - severity_order = case( - (PlatformAlert.severity == "critical", 1), - (PlatformAlert.severity == "error", 2), - (PlatformAlert.severity == "warning", 3), - (PlatformAlert.severity == "info", 4), - else_=5, - ) - - alerts = ( - query.order_by( - PlatformAlert.is_resolved, # Unresolved first - severity_order, - PlatformAlert.last_occurred_at.desc(), - ) - .offset(skip) - .limit(limit) - .all() - ) - - return alerts, total, active_count, critical_count - - def resolve_alert( - self, - db: Session, - alert_id: int, - user_id: int, - resolution_notes: str | None = None, - ) -> PlatformAlert | None: - """Resolve a platform alert.""" - alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first() - - if alert and not alert.is_resolved: - alert.is_resolved = True - alert.resolved_at = datetime.utcnow() - alert.resolved_by_user_id = user_id - alert.resolution_notes = resolution_notes - db.flush() - - logger.info(f"Resolved platform alert {alert_id}") - - return alert - - def get_statistics(self, db: Session) -> dict[str, int]: - """Get alert statistics.""" - today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) - - total = db.query(PlatformAlert).count() - active = ( - db.query(PlatformAlert) - .filter(PlatformAlert.is_resolved == False) # noqa: E712 - .count() - ) - critical = ( - db.query(PlatformAlert) - .filter( - and_( - PlatformAlert.is_resolved == False, # noqa: E712 - PlatformAlert.severity == Severity.CRITICAL, - ) - ) - .count() - ) - resolved_today = ( - db.query(PlatformAlert) - .filter( - and_( - PlatformAlert.is_resolved == True, # noqa: E712 - PlatformAlert.resolved_at >= today_start, - ) - ) - .count() - ) - - return { - "total_alerts": total, - "active_alerts": active, - "critical_alerts": critical, - "resolved_today": resolved_today, - } - - def increment_occurrence( - self, - db: Session, - alert_id: int, - ) -> PlatformAlert | None: - """Increment occurrence count for repeated alert.""" - alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first() - - if alert: - alert.occurrence_count += 1 - alert.last_occurred_at = datetime.utcnow() - db.flush() - - return alert - - def find_similar_active_alert( - self, - db: Session, - alert_type: str, - title: str, - ) -> PlatformAlert | None: - """Find an active alert with same type and title.""" - return ( - db.query(PlatformAlert) - .filter( - and_( - PlatformAlert.alert_type == alert_type, - PlatformAlert.title == title, - PlatformAlert.is_resolved == False, # noqa: E712 - ) - ) - .first() - ) - - def create_or_increment_alert( - self, - db: Session, - alert_type: str, - severity: str, - title: str, - description: str | None = None, - affected_vendors: list[int] | None = None, - affected_systems: list[str] | None = None, - ) -> PlatformAlert: - """Create alert or increment occurrence if similar exists.""" - existing = self.find_similar_active_alert(db, alert_type, title) - - if existing: - self.increment_occurrence(db, existing.id) - return existing - - return self.create_alert( - db=db, - alert_type=alert_type, - severity=severity, - title=title, - description=description, - affected_vendors=affected_vendors, - affected_systems=affected_systems, - ) - - -# Singleton instances -admin_notification_service = AdminNotificationService() -platform_alert_service = PlatformAlertService() +from app.modules.messaging.services.admin_notification_service import ( + admin_notification_service, + AdminNotificationService, + platform_alert_service, + PlatformAlertService, + # Constants + NotificationType, + Priority, + AlertType, + Severity, +) + +__all__ = [ + "admin_notification_service", + "AdminNotificationService", + "platform_alert_service", + "PlatformAlertService", + # Constants + "NotificationType", + "Priority", + "AlertType", + "Severity", +] diff --git a/app/services/background_tasks_service.py b/app/services/background_tasks_service.py index a41f8fc7..51538849 100644 --- a/app/services/background_tasks_service.py +++ b/app/services/background_tasks_service.py @@ -1,194 +1,23 @@ # app/services/background_tasks_service.py """ -Background Tasks Service -Service for monitoring background tasks across the system +LEGACY LOCATION - Re-exports from module for backwards compatibility. + +The canonical implementation is now in: + app/modules/monitoring/services/background_tasks_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.monitoring.services import background_tasks_service """ -from datetime import UTC, datetime +from app.modules.monitoring.services.background_tasks_service import ( + background_tasks_service, + BackgroundTasksService, +) -from sqlalchemy import case, desc, func -from sqlalchemy.orm import Session - -from models.database.architecture_scan import ArchitectureScan -from models.database.marketplace_import_job import MarketplaceImportJob -from models.database.test_run import TestRun - - -class BackgroundTasksService: - """Service for monitoring background tasks""" - - def get_import_jobs( - self, db: Session, status: str | None = None, limit: int = 50 - ) -> list[MarketplaceImportJob]: - """Get import jobs with optional status filter""" - query = db.query(MarketplaceImportJob) - if status: - query = query.filter(MarketplaceImportJob.status == status) - return query.order_by(desc(MarketplaceImportJob.created_at)).limit(limit).all() - - def get_test_runs( - self, db: Session, status: str | None = None, limit: int = 50 - ) -> list[TestRun]: - """Get test runs with optional status filter""" - query = db.query(TestRun) - if status: - query = query.filter(TestRun.status == status) - return query.order_by(desc(TestRun.timestamp)).limit(limit).all() - - def get_running_imports(self, db: Session) -> list[MarketplaceImportJob]: - """Get currently running import jobs""" - return ( - db.query(MarketplaceImportJob) - .filter(MarketplaceImportJob.status == "processing") - .all() - ) - - def get_running_test_runs(self, db: Session) -> list[TestRun]: - """Get currently running test runs""" - # noqa: SVC-005 - Platform-level, TestRuns not vendor-scoped - return db.query(TestRun).filter(TestRun.status == "running").all() - - def get_import_stats(self, db: Session) -> dict: - """Get import job statistics""" - today_start = datetime.now(UTC).replace( - hour=0, minute=0, second=0, microsecond=0 - ) - - stats = db.query( - func.count(MarketplaceImportJob.id).label("total"), - func.sum( - case((MarketplaceImportJob.status == "processing", 1), else_=0) - ).label("running"), - func.sum( - case( - ( - MarketplaceImportJob.status.in_( - ["completed", "completed_with_errors"] - ), - 1, - ), - else_=0, - ) - ).label("completed"), - func.sum( - case((MarketplaceImportJob.status == "failed", 1), else_=0) - ).label("failed"), - ).first() - - today_count = ( - db.query(func.count(MarketplaceImportJob.id)) - .filter(MarketplaceImportJob.created_at >= today_start) - .scalar() - or 0 - ) - - return { - "total": stats.total or 0, - "running": stats.running or 0, - "completed": stats.completed or 0, - "failed": stats.failed or 0, - "today": today_count, - } - - def get_test_run_stats(self, db: Session) -> dict: - """Get test run statistics""" - today_start = datetime.now(UTC).replace( - hour=0, minute=0, second=0, microsecond=0 - ) - - stats = db.query( - func.count(TestRun.id).label("total"), - func.sum(case((TestRun.status == "running", 1), else_=0)).label( - "running" - ), - func.sum(case((TestRun.status == "passed", 1), else_=0)).label( - "completed" - ), - func.sum( - case((TestRun.status.in_(["failed", "error"]), 1), else_=0) - ).label("failed"), - func.avg(TestRun.duration_seconds).label("avg_duration"), - ).first() - - today_count = ( - db.query(func.count(TestRun.id)) - .filter(TestRun.timestamp >= today_start) - .scalar() - or 0 - ) - - return { - "total": stats.total or 0, - "running": stats.running or 0, - "completed": stats.completed or 0, - "failed": stats.failed or 0, - "today": today_count, - "avg_duration": round(stats.avg_duration or 0, 1), - } - - def get_code_quality_scans( - self, db: Session, status: str | None = None, limit: int = 50 - ) -> list[ArchitectureScan]: - """Get code quality scans with optional status filter""" - query = db.query(ArchitectureScan) - if status: - query = query.filter(ArchitectureScan.status == status) - return query.order_by(desc(ArchitectureScan.timestamp)).limit(limit).all() - - def get_running_scans(self, db: Session) -> list[ArchitectureScan]: - """Get currently running code quality scans""" - return ( - db.query(ArchitectureScan) - .filter(ArchitectureScan.status.in_(["pending", "running"])) - .all() - ) - - def get_scan_stats(self, db: Session) -> dict: - """Get code quality scan statistics""" - today_start = datetime.now(UTC).replace( - hour=0, minute=0, second=0, microsecond=0 - ) - - stats = db.query( - func.count(ArchitectureScan.id).label("total"), - func.sum( - case( - (ArchitectureScan.status.in_(["pending", "running"]), 1), else_=0 - ) - ).label("running"), - func.sum( - case( - ( - ArchitectureScan.status.in_( - ["completed", "completed_with_warnings"] - ), - 1, - ), - else_=0, - ) - ).label("completed"), - func.sum( - case((ArchitectureScan.status == "failed", 1), else_=0) - ).label("failed"), - func.avg(ArchitectureScan.duration_seconds).label("avg_duration"), - ).first() - - today_count = ( - db.query(func.count(ArchitectureScan.id)) - .filter(ArchitectureScan.timestamp >= today_start) - .scalar() - or 0 - ) - - return { - "total": stats.total or 0, - "running": stats.running or 0, - "completed": stats.completed or 0, - "failed": stats.failed or 0, - "today": today_count, - "avg_duration": round(stats.avg_duration or 0, 1), - } - - -# Singleton instance -background_tasks_service = BackgroundTasksService() +__all__ = [ + "background_tasks_service", + "BackgroundTasksService", +] diff --git a/app/services/billing_service.py b/app/services/billing_service.py index 5c2da065..b3fd6ace 100644 --- a/app/services/billing_service.py +++ b/app/services/billing_service.py @@ -1,588 +1,35 @@ # app/services/billing_service.py """ -Billing service for subscription and payment operations. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Provides: -- Subscription status and usage queries -- Tier management -- Invoice history -- Add-on management +The canonical implementation is now in: + app/modules/billing/services/billing_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.billing.services import billing_service """ -import logging -from datetime import datetime - -from sqlalchemy.orm import Session - -from app.services.stripe_service import stripe_service -from app.services.subscription_service import subscription_service -from models.database.subscription import ( - AddOnProduct, - BillingHistory, - SubscriptionTier, - VendorAddOn, - VendorSubscription, +from app.modules.billing.services.billing_service import ( + BillingService, + billing_service, + BillingServiceError, + PaymentSystemNotConfiguredError, + TierNotFoundError, + StripePriceNotConfiguredError, + NoActiveSubscriptionError, + SubscriptionNotCancelledError, ) -from models.database.vendor import Vendor -logger = logging.getLogger(__name__) - - -class BillingServiceError(Exception): - """Base exception for billing service errors.""" - - pass - - -class PaymentSystemNotConfiguredError(BillingServiceError): - """Raised when Stripe is not configured.""" - - def __init__(self): - super().__init__("Payment system not configured") - - -class TierNotFoundError(BillingServiceError): - """Raised when a tier is not found.""" - - def __init__(self, tier_code: str): - self.tier_code = tier_code - super().__init__(f"Tier '{tier_code}' not found") - - -class StripePriceNotConfiguredError(BillingServiceError): - """Raised when Stripe price is not configured for a tier.""" - - def __init__(self, tier_code: str): - self.tier_code = tier_code - super().__init__(f"Stripe price not configured for tier '{tier_code}'") - - -class NoActiveSubscriptionError(BillingServiceError): - """Raised when no active subscription exists.""" - - def __init__(self): - super().__init__("No active subscription found") - - -class SubscriptionNotCancelledError(BillingServiceError): - """Raised when trying to reactivate a non-cancelled subscription.""" - - def __init__(self): - super().__init__("Subscription is not cancelled") - - -class BillingService: - """Service for billing operations.""" - - def get_subscription_with_tier( - self, db: Session, vendor_id: int - ) -> tuple[VendorSubscription, SubscriptionTier | None]: - """ - Get subscription and its tier info. - - Returns: - Tuple of (subscription, tier) where tier may be None - """ - subscription = subscription_service.get_or_create_subscription(db, vendor_id) - - tier = ( - db.query(SubscriptionTier) - .filter(SubscriptionTier.code == subscription.tier) - .first() - ) - - return subscription, tier - - def get_available_tiers( - self, db: Session, current_tier: str - ) -> tuple[list[dict], dict[str, int]]: - """ - Get all available tiers with upgrade/downgrade flags. - - Returns: - Tuple of (tier_list, tier_order_map) - """ - tiers = ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.is_active == True, # noqa: E712 - SubscriptionTier.is_public == True, # noqa: E712 - ) - .order_by(SubscriptionTier.display_order) - .all() - ) - - tier_order = {t.code: t.display_order for t in tiers} - current_order = tier_order.get(current_tier, 0) - - tier_list = [] - for tier in tiers: - tier_list.append({ - "code": tier.code, - "name": tier.name, - "description": tier.description, - "price_monthly_cents": tier.price_monthly_cents, - "price_annual_cents": tier.price_annual_cents, - "orders_per_month": tier.orders_per_month, - "products_limit": tier.products_limit, - "team_members": tier.team_members, - "features": tier.features or [], - "is_current": tier.code == current_tier, - "can_upgrade": tier.display_order > current_order, - "can_downgrade": tier.display_order < current_order, - }) - - return tier_list, tier_order - - def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier: - """ - Get a tier by its code. - - Raises: - TierNotFoundError: If tier doesn't exist - """ - tier = ( - db.query(SubscriptionTier) - .filter( - SubscriptionTier.code == tier_code, - SubscriptionTier.is_active == True, # noqa: E712 - ) - .first() - ) - - if not tier: - raise TierNotFoundError(tier_code) - - return tier - - def get_vendor(self, db: Session, vendor_id: int) -> Vendor: - """ - Get vendor by ID. - - Raises: - VendorNotFoundException from app.exceptions - """ - from app.exceptions import VendorNotFoundException - - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - if not vendor: - raise VendorNotFoundException(str(vendor_id), identifier_type="id") - - return vendor - - def create_checkout_session( - self, - db: Session, - vendor_id: int, - tier_code: str, - is_annual: bool, - success_url: str, - cancel_url: str, - ) -> dict: - """ - Create a Stripe checkout session. - - Returns: - Dict with checkout_url and session_id - - Raises: - PaymentSystemNotConfiguredError: If Stripe not configured - TierNotFoundError: If tier doesn't exist - StripePriceNotConfiguredError: If price not configured - """ - if not stripe_service.is_configured: - raise PaymentSystemNotConfiguredError() - - vendor = self.get_vendor(db, vendor_id) - tier = self.get_tier_by_code(db, tier_code) - - price_id = ( - tier.stripe_price_annual_id - if is_annual and tier.stripe_price_annual_id - else tier.stripe_price_monthly_id - ) - - if not price_id: - raise StripePriceNotConfiguredError(tier_code) - - # Check if this is a new subscription (for trial) - existing_sub = subscription_service.get_subscription(db, vendor_id) - trial_days = None - if not existing_sub or not existing_sub.stripe_subscription_id: - from app.core.config import settings - trial_days = settings.stripe_trial_days - - session = stripe_service.create_checkout_session( - db=db, - vendor=vendor, - price_id=price_id, - success_url=success_url, - cancel_url=cancel_url, - trial_days=trial_days, - ) - - # Update subscription with tier info - subscription = subscription_service.get_or_create_subscription(db, vendor_id) - subscription.tier = tier_code - subscription.is_annual = is_annual - - return { - "checkout_url": session.url, - "session_id": session.id, - } - - def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict: - """ - Create a Stripe customer portal session. - - Returns: - Dict with portal_url - - Raises: - PaymentSystemNotConfiguredError: If Stripe not configured - NoActiveSubscriptionError: If no subscription with customer ID - """ - if not stripe_service.is_configured: - raise PaymentSystemNotConfiguredError() - - subscription = subscription_service.get_subscription(db, vendor_id) - - if not subscription or not subscription.stripe_customer_id: - raise NoActiveSubscriptionError() - - session = stripe_service.create_portal_session( - customer_id=subscription.stripe_customer_id, - return_url=return_url, - ) - - return {"portal_url": session.url} - - def get_invoices( - self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20 - ) -> tuple[list[BillingHistory], int]: - """ - Get invoice history for a vendor. - - Returns: - Tuple of (invoices, total_count) - """ - query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id) - - total = query.count() - - invoices = ( - query.order_by(BillingHistory.invoice_date.desc()) - .offset(skip) - .limit(limit) - .all() - ) - - return invoices, total - - def get_available_addons( - self, db: Session, category: str | None = None - ) -> list[AddOnProduct]: - """Get available add-on products.""" - query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712 - - if category: - query = query.filter(AddOnProduct.category == category) - - return query.order_by(AddOnProduct.display_order).all() - - def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]: - """Get vendor's purchased add-ons.""" - return ( - db.query(VendorAddOn) - .filter(VendorAddOn.vendor_id == vendor_id) - .all() - ) - - def cancel_subscription( - self, db: Session, vendor_id: int, reason: str | None, immediately: bool - ) -> dict: - """ - Cancel a subscription. - - Returns: - Dict with message and effective_date - - Raises: - NoActiveSubscriptionError: If no subscription to cancel - """ - subscription = subscription_service.get_subscription(db, vendor_id) - - if not subscription or not subscription.stripe_subscription_id: - raise NoActiveSubscriptionError() - - if stripe_service.is_configured: - stripe_service.cancel_subscription( - subscription_id=subscription.stripe_subscription_id, - immediately=immediately, - cancellation_reason=reason, - ) - - subscription.cancelled_at = datetime.utcnow() - subscription.cancellation_reason = reason - - effective_date = ( - datetime.utcnow().isoformat() - if immediately - else subscription.period_end.isoformat() - if subscription.period_end - else datetime.utcnow().isoformat() - ) - - return { - "message": "Subscription cancelled successfully", - "effective_date": effective_date, - } - - def reactivate_subscription(self, db: Session, vendor_id: int) -> dict: - """ - Reactivate a cancelled subscription. - - Returns: - Dict with success message - - Raises: - NoActiveSubscriptionError: If no subscription - SubscriptionNotCancelledError: If not cancelled - """ - subscription = subscription_service.get_subscription(db, vendor_id) - - if not subscription or not subscription.stripe_subscription_id: - raise NoActiveSubscriptionError() - - if not subscription.cancelled_at: - raise SubscriptionNotCancelledError() - - if stripe_service.is_configured: - stripe_service.reactivate_subscription(subscription.stripe_subscription_id) - - subscription.cancelled_at = None - subscription.cancellation_reason = None - - return {"message": "Subscription reactivated successfully"} - - def get_upcoming_invoice(self, db: Session, vendor_id: int) -> dict: - """ - Get upcoming invoice preview. - - Returns: - Dict with amount_due_cents, currency, next_payment_date, line_items - - Raises: - NoActiveSubscriptionError: If no subscription with customer ID - """ - subscription = subscription_service.get_subscription(db, vendor_id) - - if not subscription or not subscription.stripe_customer_id: - raise NoActiveSubscriptionError() - - if not stripe_service.is_configured: - # Return empty preview if Stripe not configured - return { - "amount_due_cents": 0, - "currency": "EUR", - "next_payment_date": None, - "line_items": [], - } - - invoice = stripe_service.get_upcoming_invoice(subscription.stripe_customer_id) - - if not invoice: - return { - "amount_due_cents": 0, - "currency": "EUR", - "next_payment_date": None, - "line_items": [], - } - - line_items = [] - if invoice.lines and invoice.lines.data: - for line in invoice.lines.data: - line_items.append({ - "description": line.description or "", - "amount_cents": line.amount, - "quantity": line.quantity or 1, - }) - - return { - "amount_due_cents": invoice.amount_due, - "currency": invoice.currency.upper(), - "next_payment_date": datetime.fromtimestamp(invoice.next_payment_attempt).isoformat() - if invoice.next_payment_attempt - else None, - "line_items": line_items, - } - - def change_tier( - self, - db: Session, - vendor_id: int, - new_tier_code: str, - is_annual: bool, - ) -> dict: - """ - Change subscription tier (upgrade/downgrade). - - Returns: - Dict with message, new_tier, effective_immediately - - Raises: - TierNotFoundError: If tier doesn't exist - NoActiveSubscriptionError: If no subscription - StripePriceNotConfiguredError: If price not configured - """ - subscription = subscription_service.get_subscription(db, vendor_id) - - if not subscription or not subscription.stripe_subscription_id: - raise NoActiveSubscriptionError() - - tier = self.get_tier_by_code(db, new_tier_code) - - price_id = ( - tier.stripe_price_annual_id - if is_annual and tier.stripe_price_annual_id - else tier.stripe_price_monthly_id - ) - - if not price_id: - raise StripePriceNotConfiguredError(new_tier_code) - - # Update in Stripe - if stripe_service.is_configured: - stripe_service.update_subscription( - subscription_id=subscription.stripe_subscription_id, - new_price_id=price_id, - ) - - # Update local subscription - old_tier = subscription.tier - subscription.tier = new_tier_code - subscription.tier_id = tier.id - subscription.is_annual = is_annual - subscription.updated_at = datetime.utcnow() - - is_upgrade = self._is_upgrade(db, old_tier, new_tier_code) - - return { - "message": f"Subscription {'upgraded' if is_upgrade else 'changed'} to {tier.name}", - "new_tier": new_tier_code, - "effective_immediately": True, - } - - def _is_upgrade(self, db: Session, old_tier: str, new_tier: str) -> bool: - """Check if tier change is an upgrade.""" - old = db.query(SubscriptionTier).filter(SubscriptionTier.code == old_tier).first() - new = db.query(SubscriptionTier).filter(SubscriptionTier.code == new_tier).first() - - if not old or not new: - return False - - return new.display_order > old.display_order - - def purchase_addon( - self, - db: Session, - vendor_id: int, - addon_code: str, - domain_name: str | None, - quantity: int, - success_url: str, - cancel_url: str, - ) -> dict: - """ - Create checkout session for add-on purchase. - - Returns: - Dict with checkout_url and session_id - - Raises: - PaymentSystemNotConfiguredError: If Stripe not configured - AddonNotFoundError: If addon doesn't exist - """ - if not stripe_service.is_configured: - raise PaymentSystemNotConfiguredError() - - addon = ( - db.query(AddOnProduct) - .filter( - AddOnProduct.code == addon_code, - AddOnProduct.is_active == True, # noqa: E712 - ) - .first() - ) - - if not addon: - raise BillingServiceError(f"Add-on '{addon_code}' not found") - - if not addon.stripe_price_id: - raise BillingServiceError(f"Stripe price not configured for add-on '{addon_code}'") - - vendor = self.get_vendor(db, vendor_id) - subscription = subscription_service.get_or_create_subscription(db, vendor_id) - - # Create checkout session for add-on - session = stripe_service.create_checkout_session( - db=db, - vendor=vendor, - price_id=addon.stripe_price_id, - success_url=success_url, - cancel_url=cancel_url, - quantity=quantity, - metadata={ - "addon_code": addon_code, - "domain_name": domain_name or "", - }, - ) - - return { - "checkout_url": session.url, - "session_id": session.id, - } - - def cancel_addon(self, db: Session, vendor_id: int, addon_id: int) -> dict: - """ - Cancel a purchased add-on. - - Returns: - Dict with message and addon_code - - Raises: - BillingServiceError: If addon not found or not owned by vendor - """ - vendor_addon = ( - db.query(VendorAddOn) - .filter( - VendorAddOn.id == addon_id, - VendorAddOn.vendor_id == vendor_id, - ) - .first() - ) - - if not vendor_addon: - raise BillingServiceError("Add-on not found") - - addon_code = vendor_addon.addon_product.code - - # Cancel in Stripe if applicable - if stripe_service.is_configured and vendor_addon.stripe_subscription_item_id: - try: - stripe_service.cancel_subscription_item(vendor_addon.stripe_subscription_item_id) - except Exception as e: - logger.warning(f"Failed to cancel addon in Stripe: {e}") - - # Mark as cancelled - vendor_addon.status = "cancelled" - vendor_addon.cancelled_at = datetime.utcnow() - - return { - "message": "Add-on cancelled successfully", - "addon_code": addon_code, - } - - -# Create service instance -billing_service = BillingService() +__all__ = [ + "BillingService", + "billing_service", + "BillingServiceError", + "PaymentSystemNotConfiguredError", + "TierNotFoundError", + "StripePriceNotConfiguredError", + "NoActiveSubscriptionError", + "SubscriptionNotCancelledError", +] diff --git a/app/services/code_quality_service.py b/app/services/code_quality_service.py index d3162e78..77d6bd91 100644 --- a/app/services/code_quality_service.py +++ b/app/services/code_quality_service.py @@ -1,820 +1,35 @@ +# app/services/code_quality_service.py """ -Code Quality Service -Business logic for managing code quality scans and violations -Supports multiple validator types: architecture, security, performance +LEGACY LOCATION - Re-exports from module for backwards compatibility. + +The canonical implementation is now in: + app/modules/dev_tools/services/code_quality_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.dev_tools.services import code_quality_service """ -import json -import logging -import subprocess -from datetime import datetime - -from sqlalchemy import desc, func -from sqlalchemy.orm import Session - -from app.exceptions import ( - ScanParseException, - ScanTimeoutException, - ViolationNotFoundException, -) -from models.database.architecture_scan import ( - ArchitectureScan, - ArchitectureViolation, - ViolationAssignment, - ViolationComment, +from app.modules.dev_tools.services.code_quality_service import ( + code_quality_service, + CodeQualityService, + VALIDATOR_ARCHITECTURE, + VALIDATOR_SECURITY, + VALIDATOR_PERFORMANCE, + VALID_VALIDATOR_TYPES, + VALIDATOR_SCRIPTS, + VALIDATOR_NAMES, ) -logger = logging.getLogger(__name__) - -# Validator type constants -VALIDATOR_ARCHITECTURE = "architecture" -VALIDATOR_SECURITY = "security" -VALIDATOR_PERFORMANCE = "performance" - -VALID_VALIDATOR_TYPES = [VALIDATOR_ARCHITECTURE, VALIDATOR_SECURITY, VALIDATOR_PERFORMANCE] - -# Map validator types to their scripts -VALIDATOR_SCRIPTS = { - VALIDATOR_ARCHITECTURE: "scripts/validate_architecture.py", - VALIDATOR_SECURITY: "scripts/validate_security.py", - VALIDATOR_PERFORMANCE: "scripts/validate_performance.py", -} - -# Human-readable names -VALIDATOR_NAMES = { - VALIDATOR_ARCHITECTURE: "Architecture", - VALIDATOR_SECURITY: "Security", - VALIDATOR_PERFORMANCE: "Performance", -} - - -class CodeQualityService: - """Service for managing code quality scans and violations""" - - def run_scan( - self, - db: Session, - triggered_by: str = "manual", - validator_type: str = VALIDATOR_ARCHITECTURE, - ) -> ArchitectureScan: - """ - Run a code quality validator and store results in database - - Args: - db: Database session - triggered_by: Who/what triggered the scan ('manual', 'scheduled', 'ci/cd') - validator_type: Type of validator ('architecture', 'security', 'performance') - - Returns: - ArchitectureScan object with results - - Raises: - ValueError: If validator_type is invalid - ScanTimeoutException: If validator times out - ScanParseException: If validator output cannot be parsed - """ - if validator_type not in VALID_VALIDATOR_TYPES: - raise ValueError( - f"Invalid validator type: {validator_type}. " - f"Must be one of: {VALID_VALIDATOR_TYPES}" - ) - - script_path = VALIDATOR_SCRIPTS[validator_type] - validator_name = VALIDATOR_NAMES[validator_type] - - logger.info( - f"Starting {validator_name} scan (triggered by: {triggered_by})" - ) - - # Get git commit hash - git_commit = self._get_git_commit_hash() - - # Run validator with JSON output - start_time = datetime.now() - try: - result = subprocess.run( - ["python", script_path, "--json"], - capture_output=True, - text=True, - timeout=300, # 5 minute timeout - ) - except subprocess.TimeoutExpired: - logger.error(f"{validator_name} scan timed out after 5 minutes") - raise ScanTimeoutException(timeout_seconds=300) - - duration = (datetime.now() - start_time).total_seconds() - - # Parse JSON output (get only the JSON part, skip progress messages) - try: - # Find the JSON part in stdout - lines = result.stdout.strip().split("\n") - json_start = -1 - for i, line in enumerate(lines): - if line.strip().startswith("{"): - json_start = i - break - - if json_start == -1: - raise ValueError("No JSON output found") - - json_output = "\n".join(lines[json_start:]) - data = json.loads(json_output) - except (json.JSONDecodeError, ValueError) as e: - logger.error(f"Failed to parse {validator_name} validator output: {e}") - logger.error(f"Stdout: {result.stdout}") - logger.error(f"Stderr: {result.stderr}") - raise ScanParseException(reason=str(e)) - - # Create scan record - scan = ArchitectureScan( - timestamp=datetime.now(), - validator_type=validator_type, - total_files=data.get("files_checked", 0), - total_violations=data.get("total_violations", 0), - errors=data.get("errors", 0), - warnings=data.get("warnings", 0), - duration_seconds=duration, - triggered_by=triggered_by, - git_commit_hash=git_commit, - ) - db.add(scan) - db.flush() # Get scan.id - - # Create violation records - violations_data = data.get("violations", []) - logger.info(f"Creating {len(violations_data)} {validator_name} violation records") - - for v in violations_data: - violation = ArchitectureViolation( - scan_id=scan.id, - validator_type=validator_type, - rule_id=v["rule_id"], - rule_name=v["rule_name"], - severity=v["severity"], - file_path=v["file_path"], - line_number=v["line_number"], - message=v["message"], - context=v.get("context", ""), - suggestion=v.get("suggestion", ""), - status="open", - ) - db.add(violation) - - db.flush() - db.refresh(scan) - - logger.info( - f"{validator_name} scan completed: {scan.total_violations} violations found" - ) - return scan - - def run_all_scans( - self, db: Session, triggered_by: str = "manual" - ) -> list[ArchitectureScan]: - """ - Run all validators and return list of scans - - Args: - db: Database session - triggered_by: Who/what triggered the scan - - Returns: - List of ArchitectureScan objects (one per validator) - """ - results = [] - for validator_type in VALID_VALIDATOR_TYPES: - try: - scan = self.run_scan(db, triggered_by, validator_type) - results.append(scan) - except Exception as e: - logger.error(f"Failed to run {validator_type} scan: {e}") - # Continue with other validators even if one fails - return results - - def get_latest_scan( - self, db: Session, validator_type: str = None - ) -> ArchitectureScan | None: - """ - Get the most recent scan - - Args: - db: Database session - validator_type: Optional filter by validator type - - Returns: - Most recent ArchitectureScan or None - """ - query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)) - - if validator_type: - query = query.filter(ArchitectureScan.validator_type == validator_type) - - return query.first() - - def get_latest_scans_by_type(self, db: Session) -> dict[str, ArchitectureScan]: - """ - Get the most recent scan for each validator type - - Returns: - Dictionary mapping validator_type to its latest scan - """ - result = {} - for vtype in VALID_VALIDATOR_TYPES: - scan = self.get_latest_scan(db, validator_type=vtype) - if scan: - result[vtype] = scan - return result - - def get_scan_by_id(self, db: Session, scan_id: int) -> ArchitectureScan | None: - """Get scan by ID""" - return db.query(ArchitectureScan).filter(ArchitectureScan.id == scan_id).first() - - def create_pending_scan( - self, db: Session, validator_type: str, triggered_by: str - ) -> ArchitectureScan: - """ - Create a new scan record with pending status. - - Args: - db: Database session - validator_type: Type of validator (architecture, security, performance) - triggered_by: Who triggered the scan (e.g., "manual:username") - - Returns: - The created ArchitectureScan record with ID populated - """ - scan = ArchitectureScan( - timestamp=datetime.now(UTC), - validator_type=validator_type, - status="pending", - triggered_by=triggered_by, - ) - db.add(scan) - db.flush() # Get scan.id - return scan - - def get_running_scans(self, db: Session) -> list[ArchitectureScan]: - """ - Get all currently running scans (pending or running status). - - Returns: - List of scans with status 'pending' or 'running', newest first - """ - return ( - db.query(ArchitectureScan) - .filter(ArchitectureScan.status.in_(["pending", "running"])) - .order_by(ArchitectureScan.timestamp.desc()) - .all() - ) - - def get_scan_history( - self, db: Session, limit: int = 30, validator_type: str = None - ) -> list[ArchitectureScan]: - """ - Get scan history for trend graphs - - Args: - db: Database session - limit: Maximum number of scans to return - validator_type: Optional filter by validator type - - Returns: - List of ArchitectureScan objects, newest first - """ - query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp)) - - if validator_type: - query = query.filter(ArchitectureScan.validator_type == validator_type) - - return query.limit(limit).all() - - def get_violations( - self, - db: Session, - scan_id: int = None, - validator_type: str = None, - severity: str = None, - status: str = None, - rule_id: str = None, - file_path: str = None, - limit: int = 100, - offset: int = 0, - ) -> tuple[list[ArchitectureViolation], int]: - """ - Get violations with filtering and pagination - - Args: - db: Database session - scan_id: Filter by scan ID (if None, use latest scan(s)) - validator_type: Filter by validator type - severity: Filter by severity ('error', 'warning') - status: Filter by status ('open', 'assigned', 'resolved', etc.) - rule_id: Filter by rule ID - file_path: Filter by file path (partial match) - limit: Page size - offset: Page offset - - Returns: - Tuple of (violations list, total count) - """ - # Build query - query = db.query(ArchitectureViolation) - - # If scan_id specified, filter by it - if scan_id is not None: - query = query.filter(ArchitectureViolation.scan_id == scan_id) - else: - # If no scan_id, get violations from latest scan(s) - if validator_type: - # Get latest scan for specific validator type - latest_scan = self.get_latest_scan(db, validator_type) - if not latest_scan: - return [], 0 - query = query.filter(ArchitectureViolation.scan_id == latest_scan.id) - else: - # Get violations from latest scans of all types - latest_scans = self.get_latest_scans_by_type(db) - if not latest_scans: - return [], 0 - scan_ids = [s.id for s in latest_scans.values()] - query = query.filter(ArchitectureViolation.scan_id.in_(scan_ids)) - - # Apply validator_type filter if specified (for scan_id queries) - if validator_type and scan_id is not None: - query = query.filter(ArchitectureViolation.validator_type == validator_type) - - # Apply other filters - if severity: - query = query.filter(ArchitectureViolation.severity == severity) - - if status: - query = query.filter(ArchitectureViolation.status == status) - - if rule_id: - query = query.filter(ArchitectureViolation.rule_id == rule_id) - - if file_path: - query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%")) - - # Get total count - total = query.count() - - # Get page of results - violations = ( - query.order_by( - ArchitectureViolation.severity.desc(), - ArchitectureViolation.validator_type, - ArchitectureViolation.file_path, - ) - .limit(limit) - .offset(offset) - .all() - ) - - return violations, total - - def get_violation_by_id( - self, db: Session, violation_id: int - ) -> ArchitectureViolation | None: - """Get single violation with details""" - return ( - db.query(ArchitectureViolation) - .filter(ArchitectureViolation.id == violation_id) - .first() - ) - - def assign_violation( - self, - db: Session, - violation_id: int, - user_id: int, - assigned_by: int, - due_date: datetime = None, - priority: str = "medium", - ) -> ViolationAssignment: - """ - Assign violation to a developer - - Args: - db: Database session - violation_id: Violation ID - user_id: User to assign to - assigned_by: User who is assigning - due_date: Due date (optional) - priority: Priority level ('low', 'medium', 'high', 'critical') - - Returns: - ViolationAssignment object - """ - # Update violation status - violation = self.get_violation_by_id(db, violation_id) - if violation: - violation.status = "assigned" - violation.assigned_to = user_id - - # Create assignment record - assignment = ViolationAssignment( - violation_id=violation_id, - user_id=user_id, - assigned_by=assigned_by, - due_date=due_date, - priority=priority, - ) - db.add(assignment) - db.flush() - - logger.info(f"Violation {violation_id} assigned to user {user_id}") - return assignment - - def resolve_violation( - self, db: Session, violation_id: int, resolved_by: int, resolution_note: str - ) -> ArchitectureViolation: - """ - Mark violation as resolved - - Args: - db: Database session - violation_id: Violation ID - resolved_by: User who resolved it - resolution_note: Note about resolution - - Returns: - Updated ArchitectureViolation object - """ - violation = self.get_violation_by_id(db, violation_id) - if not violation: - raise ViolationNotFoundException(violation_id) - - violation.status = "resolved" - violation.resolved_at = datetime.now() - violation.resolved_by = resolved_by - violation.resolution_note = resolution_note - - db.flush() - logger.info(f"Violation {violation_id} resolved by user {resolved_by}") - return violation - - def ignore_violation( - self, db: Session, violation_id: int, ignored_by: int, reason: str - ) -> ArchitectureViolation: - """ - Mark violation as ignored/won't fix - - Args: - db: Database session - violation_id: Violation ID - ignored_by: User who ignored it - reason: Reason for ignoring - - Returns: - Updated ArchitectureViolation object - """ - violation = self.get_violation_by_id(db, violation_id) - if not violation: - raise ViolationNotFoundException(violation_id) - - violation.status = "ignored" - violation.resolved_at = datetime.now() - violation.resolved_by = ignored_by - violation.resolution_note = f"Ignored: {reason}" - - db.flush() - logger.info(f"Violation {violation_id} ignored by user {ignored_by}") - return violation - - def add_comment( - self, db: Session, violation_id: int, user_id: int, comment: str - ) -> ViolationComment: - """ - Add comment to violation - - Args: - db: Database session - violation_id: Violation ID - user_id: User posting comment - comment: Comment text - - Returns: - ViolationComment object - """ - comment_obj = ViolationComment( - violation_id=violation_id, user_id=user_id, comment=comment - ) - db.add(comment_obj) - db.flush() - - logger.info(f"Comment added to violation {violation_id} by user {user_id}") - return comment_obj - - def get_dashboard_stats( - self, db: Session, validator_type: str = None - ) -> dict: - """ - Get statistics for dashboard - - Args: - db: Database session - validator_type: Optional filter by validator type. If None, returns combined stats. - - Returns: - Dictionary with various statistics including per-validator breakdown - """ - # Get latest scans by type - latest_scans = self.get_latest_scans_by_type(db) - - if not latest_scans: - return self._empty_dashboard_stats() - - # If specific validator type requested - if validator_type and validator_type in latest_scans: - scan = latest_scans[validator_type] - return self._get_stats_for_scan(db, scan, validator_type) - - # Combined stats across all validators - return self._get_combined_stats(db, latest_scans) - - def _empty_dashboard_stats(self) -> dict: - """Return empty dashboard stats structure""" - return { - "total_violations": 0, - "errors": 0, - "warnings": 0, - "info": 0, - "open": 0, - "assigned": 0, - "resolved": 0, - "ignored": 0, - "technical_debt_score": 100, - "trend": [], - "by_severity": {}, - "by_rule": {}, - "by_module": {}, - "top_files": [], - "last_scan": None, - "by_validator": {}, - } - - def _get_stats_for_scan( - self, db: Session, scan: ArchitectureScan, validator_type: str - ) -> dict: - """Get stats for a single scan/validator type""" - # Get violation counts by status - status_counts = ( - db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id)) - .filter(ArchitectureViolation.scan_id == scan.id) - .group_by(ArchitectureViolation.status) - .all() - ) - status_dict = {status: count for status, count in status_counts} - - # Get violations by severity - severity_counts = ( - db.query( - ArchitectureViolation.severity, func.count(ArchitectureViolation.id) - ) - .filter(ArchitectureViolation.scan_id == scan.id) - .group_by(ArchitectureViolation.severity) - .all() - ) - by_severity = {sev: count for sev, count in severity_counts} - - # Get violations by rule - rule_counts = ( - db.query( - ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id) - ) - .filter(ArchitectureViolation.scan_id == scan.id) - .group_by(ArchitectureViolation.rule_id) - .all() - ) - by_rule = { - rule: count - for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10] - } - - # Get top violating files - file_counts = ( - db.query( - ArchitectureViolation.file_path, - func.count(ArchitectureViolation.id).label("count"), - ) - .filter(ArchitectureViolation.scan_id == scan.id) - .group_by(ArchitectureViolation.file_path) - .order_by(desc("count")) - .limit(10) - .all() - ) - top_files = [{"file": file, "count": count} for file, count in file_counts] - - # Get violations by module - by_module = self._get_violations_by_module(db, scan.id) - - # Get trend for this validator type - trend_scans = self.get_scan_history(db, limit=7, validator_type=validator_type) - trend = [ - { - "timestamp": s.timestamp.isoformat(), - "violations": s.total_violations, - "errors": s.errors, - "warnings": s.warnings, - } - for s in reversed(trend_scans) - ] - - return { - "total_violations": scan.total_violations, - "errors": scan.errors, - "warnings": scan.warnings, - "info": by_severity.get("info", 0), - "open": status_dict.get("open", 0), - "assigned": status_dict.get("assigned", 0), - "resolved": status_dict.get("resolved", 0), - "ignored": status_dict.get("ignored", 0), - "technical_debt_score": self._calculate_score(scan.errors, scan.warnings), - "trend": trend, - "by_severity": by_severity, - "by_rule": by_rule, - "by_module": by_module, - "top_files": top_files, - "last_scan": scan.timestamp.isoformat(), - "validator_type": validator_type, - "by_validator": { - validator_type: { - "total_violations": scan.total_violations, - "errors": scan.errors, - "warnings": scan.warnings, - "last_scan": scan.timestamp.isoformat(), - } - }, - } - - def _get_combined_stats( - self, db: Session, latest_scans: dict[str, ArchitectureScan] - ) -> dict: - """Get combined stats across all validators""" - # Aggregate totals - total_violations = sum(s.total_violations for s in latest_scans.values()) - total_errors = sum(s.errors for s in latest_scans.values()) - total_warnings = sum(s.warnings for s in latest_scans.values()) - - # Get all scan IDs - scan_ids = [s.id for s in latest_scans.values()] - - # Get violation counts by status - status_counts = ( - db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id)) - .filter(ArchitectureViolation.scan_id.in_(scan_ids)) - .group_by(ArchitectureViolation.status) - .all() - ) - status_dict = {status: count for status, count in status_counts} - - # Get violations by severity - severity_counts = ( - db.query( - ArchitectureViolation.severity, func.count(ArchitectureViolation.id) - ) - .filter(ArchitectureViolation.scan_id.in_(scan_ids)) - .group_by(ArchitectureViolation.severity) - .all() - ) - by_severity = {sev: count for sev, count in severity_counts} - - # Get violations by rule (across all validators) - rule_counts = ( - db.query( - ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id) - ) - .filter(ArchitectureViolation.scan_id.in_(scan_ids)) - .group_by(ArchitectureViolation.rule_id) - .all() - ) - by_rule = { - rule: count - for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10] - } - - # Get top violating files - file_counts = ( - db.query( - ArchitectureViolation.file_path, - func.count(ArchitectureViolation.id).label("count"), - ) - .filter(ArchitectureViolation.scan_id.in_(scan_ids)) - .group_by(ArchitectureViolation.file_path) - .order_by(desc("count")) - .limit(10) - .all() - ) - top_files = [{"file": file, "count": count} for file, count in file_counts] - - # Get violations by module - by_module = {} - for scan_id in scan_ids: - module_counts = self._get_violations_by_module(db, scan_id) - for module, count in module_counts.items(): - by_module[module] = by_module.get(module, 0) + count - by_module = dict( - sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10] - ) - - # Per-validator breakdown - by_validator = {} - for vtype, scan in latest_scans.items(): - by_validator[vtype] = { - "total_violations": scan.total_violations, - "errors": scan.errors, - "warnings": scan.warnings, - "last_scan": scan.timestamp.isoformat(), - } - - # Get most recent scan timestamp - most_recent = max(latest_scans.values(), key=lambda s: s.timestamp) - - return { - "total_violations": total_violations, - "errors": total_errors, - "warnings": total_warnings, - "info": by_severity.get("info", 0), - "open": status_dict.get("open", 0), - "assigned": status_dict.get("assigned", 0), - "resolved": status_dict.get("resolved", 0), - "ignored": status_dict.get("ignored", 0), - "technical_debt_score": self._calculate_score(total_errors, total_warnings), - "trend": [], # Combined trend would need special handling - "by_severity": by_severity, - "by_rule": by_rule, - "by_module": by_module, - "top_files": top_files, - "last_scan": most_recent.timestamp.isoformat(), - "by_validator": by_validator, - } - - def _get_violations_by_module(self, db: Session, scan_id: int) -> dict[str, int]: - """Extract module from file paths and count violations""" - by_module = {} - violations = ( - db.query(ArchitectureViolation.file_path) - .filter(ArchitectureViolation.scan_id == scan_id) - .all() - ) - - for v in violations: - path_parts = v.file_path.split("/") - if len(path_parts) >= 2: - module = "/".join(path_parts[:2]) - else: - module = path_parts[0] - by_module[module] = by_module.get(module, 0) + 1 - - return dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]) - - def _calculate_score(self, errors: int, warnings: int) -> int: - """Calculate technical debt score (0-100)""" - score = 100 - (errors * 0.5 + warnings * 0.05) - return max(0, min(100, int(score))) - - def calculate_technical_debt_score( - self, db: Session, scan_id: int = None, validator_type: str = None - ) -> int: - """ - Calculate technical debt score (0-100) - - Formula: 100 - (errors * 0.5 + warnings * 0.05) - Capped at 0 minimum - - Args: - db: Database session - scan_id: Scan ID (if None, use latest) - validator_type: Filter by validator type - - Returns: - Score from 0-100 - """ - if scan_id is None: - latest_scan = self.get_latest_scan(db, validator_type) - if not latest_scan: - return 100 - scan_id = latest_scan.id - - scan = self.get_scan_by_id(db, scan_id) - if not scan: - return 100 - - return self._calculate_score(scan.errors, scan.warnings) - - def _get_git_commit_hash(self) -> str | None: - """Get current git commit hash""" - try: - result = subprocess.run( - ["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5 - ) - if result.returncode == 0: - return result.stdout.strip()[:40] - except Exception: - pass - return None - - -# Singleton instance -code_quality_service = CodeQualityService() +__all__ = [ + "code_quality_service", + "CodeQualityService", + "VALIDATOR_ARCHITECTURE", + "VALIDATOR_SECURITY", + "VALIDATOR_PERFORMANCE", + "VALID_VALIDATOR_TYPES", + "VALIDATOR_SCRIPTS", + "VALIDATOR_NAMES", +] diff --git a/app/services/customer_address_service.py b/app/services/customer_address_service.py index a1548b7c..8af43d33 100644 --- a/app/services/customer_address_service.py +++ b/app/services/customer_address_service.py @@ -1,314 +1,23 @@ # app/services/customer_address_service.py """ -Customer Address Service +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Business logic for managing customer addresses with vendor isolation. +The canonical implementation is now in: + app/modules/customers/services/customer_address_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.customers.services import customer_address_service """ -import logging - -from sqlalchemy.orm import Session - -from app.exceptions import ( - AddressLimitExceededException, - AddressNotFoundException, +from app.modules.customers.services.customer_address_service import ( + customer_address_service, + CustomerAddressService, ) -from models.database.customer import CustomerAddress -from models.schema.customer import CustomerAddressCreate, CustomerAddressUpdate -logger = logging.getLogger(__name__) - - -class CustomerAddressService: - """Service for managing customer addresses with vendor isolation.""" - - MAX_ADDRESSES_PER_CUSTOMER = 10 - - def list_addresses( - self, db: Session, vendor_id: int, customer_id: int - ) -> list[CustomerAddress]: - """ - Get all addresses for a customer. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - - Returns: - List of customer addresses - """ - return ( - db.query(CustomerAddress) - .filter( - CustomerAddress.vendor_id == vendor_id, - CustomerAddress.customer_id == customer_id, - ) - .order_by(CustomerAddress.is_default.desc(), CustomerAddress.created_at.desc()) - .all() - ) - - def get_address( - self, db: Session, vendor_id: int, customer_id: int, address_id: int - ) -> CustomerAddress: - """ - Get a specific address with ownership validation. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - address_id: Address ID - - Returns: - Customer address - - Raises: - AddressNotFoundException: If address not found or doesn't belong to customer - """ - address = ( - db.query(CustomerAddress) - .filter( - CustomerAddress.id == address_id, - CustomerAddress.vendor_id == vendor_id, - CustomerAddress.customer_id == customer_id, - ) - .first() - ) - - if not address: - raise AddressNotFoundException(address_id) - - return address - - def get_default_address( - self, db: Session, vendor_id: int, customer_id: int, address_type: str - ) -> CustomerAddress | None: - """ - Get the default address for a specific type. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - address_type: 'shipping' or 'billing' - - Returns: - Default address or None if not set - """ - return ( - db.query(CustomerAddress) - .filter( - CustomerAddress.vendor_id == vendor_id, - CustomerAddress.customer_id == customer_id, - CustomerAddress.address_type == address_type, - CustomerAddress.is_default == True, # noqa: E712 - ) - .first() - ) - - def create_address( - self, - db: Session, - vendor_id: int, - customer_id: int, - address_data: CustomerAddressCreate, - ) -> CustomerAddress: - """ - Create a new address for a customer. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - address_data: Address creation data - - Returns: - Created customer address - - Raises: - AddressLimitExceededException: If max addresses reached - """ - # Check address limit - current_count = ( - db.query(CustomerAddress) - .filter( - CustomerAddress.vendor_id == vendor_id, - CustomerAddress.customer_id == customer_id, - ) - .count() - ) - - if current_count >= self.MAX_ADDRESSES_PER_CUSTOMER: - raise AddressLimitExceededException(self.MAX_ADDRESSES_PER_CUSTOMER) - - # If setting as default, clear other defaults of same type - if address_data.is_default: - self._clear_other_defaults( - db, vendor_id, customer_id, address_data.address_type - ) - - # Create the address - address = CustomerAddress( - vendor_id=vendor_id, - customer_id=customer_id, - address_type=address_data.address_type, - first_name=address_data.first_name, - last_name=address_data.last_name, - company=address_data.company, - address_line_1=address_data.address_line_1, - address_line_2=address_data.address_line_2, - city=address_data.city, - postal_code=address_data.postal_code, - country_name=address_data.country_name, - country_iso=address_data.country_iso, - is_default=address_data.is_default, - ) - - db.add(address) - db.flush() - - logger.info( - f"Created address {address.id} for customer {customer_id} " - f"(type={address_data.address_type}, default={address_data.is_default})" - ) - - return address - - def update_address( - self, - db: Session, - vendor_id: int, - customer_id: int, - address_id: int, - address_data: CustomerAddressUpdate, - ) -> CustomerAddress: - """ - Update an existing address. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - address_id: Address ID - address_data: Address update data - - Returns: - Updated customer address - - Raises: - AddressNotFoundException: If address not found - """ - address = self.get_address(db, vendor_id, customer_id, address_id) - - # Update only provided fields - update_data = address_data.model_dump(exclude_unset=True) - - # Handle default flag - clear others if setting to default - if update_data.get("is_default") is True: - # Use updated type if provided, otherwise current type - address_type = update_data.get("address_type", address.address_type) - self._clear_other_defaults( - db, vendor_id, customer_id, address_type, exclude_id=address_id - ) - - for field, value in update_data.items(): - setattr(address, field, value) - - db.flush() - - logger.info(f"Updated address {address_id} for customer {customer_id}") - - return address - - def delete_address( - self, db: Session, vendor_id: int, customer_id: int, address_id: int - ) -> None: - """ - Delete an address. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - address_id: Address ID - - Raises: - AddressNotFoundException: If address not found - """ - address = self.get_address(db, vendor_id, customer_id, address_id) - - db.delete(address) - db.flush() - - logger.info(f"Deleted address {address_id} for customer {customer_id}") - - def set_default( - self, db: Session, vendor_id: int, customer_id: int, address_id: int - ) -> CustomerAddress: - """ - Set an address as the default for its type. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - address_id: Address ID - - Returns: - Updated customer address - - Raises: - AddressNotFoundException: If address not found - """ - address = self.get_address(db, vendor_id, customer_id, address_id) - - # Clear other defaults of same type - self._clear_other_defaults( - db, vendor_id, customer_id, address.address_type, exclude_id=address_id - ) - - # Set this one as default - address.is_default = True - db.flush() - - logger.info( - f"Set address {address_id} as default {address.address_type} " - f"for customer {customer_id}" - ) - - return address - - def _clear_other_defaults( - self, - db: Session, - vendor_id: int, - customer_id: int, - address_type: str, - exclude_id: int | None = None, - ) -> None: - """ - Clear the default flag on other addresses of the same type. - - Args: - db: Database session - vendor_id: Vendor ID for isolation - customer_id: Customer ID - address_type: 'shipping' or 'billing' - exclude_id: Address ID to exclude from clearing - """ - query = db.query(CustomerAddress).filter( - CustomerAddress.vendor_id == vendor_id, - CustomerAddress.customer_id == customer_id, - CustomerAddress.address_type == address_type, - CustomerAddress.is_default == True, # noqa: E712 - ) - - if exclude_id: - query = query.filter(CustomerAddress.id != exclude_id) - - query.update({"is_default": False}, synchronize_session=False) - - -# Singleton instance -customer_address_service = CustomerAddressService() +__all__ = [ + "customer_address_service", + "CustomerAddressService", +] diff --git a/app/services/customer_service.py b/app/services/customer_service.py index 0bda80b7..85b02a56 100644 --- a/app/services/customer_service.py +++ b/app/services/customer_service.py @@ -1,664 +1,23 @@ # app/services/customer_service.py """ -Customer management service. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Handles customer registration, authentication, and profile management -with complete vendor isolation. +The canonical implementation is now in: + app/modules/customers/services/customer_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.customers.services import customer_service """ -import logging -from datetime import UTC, datetime, timedelta -from typing import Any - -from sqlalchemy import and_ -from sqlalchemy.orm import Session - -from app.exceptions.customer import ( - CustomerNotActiveException, - CustomerNotFoundException, - CustomerValidationException, - DuplicateCustomerEmailException, - InvalidCustomerCredentialsException, - InvalidPasswordResetTokenException, - PasswordTooShortException, +from app.modules.customers.services.customer_service import ( + customer_service, + CustomerService, ) -from app.exceptions.vendor import VendorNotActiveException, VendorNotFoundException -from app.services.auth_service import AuthService -from models.database.customer import Customer -from models.database.password_reset_token import PasswordResetToken -from models.database.vendor import Vendor -from models.schema.auth import UserLogin -from models.schema.customer import CustomerRegister, CustomerUpdate -logger = logging.getLogger(__name__) - - -class CustomerService: - """Service for managing vendor-scoped customers.""" - - def __init__(self): - self.auth_service = AuthService() - - def register_customer( - self, db: Session, vendor_id: int, customer_data: CustomerRegister - ) -> Customer: - """ - Register a new customer for a specific vendor. - - Args: - db: Database session - vendor_id: Vendor ID - customer_data: Customer registration data - - Returns: - Customer: Created customer object - - Raises: - VendorNotFoundException: If vendor doesn't exist - VendorNotActiveException: If vendor is not active - DuplicateCustomerEmailException: If email already exists for this vendor - CustomerValidationException: If customer data is invalid - """ - # Verify vendor exists and is active - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - if not vendor: - raise VendorNotFoundException(str(vendor_id), identifier_type="id") - - if not vendor.is_active: - raise VendorNotActiveException(vendor.vendor_code) - - # Check if email already exists for this vendor - existing_customer = ( - db.query(Customer) - .filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == customer_data.email.lower(), - ) - ) - .first() - ) - - if existing_customer: - raise DuplicateCustomerEmailException( - customer_data.email, vendor.vendor_code - ) - - # Generate unique customer number for this vendor - customer_number = self._generate_customer_number( - db, vendor_id, vendor.vendor_code - ) - - # Hash password - hashed_password = self.auth_service.hash_password(customer_data.password) - - # Create customer - customer = Customer( - vendor_id=vendor_id, - email=customer_data.email.lower(), - hashed_password=hashed_password, - first_name=customer_data.first_name, - last_name=customer_data.last_name, - phone=customer_data.phone, - customer_number=customer_number, - marketing_consent=( - customer_data.marketing_consent - if hasattr(customer_data, "marketing_consent") - else False - ), - is_active=True, - ) - - try: - db.add(customer) - db.flush() - db.refresh(customer) - - logger.info( - f"Customer registered successfully: {customer.email} " - f"(ID: {customer.id}, Number: {customer.customer_number}) " - f"for vendor {vendor.vendor_code}" - ) - - return customer - - except Exception as e: - logger.error(f"Error registering customer: {str(e)}") - raise CustomerValidationException( - message="Failed to register customer", details={"error": str(e)} - ) - - def login_customer( - self, db: Session, vendor_id: int, credentials: UserLogin - ) -> dict[str, Any]: - """ - Authenticate customer and generate JWT token. - - Args: - db: Database session - vendor_id: Vendor ID - credentials: Login credentials - - Returns: - Dict containing customer and token data - - Raises: - VendorNotFoundException: If vendor doesn't exist - InvalidCustomerCredentialsException: If credentials are invalid - CustomerNotActiveException: If customer account is inactive - """ - # Verify vendor exists - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - if not vendor: - raise VendorNotFoundException(str(vendor_id), identifier_type="id") - - # Find customer by email (vendor-scoped) - customer = ( - db.query(Customer) - .filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == credentials.email_or_username.lower(), - ) - ) - .first() - ) - - if not customer: - raise InvalidCustomerCredentialsException() - - # Verify password using auth_manager directly - if not self.auth_service.auth_manager.verify_password( - credentials.password, customer.hashed_password - ): - raise InvalidCustomerCredentialsException() - - # Check if customer is active - if not customer.is_active: - raise CustomerNotActiveException(customer.email) - - # Generate JWT token with customer context - # Use auth_manager directly since Customer is not a User model - from datetime import datetime - - from jose import jwt - - auth_manager = self.auth_service.auth_manager - expires_delta = timedelta(minutes=auth_manager.token_expire_minutes) - expire = datetime.now(UTC) + expires_delta - - payload = { - "sub": str(customer.id), - "email": customer.email, - "vendor_id": vendor_id, - "type": "customer", - "exp": expire, - "iat": datetime.now(UTC), - } - - token = jwt.encode( - payload, auth_manager.secret_key, algorithm=auth_manager.algorithm - ) - - token_data = { - "access_token": token, - "token_type": "bearer", - "expires_in": auth_manager.token_expire_minutes * 60, - } - - logger.info( - f"Customer login successful: {customer.email} " - f"for vendor {vendor.vendor_code}" - ) - - return {"customer": customer, "token_data": token_data} - - def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer: - """ - Get customer by ID with vendor isolation. - - Args: - db: Database session - vendor_id: Vendor ID - customer_id: Customer ID - - Returns: - Customer: Customer object - - Raises: - CustomerNotFoundException: If customer not found - """ - customer = ( - db.query(Customer) - .filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id)) - .first() - ) - - if not customer: - raise CustomerNotFoundException(str(customer_id)) - - return customer - - def get_customer_by_email( - self, db: Session, vendor_id: int, email: str - ) -> Customer | None: - """ - Get customer by email (vendor-scoped). - - Args: - db: Database session - vendor_id: Vendor ID - email: Customer email - - Returns: - Optional[Customer]: Customer object or None - """ - return ( - db.query(Customer) - .filter( - and_(Customer.vendor_id == vendor_id, Customer.email == email.lower()) - ) - .first() - ) - - def get_vendor_customers( - self, - db: Session, - vendor_id: int, - skip: int = 0, - limit: int = 100, - search: str | None = None, - is_active: bool | None = None, - ) -> tuple[list[Customer], int]: - """ - Get all customers for a vendor with filtering and pagination. - - Args: - db: Database session - vendor_id: Vendor ID - skip: Pagination offset - limit: Pagination limit - search: Search in name/email - is_active: Filter by active status - - Returns: - Tuple of (customers, total_count) - """ - from sqlalchemy import or_ - - query = db.query(Customer).filter(Customer.vendor_id == vendor_id) - - if search: - search_pattern = f"%{search}%" - query = query.filter( - or_( - Customer.email.ilike(search_pattern), - Customer.first_name.ilike(search_pattern), - Customer.last_name.ilike(search_pattern), - Customer.customer_number.ilike(search_pattern), - ) - ) - - if is_active is not None: - query = query.filter(Customer.is_active == is_active) - - # Order by most recent first - query = query.order_by(Customer.created_at.desc()) - - total = query.count() - customers = query.offset(skip).limit(limit).all() - - return customers, total - - def get_customer_orders( - self, - db: Session, - vendor_id: int, - customer_id: int, - skip: int = 0, - limit: int = 50, - ) -> tuple[list, int]: - """ - Get orders for a specific customer. - - Args: - db: Database session - vendor_id: Vendor ID - customer_id: Customer ID - skip: Pagination offset - limit: Pagination limit - - Returns: - Tuple of (orders, total_count) - - Raises: - CustomerNotFoundException: If customer not found - """ - from models.database.order import Order - - # Verify customer belongs to vendor - self.get_customer(db, vendor_id, customer_id) - - # Get customer orders - query = ( - db.query(Order) - .filter( - Order.customer_id == customer_id, - Order.vendor_id == vendor_id, - ) - .order_by(Order.created_at.desc()) - ) - - total = query.count() - orders = query.offset(skip).limit(limit).all() - - return orders, total - - def get_customer_statistics( - self, db: Session, vendor_id: int, customer_id: int - ) -> dict: - """ - Get detailed statistics for a customer. - - Args: - db: Database session - vendor_id: Vendor ID - customer_id: Customer ID - - Returns: - Dict with customer statistics - """ - from sqlalchemy import func - - from models.database.order import Order - - customer = self.get_customer(db, vendor_id, customer_id) - - # Get order statistics - order_stats = ( - db.query( - func.count(Order.id).label("total_orders"), - func.sum(Order.total_cents).label("total_spent_cents"), - func.avg(Order.total_cents).label("avg_order_cents"), - func.max(Order.created_at).label("last_order_date"), - ) - .filter(Order.customer_id == customer_id) - .first() - ) - - total_orders = order_stats.total_orders or 0 - total_spent_cents = order_stats.total_spent_cents or 0 - avg_order_cents = order_stats.avg_order_cents or 0 - - return { - "customer_id": customer_id, - "total_orders": total_orders, - "total_spent": total_spent_cents / 100, # Convert to euros - "average_order_value": avg_order_cents / 100 if avg_order_cents else 0.0, - "last_order_date": order_stats.last_order_date, - "member_since": customer.created_at, - "is_active": customer.is_active, - } - - def toggle_customer_status( - self, db: Session, vendor_id: int, customer_id: int - ) -> Customer: - """ - Toggle customer active status. - - Args: - db: Database session - vendor_id: Vendor ID - customer_id: Customer ID - - Returns: - Customer: Updated customer - """ - customer = self.get_customer(db, vendor_id, customer_id) - customer.is_active = not customer.is_active - - db.flush() - db.refresh(customer) - - action = "activated" if customer.is_active else "deactivated" - logger.info(f"Customer {action}: {customer.email} (ID: {customer.id})") - - return customer - - def update_customer( - self, - db: Session, - vendor_id: int, - customer_id: int, - customer_data: CustomerUpdate, - ) -> Customer: - """ - Update customer profile. - - Args: - db: Database session - vendor_id: Vendor ID - customer_id: Customer ID - customer_data: Updated customer data - - Returns: - Customer: Updated customer object - - Raises: - CustomerNotFoundException: If customer not found - CustomerValidationException: If update data is invalid - """ - customer = self.get_customer(db, vendor_id, customer_id) - - # Update fields - update_data = customer_data.model_dump(exclude_unset=True) - - for field, value in update_data.items(): - if field == "email" and value: - # Check if new email already exists for this vendor - existing = ( - db.query(Customer) - .filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == value.lower(), - Customer.id != customer_id, - ) - ) - .first() - ) - - if existing: - raise DuplicateCustomerEmailException(value, "vendor") - - setattr(customer, field, value.lower()) - elif hasattr(customer, field): - setattr(customer, field, value) - - try: - db.flush() - db.refresh(customer) - - logger.info(f"Customer updated: {customer.email} (ID: {customer.id})") - - return customer - - except Exception as e: - logger.error(f"Error updating customer: {str(e)}") - raise CustomerValidationException( - message="Failed to update customer", details={"error": str(e)} - ) - - def deactivate_customer( - self, db: Session, vendor_id: int, customer_id: int - ) -> Customer: - """ - Deactivate customer account. - - Args: - db: Database session - vendor_id: Vendor ID - customer_id: Customer ID - - Returns: - Customer: Deactivated customer object - - Raises: - CustomerNotFoundException: If customer not found - """ - customer = self.get_customer(db, vendor_id, customer_id) - customer.is_active = False - - db.flush() - db.refresh(customer) - - logger.info(f"Customer deactivated: {customer.email} (ID: {customer.id})") - - return customer - - def update_customer_stats( - self, db: Session, customer_id: int, order_total: float - ) -> None: - """ - Update customer statistics after order. - - Args: - db: Database session - customer_id: Customer ID - order_total: Order total amount - """ - customer = db.query(Customer).filter(Customer.id == customer_id).first() - - if customer: - customer.total_orders += 1 - customer.total_spent += order_total - customer.last_order_date = datetime.utcnow() - - logger.debug(f"Updated stats for customer {customer.email}") - - def _generate_customer_number( - self, db: Session, vendor_id: int, vendor_code: str - ) -> str: - """ - Generate unique customer number for vendor. - - Format: {VENDOR_CODE}-CUST-{SEQUENCE} - Example: VENDORA-CUST-00001 - - Args: - db: Database session - vendor_id: Vendor ID - vendor_code: Vendor code - - Returns: - str: Unique customer number - """ - # Get count of customers for this vendor - count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count() - - # Generate number with padding - sequence = str(count + 1).zfill(5) - customer_number = f"{vendor_code.upper()}-CUST-{sequence}" - - # Ensure uniqueness (in case of deletions) - while ( - db.query(Customer) - .filter( - and_( - Customer.vendor_id == vendor_id, - Customer.customer_number == customer_number, - ) - ) - .first() - ): - count += 1 - sequence = str(count + 1).zfill(5) - customer_number = f"{vendor_code.upper()}-CUST-{sequence}" - - return customer_number - - def get_customer_for_password_reset( - self, db: Session, vendor_id: int, email: str - ) -> Customer | None: - """ - Get active customer by email for password reset. - - Args: - db: Database session - vendor_id: Vendor ID - email: Customer email - - Returns: - Customer if found and active, None otherwise - """ - return ( - db.query(Customer) - .filter( - Customer.vendor_id == vendor_id, - Customer.email == email.lower(), - Customer.is_active == True, # noqa: E712 - ) - .first() - ) - - def validate_and_reset_password( - self, - db: Session, - vendor_id: int, - reset_token: str, - new_password: str, - ) -> Customer: - """ - Validate reset token and update customer password. - - Args: - db: Database session - vendor_id: Vendor ID - reset_token: Password reset token from email - new_password: New password - - Returns: - Customer: Updated customer - - Raises: - PasswordTooShortException: If password too short - InvalidPasswordResetTokenException: If token invalid/expired - CustomerNotActiveException: If customer not active - """ - # Validate password length - if len(new_password) < 8: - raise PasswordTooShortException(min_length=8) - - # Find valid token - token_record = PasswordResetToken.find_valid_token(db, reset_token) - - if not token_record: - raise InvalidPasswordResetTokenException() - - # Get the customer and verify they belong to this vendor - customer = ( - db.query(Customer) - .filter(Customer.id == token_record.customer_id) - .first() - ) - - if not customer or customer.vendor_id != vendor_id: - raise InvalidPasswordResetTokenException() - - if not customer.is_active: - raise CustomerNotActiveException(customer.email) - - # Hash the new password and update customer - hashed_password = self.auth_service.hash_password(new_password) - customer.hashed_password = hashed_password - - # Mark token as used - token_record.mark_used(db) - - logger.info(f"Password reset completed for customer {customer.id}") - - return customer - - -# Singleton instance -customer_service = CustomerService() +__all__ = [ + "customer_service", + "CustomerService", +] diff --git a/app/services/inventory_import_service.py b/app/services/inventory_import_service.py index b7d3cfa7..2e3ef6e3 100644 --- a/app/services/inventory_import_service.py +++ b/app/services/inventory_import_service.py @@ -1,250 +1,25 @@ # app/services/inventory_import_service.py """ -Inventory import service for bulk importing stock from TSV/CSV files. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Supports two formats: -1. One row per unit (quantity = count of rows): - BIN EAN PRODUCT - SA-10-02 0810050910101 Product Name - SA-10-02 0810050910101 Product Name (2nd unit) +The canonical implementation is now in: + app/modules/inventory/services/inventory_import_service.py -2. With explicit quantity column: - BIN EAN PRODUCT QUANTITY - SA-10-02 0810050910101 Product Name 12 +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: -Products are matched by GTIN/EAN to existing vendor products. + from app.modules.inventory.services import inventory_import_service """ -import csv -import io -import logging -from collections import defaultdict -from dataclasses import dataclass, field +from app.modules.inventory.services.inventory_import_service import ( + inventory_import_service, + InventoryImportService, + ImportResult, +) -from sqlalchemy.orm import Session - -from models.database.inventory import Inventory -from models.database.product import Product - -logger = logging.getLogger(__name__) - - -@dataclass -class ImportResult: - """Result of an inventory import operation.""" - - success: bool = True - total_rows: int = 0 - entries_created: int = 0 - entries_updated: int = 0 - quantity_imported: int = 0 - unmatched_gtins: list = field(default_factory=list) - errors: list = field(default_factory=list) - - -class InventoryImportService: - """Service for importing inventory from TSV/CSV files.""" - - def import_from_text( - self, - db: Session, - content: str, - vendor_id: int, - warehouse: str = "strassen", - delimiter: str = "\t", - clear_existing: bool = False, - ) -> ImportResult: - """ - Import inventory from TSV/CSV text content. - - Args: - db: Database session - content: TSV/CSV content as string - vendor_id: Vendor ID for inventory - warehouse: Warehouse name (default: "strassen") - delimiter: Column delimiter (default: tab) - clear_existing: If True, clear existing inventory before import - - Returns: - ImportResult with summary and errors - """ - result = ImportResult() - - try: - # Parse CSV/TSV - reader = csv.DictReader(io.StringIO(content), delimiter=delimiter) - - # Normalize headers (case-insensitive, strip whitespace) - if reader.fieldnames: - reader.fieldnames = [h.strip().upper() for h in reader.fieldnames] - - # Validate required columns - required = {"BIN", "EAN"} - if not reader.fieldnames or not required.issubset(set(reader.fieldnames)): - result.success = False - result.errors.append( - f"Missing required columns. Found: {reader.fieldnames}, Required: {required}" - ) - return result - - has_quantity = "QUANTITY" in reader.fieldnames - - # Group entries by (EAN, BIN) - # Key: (ean, bin) -> quantity - inventory_data: dict[tuple[str, str], int] = defaultdict(int) - product_names: dict[str, str] = {} # EAN -> product name (for logging) - - for row in reader: - result.total_rows += 1 - - ean = row.get("EAN", "").strip() - bin_loc = row.get("BIN", "").strip() - product_name = row.get("PRODUCT", "").strip() - - if not ean or not bin_loc: - result.errors.append(f"Row {result.total_rows}: Missing EAN or BIN") - continue - - # Get quantity - if has_quantity: - try: - qty = int(row.get("QUANTITY", "1").strip()) - except ValueError: - result.errors.append( - f"Row {result.total_rows}: Invalid quantity '{row.get('QUANTITY')}'" - ) - continue - else: - qty = 1 # Each row = 1 unit - - inventory_data[(ean, bin_loc)] += qty - if product_name: - product_names[ean] = product_name - - # Clear existing inventory if requested - if clear_existing: - db.query(Inventory).filter( - Inventory.vendor_id == vendor_id, - Inventory.warehouse == warehouse, - ).delete() - db.flush() - - # Build EAN to Product mapping for this vendor - products = ( - db.query(Product) - .filter( - Product.vendor_id == vendor_id, - Product.gtin.isnot(None), - ) - .all() - ) - ean_to_product: dict[str, Product] = {p.gtin: p for p in products if p.gtin} - - # Track unmatched GTINs - unmatched: dict[str, int] = {} # EAN -> total quantity - - # Process inventory entries - for (ean, bin_loc), quantity in inventory_data.items(): - product = ean_to_product.get(ean) - - if not product: - # Track unmatched - if ean not in unmatched: - unmatched[ean] = 0 - unmatched[ean] += quantity - continue - - # Upsert inventory entry - existing = ( - db.query(Inventory) - .filter( - Inventory.product_id == product.id, - Inventory.warehouse == warehouse, - Inventory.bin_location == bin_loc, - ) - .first() - ) - - if existing: - existing.quantity = quantity - existing.gtin = ean - result.entries_updated += 1 - else: - inv = Inventory( - product_id=product.id, - vendor_id=vendor_id, - warehouse=warehouse, - bin_location=bin_loc, - location=bin_loc, # Legacy field - quantity=quantity, - gtin=ean, - ) - db.add(inv) - result.entries_created += 1 - - result.quantity_imported += quantity - - db.flush() - - # Format unmatched GTINs for result - for ean, qty in unmatched.items(): - product_name = product_names.get(ean, "Unknown") - result.unmatched_gtins.append( - {"gtin": ean, "quantity": qty, "product_name": product_name} - ) - - if result.unmatched_gtins: - logger.warning( - f"Import had {len(result.unmatched_gtins)} unmatched GTINs" - ) - - except Exception as e: - logger.exception("Inventory import failed") - result.success = False - result.errors.append(str(e)) - - return result - - def import_from_file( - self, - db: Session, - file_path: str, - vendor_id: int, - warehouse: str = "strassen", - clear_existing: bool = False, - ) -> ImportResult: - """ - Import inventory from a TSV/CSV file. - - Args: - db: Database session - file_path: Path to TSV/CSV file - vendor_id: Vendor ID for inventory - warehouse: Warehouse name - clear_existing: If True, clear existing inventory before import - - Returns: - ImportResult with summary and errors - """ - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - except Exception as e: - return ImportResult(success=False, errors=[f"Failed to read file: {e}"]) - - # Detect delimiter - first_line = content.split("\n")[0] if content else "" - delimiter = "\t" if "\t" in first_line else "," - - return self.import_from_text( - db=db, - content=content, - vendor_id=vendor_id, - warehouse=warehouse, - delimiter=delimiter, - clear_existing=clear_existing, - ) - - -# Singleton instance -inventory_import_service = InventoryImportService() +__all__ = [ + "inventory_import_service", + "InventoryImportService", + "ImportResult", +] diff --git a/app/services/inventory_service.py b/app/services/inventory_service.py index e91cc675..ee59cf38 100644 --- a/app/services/inventory_service.py +++ b/app/services/inventory_service.py @@ -1,949 +1,23 @@ # app/services/inventory_service.py -import logging -from datetime import UTC, datetime +""" +LEGACY LOCATION - Re-exports from module for backwards compatibility. -from sqlalchemy import func -from sqlalchemy.orm import Session +The canonical implementation is now in: + app/modules/inventory/services/inventory_service.py -from app.exceptions import ( - InsufficientInventoryException, - InvalidQuantityException, - InventoryNotFoundException, - InventoryValidationException, - ProductNotFoundException, - ValidationException, - VendorNotFoundException, -) -from models.database.inventory import Inventory -from models.database.product import Product -from models.database.vendor import Vendor -from models.schema.inventory import ( - AdminInventoryItem, - AdminInventoryListResponse, - AdminInventoryLocationsResponse, - AdminInventoryStats, - AdminLowStockItem, - AdminVendorsWithInventoryResponse, - AdminVendorWithInventory, - InventoryAdjust, - InventoryCreate, - InventoryLocationResponse, - InventoryReserve, - InventoryUpdate, - ProductInventorySummary, +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.inventory.services import inventory_service +""" + +from app.modules.inventory.services.inventory_service import ( + inventory_service, + InventoryService, ) -logger = logging.getLogger(__name__) - - -class InventoryService: - """Service for inventory operations with vendor isolation.""" - - def set_inventory( - self, db: Session, vendor_id: int, inventory_data: InventoryCreate - ) -> Inventory: - """ - Set exact inventory quantity for a product at a location (replaces existing). - - Args: - db: Database session - vendor_id: Vendor ID (from middleware) - inventory_data: Inventory data - - Returns: - Inventory object - """ - try: - # Validate product belongs to vendor - product = self._get_vendor_product(db, vendor_id, inventory_data.product_id) - - # Validate location - location = self._validate_location(inventory_data.location) - - # Validate quantity - self._validate_quantity(inventory_data.quantity, allow_zero=True) - - # Check if inventory entry exists - existing = self._get_inventory_entry( - db, inventory_data.product_id, location - ) - - if existing: - old_qty = existing.quantity - existing.quantity = inventory_data.quantity - existing.updated_at = datetime.now(UTC) - db.flush() - db.refresh(existing) - - logger.info( - f"Set inventory for product {inventory_data.product_id} at {location}: " - f"{old_qty} → {inventory_data.quantity}" - ) - return existing - # Create new inventory entry - new_inventory = Inventory( - product_id=inventory_data.product_id, - vendor_id=vendor_id, - warehouse="strassen", # Default warehouse - bin_location=location, # Use location as bin location - location=location, # Keep for backward compatibility - quantity=inventory_data.quantity, - gtin=product.marketplace_product.gtin, # Optional reference - ) - db.add(new_inventory) - db.flush() - db.refresh(new_inventory) - - logger.info( - f"Created inventory for product {inventory_data.product_id} at {location}: " - f"{inventory_data.quantity}" - ) - return new_inventory - - except ( - ProductNotFoundException, - InvalidQuantityException, - InventoryValidationException, - ): - db.rollback() - raise - except Exception as e: - db.rollback() - logger.error(f"Error setting inventory: {str(e)}") - raise ValidationException("Failed to set inventory") - - def adjust_inventory( - self, db: Session, vendor_id: int, inventory_data: InventoryAdjust - ) -> Inventory: - """ - Adjust inventory by adding or removing quantity. - Positive quantity = add, negative = remove. - - Args: - db: Database session - vendor_id: Vendor ID - inventory_data: Adjustment data - - Returns: - Updated Inventory object - """ - try: - # Validate product belongs to vendor - product = self._get_vendor_product(db, vendor_id, inventory_data.product_id) - - # Validate location - location = self._validate_location(inventory_data.location) - - # Check if inventory exists - existing = self._get_inventory_entry( - db, inventory_data.product_id, location - ) - - if not existing: - # Create new if adding, error if removing - if inventory_data.quantity < 0: - raise InventoryNotFoundException( - f"No inventory found for product {inventory_data.product_id} at {location}" - ) - - # Create with positive quantity - new_inventory = Inventory( - product_id=inventory_data.product_id, - vendor_id=vendor_id, - warehouse="strassen", # Default warehouse - bin_location=location, # Use location as bin location - location=location, # Keep for backward compatibility - quantity=inventory_data.quantity, - gtin=product.marketplace_product.gtin, - ) - db.add(new_inventory) - db.flush() - db.refresh(new_inventory) - - logger.info( - f"Created inventory for product {inventory_data.product_id} at {location}: " - f"+{inventory_data.quantity}" - ) - return new_inventory - - # Adjust existing inventory - old_qty = existing.quantity - new_qty = old_qty + inventory_data.quantity - - # Validate resulting quantity - if new_qty < 0: - raise InsufficientInventoryException( - f"Insufficient inventory. Available: {old_qty}, " - f"Requested removal: {abs(inventory_data.quantity)}" - ) - - existing.quantity = new_qty - existing.updated_at = datetime.now(UTC) - db.flush() - db.refresh(existing) - - logger.info( - f"Adjusted inventory for product {inventory_data.product_id} at {location}: " - f"{old_qty} {'+' if inventory_data.quantity >= 0 else ''}{inventory_data.quantity} = {new_qty}" - ) - return existing - - except ( - ProductNotFoundException, - InventoryNotFoundException, - InsufficientInventoryException, - InventoryValidationException, - ): - db.rollback() - raise - except Exception as e: - db.rollback() - logger.error(f"Error adjusting inventory: {str(e)}") - raise ValidationException("Failed to adjust inventory") - - def reserve_inventory( - self, db: Session, vendor_id: int, reserve_data: InventoryReserve - ) -> Inventory: - """ - Reserve inventory for an order (increases reserved_quantity). - - Args: - db: Database session - vendor_id: Vendor ID - reserve_data: Reservation data - - Returns: - Updated Inventory object - """ - try: - # Validate product - product = self._get_vendor_product(db, vendor_id, reserve_data.product_id) - - # Validate location and quantity - location = self._validate_location(reserve_data.location) - self._validate_quantity(reserve_data.quantity, allow_zero=False) - - # Get inventory entry - inventory = self._get_inventory_entry(db, reserve_data.product_id, location) - if not inventory: - raise InventoryNotFoundException( - f"No inventory found for product {reserve_data.product_id} at {location}" - ) - - # Check available quantity - available = inventory.quantity - inventory.reserved_quantity - if available < reserve_data.quantity: - raise InsufficientInventoryException( - f"Insufficient available inventory. Available: {available}, " - f"Requested: {reserve_data.quantity}" - ) - - # Reserve inventory - inventory.reserved_quantity += reserve_data.quantity - inventory.updated_at = datetime.now(UTC) - db.flush() - db.refresh(inventory) - - logger.info( - f"Reserved {reserve_data.quantity} units for product {reserve_data.product_id} " - f"at {location}" - ) - return inventory - - except ( - ProductNotFoundException, - InventoryNotFoundException, - InsufficientInventoryException, - InvalidQuantityException, - ): - db.rollback() - raise - except Exception as e: - db.rollback() - logger.error(f"Error reserving inventory: {str(e)}") - raise ValidationException("Failed to reserve inventory") - - def release_reservation( - self, db: Session, vendor_id: int, reserve_data: InventoryReserve - ) -> Inventory: - """ - Release reserved inventory (decreases reserved_quantity). - - Args: - db: Database session - vendor_id: Vendor ID - reserve_data: Reservation data - - Returns: - Updated Inventory object - """ - try: - # Validate product - product = self._get_vendor_product(db, vendor_id, reserve_data.product_id) - - location = self._validate_location(reserve_data.location) - self._validate_quantity(reserve_data.quantity, allow_zero=False) - - inventory = self._get_inventory_entry(db, reserve_data.product_id, location) - if not inventory: - raise InventoryNotFoundException( - f"No inventory found for product {reserve_data.product_id} at {location}" - ) - - # Validate reserved quantity - if inventory.reserved_quantity < reserve_data.quantity: - logger.warning( - f"Attempting to release more than reserved. Reserved: {inventory.reserved_quantity}, " - f"Requested: {reserve_data.quantity}" - ) - inventory.reserved_quantity = 0 - else: - inventory.reserved_quantity -= reserve_data.quantity - - inventory.updated_at = datetime.now(UTC) - db.flush() - db.refresh(inventory) - - logger.info( - f"Released {reserve_data.quantity} units for product {reserve_data.product_id} " - f"at {location}" - ) - return inventory - - except ( - ProductNotFoundException, - InventoryNotFoundException, - InvalidQuantityException, - ): - db.rollback() - raise - except Exception as e: - db.rollback() - logger.error(f"Error releasing reservation: {str(e)}") - raise ValidationException("Failed to release reservation") - - def fulfill_reservation( - self, db: Session, vendor_id: int, reserve_data: InventoryReserve - ) -> Inventory: - """ - Fulfill a reservation (decreases both quantity and reserved_quantity). - Use when order is shipped/completed. - - Args: - db: Database session - vendor_id: Vendor ID - reserve_data: Reservation data - - Returns: - Updated Inventory object - """ - try: - product = self._get_vendor_product(db, vendor_id, reserve_data.product_id) - location = self._validate_location(reserve_data.location) - self._validate_quantity(reserve_data.quantity, allow_zero=False) - - inventory = self._get_inventory_entry(db, reserve_data.product_id, location) - if not inventory: - raise InventoryNotFoundException( - f"No inventory found for product {reserve_data.product_id} at {location}" - ) - - # Validate quantities - if inventory.quantity < reserve_data.quantity: - raise InsufficientInventoryException( - f"Insufficient inventory. Available: {inventory.quantity}, " - f"Requested: {reserve_data.quantity}" - ) - - if inventory.reserved_quantity < reserve_data.quantity: - logger.warning( - f"Fulfilling more than reserved. Reserved: {inventory.reserved_quantity}, " - f"Fulfilling: {reserve_data.quantity}" - ) - - # Fulfill (remove from both quantity and reserved) - inventory.quantity -= reserve_data.quantity - inventory.reserved_quantity = max( - 0, inventory.reserved_quantity - reserve_data.quantity - ) - inventory.updated_at = datetime.now(UTC) - db.flush() - db.refresh(inventory) - - logger.info( - f"Fulfilled {reserve_data.quantity} units for product {reserve_data.product_id} " - f"at {location}" - ) - return inventory - - except ( - ProductNotFoundException, - InventoryNotFoundException, - InsufficientInventoryException, - InvalidQuantityException, - ): - db.rollback() - raise - except Exception as e: - db.rollback() - logger.error(f"Error fulfilling reservation: {str(e)}") - raise ValidationException("Failed to fulfill reservation") - - def get_product_inventory( - self, db: Session, vendor_id: int, product_id: int - ) -> ProductInventorySummary: - """ - Get inventory summary for a product across all locations. - - Args: - db: Database session - vendor_id: Vendor ID - product_id: Product ID - - Returns: - ProductInventorySummary - """ - try: - product = self._get_vendor_product(db, vendor_id, product_id) - - inventory_entries = ( - db.query(Inventory).filter(Inventory.product_id == product_id).all() - ) - - if not inventory_entries: - return ProductInventorySummary( - product_id=product_id, - vendor_id=vendor_id, - product_sku=product.vendor_sku, - product_title=product.marketplace_product.get_title() or "", - total_quantity=0, - total_reserved=0, - total_available=0, - locations=[], - ) - - total_qty = sum(inv.quantity for inv in inventory_entries) - total_reserved = sum(inv.reserved_quantity for inv in inventory_entries) - total_available = sum(inv.available_quantity for inv in inventory_entries) - - locations = [ - InventoryLocationResponse( - location=inv.location, - quantity=inv.quantity, - reserved_quantity=inv.reserved_quantity, - available_quantity=inv.available_quantity, - ) - for inv in inventory_entries - ] - - return ProductInventorySummary( - product_id=product_id, - vendor_id=vendor_id, - product_sku=product.vendor_sku, - product_title=product.marketplace_product.get_title() or "", - total_quantity=total_qty, - total_reserved=total_reserved, - total_available=total_available, - locations=locations, - ) - - except ProductNotFoundException: - raise - except Exception as e: - logger.error(f"Error getting product inventory: {str(e)}") - raise ValidationException("Failed to retrieve product inventory") - - def get_vendor_inventory( - self, - db: Session, - vendor_id: int, - skip: int = 0, - limit: int = 100, - location: str | None = None, - low_stock_threshold: int | None = None, - ) -> list[Inventory]: - """ - Get all inventory for a vendor with filtering. - - Args: - db: Database session - vendor_id: Vendor ID - skip: Pagination offset - limit: Pagination limit - location: Filter by location - low_stock_threshold: Filter items below threshold - - Returns: - List of Inventory objects - """ - try: - query = db.query(Inventory).filter(Inventory.vendor_id == vendor_id) - - if location: - query = query.filter(Inventory.location.ilike(f"%{location}%")) - - if low_stock_threshold is not None: - query = query.filter(Inventory.quantity <= low_stock_threshold) - - return query.offset(skip).limit(limit).all() - - except Exception as e: - logger.error(f"Error getting vendor inventory: {str(e)}") - raise ValidationException("Failed to retrieve vendor inventory") - - def update_inventory( - self, - db: Session, - vendor_id: int, - inventory_id: int, - inventory_update: InventoryUpdate, - ) -> Inventory: - """Update inventory entry.""" - try: - inventory = self._get_inventory_by_id(db, inventory_id) - - # Verify ownership - if inventory.vendor_id != vendor_id: - raise InventoryNotFoundException(f"Inventory {inventory_id} not found") - - # Update fields - if inventory_update.quantity is not None: - self._validate_quantity(inventory_update.quantity, allow_zero=True) - inventory.quantity = inventory_update.quantity - - if inventory_update.reserved_quantity is not None: - self._validate_quantity( - inventory_update.reserved_quantity, allow_zero=True - ) - inventory.reserved_quantity = inventory_update.reserved_quantity - - if inventory_update.location: - inventory.location = self._validate_location(inventory_update.location) - - inventory.updated_at = datetime.now(UTC) - db.flush() - db.refresh(inventory) - - logger.info(f"Updated inventory {inventory_id}") - return inventory - - except ( - InventoryNotFoundException, - InvalidQuantityException, - InventoryValidationException, - ): - db.rollback() - raise - except Exception as e: - db.rollback() - logger.error(f"Error updating inventory: {str(e)}") - raise ValidationException("Failed to update inventory") - - def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool: - """Delete inventory entry.""" - try: - inventory = self._get_inventory_by_id(db, inventory_id) - - # Verify ownership - if inventory.vendor_id != vendor_id: - raise InventoryNotFoundException(f"Inventory {inventory_id} not found") - - db.delete(inventory) - db.flush() - - logger.info(f"Deleted inventory {inventory_id}") - return True - - except InventoryNotFoundException: - raise - except Exception as e: - db.rollback() - logger.error(f"Error deleting inventory: {str(e)}") - raise ValidationException("Failed to delete inventory") - - # ========================================================================= - # Admin Methods (cross-vendor operations) - # ========================================================================= - - def get_all_inventory_admin( - self, - db: Session, - skip: int = 0, - limit: int = 50, - vendor_id: int | None = None, - location: str | None = None, - low_stock: int | None = None, - search: str | None = None, - ) -> AdminInventoryListResponse: - """ - Get inventory across all vendors with filtering (admin only). - - Args: - db: Database session - skip: Pagination offset - limit: Pagination limit - vendor_id: Filter by vendor - location: Filter by location - low_stock: Filter items below threshold - search: Search by product title or SKU - - Returns: - AdminInventoryListResponse - """ - query = db.query(Inventory).join(Product).join(Vendor) - - # Apply filters - if vendor_id is not None: - query = query.filter(Inventory.vendor_id == vendor_id) - - if location: - query = query.filter(Inventory.location.ilike(f"%{location}%")) - - if low_stock is not None: - query = query.filter(Inventory.quantity <= low_stock) - - if search: - from models.database.marketplace_product import MarketplaceProduct - from models.database.marketplace_product_translation import ( - MarketplaceProductTranslation, - ) - - query = ( - query.join(MarketplaceProduct) - .outerjoin(MarketplaceProductTranslation) - .filter( - (MarketplaceProductTranslation.title.ilike(f"%{search}%")) - | (Product.vendor_sku.ilike(f"%{search}%")) - ) - ) - - # Get total count before pagination - total = query.count() - - # Apply pagination - inventories = query.offset(skip).limit(limit).all() - - # Build response with vendor/product info - items = [] - for inv in inventories: - product = inv.product - vendor = inv.vendor - title = None - if product and product.marketplace_product: - title = product.marketplace_product.get_title() - - items.append( - AdminInventoryItem( - id=inv.id, - product_id=inv.product_id, - vendor_id=inv.vendor_id, - vendor_name=vendor.name if vendor else None, - vendor_code=vendor.vendor_code if vendor else None, - product_title=title, - product_sku=product.vendor_sku if product else None, - location=inv.location, - quantity=inv.quantity, - reserved_quantity=inv.reserved_quantity, - available_quantity=inv.available_quantity, - gtin=inv.gtin, - created_at=inv.created_at, - updated_at=inv.updated_at, - ) - ) - - return AdminInventoryListResponse( - inventories=items, - total=total, - skip=skip, - limit=limit, - vendor_filter=vendor_id, - location_filter=location, - ) - - def get_inventory_stats_admin(self, db: Session) -> AdminInventoryStats: - """Get platform-wide inventory statistics (admin only).""" - # Total entries - total_entries = db.query(func.count(Inventory.id)).scalar() or 0 - - # Aggregate quantities - totals = db.query( - func.sum(Inventory.quantity).label("total_qty"), - func.sum(Inventory.reserved_quantity).label("total_reserved"), - ).first() - - total_quantity = totals.total_qty or 0 - total_reserved = totals.total_reserved or 0 - total_available = total_quantity - total_reserved - - # Low stock count (default threshold: 10) - low_stock_count = ( - db.query(func.count(Inventory.id)) - .filter(Inventory.quantity <= 10) - .scalar() - or 0 - ) - - # Vendors with inventory - vendors_with_inventory = ( - db.query(func.count(func.distinct(Inventory.vendor_id))).scalar() or 0 - ) - - # Unique locations - unique_locations = ( - db.query(func.count(func.distinct(Inventory.location))).scalar() or 0 - ) - - return AdminInventoryStats( - total_entries=total_entries, - total_quantity=total_quantity, - total_reserved=total_reserved, - total_available=total_available, - low_stock_count=low_stock_count, - vendors_with_inventory=vendors_with_inventory, - unique_locations=unique_locations, - ) - - def get_low_stock_items_admin( - self, - db: Session, - threshold: int = 10, - vendor_id: int | None = None, - limit: int = 50, - ) -> list[AdminLowStockItem]: - """Get items with low stock levels (admin only).""" - query = ( - db.query(Inventory) - .join(Product) - .join(Vendor) - .filter(Inventory.quantity <= threshold) - ) - - if vendor_id is not None: - query = query.filter(Inventory.vendor_id == vendor_id) - - # Order by quantity ascending (most critical first) - query = query.order_by(Inventory.quantity.asc()) - - inventories = query.limit(limit).all() - - items = [] - for inv in inventories: - product = inv.product - vendor = inv.vendor - title = None - if product and product.marketplace_product: - title = product.marketplace_product.get_title() - - items.append( - AdminLowStockItem( - id=inv.id, - product_id=inv.product_id, - vendor_id=inv.vendor_id, - vendor_name=vendor.name if vendor else None, - product_title=title, - location=inv.location, - quantity=inv.quantity, - reserved_quantity=inv.reserved_quantity, - available_quantity=inv.available_quantity, - ) - ) - - return items - - def get_vendors_with_inventory_admin( - self, db: Session - ) -> AdminVendorsWithInventoryResponse: - """Get list of vendors that have inventory entries (admin only).""" - # noqa: SVC-005 - Admin function, intentionally cross-vendor - # Use subquery to avoid DISTINCT on JSON columns (PostgreSQL can't compare JSON) - vendor_ids_subquery = ( - db.query(Inventory.vendor_id) - .distinct() - .subquery() - ) - vendors = ( - db.query(Vendor) - .filter(Vendor.id.in_(db.query(vendor_ids_subquery.c.vendor_id))) - .order_by(Vendor.name) - .all() - ) - - return AdminVendorsWithInventoryResponse( - vendors=[ - AdminVendorWithInventory( - id=v.id, name=v.name, vendor_code=v.vendor_code - ) - for v in vendors - ] - ) - - def get_inventory_locations_admin( - self, db: Session, vendor_id: int | None = None - ) -> AdminInventoryLocationsResponse: - """Get list of unique inventory locations (admin only).""" - query = db.query(func.distinct(Inventory.location)) - - if vendor_id is not None: - query = query.filter(Inventory.vendor_id == vendor_id) - - locations = [loc[0] for loc in query.all()] - - return AdminInventoryLocationsResponse(locations=sorted(locations)) - - def get_vendor_inventory_admin( - self, - db: Session, - vendor_id: int, - skip: int = 0, - limit: int = 50, - location: str | None = None, - low_stock: int | None = None, - ) -> AdminInventoryListResponse: - """Get inventory for a specific vendor (admin only).""" - # Verify vendor exists - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - if not vendor: - raise VendorNotFoundException(f"Vendor {vendor_id} not found") - - # Use the existing method - inventories = self.get_vendor_inventory( - db=db, - vendor_id=vendor_id, - skip=skip, - limit=limit, - location=location, - low_stock_threshold=low_stock, - ) - - # Build response with product info - items = [] - for inv in inventories: - product = inv.product - title = None - if product and product.marketplace_product: - title = product.marketplace_product.get_title() - - items.append( - AdminInventoryItem( - id=inv.id, - product_id=inv.product_id, - vendor_id=inv.vendor_id, - vendor_name=vendor.name, - vendor_code=vendor.vendor_code, - product_title=title, - product_sku=product.vendor_sku if product else None, - location=inv.location, - quantity=inv.quantity, - reserved_quantity=inv.reserved_quantity, - available_quantity=inv.available_quantity, - gtin=inv.gtin, - created_at=inv.created_at, - updated_at=inv.updated_at, - ) - ) - - # Get total count for pagination - total_query = db.query(func.count(Inventory.id)).filter( - Inventory.vendor_id == vendor_id - ) - if location: - total_query = total_query.filter(Inventory.location.ilike(f"%{location}%")) - if low_stock is not None: - total_query = total_query.filter(Inventory.quantity <= low_stock) - total = total_query.scalar() or 0 - - return AdminInventoryListResponse( - inventories=items, - total=total, - skip=skip, - limit=limit, - vendor_filter=vendor_id, - location_filter=location, - ) - - def get_product_inventory_admin( - self, db: Session, product_id: int - ) -> ProductInventorySummary: - """Get inventory summary for a product (admin only - no vendor check).""" - product = db.query(Product).filter(Product.id == product_id).first() - if not product: - raise ProductNotFoundException(f"Product {product_id} not found") - - # Use existing method with the product's vendor_id - return self.get_product_inventory(db, product.vendor_id, product_id) - - def verify_vendor_exists(self, db: Session, vendor_id: int) -> Vendor: - """Verify vendor exists and return it.""" - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - if not vendor: - raise VendorNotFoundException(f"Vendor {vendor_id} not found") - return vendor - - def get_inventory_by_id_admin(self, db: Session, inventory_id: int) -> Inventory: - """Get inventory by ID (admin only - returns inventory with vendor_id).""" - inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first() - if not inventory: - raise InventoryNotFoundException(f"Inventory {inventory_id} not found") - return inventory - - # ========================================================================= - # Private helper methods - # ========================================================================= - - def _get_vendor_product( - self, db: Session, vendor_id: int, product_id: int - ) -> Product: - """Get product and verify it belongs to vendor.""" - product = ( - db.query(Product) - .filter(Product.id == product_id, Product.vendor_id == vendor_id) - .first() - ) - - if not product: - raise ProductNotFoundException( - f"Product {product_id} not found in your catalog" - ) - - return product - - def _get_inventory_entry( - self, db: Session, product_id: int, location: str - ) -> Inventory | None: - """Get inventory entry by product and location.""" - return ( - db.query(Inventory) - .filter(Inventory.product_id == product_id, Inventory.location == location) - .first() - ) - - def _get_inventory_by_id(self, db: Session, inventory_id: int) -> Inventory: - """Get inventory by ID or raise exception.""" - inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first() - if not inventory: - raise InventoryNotFoundException(f"Inventory {inventory_id} not found") - return inventory - - def _validate_location(self, location: str) -> str: - """Validate and normalize location.""" - if not location or not location.strip(): - raise InventoryValidationException("Location is required") - return location.strip().upper() - - def _validate_quantity(self, quantity: int, allow_zero: bool = True) -> None: - """Validate quantity value.""" - if quantity is None: - raise InvalidQuantityException("Quantity is required") - - if not isinstance(quantity, int): - raise InvalidQuantityException("Quantity must be an integer") - - if quantity < 0: - raise InvalidQuantityException("Quantity cannot be negative") - - if not allow_zero and quantity == 0: - raise InvalidQuantityException("Quantity must be positive") - - -# Create service instance -inventory_service = InventoryService() +__all__ = [ + "inventory_service", + "InventoryService", +] diff --git a/app/services/inventory_transaction_service.py b/app/services/inventory_transaction_service.py index fe489fb4..e95a9faa 100644 --- a/app/services/inventory_transaction_service.py +++ b/app/services/inventory_transaction_service.py @@ -1,431 +1,23 @@ # app/services/inventory_transaction_service.py """ -Inventory Transaction Service. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Provides query operations for inventory transaction history. -All transaction WRITES are handled by OrderInventoryService. -This service handles transaction READS for reporting and auditing. +The canonical implementation is now in: + app/modules/inventory/services/inventory_transaction_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.inventory.services import inventory_transaction_service """ -import logging -from sqlalchemy import func -from sqlalchemy.orm import Session +from app.modules.inventory.services.inventory_transaction_service import ( + inventory_transaction_service, + InventoryTransactionService, +) -from app.exceptions import OrderNotFoundException, ProductNotFoundException -from models.database.inventory import Inventory -from models.database.inventory_transaction import InventoryTransaction -from models.database.order import Order -from models.database.product import Product - -logger = logging.getLogger(__name__) - - -class InventoryTransactionService: - """Service for querying inventory transaction history.""" - - def get_vendor_transactions( - self, - db: Session, - vendor_id: int, - skip: int = 0, - limit: int = 50, - product_id: int | None = None, - transaction_type: str | None = None, - ) -> tuple[list[dict], int]: - """ - Get inventory transactions for a vendor with optional filters. - - Args: - db: Database session - vendor_id: Vendor ID - skip: Pagination offset - limit: Pagination limit - product_id: Optional product filter - transaction_type: Optional transaction type filter - - Returns: - Tuple of (transactions with product details, total count) - """ - # Build query - query = db.query(InventoryTransaction).filter( - InventoryTransaction.vendor_id == vendor_id - ) - - # Apply filters - if product_id: - query = query.filter(InventoryTransaction.product_id == product_id) - if transaction_type: - query = query.filter( - InventoryTransaction.transaction_type == transaction_type - ) - - # Get total count - total = query.count() - - # Get transactions with pagination (newest first) - transactions = ( - query.order_by(InventoryTransaction.created_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) - - # Build result with product details - result = [] - for tx in transactions: - product = db.query(Product).filter(Product.id == tx.product_id).first() - product_title = None - product_sku = None - if product: - product_sku = product.vendor_sku - if product.marketplace_product: - product_title = product.marketplace_product.get_title() - - result.append( - { - "id": tx.id, - "vendor_id": tx.vendor_id, - "product_id": tx.product_id, - "inventory_id": tx.inventory_id, - "transaction_type": ( - tx.transaction_type.value if tx.transaction_type else None - ), - "quantity_change": tx.quantity_change, - "quantity_after": tx.quantity_after, - "reserved_after": tx.reserved_after, - "location": tx.location, - "warehouse": tx.warehouse, - "order_id": tx.order_id, - "order_number": tx.order_number, - "reason": tx.reason, - "created_by": tx.created_by, - "created_at": tx.created_at, - "product_title": product_title, - "product_sku": product_sku, - } - ) - - return result, total - - def get_product_history( - self, - db: Session, - vendor_id: int, - product_id: int, - limit: int = 50, - ) -> dict: - """ - Get transaction history for a specific product. - - Args: - db: Database session - vendor_id: Vendor ID - product_id: Product ID - limit: Max transactions to return - - Returns: - Dict with product info, current inventory, and transactions - - Raises: - ProductNotFoundException: If product not found or doesn't belong to vendor - """ - # Get product details - product = ( - db.query(Product) - .filter(Product.id == product_id, Product.vendor_id == vendor_id) - .first() - ) - - if not product: - raise ProductNotFoundException( - f"Product {product_id} not found in vendor catalog" - ) - - product_title = None - product_sku = product.vendor_sku - if product.marketplace_product: - product_title = product.marketplace_product.get_title() - - # Get current inventory - inventory = ( - db.query(Inventory) - .filter(Inventory.product_id == product_id, Inventory.vendor_id == vendor_id) - .first() - ) - - current_quantity = inventory.quantity if inventory else 0 - current_reserved = inventory.reserved_quantity if inventory else 0 - - # Get transactions - transactions = ( - db.query(InventoryTransaction) - .filter( - InventoryTransaction.vendor_id == vendor_id, - InventoryTransaction.product_id == product_id, - ) - .order_by(InventoryTransaction.created_at.desc()) - .limit(limit) - .all() - ) - - total = ( - db.query(func.count(InventoryTransaction.id)) - .filter( - InventoryTransaction.vendor_id == vendor_id, - InventoryTransaction.product_id == product_id, - ) - .scalar() - or 0 - ) - - return { - "product_id": product_id, - "product_title": product_title, - "product_sku": product_sku, - "current_quantity": current_quantity, - "current_reserved": current_reserved, - "transactions": [ - { - "id": tx.id, - "vendor_id": tx.vendor_id, - "product_id": tx.product_id, - "inventory_id": tx.inventory_id, - "transaction_type": ( - tx.transaction_type.value if tx.transaction_type else None - ), - "quantity_change": tx.quantity_change, - "quantity_after": tx.quantity_after, - "reserved_after": tx.reserved_after, - "location": tx.location, - "warehouse": tx.warehouse, - "order_id": tx.order_id, - "order_number": tx.order_number, - "reason": tx.reason, - "created_by": tx.created_by, - "created_at": tx.created_at, - } - for tx in transactions - ], - "total": total, - } - - def get_order_history( - self, - db: Session, - vendor_id: int, - order_id: int, - ) -> dict: - """ - Get all inventory transactions for a specific order. - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - - Returns: - Dict with order info and transactions - - Raises: - OrderNotFoundException: If order not found or doesn't belong to vendor - """ - # Verify order belongs to vendor - order = ( - db.query(Order) - .filter(Order.id == order_id, Order.vendor_id == vendor_id) - .first() - ) - - if not order: - raise OrderNotFoundException(f"Order {order_id} not found") - - # Get transactions for this order - transactions = ( - db.query(InventoryTransaction) - .filter(InventoryTransaction.order_id == order_id) - .order_by(InventoryTransaction.created_at.asc()) - .all() - ) - - # Build result with product details - result = [] - for tx in transactions: - product = db.query(Product).filter(Product.id == tx.product_id).first() - product_title = None - product_sku = None - if product: - product_sku = product.vendor_sku - if product.marketplace_product: - product_title = product.marketplace_product.get_title() - - result.append( - { - "id": tx.id, - "vendor_id": tx.vendor_id, - "product_id": tx.product_id, - "inventory_id": tx.inventory_id, - "transaction_type": ( - tx.transaction_type.value if tx.transaction_type else None - ), - "quantity_change": tx.quantity_change, - "quantity_after": tx.quantity_after, - "reserved_after": tx.reserved_after, - "location": tx.location, - "warehouse": tx.warehouse, - "order_id": tx.order_id, - "order_number": tx.order_number, - "reason": tx.reason, - "created_by": tx.created_by, - "created_at": tx.created_at, - "product_title": product_title, - "product_sku": product_sku, - } - ) - - return { - "order_id": order_id, - "order_number": order.order_number, - "transactions": result, - } - - - # ========================================================================= - # Admin Methods (cross-vendor operations) - # ========================================================================= - - def get_all_transactions_admin( - self, - db: Session, - skip: int = 0, - limit: int = 50, - vendor_id: int | None = None, - product_id: int | None = None, - transaction_type: str | None = None, - order_id: int | None = None, - ) -> tuple[list[dict], int]: - """ - Get inventory transactions across all vendors (admin only). - - Args: - db: Database session - skip: Pagination offset - limit: Pagination limit - vendor_id: Optional vendor filter - product_id: Optional product filter - transaction_type: Optional transaction type filter - order_id: Optional order filter - - Returns: - Tuple of (transactions with details, total count) - """ - from models.database.vendor import Vendor - - # Build query - query = db.query(InventoryTransaction) - - # Apply filters - if vendor_id: - query = query.filter(InventoryTransaction.vendor_id == vendor_id) - if product_id: - query = query.filter(InventoryTransaction.product_id == product_id) - if transaction_type: - query = query.filter( - InventoryTransaction.transaction_type == transaction_type - ) - if order_id: - query = query.filter(InventoryTransaction.order_id == order_id) - - # Get total count - total = query.count() - - # Get transactions with pagination (newest first) - transactions = ( - query.order_by(InventoryTransaction.created_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) - - # Build result with vendor and product details - result = [] - for tx in transactions: - vendor = db.query(Vendor).filter(Vendor.id == tx.vendor_id).first() - product = db.query(Product).filter(Product.id == tx.product_id).first() - - product_title = None - product_sku = None - if product: - product_sku = product.vendor_sku - if product.marketplace_product: - product_title = product.marketplace_product.get_title() - - result.append( - { - "id": tx.id, - "vendor_id": tx.vendor_id, - "vendor_name": vendor.name if vendor else None, - "vendor_code": vendor.vendor_code if vendor else None, - "product_id": tx.product_id, - "inventory_id": tx.inventory_id, - "transaction_type": ( - tx.transaction_type.value if tx.transaction_type else None - ), - "quantity_change": tx.quantity_change, - "quantity_after": tx.quantity_after, - "reserved_after": tx.reserved_after, - "location": tx.location, - "warehouse": tx.warehouse, - "order_id": tx.order_id, - "order_number": tx.order_number, - "reason": tx.reason, - "created_by": tx.created_by, - "created_at": tx.created_at, - "product_title": product_title, - "product_sku": product_sku, - } - ) - - return result, total - - def get_transaction_stats_admin(self, db: Session) -> dict: - """ - Get transaction statistics across the platform (admin only). - - Returns: - Dict with transaction counts by type - """ - from sqlalchemy import func as sql_func - - # Count by transaction type - type_counts = ( - db.query( - InventoryTransaction.transaction_type, - sql_func.count(InventoryTransaction.id).label("count"), - ) - .group_by(InventoryTransaction.transaction_type) - .all() - ) - - # Total transactions - total = db.query(sql_func.count(InventoryTransaction.id)).scalar() or 0 - - # Transactions today - from datetime import UTC, datetime, timedelta - - today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0) - today_count = ( - db.query(sql_func.count(InventoryTransaction.id)) - .filter(InventoryTransaction.created_at >= today_start) - .scalar() - or 0 - ) - - return { - "total_transactions": total, - "transactions_today": today_count, - "by_type": {tc.transaction_type.value: tc.count for tc in type_counts}, - } - - -# Create service instance -inventory_transaction_service = InventoryTransactionService() +__all__ = [ + "inventory_transaction_service", + "InventoryTransactionService", +] diff --git a/app/services/invoice_pdf_service.py b/app/services/invoice_pdf_service.py index e9bff6d3..c259eb0d 100644 --- a/app/services/invoice_pdf_service.py +++ b/app/services/invoice_pdf_service.py @@ -1,164 +1,23 @@ # app/services/invoice_pdf_service.py """ -Invoice PDF generation service using WeasyPrint. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint. -Stores generated PDFs in the configured storage location. +The canonical implementation is now in: + app/modules/orders/services/invoice_pdf_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.orders.services import invoice_pdf_service """ -import logging -import os -from datetime import UTC, datetime -from pathlib import Path +from app.modules.orders.services.invoice_pdf_service import ( + invoice_pdf_service, + InvoicePDFService, +) -from jinja2 import Environment, FileSystemLoader -from sqlalchemy.orm import Session - -from app.core.config import settings -from models.database.invoice import Invoice - -logger = logging.getLogger(__name__) - -# Template directory -TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices" - -# PDF storage directory (relative to project root) -PDF_STORAGE_DIR = Path("storage") / "invoices" - - -class InvoicePDFService: - """Service for generating invoice PDFs.""" - - def __init__(self): - """Initialize the PDF service with Jinja2 environment.""" - self.env = Environment( - loader=FileSystemLoader(str(TEMPLATE_DIR)), - autoescape=True, - ) - - def _ensure_storage_dir(self, vendor_id: int) -> Path: - """Ensure the storage directory exists for a vendor.""" - storage_path = PDF_STORAGE_DIR / str(vendor_id) - storage_path.mkdir(parents=True, exist_ok=True) - return storage_path - - def _get_pdf_filename(self, invoice: Invoice) -> str: - """Generate PDF filename for an invoice.""" - # Sanitize invoice number for filename - safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-") - return f"{safe_number}.pdf" - - def generate_pdf( - self, - db: Session, - invoice: Invoice, - force_regenerate: bool = False, - ) -> str: - """ - Generate PDF for an invoice. - - Args: - db: Database session - invoice: Invoice to generate PDF for - force_regenerate: If True, regenerate even if PDF already exists - - Returns: - Path to the generated PDF file - """ - # Check if PDF already exists - if invoice.pdf_path and not force_regenerate: - if Path(invoice.pdf_path).exists(): - logger.debug(f"PDF already exists for invoice {invoice.invoice_number}") - return invoice.pdf_path - - # Ensure storage directory exists - storage_dir = self._ensure_storage_dir(invoice.vendor_id) - pdf_filename = self._get_pdf_filename(invoice) - pdf_path = storage_dir / pdf_filename - - # Render HTML template - html_content = self._render_html(invoice) - - # Generate PDF using WeasyPrint - try: - from weasyprint import HTML - - html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR)) - html_doc.write_pdf(str(pdf_path)) - - logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}") - except ImportError: - logger.error("WeasyPrint not installed. Install with: pip install weasyprint") - raise RuntimeError("WeasyPrint not installed") - except Exception as e: - logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}") - raise - - # Update invoice record with PDF path and timestamp - invoice.pdf_path = str(pdf_path) - invoice.pdf_generated_at = datetime.now(UTC) - db.flush() - - return str(pdf_path) - - def _render_html(self, invoice: Invoice) -> str: - """Render the invoice HTML template.""" - template = self.env.get_template("invoice.html") - - # Prepare template context - context = { - "invoice": invoice, - "seller": invoice.seller_details, - "buyer": invoice.buyer_details, - "line_items": invoice.line_items, - "bank_details": invoice.bank_details, - "payment_terms": invoice.payment_terms, - "footer_text": invoice.footer_text, - "now": datetime.now(UTC), - } - - return template.render(**context) - - def get_pdf_path(self, invoice: Invoice) -> str | None: - """Get the PDF path for an invoice if it exists.""" - if invoice.pdf_path and Path(invoice.pdf_path).exists(): - return invoice.pdf_path - return None - - def delete_pdf(self, invoice: Invoice, db: Session) -> bool: - """ - Delete the PDF file for an invoice. - - Args: - invoice: Invoice whose PDF to delete - db: Database session - - Returns: - True if deleted, False if not found - """ - if not invoice.pdf_path: - return False - - pdf_path = Path(invoice.pdf_path) - if pdf_path.exists(): - try: - pdf_path.unlink() - logger.info(f"Deleted PDF for invoice {invoice.invoice_number}") - except Exception as e: - logger.error(f"Failed to delete PDF {pdf_path}: {e}") - return False - - # Clear PDF fields - invoice.pdf_path = None - invoice.pdf_generated_at = None - db.flush() - - return True - - def regenerate_pdf(self, db: Session, invoice: Invoice) -> str: - """Force regenerate PDF for an invoice.""" - return self.generate_pdf(db, invoice, force_regenerate=True) - - -# Singleton instance -invoice_pdf_service = InvoicePDFService() +__all__ = [ + "invoice_pdf_service", + "InvoicePDFService", +] diff --git a/app/services/invoice_service.py b/app/services/invoice_service.py index 3d9e1348..ec5840f1 100644 --- a/app/services/invoice_service.py +++ b/app/services/invoice_service.py @@ -1,675 +1,23 @@ # app/services/invoice_service.py """ -Invoice service for generating and managing invoices. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Handles: -- Vendor invoice settings management -- Invoice generation from orders -- VAT calculation (Luxembourg, EU, B2B reverse charge) -- Invoice number sequencing -- PDF generation (via separate module) +The canonical implementation is now in: + app/modules/orders/services/invoice_service.py -VAT Logic: -- Luxembourg domestic: 17% (standard), 8% (reduced), 3% (super-reduced), 14% (intermediate) -- EU cross-border B2C with OSS: Use destination country VAT rate -- EU cross-border B2C without OSS: Use Luxembourg VAT rate (origin principle) -- EU B2B with valid VAT number: Reverse charge (0% VAT) +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.orders.services import invoice_service """ -import logging -from datetime import UTC, datetime -from decimal import Decimal -from typing import Any - -from sqlalchemy import and_, func -from sqlalchemy.orm import Session - -from app.exceptions import ( - ValidationException, -) -from app.exceptions.invoice import ( - InvoiceNotFoundException, - InvoicePDFGenerationException, - InvoicePDFNotFoundException, - InvoiceSettingsAlreadyExistException, - InvoiceSettingsNotFoundException, - InvoiceValidationException, - OrderNotFoundException, -) -from models.database.invoice import ( - Invoice, - InvoiceStatus, - VATRegime, - VendorInvoiceSettings, -) -from models.database.order import Order -from models.database.vendor import Vendor -from models.schema.invoice import ( - InvoiceBuyerDetails, - InvoiceCreate, - InvoiceLineItem, - InvoiceManualCreate, - InvoiceSellerDetails, - VendorInvoiceSettingsCreate, - VendorInvoiceSettingsUpdate, +from app.modules.orders.services.invoice_service import ( + invoice_service, + InvoiceService, ) -logger = logging.getLogger(__name__) - - -# EU VAT rates by country code (2024 standard rates) -EU_VAT_RATES: dict[str, Decimal] = { - "AT": Decimal("20.00"), # Austria - "BE": Decimal("21.00"), # Belgium - "BG": Decimal("20.00"), # Bulgaria - "HR": Decimal("25.00"), # Croatia - "CY": Decimal("19.00"), # Cyprus - "CZ": Decimal("21.00"), # Czech Republic - "DK": Decimal("25.00"), # Denmark - "EE": Decimal("22.00"), # Estonia - "FI": Decimal("24.00"), # Finland - "FR": Decimal("20.00"), # France - "DE": Decimal("19.00"), # Germany - "GR": Decimal("24.00"), # Greece - "HU": Decimal("27.00"), # Hungary - "IE": Decimal("23.00"), # Ireland - "IT": Decimal("22.00"), # Italy - "LV": Decimal("21.00"), # Latvia - "LT": Decimal("21.00"), # Lithuania - "LU": Decimal("17.00"), # Luxembourg (standard) - "MT": Decimal("18.00"), # Malta - "NL": Decimal("21.00"), # Netherlands - "PL": Decimal("23.00"), # Poland - "PT": Decimal("23.00"), # Portugal - "RO": Decimal("19.00"), # Romania - "SK": Decimal("20.00"), # Slovakia - "SI": Decimal("22.00"), # Slovenia - "ES": Decimal("21.00"), # Spain - "SE": Decimal("25.00"), # Sweden -} - -# Luxembourg specific VAT rates -LU_VAT_RATES = { - "standard": Decimal("17.00"), - "intermediate": Decimal("14.00"), - "reduced": Decimal("8.00"), - "super_reduced": Decimal("3.00"), -} - - -class InvoiceService: - """Service for invoice operations.""" - - # ========================================================================= - # VAT Calculation - # ========================================================================= - - def get_vat_rate_for_country(self, country_iso: str) -> Decimal: - """Get standard VAT rate for EU country.""" - return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00")) - - def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str: - """Get human-readable VAT rate label.""" - country_names = { - "AT": "Austria", - "BE": "Belgium", - "BG": "Bulgaria", - "HR": "Croatia", - "CY": "Cyprus", - "CZ": "Czech Republic", - "DK": "Denmark", - "EE": "Estonia", - "FI": "Finland", - "FR": "France", - "DE": "Germany", - "GR": "Greece", - "HU": "Hungary", - "IE": "Ireland", - "IT": "Italy", - "LV": "Latvia", - "LT": "Lithuania", - "LU": "Luxembourg", - "MT": "Malta", - "NL": "Netherlands", - "PL": "Poland", - "PT": "Portugal", - "RO": "Romania", - "SK": "Slovakia", - "SI": "Slovenia", - "ES": "Spain", - "SE": "Sweden", - } - country_name = country_names.get(country_iso.upper(), country_iso) - return f"{country_name} VAT {vat_rate}%" - - def determine_vat_regime( - self, - seller_country: str, - buyer_country: str, - buyer_vat_number: str | None, - seller_oss_registered: bool, - ) -> tuple[VATRegime, Decimal, str | None]: - """ - Determine VAT regime and rate for invoice. - - Returns: (regime, vat_rate, destination_country) - """ - seller_country = seller_country.upper() - buyer_country = buyer_country.upper() - - # Same country = domestic VAT - if seller_country == buyer_country: - vat_rate = self.get_vat_rate_for_country(seller_country) - return VATRegime.DOMESTIC, vat_rate, None - - # Different EU countries - if buyer_country in EU_VAT_RATES: - # B2B with valid VAT number = reverse charge - if buyer_vat_number: - return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country - - # B2C cross-border - if seller_oss_registered: - # OSS: use destination country VAT - vat_rate = self.get_vat_rate_for_country(buyer_country) - return VATRegime.OSS, vat_rate, buyer_country - else: - # No OSS: use origin country VAT - vat_rate = self.get_vat_rate_for_country(seller_country) - return VATRegime.ORIGIN, vat_rate, buyer_country - - # Non-EU = VAT exempt (export) - return VATRegime.EXEMPT, Decimal("0.00"), buyer_country - - # ========================================================================= - # Invoice Settings Management - # ========================================================================= - - def get_settings( - self, db: Session, vendor_id: int - ) -> VendorInvoiceSettings | None: - """Get vendor invoice settings.""" - return ( - db.query(VendorInvoiceSettings) - .filter(VendorInvoiceSettings.vendor_id == vendor_id) - .first() - ) - - def get_settings_or_raise( - self, db: Session, vendor_id: int - ) -> VendorInvoiceSettings: - """Get vendor invoice settings or raise exception.""" - settings = self.get_settings(db, vendor_id) - if not settings: - raise InvoiceSettingsNotFoundException(vendor_id) - return settings - - def create_settings( - self, - db: Session, - vendor_id: int, - data: VendorInvoiceSettingsCreate, - ) -> VendorInvoiceSettings: - """Create vendor invoice settings.""" - # Check if settings already exist - existing = self.get_settings(db, vendor_id) - if existing: - raise ValidationException( - "Invoice settings already exist for this vendor" - ) - - settings = VendorInvoiceSettings( - vendor_id=vendor_id, - **data.model_dump(), - ) - db.add(settings) - db.flush() - db.refresh(settings) - - logger.info(f"Created invoice settings for vendor {vendor_id}") - return settings - - def update_settings( - self, - db: Session, - vendor_id: int, - data: VendorInvoiceSettingsUpdate, - ) -> VendorInvoiceSettings: - """Update vendor invoice settings.""" - settings = self.get_settings_or_raise(db, vendor_id) - - update_data = data.model_dump(exclude_unset=True) - for key, value in update_data.items(): - setattr(settings, key, value) - - settings.updated_at = datetime.now(UTC) - db.flush() - db.refresh(settings) - - logger.info(f"Updated invoice settings for vendor {vendor_id}") - return settings - - def create_settings_from_vendor( - self, - db: Session, - vendor: Vendor, - ) -> VendorInvoiceSettings: - """ - Create invoice settings from vendor/company info. - - Used for initial setup based on existing vendor data. - """ - company = vendor.company - - settings = VendorInvoiceSettings( - vendor_id=vendor.id, - company_name=company.legal_name if company else vendor.name, - company_address=vendor.effective_business_address, - company_city=None, # Would need to parse from address - company_postal_code=None, - company_country="LU", - vat_number=vendor.effective_tax_number, - is_vat_registered=bool(vendor.effective_tax_number), - ) - db.add(settings) - db.flush() - db.refresh(settings) - - logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}") - return settings - - # ========================================================================= - # Invoice Number Generation - # ========================================================================= - - def _get_next_invoice_number( - self, db: Session, settings: VendorInvoiceSettings - ) -> str: - """Generate next invoice number and increment counter.""" - number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding) - invoice_number = f"{settings.invoice_prefix}{number}" - - # Increment counter - settings.invoice_next_number += 1 - db.flush() - - return invoice_number - - # ========================================================================= - # Invoice Creation - # ========================================================================= - - def create_invoice_from_order( - self, - db: Session, - vendor_id: int, - order_id: int, - notes: str | None = None, - ) -> Invoice: - """ - Create an invoice from an order. - - Captures snapshots of seller/buyer details and calculates VAT. - """ - # Get invoice settings - settings = self.get_settings_or_raise(db, vendor_id) - - # Get order - order = ( - db.query(Order) - .filter(and_(Order.id == order_id, Order.vendor_id == vendor_id)) - .first() - ) - if not order: - raise OrderNotFoundException(f"Order {order_id} not found") - - # Check for existing invoice - existing = ( - db.query(Invoice) - .filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id)) - .first() - ) - if existing: - raise ValidationException(f"Invoice already exists for order {order_id}") - - # Determine VAT regime - buyer_country = order.bill_country_iso - vat_regime, vat_rate, destination_country = self.determine_vat_regime( - seller_country=settings.company_country, - buyer_country=buyer_country, - buyer_vat_number=None, # TODO: Add B2B VAT number support - seller_oss_registered=settings.is_oss_registered, - ) - - # Build seller details snapshot - seller_details = { - "company_name": settings.company_name, - "address": settings.company_address, - "city": settings.company_city, - "postal_code": settings.company_postal_code, - "country": settings.company_country, - "vat_number": settings.vat_number, - } - - # Build buyer details snapshot - buyer_details = { - "name": f"{order.bill_first_name} {order.bill_last_name}".strip(), - "email": order.customer_email, - "address": order.bill_address_line_1, - "city": order.bill_city, - "postal_code": order.bill_postal_code, - "country": order.bill_country_iso, - "vat_number": None, # TODO: B2B support - } - if order.bill_company: - buyer_details["company"] = order.bill_company - - # Build line items from order items - line_items = [] - for item in order.items: - line_items.append({ - "description": item.product_name, - "quantity": item.quantity, - "unit_price_cents": item.unit_price_cents, - "total_cents": item.total_price_cents, - "sku": item.product_sku, - "ean": item.gtin, - }) - - # Calculate amounts - subtotal_cents = sum(item["total_cents"] for item in line_items) - - # Calculate VAT - if vat_rate > 0: - vat_amount_cents = int( - subtotal_cents * float(vat_rate) / 100 - ) - else: - vat_amount_cents = 0 - - total_cents = subtotal_cents + vat_amount_cents - - # Get VAT label - vat_rate_label = None - if vat_rate > 0: - if destination_country: - vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate) - else: - vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate) - - # Generate invoice number - invoice_number = self._get_next_invoice_number(db, settings) - - # Create invoice - invoice = Invoice( - vendor_id=vendor_id, - order_id=order_id, - invoice_number=invoice_number, - invoice_date=datetime.now(UTC), - status=InvoiceStatus.DRAFT.value, - seller_details=seller_details, - buyer_details=buyer_details, - line_items=line_items, - vat_regime=vat_regime.value, - destination_country=destination_country, - vat_rate=vat_rate, - vat_rate_label=vat_rate_label, - currency=order.currency, - subtotal_cents=subtotal_cents, - vat_amount_cents=vat_amount_cents, - total_cents=total_cents, - payment_terms=settings.payment_terms, - bank_details={ - "bank_name": settings.bank_name, - "iban": settings.bank_iban, - "bic": settings.bank_bic, - } if settings.bank_iban else None, - footer_text=settings.footer_text, - notes=notes, - ) - - db.add(invoice) - db.flush() - db.refresh(invoice) - - logger.info( - f"Created invoice {invoice_number} for order {order_id} " - f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})" - ) - - return invoice - - # ========================================================================= - # Invoice Retrieval - # ========================================================================= - - def get_invoice( - self, db: Session, vendor_id: int, invoice_id: int - ) -> Invoice | None: - """Get invoice by ID.""" - return ( - db.query(Invoice) - .filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id)) - .first() - ) - - def get_invoice_or_raise( - self, db: Session, vendor_id: int, invoice_id: int - ) -> Invoice: - """Get invoice by ID or raise exception.""" - invoice = self.get_invoice(db, vendor_id, invoice_id) - if not invoice: - raise InvoiceNotFoundException(invoice_id) - return invoice - - def get_invoice_by_number( - self, db: Session, vendor_id: int, invoice_number: str - ) -> Invoice | None: - """Get invoice by invoice number.""" - return ( - db.query(Invoice) - .filter( - and_( - Invoice.invoice_number == invoice_number, - Invoice.vendor_id == vendor_id, - ) - ) - .first() - ) - - def get_invoice_by_order_id( - self, db: Session, vendor_id: int, order_id: int - ) -> Invoice | None: - """Get invoice by order ID.""" - return ( - db.query(Invoice) - .filter( - and_( - Invoice.order_id == order_id, - Invoice.vendor_id == vendor_id, - ) - ) - .first() - ) - - def list_invoices( - self, - db: Session, - vendor_id: int, - status: str | None = None, - page: int = 1, - per_page: int = 20, - ) -> tuple[list[Invoice], int]: - """ - List invoices for vendor with pagination. - - Returns: (invoices, total_count) - """ - query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id) - - if status: - query = query.filter(Invoice.status == status) - - # Get total count - total = query.count() - - # Apply pagination and order - invoices = ( - query.order_by(Invoice.invoice_date.desc()) - .offset((page - 1) * per_page) - .limit(per_page) - .all() - ) - - return invoices, total - - # ========================================================================= - # Invoice Status Management - # ========================================================================= - - def update_status( - self, - db: Session, - vendor_id: int, - invoice_id: int, - new_status: str, - ) -> Invoice: - """Update invoice status.""" - invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id) - - # Validate status transition - valid_statuses = [s.value for s in InvoiceStatus] - if new_status not in valid_statuses: - raise ValidationException(f"Invalid status: {new_status}") - - # Cannot change cancelled invoices - if invoice.status == InvoiceStatus.CANCELLED.value: - raise ValidationException("Cannot change status of cancelled invoice") - - invoice.status = new_status - invoice.updated_at = datetime.now(UTC) - db.flush() - db.refresh(invoice) - - logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}") - return invoice - - def mark_as_issued( - self, db: Session, vendor_id: int, invoice_id: int - ) -> Invoice: - """Mark invoice as issued.""" - return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value) - - def mark_as_paid( - self, db: Session, vendor_id: int, invoice_id: int - ) -> Invoice: - """Mark invoice as paid.""" - return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value) - - def cancel_invoice( - self, db: Session, vendor_id: int, invoice_id: int - ) -> Invoice: - """Cancel invoice.""" - return self.update_status( - db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value - ) - - # ========================================================================= - # Statistics - # ========================================================================= - - def get_invoice_stats( - self, db: Session, vendor_id: int - ) -> dict[str, Any]: - """Get invoice statistics for vendor.""" - total_count = ( - db.query(func.count(Invoice.id)) - .filter(Invoice.vendor_id == vendor_id) - .scalar() - or 0 - ) - - total_revenue = ( - db.query(func.sum(Invoice.total_cents)) - .filter( - and_( - Invoice.vendor_id == vendor_id, - Invoice.status.in_([ - InvoiceStatus.ISSUED.value, - InvoiceStatus.PAID.value, - ]), - ) - ) - .scalar() - or 0 - ) - - draft_count = ( - db.query(func.count(Invoice.id)) - .filter( - and_( - Invoice.vendor_id == vendor_id, - Invoice.status == InvoiceStatus.DRAFT.value, - ) - ) - .scalar() - or 0 - ) - - paid_count = ( - db.query(func.count(Invoice.id)) - .filter( - and_( - Invoice.vendor_id == vendor_id, - Invoice.status == InvoiceStatus.PAID.value, - ) - ) - .scalar() - or 0 - ) - - return { - "total_invoices": total_count, - "total_revenue_cents": total_revenue, - "total_revenue": total_revenue / 100 if total_revenue else 0, - "draft_count": draft_count, - "paid_count": paid_count, - } - - - # ========================================================================= - # PDF Generation - # ========================================================================= - - def generate_pdf( - self, - db: Session, - vendor_id: int, - invoice_id: int, - force_regenerate: bool = False, - ) -> str: - """ - Generate PDF for an invoice. - - Returns path to the generated PDF. - """ - from app.services.invoice_pdf_service import invoice_pdf_service - - invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id) - return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate) - - def get_pdf_path( - self, - db: Session, - vendor_id: int, - invoice_id: int, - ) -> str | None: - """Get PDF path for an invoice if it exists.""" - from app.services.invoice_pdf_service import invoice_pdf_service - - invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id) - return invoice_pdf_service.get_pdf_path(invoice) - - -# Singleton instance -invoice_service = InvoiceService() +__all__ = [ + "invoice_service", + "InvoiceService", +] diff --git a/app/services/letzshop/__init__.py b/app/services/letzshop/__init__.py index c7e8af06..8d9416e5 100644 --- a/app/services/letzshop/__init__.py +++ b/app/services/letzshop/__init__.py @@ -1,33 +1,33 @@ # app/services/letzshop/__init__.py """ -Letzshop marketplace integration services. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Provides: -- GraphQL client for API communication -- Credential management service -- Order import service -- Fulfillment sync service -- Vendor directory sync service +The canonical implementation is now in: + app/modules/marketplace/services/letzshop/ + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.marketplace.services.letzshop import LetzshopClient """ -from .client_service import ( - LetzshopAPIError, - LetzshopAuthError, +from app.modules.marketplace.services.letzshop import ( + # Client LetzshopClient, LetzshopClientError, + LetzshopAuthError, + LetzshopAPIError, LetzshopConnectionError, -) -from .credentials_service import ( + # Credentials + LetzshopCredentialsService, CredentialsError, CredentialsNotFoundError, - LetzshopCredentialsService, -) -from .order_service import ( + # Order Service LetzshopOrderService, OrderNotFoundError, VendorNotFoundError, -) -from .vendor_sync_service import ( + # Vendor Sync Service LetzshopVendorSyncService, get_vendor_sync_service, ) diff --git a/app/services/letzshop_export_service.py b/app/services/letzshop_export_service.py index 5e4dee29..16718d47 100644 --- a/app/services/letzshop_export_service.py +++ b/app/services/letzshop_export_service.py @@ -1,338 +1,25 @@ # app/services/letzshop_export_service.py """ -Service for exporting products to Letzshop CSV format. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Generates Google Shopping compatible CSV files for Letzshop marketplace. +The canonical implementation is now in: + app/modules/marketplace/services/letzshop_export_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.marketplace.services import letzshop_export_service """ -import csv -import io -import logging -from datetime import UTC, datetime +from app.modules.marketplace.services.letzshop_export_service import ( + LetzshopExportService, + letzshop_export_service, + LETZSHOP_CSV_COLUMNS, +) -from sqlalchemy.orm import Session, joinedload - -from models.database.letzshop import LetzshopSyncLog -from models.database.marketplace_product import MarketplaceProduct -from models.database.product import Product - -logger = logging.getLogger(__name__) - -# Letzshop CSV columns in order -LETZSHOP_CSV_COLUMNS = [ - "id", - "title", - "description", - "link", - "image_link", - "additional_image_link", - "availability", - "price", - "sale_price", - "brand", - "gtin", - "mpn", - "google_product_category", - "product_type", - "condition", - "adult", - "multipack", - "is_bundle", - "age_group", - "color", - "gender", - "material", - "pattern", - "size", - "size_type", - "size_system", - "item_group_id", - "custom_label_0", - "custom_label_1", - "custom_label_2", - "custom_label_3", - "custom_label_4", - "identifier_exists", - "unit_pricing_measure", - "unit_pricing_base_measure", - "shipping", - "atalanda:tax_rate", - "atalanda:quantity", - "atalanda:boost_sort", - "atalanda:delivery_method", +__all__ = [ + "LetzshopExportService", + "letzshop_export_service", + "LETZSHOP_CSV_COLUMNS", ] - - -class LetzshopExportService: - """Service for exporting products to Letzshop CSV format.""" - - def __init__(self, default_tax_rate: float = 17.0): - """ - Initialize the export service. - - Args: - default_tax_rate: Default VAT rate for Luxembourg (17%) - """ - self.default_tax_rate = default_tax_rate - - def export_vendor_products( - self, - db: Session, - vendor_id: int, - language: str = "en", - include_inactive: bool = False, - ) -> str: - """ - Export all products for a vendor in Letzshop CSV format. - - Args: - db: Database session - vendor_id: Vendor ID to export products for - language: Language for title/description (en, fr, de) - include_inactive: Whether to include inactive products - - Returns: - CSV string content - """ - # Query products for this vendor with their marketplace product data - query = ( - db.query(Product) - .filter(Product.vendor_id == vendor_id) - .options( - joinedload(Product.marketplace_product).joinedload( - MarketplaceProduct.translations - ) - ) - ) - - if not include_inactive: - query = query.filter(Product.is_active == True) - - products = query.all() - - logger.info( - f"Exporting {len(products)} products for vendor {vendor_id} in {language}" - ) - - return self._generate_csv(products, language) - - def export_marketplace_products( - self, - db: Session, - marketplace: str = "Letzshop", - language: str = "en", - limit: int | None = None, - ) -> str: - """ - Export marketplace products directly (admin use). - - Args: - db: Database session - marketplace: Filter by marketplace source - language: Language for title/description - limit: Optional limit on number of products - - Returns: - CSV string content - """ - query = ( - db.query(MarketplaceProduct) - .filter(MarketplaceProduct.is_active == True) - .options(joinedload(MarketplaceProduct.translations)) - ) - - if marketplace: - query = query.filter( - MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") - ) - - if limit: - query = query.limit(limit) - - products = query.all() - - logger.info( - f"Exporting {len(products)} marketplace products for {marketplace} in {language}" - ) - - return self._generate_csv_from_marketplace_products(products, language) - - def _generate_csv(self, products: list[Product], language: str) -> str: - """Generate CSV from vendor Product objects.""" - output = io.StringIO() - writer = csv.DictWriter( - output, - fieldnames=LETZSHOP_CSV_COLUMNS, - delimiter="\t", - quoting=csv.QUOTE_MINIMAL, - ) - writer.writeheader() - - for product in products: - if product.marketplace_product: - row = self._product_to_row(product, language) - writer.writerow(row) - - return output.getvalue() - - def _generate_csv_from_marketplace_products( - self, products: list[MarketplaceProduct], language: str - ) -> str: - """Generate CSV from MarketplaceProduct objects directly.""" - output = io.StringIO() - writer = csv.DictWriter( - output, - fieldnames=LETZSHOP_CSV_COLUMNS, - delimiter="\t", - quoting=csv.QUOTE_MINIMAL, - ) - writer.writeheader() - - for mp in products: - row = self._marketplace_product_to_row(mp, language) - writer.writerow(row) - - return output.getvalue() - - def _product_to_row(self, product: Product, language: str) -> dict: - """Convert a Product (with MarketplaceProduct) to a CSV row.""" - mp = product.marketplace_product - return self._marketplace_product_to_row( - mp, language, vendor_sku=product.vendor_sku - ) - - def _marketplace_product_to_row( - self, - mp: MarketplaceProduct, - language: str, - vendor_sku: str | None = None, - ) -> dict: - """Convert a MarketplaceProduct to a CSV row dict.""" - # Get localized title and description - title = mp.get_title(language) or "" - description = mp.get_description(language) or "" - - # Format price with currency - price = "" - if mp.price_numeric: - price = f"{mp.price_numeric:.2f} {mp.currency or 'EUR'}" - elif mp.price: - price = mp.price - - # Format sale price - sale_price = "" - if mp.sale_price_numeric: - sale_price = f"{mp.sale_price_numeric:.2f} {mp.currency or 'EUR'}" - elif mp.sale_price: - sale_price = mp.sale_price - - # Additional images - join with comma if multiple - additional_images = "" - if mp.additional_images: - additional_images = ",".join(mp.additional_images) - elif mp.additional_image_link: - additional_images = mp.additional_image_link - - # Determine identifier_exists - identifier_exists = mp.identifier_exists - if not identifier_exists: - identifier_exists = "yes" if (mp.gtin or mp.mpn) else "no" - - return { - "id": vendor_sku or mp.marketplace_product_id, - "title": title, - "description": description, - "link": mp.link or mp.source_url or "", - "image_link": mp.image_link or "", - "additional_image_link": additional_images, - "availability": mp.availability or "in stock", - "price": price, - "sale_price": sale_price, - "brand": mp.brand or "", - "gtin": mp.gtin or "", - "mpn": mp.mpn or "", - "google_product_category": mp.google_product_category or "", - "product_type": mp.product_type_raw or "", - "condition": mp.condition or "new", - "adult": mp.adult or "no", - "multipack": str(mp.multipack) if mp.multipack else "", - "is_bundle": mp.is_bundle or "no", - "age_group": mp.age_group or "", - "color": mp.color or "", - "gender": mp.gender or "", - "material": mp.material or "", - "pattern": mp.pattern or "", - "size": mp.size or "", - "size_type": mp.size_type or "", - "size_system": mp.size_system or "", - "item_group_id": mp.item_group_id or "", - "custom_label_0": mp.custom_label_0 or "", - "custom_label_1": mp.custom_label_1 or "", - "custom_label_2": mp.custom_label_2 or "", - "custom_label_3": mp.custom_label_3 or "", - "custom_label_4": mp.custom_label_4 or "", - "identifier_exists": identifier_exists, - "unit_pricing_measure": mp.unit_pricing_measure or "", - "unit_pricing_base_measure": mp.unit_pricing_base_measure or "", - "shipping": mp.shipping or "", - "atalanda:tax_rate": str(self.default_tax_rate), - "atalanda:quantity": "", # Would need inventory data - "atalanda:boost_sort": "", - "atalanda:delivery_method": "", - } - - def log_export( - self, - db: Session, - vendor_id: int, - started_at: datetime, - completed_at: datetime, - files_processed: int, - files_succeeded: int, - files_failed: int, - products_exported: int, - triggered_by: str, - error_details: dict | None = None, - ) -> LetzshopSyncLog: - """ - Log an export operation to the sync log. - - Args: - db: Database session - vendor_id: Vendor ID - started_at: When the export started - completed_at: When the export completed - files_processed: Number of language files to export (e.g., 3) - files_succeeded: Number of files successfully exported - files_failed: Number of files that failed - products_exported: Total products in the export - triggered_by: Who triggered the export (e.g., "admin:123") - error_details: Optional error details if any failures - - Returns: - Created LetzshopSyncLog entry - """ - sync_log = LetzshopSyncLog( - vendor_id=vendor_id, - operation_type="product_export", - direction="outbound", - status="completed" if files_failed == 0 else "partial", - records_processed=files_processed, - records_succeeded=files_succeeded, - records_failed=files_failed, - started_at=started_at, - completed_at=completed_at, - duration_seconds=int((completed_at - started_at).total_seconds()), - triggered_by=triggered_by, - error_details={ - "products_exported": products_exported, - **(error_details or {}), - } if products_exported or error_details else None, - ) - db.add(sync_log) - db.flush() - return sync_log - - -# Singleton instance -letzshop_export_service = LetzshopExportService() diff --git a/app/services/marketplace_import_job_service.py b/app/services/marketplace_import_job_service.py index a36973c2..79c1d7fc 100644 --- a/app/services/marketplace_import_job_service.py +++ b/app/services/marketplace_import_job_service.py @@ -1,334 +1,23 @@ # app/services/marketplace_import_job_service.py -import logging +""" +LEGACY LOCATION - Re-exports from module for backwards compatibility. -from sqlalchemy.orm import Session +The canonical implementation is now in: + app/modules/marketplace/services/marketplace_import_job_service.py -from app.exceptions import ( - ImportJobNotFoundException, - ImportJobNotOwnedException, - ValidationException, -) -from models.database.marketplace_import_job import ( - MarketplaceImportError, - MarketplaceImportJob, -) -from models.database.user import User -from models.database.vendor import Vendor -from models.schema.marketplace_import_job import ( - AdminMarketplaceImportJobResponse, - MarketplaceImportJobRequest, - MarketplaceImportJobResponse, +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.marketplace.services import marketplace_import_job_service +""" + +from app.modules.marketplace.services.marketplace_import_job_service import ( + MarketplaceImportJobService, + marketplace_import_job_service, ) -logger = logging.getLogger(__name__) - - -class MarketplaceImportJobService: - """Service class for Marketplace operations.""" - - def create_import_job( - self, - db: Session, - request: MarketplaceImportJobRequest, - vendor: Vendor, # CHANGED: Vendor object from middleware - user: User, - ) -> MarketplaceImportJob: - """ - Create a new marketplace import job. - - Args: - db: Database session - request: Import request data - vendor: Vendor object (from middleware) - user: User creating the job - - Returns: - Created MarketplaceImportJob object - """ - try: - # Create marketplace import job record - import_job = MarketplaceImportJob( - status="pending", - source_url=request.source_url, - marketplace=request.marketplace, - language=request.language, - vendor_id=vendor.id, - user_id=user.id, - ) - - db.add(import_job) - db.flush() - db.refresh(import_job) - - logger.info( - f"Created marketplace import job {import_job.id}: " - f"{request.marketplace} -> {vendor.name} (code: {vendor.vendor_code}) " - f"by user {user.username}" - ) - - return import_job - - except Exception as e: - logger.error(f"Error creating import job: {str(e)}") - raise ValidationException("Failed to create import job") - - def get_import_job_by_id( - self, db: Session, job_id: int, user: User - ) -> MarketplaceImportJob: - """Get a marketplace import job by ID with access control.""" - try: - job = ( - db.query(MarketplaceImportJob) - .filter(MarketplaceImportJob.id == job_id) - .first() - ) - - if not job: - raise ImportJobNotFoundException(job_id) - - # Users can only see their own jobs, admins can see all - if user.role != "admin" and job.user_id != user.id: - raise ImportJobNotOwnedException(job_id, user.id) - - return job - - except (ImportJobNotFoundException, ImportJobNotOwnedException): - raise - except Exception as e: - logger.error(f"Error getting import job {job_id}: {str(e)}") - raise ValidationException("Failed to retrieve import job") - - def get_import_job_for_vendor( - self, db: Session, job_id: int, vendor_id: int - ) -> MarketplaceImportJob: - """ - Get a marketplace import job by ID with vendor access control. - - Validates that the job belongs to the specified vendor. - - Args: - db: Database session - job_id: Import job ID - vendor_id: Vendor ID from token (to verify ownership) - - Raises: - ImportJobNotFoundException: If job not found - UnauthorizedVendorAccessException: If job doesn't belong to vendor - """ - from app.exceptions import UnauthorizedVendorAccessException - - try: - job = ( - db.query(MarketplaceImportJob) - .filter(MarketplaceImportJob.id == job_id) - .first() - ) - - if not job: - raise ImportJobNotFoundException(job_id) - - # Verify job belongs to vendor (service layer validation) - if job.vendor_id != vendor_id: - raise UnauthorizedVendorAccessException( - vendor_code=str(vendor_id), - user_id=0, # Not user-specific, but vendor mismatch - ) - - return job - - except (ImportJobNotFoundException, UnauthorizedVendorAccessException): - raise - except Exception as e: - logger.error( - f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}" - ) - raise ValidationException("Failed to retrieve import job") - - def get_import_jobs( - self, - db: Session, - vendor: Vendor, # ADDED: Vendor filter - user: User, - marketplace: str | None = None, - skip: int = 0, - limit: int = 50, - ) -> list[MarketplaceImportJob]: - """Get marketplace import jobs for a specific vendor.""" - try: - query = db.query(MarketplaceImportJob).filter( - MarketplaceImportJob.vendor_id == vendor.id - ) - - # Users can only see their own jobs, admins can see all vendor jobs - if user.role != "admin": - query = query.filter(MarketplaceImportJob.user_id == user.id) - - # Apply marketplace filter - if marketplace: - query = query.filter( - MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") - ) - - # Order by creation date (newest first) and apply pagination - jobs = ( - query.order_by( - MarketplaceImportJob.created_at.desc(), - MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp - ) - .offset(skip) - .limit(limit) - .all() - ) - - return jobs - - except Exception as e: - logger.error(f"Error getting import jobs: {str(e)}") - raise ValidationException("Failed to retrieve import jobs") - - def convert_to_response_model( - self, job: MarketplaceImportJob - ) -> MarketplaceImportJobResponse: - """Convert database model to API response model.""" - return MarketplaceImportJobResponse( - job_id=job.id, - status=job.status, - marketplace=job.marketplace, - language=job.language, - vendor_id=job.vendor_id, - vendor_code=job.vendor.vendor_code if job.vendor else None, - vendor_name=job.vendor.name if job.vendor else None, - source_url=job.source_url, - imported=job.imported_count or 0, - updated=job.updated_count or 0, - total_processed=job.total_processed or 0, - error_count=job.error_count or 0, - error_message=job.error_message, - created_at=job.created_at, - started_at=job.started_at, - completed_at=job.completed_at, - ) - - def convert_to_admin_response_model( - self, job: MarketplaceImportJob - ) -> AdminMarketplaceImportJobResponse: - """Convert database model to admin API response model with extra fields.""" - return AdminMarketplaceImportJobResponse( - id=job.id, - job_id=job.id, - status=job.status, - marketplace=job.marketplace, - language=job.language, - vendor_id=job.vendor_id, - vendor_code=job.vendor.vendor_code if job.vendor else None, - vendor_name=job.vendor.name if job.vendor else None, - source_url=job.source_url, - imported=job.imported_count or 0, - updated=job.updated_count or 0, - total_processed=job.total_processed or 0, - error_count=job.error_count or 0, - error_message=job.error_message, - error_details=[], - created_at=job.created_at, - started_at=job.started_at, - completed_at=job.completed_at, - created_by_name=job.user.username if job.user else None, - ) - - def get_all_import_jobs_paginated( - self, - db: Session, - marketplace: str | None = None, - status: str | None = None, - page: int = 1, - limit: int = 100, - ) -> tuple[list[MarketplaceImportJob], int]: - """Get all marketplace import jobs with pagination (for admin).""" - try: - query = db.query(MarketplaceImportJob) - - if marketplace: - query = query.filter( - MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%") - ) - if status: - query = query.filter(MarketplaceImportJob.status == status) - - total = query.count() - skip = (page - 1) * limit - jobs = ( - query.order_by( - MarketplaceImportJob.created_at.desc(), - MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp - ) - .offset(skip) - .limit(limit) - .all() - ) - - return jobs, total - - except Exception as e: - logger.error(f"Error getting all import jobs: {str(e)}") - raise ValidationException("Failed to retrieve import jobs") - - def get_import_job_by_id_admin( - self, db: Session, job_id: int - ) -> MarketplaceImportJob: - """Get a marketplace import job by ID (admin - no access control).""" - job = ( - db.query(MarketplaceImportJob) - .filter(MarketplaceImportJob.id == job_id) - .first() - ) - if not job: - raise ImportJobNotFoundException(job_id) - return job - - def get_import_job_errors( - self, - db: Session, - job_id: int, - error_type: str | None = None, - page: int = 1, - limit: int = 50, - ) -> tuple[list[MarketplaceImportError], int]: - """ - Get import errors for a specific job with pagination. - - Args: - db: Database session - job_id: Import job ID - error_type: Optional filter by error type - page: Page number (1-indexed) - limit: Number of items per page - - Returns: - Tuple of (list of errors, total count) - """ - try: - query = db.query(MarketplaceImportError).filter( - MarketplaceImportError.import_job_id == job_id - ) - - if error_type: - query = query.filter(MarketplaceImportError.error_type == error_type) - - total = query.count() - - offset = (page - 1) * limit - errors = ( - query.order_by(MarketplaceImportError.row_number) - .offset(offset) - .limit(limit) - .all() - ) - - return errors, total - - except Exception as e: - logger.error(f"Error getting import job errors for job {job_id}: {str(e)}") - raise ValidationException("Failed to retrieve import errors") - - -marketplace_import_job_service = MarketplaceImportJobService() +__all__ = [ + "MarketplaceImportJobService", + "marketplace_import_job_service", +] diff --git a/app/services/marketplace_product_service.py b/app/services/marketplace_product_service.py index 720bcc37..901ce9d7 100644 --- a/app/services/marketplace_product_service.py +++ b/app/services/marketplace_product_service.py @@ -1,1075 +1,23 @@ # app/services/marketplace_product_service.py """ -MarketplaceProduct service for managing product operations and data processing. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -This module provides classes and functions for: -- MarketplaceProduct CRUD operations with validation -- Advanced product filtering and search -- Inventory information integration -- CSV export functionality +The canonical implementation is now in: + app/modules/marketplace/services/marketplace_product_service.py -Note: Title and description are now stored in MarketplaceProductTranslation table. -Use get_title(language) and get_description(language) methods on the model. +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.marketplace.services import marketplace_product_service """ -import csv -import logging -from collections.abc import Generator -from datetime import UTC, datetime -from io import StringIO - -from sqlalchemy import or_ -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, joinedload - -from app.exceptions import ( - InvalidMarketplaceProductDataException, - MarketplaceProductAlreadyExistsException, - MarketplaceProductNotFoundException, - MarketplaceProductValidationException, - ValidationException, -) -from app.utils.data_processing import GTINProcessor, PriceProcessor -from models.database.inventory import Inventory -from models.database.marketplace_product import MarketplaceProduct -from models.database.marketplace_product_translation import ( - MarketplaceProductTranslation, -) -from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse -from models.schema.marketplace_product import ( - MarketplaceProductCreate, - MarketplaceProductUpdate, +from app.modules.marketplace.services.marketplace_product_service import ( + MarketplaceProductService, + marketplace_product_service, ) -logger = logging.getLogger(__name__) - - -class MarketplaceProductService: - """Service class for MarketplaceProduct operations following the application's service pattern.""" - - def __init__(self): - """Class constructor.""" - self.gtin_processor = GTINProcessor() - self.price_processor = PriceProcessor() - - def create_product( - self, - db: Session, - product_data: MarketplaceProductCreate, - title: str | None = None, - description: str | None = None, - language: str = "en", - ) -> MarketplaceProduct: - """Create a new product with validation. - - Args: - db: Database session - product_data: Product data from schema - title: Product title (stored in translations table) - description: Product description (stored in translations table) - language: Language code for translation (default: 'en') - - Returns: - Created MarketplaceProduct instance - """ - try: - # Process and validate GTIN if provided - if product_data.gtin: - normalized_gtin = self.gtin_processor.normalize(product_data.gtin) - if not normalized_gtin: - raise InvalidMarketplaceProductDataException( - "Invalid GTIN format", field="gtin" - ) - product_data.gtin = normalized_gtin - - # Process price if provided - if product_data.price: - try: - parsed_price, currency = self.price_processor.parse_price_currency( - product_data.price - ) - if parsed_price: - product_data.price = parsed_price - product_data.currency = currency - except ValueError as e: - # Convert ValueError to domain-specific exception - raise InvalidMarketplaceProductDataException(str(e), field="price") - - # Set default marketplace if not provided - if not product_data.marketplace: - product_data.marketplace = "Letzshop" - - # Validate required fields - if ( - not product_data.marketplace_product_id - or not product_data.marketplace_product_id.strip() - ): - raise MarketplaceProductValidationException( - "MarketplaceProduct ID is required", field="marketplace_product_id" - ) - - # Create the product (without title/description - those go in translations) - product_dict = product_data.model_dump() - # Remove any title/description if present in schema (for backwards compatibility) - product_dict.pop("title", None) - product_dict.pop("description", None) - - db_product = MarketplaceProduct(**product_dict) - db.add(db_product) - db.flush() # Get the ID - - # Create translation if title is provided - if title and title.strip(): - translation = MarketplaceProductTranslation( - marketplace_product_id=db_product.id, - language=language, - title=title.strip(), - description=description.strip() if description else None, - ) - db.add(translation) - - db.flush() - db.refresh(db_product) - - logger.info(f"Created product {db_product.marketplace_product_id}") - return db_product - - except ( - InvalidMarketplaceProductDataException, - MarketplaceProductValidationException, - ): - raise # Re-raise custom exceptions - except IntegrityError as e: - logger.error(f"Database integrity error: {str(e)}") - if "marketplace_product_id" in str(e).lower() or "unique" in str(e).lower(): - raise MarketplaceProductAlreadyExistsException( - product_data.marketplace_product_id - ) - raise MarketplaceProductValidationException( - "Data integrity constraint violation" - ) - except Exception as e: - logger.error(f"Error creating product: {str(e)}") - raise ValidationException("Failed to create product") - - def get_product_by_id( - self, db: Session, marketplace_product_id: str - ) -> MarketplaceProduct | None: - """Get a product by its ID.""" - try: - return ( - db.query(MarketplaceProduct) - .options(joinedload(MarketplaceProduct.translations)) - .filter( - MarketplaceProduct.marketplace_product_id == marketplace_product_id - ) - .first() - ) - except Exception as e: - logger.error(f"Error getting product {marketplace_product_id}: {str(e)}") - return None - - def get_product_by_id_or_raise( - self, db: Session, marketplace_product_id: str - ) -> MarketplaceProduct: - """ - Get a product by its ID or raise exception. - - Args: - db: Database session - marketplace_product_id: MarketplaceProduct ID to find - - Returns: - MarketplaceProduct object - - Raises: - MarketplaceProductNotFoundException: If product doesn't exist - """ - product = self.get_product_by_id(db, marketplace_product_id) - if not product: - raise MarketplaceProductNotFoundException(marketplace_product_id) - return product - - def get_products_with_filters( - self, - db: Session, - skip: int = 0, - limit: int = 100, - brand: str | None = None, - category: str | None = None, - availability: str | None = None, - marketplace: str | None = None, - vendor_name: str | None = None, - search: str | None = None, - language: str = "en", - ) -> tuple[list[MarketplaceProduct], int]: - """ - Get products with filtering and pagination. - - Args: - db: Database session - skip: Number of records to skip - limit: Maximum records to return - brand: Brand filter - category: Category filter - availability: Availability filter - marketplace: Marketplace filter - vendor_name: Vendor name filter - search: Search term (searches in translations too) - language: Language for search (default: 'en') - - Returns: - Tuple of (products_list, total_count) - """ - try: - query = db.query(MarketplaceProduct).options( - joinedload(MarketplaceProduct.translations) - ) - - # Apply filters - if brand: - query = query.filter(MarketplaceProduct.brand.ilike(f"%{brand}%")) - if category: - query = query.filter( - MarketplaceProduct.google_product_category.ilike(f"%{category}%") - ) - if availability: - query = query.filter(MarketplaceProduct.availability == availability) - if marketplace: - query = query.filter( - MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") - ) - if vendor_name: - query = query.filter( - MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") - ) - if search: - # Search in marketplace, vendor_name, brand, and translations - search_term = f"%{search}%" - # Use subquery to get distinct IDs (PostgreSQL can't compare JSON for DISTINCT) - id_subquery = ( - db.query(MarketplaceProduct.id) - .outerjoin(MarketplaceProductTranslation) - .filter( - or_( - MarketplaceProduct.marketplace.ilike(search_term), - MarketplaceProduct.vendor_name.ilike(search_term), - MarketplaceProduct.brand.ilike(search_term), - MarketplaceProduct.gtin.ilike(search_term), - MarketplaceProduct.marketplace_product_id.ilike(search_term), - MarketplaceProductTranslation.title.ilike(search_term), - MarketplaceProductTranslation.description.ilike(search_term), - ) - ) - .distinct() - .subquery() - ) - query = query.filter(MarketplaceProduct.id.in_( - db.query(id_subquery.c.id) - )) - - total = query.count() - products = query.offset(skip).limit(limit).all() - - return products, total - - except Exception as e: - logger.error(f"Error getting products with filters: {str(e)}") - raise ValidationException("Failed to retrieve products") - - def update_product( - self, - db: Session, - marketplace_product_id: str, - product_update: MarketplaceProductUpdate, - title: str | None = None, - description: str | None = None, - language: str = "en", - ) -> MarketplaceProduct: - """Update product with validation. - - Args: - db: Database session - marketplace_product_id: ID of product to update - product_update: Product update data from schema - title: Updated title (stored in translations table) - description: Updated description (stored in translations table) - language: Language code for translation (default: 'en') - - Returns: - Updated MarketplaceProduct instance - """ - try: - product = self.get_product_by_id_or_raise(db, marketplace_product_id) - - # Update fields - update_data = product_update.model_dump(exclude_unset=True) - - # Remove title/description from update data (handled separately) - update_data.pop("title", None) - update_data.pop("description", None) - - # Validate GTIN if being updated - if "gtin" in update_data and update_data["gtin"]: - normalized_gtin = self.gtin_processor.normalize(update_data["gtin"]) - if not normalized_gtin: - raise InvalidMarketplaceProductDataException( - "Invalid GTIN format", field="gtin" - ) - update_data["gtin"] = normalized_gtin - - # Process price if being updated - if "price" in update_data and update_data["price"]: - try: - parsed_price, currency = self.price_processor.parse_price_currency( - update_data["price"] - ) - if parsed_price: - update_data["price"] = parsed_price - update_data["currency"] = currency - except ValueError as e: - # Convert ValueError to domain-specific exception - raise InvalidMarketplaceProductDataException(str(e), field="price") - - # Apply updates to product - for key, value in update_data.items(): - if hasattr(product, key): - setattr(product, key, value) - - product.updated_at = datetime.now(UTC) - - # Update or create translation if title/description provided - if title is not None or description is not None: - self._update_or_create_translation( - db, product, title, description, language - ) - - db.flush() - db.refresh(product) - - logger.info(f"Updated product {marketplace_product_id}") - return product - - except ( - MarketplaceProductNotFoundException, - InvalidMarketplaceProductDataException, - MarketplaceProductValidationException, - ): - raise # Re-raise custom exceptions - except Exception as e: - logger.error(f"Error updating product {marketplace_product_id}: {str(e)}") - raise ValidationException("Failed to update product") - - def _update_or_create_translation( - self, - db: Session, - product: MarketplaceProduct, - title: str | None, - description: str | None, - language: str, - ) -> None: - """Update existing translation or create new one.""" - existing = ( - db.query(MarketplaceProductTranslation) - .filter( - MarketplaceProductTranslation.marketplace_product_id == product.id, - MarketplaceProductTranslation.language == language, - ) - .first() - ) - - if existing: - if title is not None: - existing.title = title.strip() if title else existing.title - if description is not None: - existing.description = description.strip() if description else None - existing.updated_at = datetime.now(UTC) - else: - # Only create if we have a title - if title and title.strip(): - new_translation = MarketplaceProductTranslation( - marketplace_product_id=product.id, - language=language, - title=title.strip(), - description=description.strip() if description else None, - ) - db.add(new_translation) - - def delete_product(self, db: Session, marketplace_product_id: str) -> bool: - """ - Delete product and associated inventory. - - Args: - db: Database session - marketplace_product_id: MarketplaceProduct ID to delete - - Returns: - True if deletion successful - - Raises: - MarketplaceProductNotFoundException: If product doesn't exist - """ - try: - product = self.get_product_by_id_or_raise(db, marketplace_product_id) - - # Delete associated inventory entries if GTIN exists - if product.gtin: - db.query(Inventory).filter(Inventory.gtin == product.gtin).delete() - - # Translations will be cascade deleted - db.delete(product) - db.flush() - - logger.info(f"Deleted product {marketplace_product_id}") - return True - - except MarketplaceProductNotFoundException: - raise # Re-raise custom exceptions - except Exception as e: - logger.error(f"Error deleting product {marketplace_product_id}: {str(e)}") - raise ValidationException("Failed to delete product") - - def get_inventory_info( - self, db: Session, gtin: str - ) -> InventorySummaryResponse | None: - """ - Get inventory information for a product by GTIN. - - Args: - db: Database session - gtin: GTIN to look up inventory for - - Returns: - InventorySummaryResponse if inventory found, None otherwise - """ - try: - # noqa: SVC-005 - Admin/internal function for inventory lookup by GTIN - inventory_entries = db.query(Inventory).filter(Inventory.gtin == gtin).all() - if not inventory_entries: - return None - - total_quantity = sum(entry.quantity for entry in inventory_entries) - locations = [ - InventoryLocationResponse( - location=entry.location, - quantity=entry.quantity, - reserved_quantity=entry.reserved_quantity or 0, - available_quantity=entry.quantity - (entry.reserved_quantity or 0), - ) - for entry in inventory_entries - ] - - return InventorySummaryResponse( - gtin=gtin, total_quantity=total_quantity, locations=locations - ) - - except Exception as e: - logger.error(f"Error getting inventory info for GTIN {gtin}: {str(e)}") - return None - - def generate_csv_export( - self, - db: Session, - marketplace: str | None = None, - vendor_name: str | None = None, - language: str = "en", - ) -> Generator[str, None, None]: - """ - Generate CSV export with streaming for memory efficiency and proper CSV escaping. - - Args: - db: Database session - marketplace: Optional marketplace filter - vendor_name: Optional vendor name filter - language: Language code for title/description (default: 'en') - - Yields: - CSV content as strings with proper escaping - """ - try: - # Create a StringIO buffer for CSV writing - output = StringIO() - writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL) - - # Write header row - headers = [ - "marketplace_product_id", - "title", - "description", - "link", - "image_link", - "availability", - "price", - "currency", - "brand", - "gtin", - "marketplace", - "vendor_name", - ] - writer.writerow(headers) - yield output.getvalue() - - # Clear buffer for reuse - output.seek(0) - output.truncate(0) - - batch_size = 1000 - offset = 0 - - while True: - query = db.query(MarketplaceProduct).options( - joinedload(MarketplaceProduct.translations) - ) - - # Apply marketplace filters - if marketplace: - query = query.filter( - MarketplaceProduct.marketplace.ilike(f"%{marketplace}%") - ) - if vendor_name: - query = query.filter( - MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") - ) - - products = query.offset(offset).limit(batch_size).all() - if not products: - break - - for product in products: - # Get title and description from translations - title = product.get_title(language) or "" - description = product.get_description(language) or "" - - # Create CSV row with proper escaping - row_data = [ - product.marketplace_product_id or "", - title, - description, - product.link or "", - product.image_link or "", - product.availability or "", - product.price or "", - product.currency or "", - product.brand or "", - product.gtin or "", - product.marketplace or "", - product.vendor_name or "", - ] - - writer.writerow(row_data) - yield output.getvalue() - - # Clear buffer for next row - output.seek(0) - output.truncate(0) - - offset += batch_size - - except Exception as e: - logger.error(f"Error generating CSV export: {str(e)}") - raise ValidationException("Failed to generate CSV export") - - def product_exists(self, db: Session, marketplace_product_id: str) -> bool: - """Check if product exists by ID.""" - try: - return ( - db.query(MarketplaceProduct) - .filter( - MarketplaceProduct.marketplace_product_id == marketplace_product_id - ) - .first() - is not None - ) - except Exception as e: - logger.error(f"Error checking if product exists: {str(e)}") - return False - - # Private helper methods - def _validate_product_data(self, product_data: dict) -> None: - """Validate product data structure.""" - required_fields = ["marketplace_product_id"] - - for field in required_fields: - if field not in product_data or not product_data[field]: - raise MarketplaceProductValidationException( - f"{field} is required", field=field - ) - - def _normalize_product_data(self, product_data: dict) -> dict: - """Normalize and clean product data.""" - normalized = product_data.copy() - - # Trim whitespace from string fields - string_fields = [ - "marketplace_product_id", - "brand", - "marketplace", - "vendor_name", - ] - for field in string_fields: - if field in normalized and normalized[field]: - normalized[field] = normalized[field].strip() - - return normalized - - # ========================================================================= - # Admin-specific methods for marketplace product management - # ========================================================================= - - def get_admin_products( - self, - db: Session, - skip: int = 0, - limit: int = 50, - search: str | None = None, - marketplace: str | None = None, - vendor_name: str | None = None, - availability: str | None = None, - is_active: bool | None = None, - is_digital: bool | None = None, - language: str = "en", - ) -> tuple[list[dict], int]: - """ - Get marketplace products for admin with search and filtering. - - Returns: - Tuple of (products list as dicts, total count) - """ - query = db.query(MarketplaceProduct).options( - joinedload(MarketplaceProduct.translations) - ) - - if search: - search_term = f"%{search}%" - # Use subquery to get distinct IDs (PostgreSQL can't compare JSON for DISTINCT) - id_subquery = ( - db.query(MarketplaceProduct.id) - .outerjoin(MarketplaceProductTranslation) - .filter( - or_( - MarketplaceProductTranslation.title.ilike(search_term), - MarketplaceProduct.gtin.ilike(search_term), - MarketplaceProduct.sku.ilike(search_term), - MarketplaceProduct.brand.ilike(search_term), - MarketplaceProduct.mpn.ilike(search_term), - MarketplaceProduct.marketplace_product_id.ilike(search_term), - ) - ) - .distinct() - .subquery() - ) - query = query.filter(MarketplaceProduct.id.in_( - db.query(id_subquery.c.id) - )) - - if marketplace: - query = query.filter(MarketplaceProduct.marketplace == marketplace) - - if vendor_name: - query = query.filter( - MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%") - ) - - if availability: - query = query.filter(MarketplaceProduct.availability == availability) - - if is_active is not None: - query = query.filter(MarketplaceProduct.is_active == is_active) - - if is_digital is not None: - query = query.filter(MarketplaceProduct.is_digital == is_digital) - - total = query.count() - - products = ( - query.order_by(MarketplaceProduct.updated_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) - - result = [] - for product in products: - title = product.get_title(language) - result.append(self._build_admin_product_item(product, title)) - - return result, total - - def get_admin_product_stats( - self, - db: Session, - marketplace: str | None = None, - vendor_name: str | None = None, - ) -> dict: - """Get product statistics for admin dashboard. - - Args: - db: Database session - marketplace: Optional filter by marketplace (e.g., 'Letzshop') - vendor_name: Optional filter by vendor name - """ - from sqlalchemy import func - - # Build base filter - base_filters = [] - if marketplace: - base_filters.append(MarketplaceProduct.marketplace == marketplace) - if vendor_name: - base_filters.append(MarketplaceProduct.vendor_name == vendor_name) - - base_query = db.query(func.count(MarketplaceProduct.id)) - if base_filters: - base_query = base_query.filter(*base_filters) - - total = base_query.scalar() or 0 - - active_query = db.query(func.count(MarketplaceProduct.id)).filter( - MarketplaceProduct.is_active == True # noqa: E712 - ) - if base_filters: - active_query = active_query.filter(*base_filters) - active = active_query.scalar() or 0 - inactive = total - active - - digital_query = db.query(func.count(MarketplaceProduct.id)).filter( - MarketplaceProduct.is_digital == True # noqa: E712 - ) - if base_filters: - digital_query = digital_query.filter(*base_filters) - digital = digital_query.scalar() or 0 - physical = total - digital - - marketplace_query = db.query( - MarketplaceProduct.marketplace, - func.count(MarketplaceProduct.id), - ) - if base_filters: - marketplace_query = marketplace_query.filter(*base_filters) - marketplace_counts = marketplace_query.group_by( - MarketplaceProduct.marketplace - ).all() - by_marketplace = {mp or "unknown": count for mp, count in marketplace_counts} - - return { - "total": total, - "active": active, - "inactive": inactive, - "digital": digital, - "physical": physical, - "by_marketplace": by_marketplace, - } - - def get_marketplaces_list(self, db: Session) -> list[str]: - """Get list of unique marketplaces in the product catalog.""" - marketplaces = ( - db.query(MarketplaceProduct.marketplace) - .distinct() - .filter(MarketplaceProduct.marketplace.isnot(None)) - .all() - ) - return [m[0] for m in marketplaces if m[0]] - - def get_source_vendors_list(self, db: Session) -> list[str]: - """Get list of unique vendor names in the product catalog.""" - vendors = ( - db.query(MarketplaceProduct.vendor_name) - .distinct() - .filter(MarketplaceProduct.vendor_name.isnot(None)) - .all() - ) - return [v[0] for v in vendors if v[0]] - - def get_admin_product_detail(self, db: Session, product_id: int) -> dict: - """Get detailed product information by database ID.""" - product = ( - db.query(MarketplaceProduct) - .options(joinedload(MarketplaceProduct.translations)) - .filter(MarketplaceProduct.id == product_id) - .first() - ) - - if not product: - raise MarketplaceProductNotFoundException( - f"Marketplace product with ID {product_id} not found" - ) - - translations = {} - for t in product.translations: - translations[t.language] = { - "title": t.title, - "description": t.description, - "short_description": t.short_description, - } - - return { - "id": product.id, - "marketplace_product_id": product.marketplace_product_id, - "gtin": product.gtin, - "mpn": product.mpn, - "sku": product.sku, - "brand": product.brand, - "marketplace": product.marketplace, - "vendor_name": product.vendor_name, - "source_url": product.source_url, - "price": product.price, - "price_numeric": product.price_numeric, - "sale_price": product.sale_price, - "sale_price_numeric": product.sale_price_numeric, - "currency": product.currency, - "availability": product.availability, - "condition": product.condition, - "image_link": product.image_link, - "additional_images": product.additional_images, - "is_active": product.is_active, - "is_digital": product.is_digital, - "product_type_enum": product.product_type_enum, - "platform": product.platform, - "google_product_category": product.google_product_category, - "category_path": product.category_path, - "color": product.color, - "size": product.size, - "weight": product.weight, - "weight_unit": product.weight_unit, - "translations": translations, - "created_at": product.created_at.isoformat() - if product.created_at - else None, - "updated_at": product.updated_at.isoformat() - if product.updated_at - else None, - } - - def copy_to_vendor_catalog( - self, - db: Session, - marketplace_product_ids: list[int], - vendor_id: int, - skip_existing: bool = True, - ) -> dict: - """ - Copy marketplace products to a vendor's catalog. - - Creates independent vendor products with ALL fields copied from the - marketplace product. Each vendor product is a standalone entity - no - field inheritance or fallback logic. The marketplace_product_id FK is - kept for "view original source" feature. - - Also copies ALL translations from the marketplace product. - - Returns: - Dict with copied, skipped, failed counts and details - """ - from models.database.product import Product - from models.database.product_translation import ProductTranslation - from models.database.vendor import Vendor - - vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first() - if not vendor: - from app.exceptions import VendorNotFoundException - - raise VendorNotFoundException(str(vendor_id), identifier_type="id") - - marketplace_products = ( - db.query(MarketplaceProduct) - .options(joinedload(MarketplaceProduct.translations)) - .filter(MarketplaceProduct.id.in_(marketplace_product_ids)) - .all() - ) - - if not marketplace_products: - raise MarketplaceProductNotFoundException("No marketplace products found") - - # Check product limit from subscription - from app.services.subscription_service import subscription_service - from sqlalchemy import func - - current_products = ( - db.query(func.count(Product.id)) - .filter(Product.vendor_id == vendor_id) - .scalar() - or 0 - ) - - subscription = subscription_service.get_or_create_subscription(db, vendor_id) - products_limit = subscription.products_limit - remaining_capacity = ( - products_limit - current_products if products_limit is not None else None - ) - - copied = 0 - skipped = 0 - failed = 0 - limit_reached = False - details = [] - - for mp in marketplace_products: - # Check if we've hit the product limit - if remaining_capacity is not None and copied >= remaining_capacity: - limit_reached = True - details.append( - { - "id": mp.id, - "status": "skipped", - "reason": "Product limit reached", - } - ) - skipped += 1 - continue - try: - existing = ( - db.query(Product) - .filter( - Product.vendor_id == vendor_id, - Product.marketplace_product_id == mp.id, - ) - .first() - ) - - if existing: - skipped += 1 - details.append( - { - "id": mp.id, - "status": "skipped", - "reason": "Already exists in catalog", - } - ) - continue - - # Create vendor product with ALL fields copied from marketplace - product = Product( - vendor_id=vendor_id, - marketplace_product_id=mp.id, - # === Vendor settings (defaults) === - is_active=True, - is_featured=False, - # === Product identifiers === - gtin=mp.gtin, - gtin_type=mp.gtin_type if hasattr(mp, "gtin_type") else None, - # === Pricing (copy from marketplace) === - price_cents=mp.price_cents, - sale_price_cents=mp.sale_price_cents, - currency=mp.currency or "EUR", - # === Product info === - brand=mp.brand, - condition=mp.condition, - availability=mp.availability, - # === Media === - primary_image_url=mp.image_link, - additional_images=mp.additional_images, - # === Digital product fields === - download_url=mp.download_url if hasattr(mp, "download_url") else None, - license_type=mp.license_type if hasattr(mp, "license_type") else None, - ) - - db.add(product) - db.flush() # Get product.id for translations - - # Copy ALL translations from marketplace product - translations_copied = 0 - for mpt in mp.translations: - product_translation = ProductTranslation( - product_id=product.id, - language=mpt.language, - title=mpt.title, - description=mpt.description, - short_description=mpt.short_description, - meta_title=mpt.meta_title, - meta_description=mpt.meta_description, - url_slug=mpt.url_slug, - ) - db.add(product_translation) - translations_copied += 1 - - copied += 1 - details.append({ - "id": mp.id, - "status": "copied", - "gtin": mp.gtin, - "translations_copied": translations_copied, - }) - - except Exception as e: - logger.error(f"Failed to copy product {mp.id}: {str(e)}") - failed += 1 - details.append({"id": mp.id, "status": "failed", "reason": str(e)}) - - db.flush() - - # Auto-match pending order item exceptions - # Collect GTINs and their product IDs from newly copied products - from app.services.order_item_exception_service import ( - order_item_exception_service, - ) - - gtin_to_product: dict[str, int] = {} - for detail in details: - if detail.get("status") == "copied" and detail.get("gtin"): - # Find the product we just created - product = ( - db.query(Product) - .filter( - Product.vendor_id == vendor_id, - Product.gtin == detail["gtin"], - ) - .first() - ) - if product: - gtin_to_product[detail["gtin"]] = product.id - - auto_matched = 0 - if gtin_to_product: - auto_matched = order_item_exception_service.auto_match_batch( - db, vendor_id, gtin_to_product - ) - if auto_matched: - logger.info( - f"Auto-matched {auto_matched} order item exceptions " - f"during product copy to vendor {vendor_id}" - ) - - logger.info( - f"Copied {copied} products to vendor {vendor.name} " - f"(skipped: {skipped}, failed: {failed}, auto_matched: {auto_matched})" - ) - - return { - "copied": copied, - "skipped": skipped, - "failed": failed, - "auto_matched": auto_matched, - "limit_reached": limit_reached, - "details": details if len(details) <= 100 else None, - } - - def _build_admin_product_item( - self, product: MarketplaceProduct, title: str | None - ) -> dict: - """Build a product list item dict for admin view.""" - return { - "id": product.id, - "marketplace_product_id": product.marketplace_product_id, - "title": title, - "brand": product.brand, - "gtin": product.gtin, - "sku": product.sku, - "marketplace": product.marketplace, - "vendor_name": product.vendor_name, - "price_numeric": product.price_numeric, - "currency": product.currency, - "availability": product.availability, - "image_link": product.image_link, - "is_active": product.is_active, - "is_digital": product.is_digital, - "product_type_enum": product.product_type_enum, - "created_at": product.created_at.isoformat() - if product.created_at - else None, - "updated_at": product.updated_at.isoformat() - if product.updated_at - else None, - } - - -# Create service instance -marketplace_product_service = MarketplaceProductService() +__all__ = [ + "MarketplaceProductService", + "marketplace_product_service", +] diff --git a/app/services/message_attachment_service.py b/app/services/message_attachment_service.py index 7fd4cd5d..78bff889 100644 --- a/app/services/message_attachment_service.py +++ b/app/services/message_attachment_service.py @@ -1,225 +1,23 @@ # app/services/message_attachment_service.py """ -Attachment handling service for messaging system. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Handles file upload, validation, storage, and retrieval. +The canonical implementation is now in: + app/modules/messaging/services/message_attachment_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.messaging.services import message_attachment_service """ -import logging -import os -import uuid -from datetime import datetime -from pathlib import Path +from app.modules.messaging.services.message_attachment_service import ( + message_attachment_service, + MessageAttachmentService, +) -from fastapi import UploadFile -from sqlalchemy.orm import Session - -from app.services.admin_settings_service import admin_settings_service - -logger = logging.getLogger(__name__) - -# Allowed MIME types for attachments -ALLOWED_MIME_TYPES = { - # Images - "image/jpeg", - "image/png", - "image/gif", - "image/webp", - # Documents - "application/pdf", - "application/msword", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-excel", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - # Archives - "application/zip", - # Text - "text/plain", - "text/csv", -} - -IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"} - -# Default max file size in MB -DEFAULT_MAX_FILE_SIZE_MB = 10 - - -class MessageAttachmentService: - """Service for handling message attachments.""" - - def __init__(self, storage_base: str = "uploads/messages"): - self.storage_base = storage_base - - def get_max_file_size_bytes(self, db: Session) -> int: - """Get maximum file size from platform settings.""" - max_mb = admin_settings_service.get_setting_value( - db, - "message_attachment_max_size_mb", - default=DEFAULT_MAX_FILE_SIZE_MB, - ) - try: - max_mb = int(max_mb) - except (TypeError, ValueError): - max_mb = DEFAULT_MAX_FILE_SIZE_MB - return max_mb * 1024 * 1024 # Convert to bytes - - def validate_file_type(self, mime_type: str) -> bool: - """Check if file type is allowed.""" - return mime_type in ALLOWED_MIME_TYPES - - def is_image(self, mime_type: str) -> bool: - """Check if file is an image.""" - return mime_type in IMAGE_MIME_TYPES - - async def validate_and_store( - self, - db: Session, - file: UploadFile, - conversation_id: int, - ) -> dict: - """ - Validate and store an uploaded file. - - Returns dict with file metadata for MessageAttachment creation. - - Raises: - ValueError: If file type or size is invalid - """ - # Validate MIME type - content_type = file.content_type or "application/octet-stream" - if not self.validate_file_type(content_type): - raise ValueError( - f"File type '{content_type}' not allowed. " - "Allowed types: images (JPEG, PNG, GIF, WebP), " - "PDF, Office documents, ZIP, text files." - ) - - # Read file content - content = await file.read() - file_size = len(content) - - # Validate file size - max_size = self.get_max_file_size_bytes(db) - if file_size > max_size: - raise ValueError( - f"File size {file_size / 1024 / 1024:.1f}MB exceeds " - f"maximum allowed size of {max_size / 1024 / 1024:.1f}MB" - ) - - # Generate unique filename - original_filename = file.filename or "attachment" - ext = Path(original_filename).suffix.lower() - unique_filename = f"{uuid.uuid4().hex}{ext}" - - # Create storage path: uploads/messages/YYYY/MM/conversation_id/filename - now = datetime.utcnow() - relative_path = os.path.join( - self.storage_base, - str(now.year), - f"{now.month:02d}", - str(conversation_id), - ) - - # Ensure directory exists - os.makedirs(relative_path, exist_ok=True) - - # Full file path - file_path = os.path.join(relative_path, unique_filename) - - # Write file - with open(file_path, "wb") as f: - f.write(content) - - # Prepare metadata - is_image = self.is_image(content_type) - metadata = { - "filename": unique_filename, - "original_filename": original_filename, - "file_path": file_path, - "file_size": file_size, - "mime_type": content_type, - "is_image": is_image, - } - - # Generate thumbnail for images - if is_image: - thumbnail_data = self._create_thumbnail(content, file_path) - metadata.update(thumbnail_data) - - logger.info( - f"Stored attachment {unique_filename} for conversation {conversation_id} " - f"({file_size} bytes, type: {content_type})" - ) - - return metadata - - def _create_thumbnail(self, content: bytes, original_path: str) -> dict: - """Create thumbnail for image attachments.""" - try: - from PIL import Image - import io - - img = Image.open(io.BytesIO(content)) - width, height = img.size - - # Create thumbnail - img.thumbnail((200, 200)) - - thumb_path = original_path.replace(".", "_thumb.") - img.save(thumb_path) - - return { - "image_width": width, - "image_height": height, - "thumbnail_path": thumb_path, - } - except ImportError: - logger.warning("PIL not installed, skipping thumbnail generation") - return {} - except Exception as e: - logger.error(f"Failed to create thumbnail: {e}") - return {} - - def delete_attachment( - self, file_path: str, thumbnail_path: str | None = None - ) -> bool: - """Delete attachment files from storage.""" - try: - if os.path.exists(file_path): - os.remove(file_path) - logger.info(f"Deleted attachment file: {file_path}") - - if thumbnail_path and os.path.exists(thumbnail_path): - os.remove(thumbnail_path) - logger.info(f"Deleted thumbnail: {thumbnail_path}") - - return True - except Exception as e: - logger.error(f"Failed to delete attachment {file_path}: {e}") - return False - - def get_download_url(self, file_path: str) -> str: - """ - Get download URL for an attachment. - - For local storage, returns a relative path that can be served - by the static file handler or a dedicated download endpoint. - """ - # Convert local path to URL path - # Assumes files are served from /static/uploads or similar - return f"/static/{file_path}" - - def get_file_content(self, file_path: str) -> bytes | None: - """Read file content from storage.""" - try: - if os.path.exists(file_path): - with open(file_path, "rb") as f: - return f.read() - return None - except Exception as e: - logger.error(f"Failed to read file {file_path}: {e}") - return None - - -# Singleton instance -message_attachment_service = MessageAttachmentService() +__all__ = [ + "message_attachment_service", + "MessageAttachmentService", +] diff --git a/app/services/messaging_service.py b/app/services/messaging_service.py index 143f7212..e46010c9 100644 --- a/app/services/messaging_service.py +++ b/app/services/messaging_service.py @@ -1,684 +1,23 @@ # app/services/messaging_service.py """ -Messaging service for conversation and message management. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Provides functionality for: -- Creating conversations between different participant types -- Sending messages with attachments -- Managing read status and unread counts -- Conversation listing with filters -- Multi-tenant data isolation +The canonical implementation is now in: + app/modules/messaging/services/messaging_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.messaging.services import messaging_service """ -import logging -from datetime import UTC, datetime -from typing import Any - -from sqlalchemy import and_, func, or_ -from sqlalchemy.orm import Session, joinedload - -from models.database.customer import Customer -from models.database.message import ( - Conversation, - ConversationParticipant, - ConversationType, - Message, - MessageAttachment, - ParticipantType, +from app.modules.messaging.services.messaging_service import ( + messaging_service, + MessagingService, ) -from models.database.user import User -logger = logging.getLogger(__name__) - - -class MessagingService: - """Service for managing conversations and messages.""" - - # ========================================================================= - # CONVERSATION MANAGEMENT - # ========================================================================= - - def create_conversation( - self, - db: Session, - conversation_type: ConversationType, - subject: str, - initiator_type: ParticipantType, - initiator_id: int, - recipient_type: ParticipantType, - recipient_id: int, - vendor_id: int | None = None, - initial_message: str | None = None, - ) -> Conversation: - """ - Create a new conversation between two participants. - - Args: - db: Database session - conversation_type: Type of conversation channel - subject: Conversation subject line - initiator_type: Type of initiating participant - initiator_id: ID of initiating participant - recipient_type: Type of receiving participant - recipient_id: ID of receiving participant - vendor_id: Required for vendor_customer/admin_customer types - initial_message: Optional first message content - - Returns: - Created Conversation object - """ - # Validate vendor_id requirement - if conversation_type in [ - ConversationType.VENDOR_CUSTOMER, - ConversationType.ADMIN_CUSTOMER, - ]: - if not vendor_id: - raise ValueError( - f"vendor_id required for {conversation_type.value} conversations" - ) - - # Create conversation - conversation = Conversation( - conversation_type=conversation_type, - subject=subject, - vendor_id=vendor_id, - ) - db.add(conversation) - db.flush() - - # Add participants - initiator_vendor_id = ( - vendor_id if initiator_type == ParticipantType.VENDOR else None - ) - recipient_vendor_id = ( - vendor_id if recipient_type == ParticipantType.VENDOR else None - ) - - initiator = ConversationParticipant( - conversation_id=conversation.id, - participant_type=initiator_type, - participant_id=initiator_id, - vendor_id=initiator_vendor_id, - unread_count=0, # Initiator has read their own message - ) - recipient = ConversationParticipant( - conversation_id=conversation.id, - participant_type=recipient_type, - participant_id=recipient_id, - vendor_id=recipient_vendor_id, - unread_count=1 if initial_message else 0, - ) - - db.add(initiator) - db.add(recipient) - db.flush() - - # Add initial message if provided - if initial_message: - self.send_message( - db=db, - conversation_id=conversation.id, - sender_type=initiator_type, - sender_id=initiator_id, - content=initial_message, - _skip_unread_update=True, # Already set above - ) - - logger.info( - f"Created {conversation_type.value} conversation {conversation.id}: " - f"{initiator_type.value}:{initiator_id} -> {recipient_type.value}:{recipient_id}" - ) - - return conversation - - def get_conversation( - self, - db: Session, - conversation_id: int, - participant_type: ParticipantType, - participant_id: int, - ) -> Conversation | None: - """ - Get conversation if participant has access. - - Validates that the requester is a participant. - """ - conversation = ( - db.query(Conversation) - .options( - joinedload(Conversation.participants), - joinedload(Conversation.messages).joinedload(Message.attachments), - ) - .filter(Conversation.id == conversation_id) - .first() - ) - - if not conversation: - return None - - # Verify participant access - has_access = any( - p.participant_type == participant_type - and p.participant_id == participant_id - for p in conversation.participants - ) - - if not has_access: - logger.warning( - f"Access denied to conversation {conversation_id} for " - f"{participant_type.value}:{participant_id}" - ) - return None - - return conversation - - def list_conversations( - self, - db: Session, - participant_type: ParticipantType, - participant_id: int, - vendor_id: int | None = None, - conversation_type: ConversationType | None = None, - is_closed: bool | None = None, - skip: int = 0, - limit: int = 20, - ) -> tuple[list[Conversation], int, int]: - """ - List conversations for a participant with filters. - - Returns: - Tuple of (conversations, total_count, total_unread) - """ - # Base query: conversations where user is a participant - query = ( - db.query(Conversation) - .join(ConversationParticipant) - .filter( - and_( - ConversationParticipant.participant_type == participant_type, - ConversationParticipant.participant_id == participant_id, - ) - ) - ) - - # Multi-tenant filter for vendor users - if participant_type == ParticipantType.VENDOR and vendor_id: - query = query.filter(ConversationParticipant.vendor_id == vendor_id) - - # Customer vendor isolation - if participant_type == ParticipantType.CUSTOMER and vendor_id: - query = query.filter(Conversation.vendor_id == vendor_id) - - # Type filter - if conversation_type: - query = query.filter(Conversation.conversation_type == conversation_type) - - # Status filter - if is_closed is not None: - query = query.filter(Conversation.is_closed == is_closed) - - # Get total count - total = query.count() - - # Get total unread across all conversations - unread_query = db.query( - func.sum(ConversationParticipant.unread_count) - ).filter( - and_( - ConversationParticipant.participant_type == participant_type, - ConversationParticipant.participant_id == participant_id, - ) - ) - - if participant_type == ParticipantType.VENDOR and vendor_id: - unread_query = unread_query.filter( - ConversationParticipant.vendor_id == vendor_id - ) - - total_unread = unread_query.scalar() or 0 - - # Get paginated results, ordered by last activity - conversations = ( - query.options(joinedload(Conversation.participants)) - .order_by(Conversation.last_message_at.desc().nullslast()) - .offset(skip) - .limit(limit) - .all() - ) - - return conversations, total, total_unread - - def close_conversation( - self, - db: Session, - conversation_id: int, - closer_type: ParticipantType, - closer_id: int, - ) -> Conversation | None: - """Close a conversation thread.""" - conversation = self.get_conversation( - db, conversation_id, closer_type, closer_id - ) - - if not conversation: - return None - - conversation.is_closed = True - conversation.closed_at = datetime.now(UTC) - conversation.closed_by_type = closer_type - conversation.closed_by_id = closer_id - - # Add system message - self.send_message( - db=db, - conversation_id=conversation_id, - sender_type=closer_type, - sender_id=closer_id, - content="Conversation closed", - is_system_message=True, - ) - - db.flush() - return conversation - - def reopen_conversation( - self, - db: Session, - conversation_id: int, - opener_type: ParticipantType, - opener_id: int, - ) -> Conversation | None: - """Reopen a closed conversation.""" - conversation = self.get_conversation( - db, conversation_id, opener_type, opener_id - ) - - if not conversation: - return None - - conversation.is_closed = False - conversation.closed_at = None - conversation.closed_by_type = None - conversation.closed_by_id = None - - # Add system message - self.send_message( - db=db, - conversation_id=conversation_id, - sender_type=opener_type, - sender_id=opener_id, - content="Conversation reopened", - is_system_message=True, - ) - - db.flush() - return conversation - - # ========================================================================= - # MESSAGE MANAGEMENT - # ========================================================================= - - def send_message( - self, - db: Session, - conversation_id: int, - sender_type: ParticipantType, - sender_id: int, - content: str, - attachments: list[dict[str, Any]] | None = None, - is_system_message: bool = False, - _skip_unread_update: bool = False, - ) -> Message: - """ - Send a message in a conversation. - - Args: - db: Database session - conversation_id: Target conversation ID - sender_type: Type of sender - sender_id: ID of sender - content: Message text content - attachments: List of attachment dicts with file metadata - is_system_message: Whether this is a system-generated message - _skip_unread_update: Internal flag to skip unread increment - - Returns: - Created Message object - """ - # Create message - message = Message( - conversation_id=conversation_id, - sender_type=sender_type, - sender_id=sender_id, - content=content, - is_system_message=is_system_message, - ) - db.add(message) - db.flush() - - # Add attachments if any - if attachments: - for att_data in attachments: - attachment = MessageAttachment( - message_id=message.id, - filename=att_data["filename"], - original_filename=att_data["original_filename"], - file_path=att_data["file_path"], - file_size=att_data["file_size"], - mime_type=att_data["mime_type"], - is_image=att_data.get("is_image", False), - image_width=att_data.get("image_width"), - image_height=att_data.get("image_height"), - thumbnail_path=att_data.get("thumbnail_path"), - ) - db.add(attachment) - - # Update conversation metadata - conversation = ( - db.query(Conversation).filter(Conversation.id == conversation_id).first() - ) - - if conversation: - conversation.last_message_at = datetime.now(UTC) - conversation.message_count += 1 - - # Update unread counts for other participants - if not _skip_unread_update: - db.query(ConversationParticipant).filter( - and_( - ConversationParticipant.conversation_id == conversation_id, - or_( - ConversationParticipant.participant_type != sender_type, - ConversationParticipant.participant_id != sender_id, - ), - ) - ).update( - { - ConversationParticipant.unread_count: ConversationParticipant.unread_count - + 1 - } - ) - - db.flush() - - logger.info( - f"Message {message.id} sent in conversation {conversation_id} " - f"by {sender_type.value}:{sender_id}" - ) - - return message - - def delete_message( - self, - db: Session, - message_id: int, - deleter_type: ParticipantType, - deleter_id: int, - ) -> Message | None: - """Soft delete a message (for moderation).""" - message = db.query(Message).filter(Message.id == message_id).first() - - if not message: - return None - - # Verify deleter has access to conversation - conversation = self.get_conversation( - db, message.conversation_id, deleter_type, deleter_id - ) - if not conversation: - return None - - message.is_deleted = True - message.deleted_at = datetime.now(UTC) - message.deleted_by_type = deleter_type - message.deleted_by_id = deleter_id - - db.flush() - return message - - def mark_conversation_read( - self, - db: Session, - conversation_id: int, - reader_type: ParticipantType, - reader_id: int, - ) -> bool: - """Mark all messages in conversation as read for participant.""" - result = ( - db.query(ConversationParticipant) - .filter( - and_( - ConversationParticipant.conversation_id == conversation_id, - ConversationParticipant.participant_type == reader_type, - ConversationParticipant.participant_id == reader_id, - ) - ) - .update( - { - ConversationParticipant.unread_count: 0, - ConversationParticipant.last_read_at: datetime.now(UTC), - } - ) - ) - db.flush() - return result > 0 - - def get_unread_count( - self, - db: Session, - participant_type: ParticipantType, - participant_id: int, - vendor_id: int | None = None, - ) -> int: - """Get total unread message count for a participant.""" - query = db.query(func.sum(ConversationParticipant.unread_count)).filter( - and_( - ConversationParticipant.participant_type == participant_type, - ConversationParticipant.participant_id == participant_id, - ) - ) - - if vendor_id: - query = query.filter(ConversationParticipant.vendor_id == vendor_id) - - return query.scalar() or 0 - - # ========================================================================= - # PARTICIPANT HELPERS - # ========================================================================= - - def get_participant_info( - self, - db: Session, - participant_type: ParticipantType, - participant_id: int, - ) -> dict[str, Any] | None: - """Get display info for a participant (name, email, avatar).""" - if participant_type in [ParticipantType.ADMIN, ParticipantType.VENDOR]: - user = db.query(User).filter(User.id == participant_id).first() - if user: - return { - "id": user.id, - "type": participant_type.value, - "name": f"{user.first_name or ''} {user.last_name or ''}".strip() - or user.username, - "email": user.email, - "avatar_url": None, # Could add avatar support later - } - elif participant_type == ParticipantType.CUSTOMER: - customer = db.query(Customer).filter(Customer.id == participant_id).first() - if customer: - return { - "id": customer.id, - "type": participant_type.value, - "name": f"{customer.first_name or ''} {customer.last_name or ''}".strip() - or customer.email, - "email": customer.email, - "avatar_url": None, - } - return None - - def get_other_participant( - self, - conversation: Conversation, - my_type: ParticipantType, - my_id: int, - ) -> ConversationParticipant | None: - """Get the other participant in a conversation.""" - for p in conversation.participants: - if p.participant_type != my_type or p.participant_id != my_id: - return p - return None - - # ========================================================================= - # NOTIFICATION PREFERENCES - # ========================================================================= - - def update_notification_preferences( - self, - db: Session, - conversation_id: int, - participant_type: ParticipantType, - participant_id: int, - email_notifications: bool | None = None, - muted: bool | None = None, - ) -> bool: - """Update notification preferences for a participant in a conversation.""" - updates = {} - if email_notifications is not None: - updates[ConversationParticipant.email_notifications] = email_notifications - if muted is not None: - updates[ConversationParticipant.muted] = muted - - if not updates: - return False - - result = ( - db.query(ConversationParticipant) - .filter( - and_( - ConversationParticipant.conversation_id == conversation_id, - ConversationParticipant.participant_type == participant_type, - ConversationParticipant.participant_id == participant_id, - ) - ) - .update(updates) - ) - db.flush() - return result > 0 - - # ========================================================================= - # RECIPIENT QUERIES - # ========================================================================= - - def get_vendor_recipients( - self, - db: Session, - vendor_id: int | None = None, - search: str | None = None, - skip: int = 0, - limit: int = 50, - ) -> tuple[list[dict], int]: - """ - Get list of vendor users as potential recipients. - - Args: - db: Database session - vendor_id: Optional vendor ID filter - search: Search term for name/email - skip: Pagination offset - limit: Max results - - Returns: - Tuple of (recipients list, total count) - """ - from models.database.vendor import VendorUser - - query = ( - db.query(User, VendorUser) - .join(VendorUser, User.id == VendorUser.user_id) - .filter(User.is_active == True) # noqa: E712 - ) - - if vendor_id: - query = query.filter(VendorUser.vendor_id == vendor_id) - - if search: - search_pattern = f"%{search}%" - query = query.filter( - (User.username.ilike(search_pattern)) - | (User.email.ilike(search_pattern)) - | (User.first_name.ilike(search_pattern)) - | (User.last_name.ilike(search_pattern)) - ) - - total = query.count() - results = query.offset(skip).limit(limit).all() - - recipients = [] - for user, vendor_user in results: - name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username - recipients.append({ - "id": user.id, - "type": ParticipantType.VENDOR, - "name": name, - "email": user.email, - "vendor_id": vendor_user.vendor_id, - "vendor_name": vendor_user.vendor.name if vendor_user.vendor else None, - }) - - return recipients, total - - def get_customer_recipients( - self, - db: Session, - vendor_id: int | None = None, - search: str | None = None, - skip: int = 0, - limit: int = 50, - ) -> tuple[list[dict], int]: - """ - Get list of customers as potential recipients. - - Args: - db: Database session - vendor_id: Optional vendor ID filter (required for vendor users) - search: Search term for name/email - skip: Pagination offset - limit: Max results - - Returns: - Tuple of (recipients list, total count) - """ - query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712 - - if vendor_id: - query = query.filter(Customer.vendor_id == vendor_id) - - if search: - search_pattern = f"%{search}%" - query = query.filter( - (Customer.email.ilike(search_pattern)) - | (Customer.first_name.ilike(search_pattern)) - | (Customer.last_name.ilike(search_pattern)) - ) - - total = query.count() - results = query.offset(skip).limit(limit).all() - - recipients = [] - for customer in results: - name = f"{customer.first_name or ''} {customer.last_name or ''}".strip() - recipients.append({ - "id": customer.id, - "type": ParticipantType.CUSTOMER, - "name": name or customer.email, - "email": customer.email, - "vendor_id": customer.vendor_id, - }) - - return recipients, total - - -# Singleton instance -messaging_service = MessagingService() +__all__ = [ + "messaging_service", + "MessagingService", +] diff --git a/app/services/order_inventory_service.py b/app/services/order_inventory_service.py index f514a46e..d979453b 100644 --- a/app/services/order_inventory_service.py +++ b/app/services/order_inventory_service.py @@ -1,735 +1,23 @@ # app/services/order_inventory_service.py """ -Order-Inventory Integration Service. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -This service orchestrates inventory operations for orders: -- Reserve inventory when orders are confirmed -- Fulfill (deduct) inventory when orders are shipped -- Release reservations when orders are cancelled +The canonical implementation is now in: + app/modules/orders/services/order_inventory_service.py -This is the critical link between the order and inventory systems -that ensures stock accuracy. +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: -All operations are logged to the inventory_transactions table for audit trail. + from app.modules.orders.services import order_inventory_service """ -import logging -from sqlalchemy.orm import Session - -from app.exceptions import ( - InsufficientInventoryException, - InventoryNotFoundException, - OrderNotFoundException, - ValidationException, +from app.modules.orders.services.order_inventory_service import ( + order_inventory_service, + OrderInventoryService, ) -from app.services.inventory_service import inventory_service -from models.database.inventory import Inventory -from models.database.inventory_transaction import InventoryTransaction, TransactionType -from models.database.order import Order, OrderItem -from models.schema.inventory import InventoryReserve -logger = logging.getLogger(__name__) - -# Default location for inventory operations -DEFAULT_LOCATION = "DEFAULT" - - -class OrderInventoryService: - """ - Orchestrate order and inventory operations together. - - This service ensures that: - 1. When orders are confirmed, inventory is reserved - 2. When orders are shipped, inventory is fulfilled (deducted) - 3. When orders are cancelled, reservations are released - - Note: Letzshop orders with unmatched products (placeholder) skip - inventory operations for those items. - """ - - def get_order_with_items( - self, db: Session, vendor_id: int, order_id: int - ) -> Order: - """Get order with items or raise OrderNotFoundException.""" - order = ( - db.query(Order) - .filter(Order.id == order_id, Order.vendor_id == vendor_id) - .first() - ) - if not order: - raise OrderNotFoundException(f"Order {order_id} not found") - return order - - def _find_inventory_location( - self, db: Session, product_id: int, vendor_id: int - ) -> str | None: - """ - Find the location with available inventory for a product. - - Returns the first location with available quantity, or None if no - inventory exists. - """ - inventory = ( - db.query(Inventory) - .filter( - Inventory.product_id == product_id, - Inventory.vendor_id == vendor_id, - Inventory.quantity > Inventory.reserved_quantity, - ) - .first() - ) - return inventory.location if inventory else None - - def _is_placeholder_product(self, order_item: OrderItem) -> bool: - """Check if the order item uses a placeholder product.""" - if not order_item.product: - return True - # Check if it's the placeholder product (GTIN 0000000000000) - return order_item.product.gtin == "0000000000000" - - def _log_transaction( - self, - db: Session, - vendor_id: int, - product_id: int, - inventory: Inventory, - transaction_type: TransactionType, - quantity_change: int, - order: Order, - reason: str | None = None, - ) -> InventoryTransaction: - """ - Create an inventory transaction record for audit trail. - - Args: - db: Database session - vendor_id: Vendor ID - product_id: Product ID - inventory: Inventory record after the operation - transaction_type: Type of transaction - quantity_change: Change in quantity (positive = add, negative = remove) - order: Order associated with this transaction - reason: Optional reason for the transaction - - Returns: - Created InventoryTransaction - """ - transaction = InventoryTransaction.create_transaction( - vendor_id=vendor_id, - product_id=product_id, - inventory_id=inventory.id if inventory else None, - transaction_type=transaction_type, - quantity_change=quantity_change, - quantity_after=inventory.quantity if inventory else 0, - reserved_after=inventory.reserved_quantity if inventory else 0, - location=inventory.location if inventory else None, - warehouse=inventory.warehouse if inventory else None, - order_id=order.id, - order_number=order.order_number, - reason=reason, - created_by="system", - ) - db.add(transaction) - return transaction - - def reserve_for_order( - self, - db: Session, - vendor_id: int, - order_id: int, - skip_missing: bool = True, - ) -> dict: - """ - Reserve inventory for all items in an order. - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - skip_missing: If True, skip items without inventory instead of failing - - Returns: - Dict with reserved count and any skipped items - - Raises: - InsufficientInventoryException: If skip_missing=False and inventory unavailable - """ - order = self.get_order_with_items(db, vendor_id, order_id) - - reserved_count = 0 - skipped_items = [] - - for item in order.items: - # Skip placeholder products - if self._is_placeholder_product(item): - skipped_items.append({ - "item_id": item.id, - "reason": "placeholder_product", - }) - continue - - # Find inventory location - location = self._find_inventory_location(db, item.product_id, vendor_id) - - if not location: - if skip_missing: - skipped_items.append({ - "item_id": item.id, - "product_id": item.product_id, - "reason": "no_inventory", - }) - continue - else: - raise InventoryNotFoundException( - f"No inventory found for product {item.product_id}" - ) - - try: - reserve_data = InventoryReserve( - product_id=item.product_id, - location=location, - quantity=item.quantity, - ) - updated_inventory = inventory_service.reserve_inventory( - db, vendor_id, reserve_data - ) - reserved_count += 1 - - # Log transaction for audit trail - self._log_transaction( - db=db, - vendor_id=vendor_id, - product_id=item.product_id, - inventory=updated_inventory, - transaction_type=TransactionType.RESERVE, - quantity_change=0, # Reserve doesn't change quantity, only reserved_quantity - order=order, - reason=f"Reserved for order {order.order_number}", - ) - - logger.info( - f"Reserved {item.quantity} units of product {item.product_id} " - f"for order {order.order_number}" - ) - except InsufficientInventoryException: - if skip_missing: - skipped_items.append({ - "item_id": item.id, - "product_id": item.product_id, - "reason": "insufficient_inventory", - }) - else: - raise - - logger.info( - f"Order {order.order_number}: reserved {reserved_count} items, " - f"skipped {len(skipped_items)}" - ) - - return { - "order_id": order_id, - "order_number": order.order_number, - "reserved_count": reserved_count, - "skipped_items": skipped_items, - } - - def fulfill_order( - self, - db: Session, - vendor_id: int, - order_id: int, - skip_missing: bool = True, - ) -> dict: - """ - Fulfill (deduct) inventory when an order is shipped. - - This decreases both the total quantity and reserved quantity, - effectively consuming the reserved stock. - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - skip_missing: If True, skip items without inventory - - Returns: - Dict with fulfilled count and any skipped items - """ - order = self.get_order_with_items(db, vendor_id, order_id) - - fulfilled_count = 0 - skipped_items = [] - - for item in order.items: - # Skip already fully shipped items - if item.is_fully_shipped: - continue - - # Skip placeholder products - if self._is_placeholder_product(item): - skipped_items.append({ - "item_id": item.id, - "reason": "placeholder_product", - }) - continue - - # Only fulfill remaining quantity - quantity_to_fulfill = item.remaining_quantity - - # Find inventory location - location = self._find_inventory_location(db, item.product_id, vendor_id) - - # Also check for inventory with reserved quantity - if not location: - inventory = ( - db.query(Inventory) - .filter( - Inventory.product_id == item.product_id, - Inventory.vendor_id == vendor_id, - ) - .first() - ) - if inventory: - location = inventory.location - - if not location: - if skip_missing: - skipped_items.append({ - "item_id": item.id, - "product_id": item.product_id, - "reason": "no_inventory", - }) - continue - else: - raise InventoryNotFoundException( - f"No inventory found for product {item.product_id}" - ) - - try: - reserve_data = InventoryReserve( - product_id=item.product_id, - location=location, - quantity=quantity_to_fulfill, - ) - updated_inventory = inventory_service.fulfill_reservation( - db, vendor_id, reserve_data - ) - fulfilled_count += 1 - - # Update item shipped quantity - item.shipped_quantity = item.quantity - item.inventory_fulfilled = True - - # Log transaction for audit trail - self._log_transaction( - db=db, - vendor_id=vendor_id, - product_id=item.product_id, - inventory=updated_inventory, - transaction_type=TransactionType.FULFILL, - quantity_change=-quantity_to_fulfill, # Negative because stock is consumed - order=order, - reason=f"Fulfilled for order {order.order_number}", - ) - - logger.info( - f"Fulfilled {quantity_to_fulfill} units of product {item.product_id} " - f"for order {order.order_number}" - ) - except (InsufficientInventoryException, InventoryNotFoundException) as e: - if skip_missing: - skipped_items.append({ - "item_id": item.id, - "product_id": item.product_id, - "reason": str(e), - }) - else: - raise - - logger.info( - f"Order {order.order_number}: fulfilled {fulfilled_count} items, " - f"skipped {len(skipped_items)}" - ) - - return { - "order_id": order_id, - "order_number": order.order_number, - "fulfilled_count": fulfilled_count, - "skipped_items": skipped_items, - } - - def fulfill_item( - self, - db: Session, - vendor_id: int, - order_id: int, - item_id: int, - quantity: int | None = None, - skip_missing: bool = True, - ) -> dict: - """ - Fulfill (deduct) inventory for a specific order item. - - Supports partial fulfillment - ship some units now, rest later. - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - item_id: Order item ID - quantity: Quantity to ship (defaults to remaining quantity) - skip_missing: If True, skip if inventory not found - - Returns: - Dict with fulfillment result - - Raises: - ValidationException: If quantity exceeds remaining - """ - order = self.get_order_with_items(db, vendor_id, order_id) - - # Find the item - item = None - for order_item in order.items: - if order_item.id == item_id: - item = order_item - break - - if not item: - raise ValidationException(f"Item {item_id} not found in order {order_id}") - - # Check if already fully shipped - if item.is_fully_shipped: - return { - "order_id": order_id, - "item_id": item_id, - "fulfilled_quantity": 0, - "message": "Item already fully shipped", - } - - # Default to remaining quantity - quantity_to_fulfill = quantity or item.remaining_quantity - - # Validate quantity - if quantity_to_fulfill > item.remaining_quantity: - raise ValidationException( - f"Cannot ship {quantity_to_fulfill} units - only {item.remaining_quantity} remaining" - ) - - if quantity_to_fulfill <= 0: - return { - "order_id": order_id, - "item_id": item_id, - "fulfilled_quantity": 0, - "message": "Nothing to fulfill", - } - - # Skip placeholder products - if self._is_placeholder_product(item): - return { - "order_id": order_id, - "item_id": item_id, - "fulfilled_quantity": 0, - "message": "Placeholder product - skipped", - } - - # Find inventory location - location = self._find_inventory_location(db, item.product_id, vendor_id) - - if not location: - inventory = ( - db.query(Inventory) - .filter( - Inventory.product_id == item.product_id, - Inventory.vendor_id == vendor_id, - ) - .first() - ) - if inventory: - location = inventory.location - - if not location: - if skip_missing: - return { - "order_id": order_id, - "item_id": item_id, - "fulfilled_quantity": 0, - "message": "No inventory found", - } - else: - raise InventoryNotFoundException( - f"No inventory found for product {item.product_id}" - ) - - try: - reserve_data = InventoryReserve( - product_id=item.product_id, - location=location, - quantity=quantity_to_fulfill, - ) - updated_inventory = inventory_service.fulfill_reservation( - db, vendor_id, reserve_data - ) - - # Update item shipped quantity - item.shipped_quantity += quantity_to_fulfill - - # Mark as fulfilled only if fully shipped - if item.is_fully_shipped: - item.inventory_fulfilled = True - - # Log transaction - self._log_transaction( - db=db, - vendor_id=vendor_id, - product_id=item.product_id, - inventory=updated_inventory, - transaction_type=TransactionType.FULFILL, - quantity_change=-quantity_to_fulfill, - order=order, - reason=f"Partial shipment for order {order.order_number}", - ) - - logger.info( - f"Fulfilled {quantity_to_fulfill} of {item.quantity} units " - f"for item {item_id} in order {order.order_number}" - ) - - return { - "order_id": order_id, - "item_id": item_id, - "fulfilled_quantity": quantity_to_fulfill, - "shipped_quantity": item.shipped_quantity, - "remaining_quantity": item.remaining_quantity, - "is_fully_shipped": item.is_fully_shipped, - } - - except (InsufficientInventoryException, InventoryNotFoundException) as e: - if skip_missing: - return { - "order_id": order_id, - "item_id": item_id, - "fulfilled_quantity": 0, - "message": str(e), - } - else: - raise - - def release_order_reservation( - self, - db: Session, - vendor_id: int, - order_id: int, - skip_missing: bool = True, - ) -> dict: - """ - Release reserved inventory when an order is cancelled. - - This decreases the reserved quantity, making the stock available again. - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - skip_missing: If True, skip items without inventory - - Returns: - Dict with released count and any skipped items - """ - order = self.get_order_with_items(db, vendor_id, order_id) - - released_count = 0 - skipped_items = [] - - for item in order.items: - # Skip placeholder products - if self._is_placeholder_product(item): - skipped_items.append({ - "item_id": item.id, - "reason": "placeholder_product", - }) - continue - - # Find inventory - look for any inventory for this product - inventory = ( - db.query(Inventory) - .filter( - Inventory.product_id == item.product_id, - Inventory.vendor_id == vendor_id, - ) - .first() - ) - - if not inventory: - if skip_missing: - skipped_items.append({ - "item_id": item.id, - "product_id": item.product_id, - "reason": "no_inventory", - }) - continue - else: - raise InventoryNotFoundException( - f"No inventory found for product {item.product_id}" - ) - - try: - reserve_data = InventoryReserve( - product_id=item.product_id, - location=inventory.location, - quantity=item.quantity, - ) - updated_inventory = inventory_service.release_reservation( - db, vendor_id, reserve_data - ) - released_count += 1 - - # Log transaction for audit trail - self._log_transaction( - db=db, - vendor_id=vendor_id, - product_id=item.product_id, - inventory=updated_inventory, - transaction_type=TransactionType.RELEASE, - quantity_change=0, # Release doesn't change quantity, only reserved_quantity - order=order, - reason=f"Released for cancelled order {order.order_number}", - ) - - logger.info( - f"Released {item.quantity} units of product {item.product_id} " - f"for cancelled order {order.order_number}" - ) - except Exception as e: - if skip_missing: - skipped_items.append({ - "item_id": item.id, - "product_id": item.product_id, - "reason": str(e), - }) - else: - raise - - logger.info( - f"Order {order.order_number}: released {released_count} items, " - f"skipped {len(skipped_items)}" - ) - - return { - "order_id": order_id, - "order_number": order.order_number, - "released_count": released_count, - "skipped_items": skipped_items, - } - - def handle_status_change( - self, - db: Session, - vendor_id: int, - order_id: int, - old_status: str | None, - new_status: str, - ) -> dict | None: - """ - Handle inventory operations based on order status changes. - - Status transitions that trigger inventory operations: - - Any → processing: Reserve inventory (if not already reserved) - - processing → shipped: Fulfill inventory (deduct from stock) - - processing → partially_shipped: Partial fulfillment already done via fulfill_item - - Any → cancelled: Release reservations - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - old_status: Previous status (can be None for new orders) - new_status: New status - - Returns: - Result of inventory operation, or None if no operation needed - """ - # Skip if status didn't change - if old_status == new_status: - return None - - result = None - - # Transitioning to processing - reserve inventory - if new_status == "processing": - result = self.reserve_for_order(db, vendor_id, order_id, skip_missing=True) - logger.info(f"Order {order_id} confirmed: inventory reserved") - - # Transitioning to shipped - fulfill remaining inventory - elif new_status == "shipped": - result = self.fulfill_order(db, vendor_id, order_id, skip_missing=True) - logger.info(f"Order {order_id} shipped: inventory fulfilled") - - # partially_shipped - no automatic fulfillment (handled via fulfill_item) - elif new_status == "partially_shipped": - logger.info( - f"Order {order_id} partially shipped: use fulfill_item for item-level fulfillment" - ) - result = {"order_id": order_id, "status": "partially_shipped"} - - # Transitioning to cancelled - release reservations - elif new_status == "cancelled": - # Only release if there was a previous status (order was in progress) - if old_status and old_status not in ("cancelled", "refunded"): - result = self.release_order_reservation( - db, vendor_id, order_id, skip_missing=True - ) - logger.info(f"Order {order_id} cancelled: reservations released") - - return result - - def get_shipment_status( - self, - db: Session, - vendor_id: int, - order_id: int, - ) -> dict: - """ - Get detailed shipment status for an order. - - Returns item-level shipment status for partial shipment tracking. - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - - Returns: - Dict with shipment status details - """ - order = self.get_order_with_items(db, vendor_id, order_id) - - items = [] - for item in order.items: - items.append({ - "item_id": item.id, - "product_id": item.product_id, - "product_name": item.product_name, - "quantity": item.quantity, - "shipped_quantity": item.shipped_quantity, - "remaining_quantity": item.remaining_quantity, - "is_fully_shipped": item.is_fully_shipped, - "is_partially_shipped": item.is_partially_shipped, - }) - - return { - "order_id": order_id, - "order_number": order.order_number, - "order_status": order.status, - "is_fully_shipped": order.is_fully_shipped, - "is_partially_shipped": order.is_partially_shipped, - "shipped_item_count": order.shipped_item_count, - "total_item_count": len(order.items), - "total_shipped_units": order.total_shipped_units, - "total_ordered_units": order.total_ordered_units, - "items": items, - } - - -# Create service instance -order_inventory_service = OrderInventoryService() +__all__ = [ + "order_inventory_service", + "OrderInventoryService", +] diff --git a/app/services/order_item_exception_service.py b/app/services/order_item_exception_service.py index 3141e1e8..d75ba061 100644 --- a/app/services/order_item_exception_service.py +++ b/app/services/order_item_exception_service.py @@ -1,632 +1,23 @@ # app/services/order_item_exception_service.py """ -Service for managing order item exceptions (unmatched products). +LEGACY LOCATION - Re-exports from module for backwards compatibility. -This service handles: -- Creating exceptions when products are not found during order import -- Resolving exceptions by assigning products -- Auto-matching when new products are imported -- Querying and statistics for exceptions +The canonical implementation is now in: + app/modules/orders/services/order_item_exception_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.orders.services import order_item_exception_service """ -import logging -from datetime import UTC, datetime -from typing import Any - -from sqlalchemy import and_, func, or_ -from sqlalchemy.orm import Session, joinedload - -from app.exceptions import ( - ExceptionAlreadyResolvedException, - InvalidProductForExceptionException, - OrderItemExceptionNotFoundException, - ProductNotFoundException, +from app.modules.orders.services.order_item_exception_service import ( + order_item_exception_service, + OrderItemExceptionService, ) -from models.database.order import Order, OrderItem -from models.database.order_item_exception import OrderItemException -from models.database.product import Product -logger = logging.getLogger(__name__) - - -class OrderItemExceptionService: - """Service for order item exception CRUD and resolution workflow.""" - - # ========================================================================= - # Exception Creation - # ========================================================================= - - def create_exception( - self, - db: Session, - order_item: OrderItem, - vendor_id: int, - original_gtin: str | None, - original_product_name: str | None, - original_sku: str | None, - exception_type: str = "product_not_found", - ) -> OrderItemException: - """ - Create an exception record for an unmatched order item. - - Args: - db: Database session - order_item: The order item that couldn't be matched - vendor_id: Vendor ID (denormalized for efficient queries) - original_gtin: Original GTIN from marketplace - original_product_name: Original product name from marketplace - original_sku: Original SKU from marketplace - exception_type: Type of exception (product_not_found, gtin_mismatch, etc.) - - Returns: - Created OrderItemException - """ - exception = OrderItemException( - order_item_id=order_item.id, - vendor_id=vendor_id, - original_gtin=original_gtin, - original_product_name=original_product_name, - original_sku=original_sku, - exception_type=exception_type, - status="pending", - ) - db.add(exception) - db.flush() - - logger.info( - f"Created order item exception {exception.id} for order item " - f"{order_item.id}, GTIN: {original_gtin}" - ) - - return exception - - # ========================================================================= - # Exception Retrieval - # ========================================================================= - - def get_exception_by_id( - self, - db: Session, - exception_id: int, - vendor_id: int | None = None, - ) -> OrderItemException: - """ - Get an exception by ID, optionally filtered by vendor. - - Args: - db: Database session - exception_id: Exception ID - vendor_id: Optional vendor ID filter (for vendor-scoped access) - - Returns: - OrderItemException - - Raises: - OrderItemExceptionNotFoundException: If not found - """ - query = db.query(OrderItemException).filter( - OrderItemException.id == exception_id - ) - - if vendor_id is not None: - query = query.filter(OrderItemException.vendor_id == vendor_id) - - exception = query.first() - - if not exception: - raise OrderItemExceptionNotFoundException(exception_id) - - return exception - - def get_pending_exceptions( - self, - db: Session, - vendor_id: int | None = None, - status: str | None = None, - search: str | None = None, - skip: int = 0, - limit: int = 50, - ) -> tuple[list[OrderItemException], int]: - """ - Get exceptions with pagination and filtering. - - Args: - db: Database session - vendor_id: Optional vendor filter - status: Optional status filter (pending, resolved, ignored) - search: Optional search in GTIN, product name, or order number - skip: Pagination offset - limit: Pagination limit - - Returns: - Tuple of (list of exceptions, total count) - """ - query = ( - db.query(OrderItemException) - .join(OrderItem) - .join(Order) - .options( - joinedload(OrderItemException.order_item).joinedload(OrderItem.order) - ) - ) - - # Apply filters - if vendor_id is not None: - query = query.filter(OrderItemException.vendor_id == vendor_id) - - if status: - query = query.filter(OrderItemException.status == status) - - if search: - search_pattern = f"%{search}%" - query = query.filter( - or_( - OrderItemException.original_gtin.ilike(search_pattern), - OrderItemException.original_product_name.ilike(search_pattern), - OrderItemException.original_sku.ilike(search_pattern), - Order.order_number.ilike(search_pattern), - ) - ) - - # Get total count - total = query.count() - - # Apply pagination and ordering - exceptions = ( - query.order_by(OrderItemException.created_at.desc()) - .offset(skip) - .limit(limit) - .all() - ) - - return exceptions, total - - def get_exceptions_for_order( - self, - db: Session, - order_id: int, - ) -> list[OrderItemException]: - """ - Get all exceptions for items in an order. - - Args: - db: Database session - order_id: Order ID - - Returns: - List of exceptions for the order - """ - return ( - db.query(OrderItemException) - .join(OrderItem) - .filter(OrderItem.order_id == order_id) - .all() - ) - - # ========================================================================= - # Exception Statistics - # ========================================================================= - - def get_exception_stats( - self, - db: Session, - vendor_id: int | None = None, - ) -> dict[str, int]: - """ - Get exception counts by status. - - Args: - db: Database session - vendor_id: Optional vendor filter - - Returns: - Dict with pending, resolved, ignored, total counts - """ - query = db.query( - OrderItemException.status, - func.count(OrderItemException.id).label("count"), - ) - - if vendor_id is not None: - query = query.filter(OrderItemException.vendor_id == vendor_id) - - results = query.group_by(OrderItemException.status).all() - - stats = { - "pending": 0, - "resolved": 0, - "ignored": 0, - "total": 0, - } - - for status, count in results: - if status in stats: - stats[status] = count - stats["total"] += count - - # Count orders with pending exceptions - orders_query = ( - db.query(func.count(func.distinct(OrderItem.order_id))) - .join(OrderItemException) - .filter(OrderItemException.status == "pending") - ) - - if vendor_id is not None: - orders_query = orders_query.filter( - OrderItemException.vendor_id == vendor_id - ) - - stats["orders_with_exceptions"] = orders_query.scalar() or 0 - - return stats - - # ========================================================================= - # Exception Resolution - # ========================================================================= - - def resolve_exception( - self, - db: Session, - exception_id: int, - product_id: int, - resolved_by: int, - notes: str | None = None, - vendor_id: int | None = None, - ) -> OrderItemException: - """ - Resolve an exception by assigning a product. - - This updates: - - The exception record (status, resolved_product_id, etc.) - - The order item (product_id, needs_product_match) - - Args: - db: Database session - exception_id: Exception ID to resolve - product_id: Product ID to assign - resolved_by: User ID who resolved - notes: Optional resolution notes - vendor_id: Optional vendor filter (for scoped access) - - Returns: - Updated OrderItemException - - Raises: - OrderItemExceptionNotFoundException: If exception not found - ExceptionAlreadyResolvedException: If already resolved - InvalidProductForExceptionException: If product is invalid - """ - exception = self.get_exception_by_id(db, exception_id, vendor_id) - - if exception.status == "resolved": - raise ExceptionAlreadyResolvedException(exception_id) - - # Validate product exists and belongs to vendor - product = db.query(Product).filter(Product.id == product_id).first() - if not product: - raise ProductNotFoundException(product_id) - - if product.vendor_id != exception.vendor_id: - raise InvalidProductForExceptionException( - product_id, "Product belongs to a different vendor" - ) - - if not product.is_active: - raise InvalidProductForExceptionException( - product_id, "Product is not active" - ) - - # Update exception - exception.status = "resolved" - exception.resolved_product_id = product_id - exception.resolved_at = datetime.now(UTC) - exception.resolved_by = resolved_by - exception.resolution_notes = notes - - # Update order item - order_item = exception.order_item - order_item.product_id = product_id - order_item.needs_product_match = False - - # Update product snapshot on order item - if product.marketplace_product: - order_item.product_name = product.marketplace_product.get_title("en") - order_item.product_sku = product.vendor_sku or order_item.product_sku - - db.flush() - - logger.info( - f"Resolved exception {exception_id} with product {product_id} " - f"by user {resolved_by}" - ) - - return exception - - def ignore_exception( - self, - db: Session, - exception_id: int, - resolved_by: int, - notes: str, - vendor_id: int | None = None, - ) -> OrderItemException: - """ - Mark an exception as ignored. - - Note: Ignored exceptions still block order confirmation. - This is for tracking purposes (e.g., product will never be matched). - - Args: - db: Database session - exception_id: Exception ID - resolved_by: User ID who ignored - notes: Reason for ignoring (required) - vendor_id: Optional vendor filter - - Returns: - Updated OrderItemException - """ - exception = self.get_exception_by_id(db, exception_id, vendor_id) - - if exception.status == "resolved": - raise ExceptionAlreadyResolvedException(exception_id) - - exception.status = "ignored" - exception.resolved_at = datetime.now(UTC) - exception.resolved_by = resolved_by - exception.resolution_notes = notes - - db.flush() - - logger.info( - f"Ignored exception {exception_id} by user {resolved_by}: {notes}" - ) - - return exception - - # ========================================================================= - # Auto-Matching - # ========================================================================= - - def auto_match_by_gtin( - self, - db: Session, - vendor_id: int, - gtin: str, - product_id: int, - ) -> list[OrderItemException]: - """ - Auto-resolve pending exceptions matching a GTIN. - - Called after a product is imported with a GTIN. - - Args: - db: Database session - vendor_id: Vendor ID - gtin: GTIN to match - product_id: Product ID to assign - - Returns: - List of resolved exceptions - """ - if not gtin: - return [] - - # Find pending exceptions for this GTIN - pending = ( - db.query(OrderItemException) - .filter( - and_( - OrderItemException.vendor_id == vendor_id, - OrderItemException.original_gtin == gtin, - OrderItemException.status == "pending", - ) - ) - .all() - ) - - if not pending: - return [] - - # Get product for snapshot update - product = db.query(Product).filter(Product.id == product_id).first() - if not product: - logger.warning(f"Product {product_id} not found for auto-match") - return [] - - resolved = [] - now = datetime.now(UTC) - - for exception in pending: - exception.status = "resolved" - exception.resolved_product_id = product_id - exception.resolved_at = now - exception.resolution_notes = "Auto-matched during product import" - - # Update order item - order_item = exception.order_item - order_item.product_id = product_id - order_item.needs_product_match = False - if product.marketplace_product: - order_item.product_name = product.marketplace_product.get_title("en") - - resolved.append(exception) - - if resolved: - db.flush() - logger.info( - f"Auto-matched {len(resolved)} exceptions for GTIN {gtin} " - f"with product {product_id}" - ) - - return resolved - - def auto_match_batch( - self, - db: Session, - vendor_id: int, - gtin_to_product: dict[str, int], - ) -> int: - """ - Batch auto-match multiple GTINs after bulk import. - - Args: - db: Database session - vendor_id: Vendor ID - gtin_to_product: Dict mapping GTIN to product ID - - Returns: - Total number of resolved exceptions - """ - if not gtin_to_product: - return 0 - - total_resolved = 0 - - for gtin, product_id in gtin_to_product.items(): - resolved = self.auto_match_by_gtin(db, vendor_id, gtin, product_id) - total_resolved += len(resolved) - - return total_resolved - - # ========================================================================= - # Confirmation Checks - # ========================================================================= - - def order_has_unresolved_exceptions( - self, - db: Session, - order_id: int, - ) -> bool: - """ - Check if order has any unresolved exceptions. - - An order cannot be confirmed if it has pending or ignored exceptions. - - Args: - db: Database session - order_id: Order ID - - Returns: - True if order has unresolved exceptions - """ - count = ( - db.query(func.count(OrderItemException.id)) - .join(OrderItem) - .filter( - and_( - OrderItem.order_id == order_id, - OrderItemException.status.in_(["pending", "ignored"]), - ) - ) - .scalar() - ) - - return count > 0 - - def get_unresolved_exception_count( - self, - db: Session, - order_id: int, - ) -> int: - """ - Get count of unresolved exceptions for an order. - - Args: - db: Database session - order_id: Order ID - - Returns: - Count of unresolved exceptions - """ - return ( - db.query(func.count(OrderItemException.id)) - .join(OrderItem) - .filter( - and_( - OrderItem.order_id == order_id, - OrderItemException.status.in_(["pending", "ignored"]), - ) - ) - .scalar() - ) or 0 - - # ========================================================================= - # Bulk Operations - # ========================================================================= - - def bulk_resolve_by_gtin( - self, - db: Session, - vendor_id: int, - gtin: str, - product_id: int, - resolved_by: int, - notes: str | None = None, - ) -> int: - """ - Bulk resolve all pending exceptions for a GTIN. - - Args: - db: Database session - vendor_id: Vendor ID - gtin: GTIN to match - product_id: Product ID to assign - resolved_by: User ID who resolved - notes: Optional notes - - Returns: - Number of resolved exceptions - """ - # Validate product - product = db.query(Product).filter(Product.id == product_id).first() - if not product: - raise ProductNotFoundException(product_id) - - if product.vendor_id != vendor_id: - raise InvalidProductForExceptionException( - product_id, "Product belongs to a different vendor" - ) - - # Find and resolve all pending exceptions for this GTIN - pending = ( - db.query(OrderItemException) - .filter( - and_( - OrderItemException.vendor_id == vendor_id, - OrderItemException.original_gtin == gtin, - OrderItemException.status == "pending", - ) - ) - .all() - ) - - now = datetime.now(UTC) - resolution_notes = notes or f"Bulk resolved for GTIN {gtin}" - - for exception in pending: - exception.status = "resolved" - exception.resolved_product_id = product_id - exception.resolved_at = now - exception.resolved_by = resolved_by - exception.resolution_notes = resolution_notes - - # Update order item - order_item = exception.order_item - order_item.product_id = product_id - order_item.needs_product_match = False - if product.marketplace_product: - order_item.product_name = product.marketplace_product.get_title("en") - - db.flush() - - logger.info( - f"Bulk resolved {len(pending)} exceptions for GTIN {gtin} " - f"with product {product_id} by user {resolved_by}" - ) - - return len(pending) - - -# Global service instance -order_item_exception_service = OrderItemExceptionService() +__all__ = [ + "order_item_exception_service", + "OrderItemExceptionService", +] diff --git a/app/services/order_service.py b/app/services/order_service.py index 22801ad4..a26c5f9c 100644 --- a/app/services/order_service.py +++ b/app/services/order_service.py @@ -1,1536 +1,23 @@ # app/services/order_service.py """ -Unified order service for all sales channels. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -This service handles: -- Order creation (direct and marketplace) -- Order status management -- Order retrieval and filtering -- Customer creation for marketplace imports -- Order item management +The canonical implementation is now in: + app/modules/orders/services/order_service.py -All orders use snapshotted customer and address data. +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: -All monetary calculations use integer cents internally for precision. -See docs/architecture/money-handling.md for details. + from app.modules.orders.services import order_service """ -import logging -import random -import string -from decimal import Decimal -from datetime import UTC, datetime -from typing import Any - -from sqlalchemy import and_, func, or_ -from sqlalchemy.orm import Session - -from app.exceptions import ( - CustomerNotFoundException, - InsufficientInventoryException, - OrderNotFoundException, - ValidationException, -) -from app.services.order_item_exception_service import order_item_exception_service -from app.services.order_inventory_service import order_inventory_service -from app.services.subscription_service import ( - subscription_service, - TierLimitExceededException, -) -from app.utils.money import Money, cents_to_euros, euros_to_cents -from app.utils.vat import ( - VATResult, - calculate_vat_amount, - determine_vat_regime, -) -from models.database.customer import Customer -from models.database.marketplace_product import MarketplaceProduct -from models.database.marketplace_product_translation import MarketplaceProductTranslation -from models.database.order import Order, OrderItem -from models.database.product import Product -from models.database.vendor import Vendor - -# Placeholder product constants -PLACEHOLDER_GTIN = "0000000000000" -PLACEHOLDER_MARKETPLACE_ID = "PLACEHOLDER" -from models.schema.order import ( - AddressSnapshot, - CustomerSnapshot, - OrderCreate, - OrderItemCreate, - OrderUpdate, +from app.modules.orders.services.order_service import ( + order_service, + OrderService, ) -logger = logging.getLogger(__name__) - - -class OrderService: - """Unified service for order operations across all channels.""" - - # ========================================================================= - # Order Number Generation - # ========================================================================= - - def _generate_order_number(self, db: Session, vendor_id: int) -> str: - """ - Generate unique order number. - - Format: ORD-{VENDOR_ID}-{TIMESTAMP}-{RANDOM} - Example: ORD-1-20250110-A1B2C3 - """ - timestamp = datetime.now(UTC).strftime("%Y%m%d") - random_suffix = "".join( - random.choices(string.ascii_uppercase + string.digits, k=6) - ) - order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" - - # Ensure uniqueness - while db.query(Order).filter(Order.order_number == order_number).first(): - random_suffix = "".join( - random.choices(string.ascii_uppercase + string.digits, k=6) - ) - order_number = f"ORD-{vendor_id}-{timestamp}-{random_suffix}" - - return order_number - - # ========================================================================= - # Tax Calculation - # ========================================================================= - - def _calculate_tax_for_order( - self, - db: Session, - vendor_id: int, - subtotal_cents: int, - billing_country_iso: str, - buyer_vat_number: str | None = None, - ) -> VATResult: - """ - Calculate tax amount for an order based on billing destination. - - Uses the shared VAT utility to determine the correct VAT regime - and rate, consistent with invoice VAT calculation. - - VAT Logic: - - Same country as seller: domestic VAT - - B2B with valid VAT number: reverse charge (0%) - - Cross-border + OSS registered: destination country VAT - - Cross-border + no OSS: origin country VAT - - Non-EU: VAT exempt (0%) - - Args: - db: Database session - vendor_id: Vendor ID (to get invoice settings) - subtotal_cents: Order subtotal in cents (before tax) - billing_country_iso: ISO 2-letter country code - buyer_vat_number: Buyer's VAT number for B2B detection - - Returns: - VATResult with regime, rate, destination country, and label - """ - from models.database.invoice import VendorInvoiceSettings - - # Get vendor invoice settings for seller country and OSS status - settings = ( - db.query(VendorInvoiceSettings) - .filter(VendorInvoiceSettings.vendor_id == vendor_id) - .first() - ) - - # Default to Luxembourg if no settings exist - seller_country = settings.company_country if settings else "LU" - seller_oss_registered = settings.is_oss_registered if settings else False - - # Determine VAT regime using shared utility - return determine_vat_regime( - seller_country=seller_country, - buyer_country=billing_country_iso or "LU", - buyer_vat_number=buyer_vat_number, - seller_oss_registered=seller_oss_registered, - ) - - # ========================================================================= - # Placeholder Product Management - # ========================================================================= - - def _get_or_create_placeholder_product( - self, - db: Session, - vendor_id: int, - ) -> Product: - """ - Get or create the vendor's placeholder product for unmatched items. - - When a marketplace order contains a GTIN that doesn't match any product - in the vendor's catalog, we link the order item to this placeholder - and create an exception for resolution. - - Args: - db: Database session - vendor_id: Vendor ID - - Returns: - Placeholder Product for the vendor - """ - # Check for existing placeholder product for this vendor - placeholder = ( - db.query(Product) - .filter( - and_( - Product.vendor_id == vendor_id, - Product.gtin == PLACEHOLDER_GTIN, - ) - ) - .first() - ) - - if placeholder: - return placeholder - - # Get or create placeholder marketplace product (shared) - mp = ( - db.query(MarketplaceProduct) - .filter( - MarketplaceProduct.marketplace_product_id == PLACEHOLDER_MARKETPLACE_ID - ) - .first() - ) - - if not mp: - mp = MarketplaceProduct( - marketplace_product_id=PLACEHOLDER_MARKETPLACE_ID, - marketplace="internal", - vendor_name="system", - product_type_enum="physical", - is_active=False, # Not for sale - ) - db.add(mp) - db.flush() - - # Add translation for placeholder - translation = MarketplaceProductTranslation( - marketplace_product_id=mp.id, - language="en", - title="Unmatched Product (Pending Resolution)", - description=( - "This is a placeholder for products not found during order import. " - "Please resolve the exception to assign the correct product." - ), - ) - db.add(translation) - db.flush() - - logger.info( - f"Created placeholder MarketplaceProduct {mp.id}" - ) - - # Create vendor-specific placeholder product - placeholder = Product( - vendor_id=vendor_id, - marketplace_product_id=mp.id, - gtin=PLACEHOLDER_GTIN, - gtin_type="placeholder", - is_active=False, - ) - db.add(placeholder) - db.flush() - - logger.info( - f"Created placeholder product {placeholder.id} for vendor {vendor_id}" - ) - - return placeholder - - # ========================================================================= - # Customer Management - # ========================================================================= - - def find_or_create_customer( - self, - db: Session, - vendor_id: int, - email: str, - first_name: str, - last_name: str, - phone: str | None = None, - is_active: bool = False, - ) -> Customer: - """ - Find existing customer by email or create new one. - - For marketplace imports, customers are created as inactive until - they register on the storefront. - - Args: - db: Database session - vendor_id: Vendor ID - email: Customer email - first_name: Customer first name - last_name: Customer last name - phone: Customer phone (optional) - is_active: Whether customer is active (default: False for imports) - - Returns: - Customer record (existing or newly created) - """ - # Look for existing customer by email within vendor scope - customer = ( - db.query(Customer) - .filter( - and_( - Customer.vendor_id == vendor_id, - Customer.email == email, - ) - ) - .first() - ) - - if customer: - return customer - - # Generate a unique customer number - timestamp = datetime.now(UTC).strftime("%Y%m%d%H%M%S") - random_suffix = "".join(random.choices(string.digits, k=4)) - customer_number = f"CUST-{vendor_id}-{timestamp}-{random_suffix}" - - # Create new customer - customer = Customer( - vendor_id=vendor_id, - email=email, - first_name=first_name, - last_name=last_name, - phone=phone, - customer_number=customer_number, - hashed_password="", # No password for imported customers - is_active=is_active, - ) - db.add(customer) - db.flush() - - logger.info( - f"Created {'active' if is_active else 'inactive'} customer " - f"{customer.id} for vendor {vendor_id}: {email}" - ) - - return customer - - # ========================================================================= - # Order Creation - # ========================================================================= - - def create_order( - self, - db: Session, - vendor_id: int, - order_data: OrderCreate, - ) -> Order: - """ - Create a new direct order. - - Args: - db: Database session - vendor_id: Vendor ID - order_data: Order creation data - - Returns: - Created Order object - - Raises: - ValidationException: If order data is invalid - InsufficientInventoryException: If not enough inventory - TierLimitExceededException: If vendor has reached order limit - """ - # Check tier limit before creating order - subscription_service.check_order_limit(db, vendor_id) - - try: - # Get or create customer - if order_data.customer_id: - customer = ( - db.query(Customer) - .filter( - and_( - Customer.id == order_data.customer_id, - Customer.vendor_id == vendor_id, - ) - ) - .first() - ) - if not customer: - raise CustomerNotFoundException(str(order_data.customer_id)) - else: - # Create customer from snapshot - customer = self.find_or_create_customer( - db=db, - vendor_id=vendor_id, - email=order_data.customer.email, - first_name=order_data.customer.first_name, - last_name=order_data.customer.last_name, - phone=order_data.customer.phone, - is_active=True, # Direct orders = active customers - ) - - # Calculate order totals and validate products - # All calculations use integer cents for precision - subtotal_cents = 0 - order_items_data = [] - - for item_data in order_data.items: - product = ( - db.query(Product) - .filter( - and_( - Product.id == item_data.product_id, - Product.vendor_id == vendor_id, - Product.is_active == True, - ) - ) - .first() - ) - - if not product: - raise ValidationException( - f"Product {item_data.product_id} not found" - ) - - # Check inventory - if product.available_inventory < item_data.quantity: - raise InsufficientInventoryException( - product_id=product.id, - requested=item_data.quantity, - available=product.available_inventory, - ) - - # Get price in cents (prefer sale price, then regular price) - unit_price_cents = ( - product.sale_price_cents - or product.price_cents - ) - if not unit_price_cents: - raise ValidationException(f"Product {product.id} has no price") - - # Calculate line total in cents - line_total_cents = Money.calculate_line_total( - unit_price_cents, item_data.quantity - ) - subtotal_cents += line_total_cents - - order_items_data.append( - { - "product_id": product.id, - "product_name": product.marketplace_product.get_title("en") - if product.marketplace_product - else str(product.id), - "product_sku": product.vendor_sku, - "gtin": product.gtin, - "gtin_type": product.gtin_type, - "quantity": item_data.quantity, - "unit_price_cents": unit_price_cents, - "total_price_cents": line_total_cents, - } - ) - - # Use billing address or shipping address for VAT - billing = order_data.billing_address or order_data.shipping_address - - # Calculate VAT using vendor settings (OSS, B2B handling) - vat_result = self._calculate_tax_for_order( - db=db, - vendor_id=vendor_id, - subtotal_cents=subtotal_cents, - billing_country_iso=billing.country_iso, - buyer_vat_number=getattr(billing, 'vat_number', None), - ) - - # Calculate amounts in cents - tax_amount_cents = calculate_vat_amount(subtotal_cents, vat_result.rate) - shipping_amount_cents = 599 if subtotal_cents < 5000 else 0 # €5.99 / €50 - discount_amount_cents = 0 - total_amount_cents = Money.calculate_order_total( - subtotal_cents, tax_amount_cents, shipping_amount_cents, discount_amount_cents - ) - - # Generate order number - order_number = self._generate_order_number(db, vendor_id) - - # Create order with snapshots - order = Order( - vendor_id=vendor_id, - customer_id=customer.id, - order_number=order_number, - channel="direct", - status="pending", - # Financials (in cents) - subtotal_cents=subtotal_cents, - tax_amount_cents=tax_amount_cents, - shipping_amount_cents=shipping_amount_cents, - discount_amount_cents=discount_amount_cents, - total_amount_cents=total_amount_cents, - currency="EUR", - # VAT information - vat_regime=vat_result.regime.value, - vat_rate=vat_result.rate, - vat_rate_label=vat_result.label, - vat_destination_country=vat_result.destination_country, - # Customer snapshot - customer_first_name=order_data.customer.first_name, - customer_last_name=order_data.customer.last_name, - customer_email=order_data.customer.email, - customer_phone=order_data.customer.phone, - customer_locale=order_data.customer.locale, - # Shipping address snapshot - ship_first_name=order_data.shipping_address.first_name, - ship_last_name=order_data.shipping_address.last_name, - ship_company=order_data.shipping_address.company, - ship_address_line_1=order_data.shipping_address.address_line_1, - ship_address_line_2=order_data.shipping_address.address_line_2, - ship_city=order_data.shipping_address.city, - ship_postal_code=order_data.shipping_address.postal_code, - ship_country_iso=order_data.shipping_address.country_iso, - # Billing address snapshot - bill_first_name=billing.first_name, - bill_last_name=billing.last_name, - bill_company=billing.company, - bill_address_line_1=billing.address_line_1, - bill_address_line_2=billing.address_line_2, - bill_city=billing.city, - bill_postal_code=billing.postal_code, - bill_country_iso=billing.country_iso, - # Other - shipping_method=order_data.shipping_method, - customer_notes=order_data.customer_notes, - order_date=datetime.now(UTC), - ) - - db.add(order) - db.flush() - - # Create order items - for item_data in order_items_data: - order_item = OrderItem(order_id=order.id, **item_data) - db.add(order_item) - - db.flush() - db.refresh(order) - - # Increment order count for subscription tracking - subscription_service.increment_order_count(db, vendor_id) - - logger.info( - f"Order {order.order_number} created for vendor {vendor_id}, " - f"total: EUR {cents_to_euros(total_amount_cents):.2f}" - ) - - return order - - except ( - ValidationException, - InsufficientInventoryException, - CustomerNotFoundException, - TierLimitExceededException, - ): - raise - except Exception as e: - logger.error(f"Error creating order: {str(e)}") - raise ValidationException(f"Failed to create order: {str(e)}") - - def create_letzshop_order( - self, - db: Session, - vendor_id: int, - shipment_data: dict[str, Any], - skip_limit_check: bool = False, - ) -> Order: - """ - Create an order from Letzshop shipment data. - - Validates all products exist BEFORE creating any database records. - This ensures we don't leave the session in an inconsistent state - if validation fails. - - Args: - db: Database session - vendor_id: Vendor ID - shipment_data: Raw shipment data from Letzshop API - skip_limit_check: If True, skip tier limit check (for batch imports - that check limit upfront) - - Returns: - Created Order object - - Raises: - ValidationException: If product not found by GTIN - TierLimitExceededException: If vendor has reached order limit - """ - # Check tier limit before creating order (unless skipped for batch ops) - if not skip_limit_check: - can_create, message = subscription_service.can_create_order(db, vendor_id) - if not can_create: - raise TierLimitExceededException( - message=message or "Order limit exceeded", - limit_type="orders", - current=0, # Will be filled by caller if needed - limit=0, - ) - - order_data = shipment_data.get("order", {}) - - # Generate order number using Letzshop order number - letzshop_order_number = order_data.get("number", "") - order_number = f"LS-{vendor_id}-{letzshop_order_number}" - - # Check if order already exists (read-only, safe to do first) - existing = ( - db.query(Order) - .filter(Order.order_number == order_number) - .first() - ) - if existing: - updated = False - - # Update tracking if available and not already set - tracking_data = shipment_data.get("tracking") or {} - new_tracking = tracking_data.get("code") or tracking_data.get("number") - if new_tracking and not existing.tracking_number: - existing.tracking_number = new_tracking - tracking_provider = tracking_data.get("provider") - if not tracking_provider and tracking_data.get("carrier"): - carrier = tracking_data.get("carrier", {}) - tracking_provider = carrier.get("code") or carrier.get("name") - existing.tracking_provider = tracking_provider - updated = True - logger.info( - f"Updated tracking for order {order_number}: " - f"{tracking_provider} {new_tracking}" - ) - - # Update shipment number if not already set - shipment_number = shipment_data.get("number") - if shipment_number and not existing.shipment_number: - existing.shipment_number = shipment_number - updated = True - - # Update carrier if not already set - shipment_data_obj = shipment_data.get("data") or {} - if shipment_data_obj and not existing.shipping_carrier: - carrier_name = shipment_data_obj.get("__typename", "").lower() - if carrier_name: - existing.shipping_carrier = carrier_name - updated = True - - if updated: - existing.updated_at = datetime.now(UTC) - - return existing - - # ===================================================================== - # PHASE 1: Parse and validate all data BEFORE any database writes - # ===================================================================== - - # Parse inventory units - inventory_units = shipment_data.get("inventoryUnits", []) - if isinstance(inventory_units, dict): - inventory_units = inventory_units.get("nodes", []) - - # Collect all GTINs and check for items without GTINs - gtins = set() - has_items_without_gtin = False - for unit in inventory_units: - variant = unit.get("variant", {}) or {} - trade_id = variant.get("tradeId", {}) or {} - gtin = trade_id.get("number") - if gtin: - gtins.add(gtin) - else: - has_items_without_gtin = True - - # Batch query all products by GTIN - products_by_gtin: dict[str, Product] = {} - if gtins: - products = ( - db.query(Product) - .filter( - and_( - Product.vendor_id == vendor_id, - Product.gtin.in_(gtins), - ) - ) - .all() - ) - products_by_gtin = {p.gtin: p for p in products if p.gtin} - - # Identify missing GTINs (graceful handling - no exception raised) - missing_gtins = gtins - set(products_by_gtin.keys()) - placeholder = None - if missing_gtins or has_items_without_gtin: - # Get or create placeholder product for unmatched items or items without GTIN - placeholder = self._get_or_create_placeholder_product(db, vendor_id) - if missing_gtins: - logger.warning( - f"Order {order_number}: {len(missing_gtins)} product(s) not found. " - f"GTINs: {missing_gtins}. Using placeholder and creating exceptions." - ) - if has_items_without_gtin: - logger.warning( - f"Order {order_number}: Some items have no GTIN. " - f"Using placeholder and creating exceptions." - ) - - # Parse address data - ship_address = order_data.get("shipAddress", {}) or {} - bill_address = order_data.get("billAddress", {}) or {} - ship_country = ship_address.get("country", {}) or {} - bill_country = bill_address.get("country", {}) or {} - - # Extract customer info - customer_email = order_data.get("email", "") - ship_first_name = ship_address.get("firstName", "") or "" - ship_last_name = ship_address.get("lastName", "") or "" - - # Parse order date - order_date = datetime.now(UTC) - completed_at_str = order_data.get("completedAt") - if completed_at_str: - try: - if completed_at_str.endswith("Z"): - completed_at_str = completed_at_str[:-1] + "+00:00" - order_date = datetime.fromisoformat(completed_at_str) - except (ValueError, TypeError): - pass - - # Parse total amount (convert to cents) - total_str = order_data.get("total", "0") - try: - # Handle format like "99.99 EUR" - total_euros = float(str(total_str).split()[0]) - total_amount_cents = euros_to_cents(total_euros) - except (ValueError, IndexError): - total_amount_cents = 0 - - # Map Letzshop state to status - letzshop_state = shipment_data.get("state", "unconfirmed") - status_mapping = { - "unconfirmed": "pending", - "confirmed": "processing", - "declined": "cancelled", - } - status = status_mapping.get(letzshop_state, "pending") - - # Parse tracking info if available - tracking_data = shipment_data.get("tracking") or {} - tracking_number = tracking_data.get("code") or tracking_data.get("number") - tracking_provider = tracking_data.get("provider") - # Handle carrier object format: tracking { carrier { name code } } - if not tracking_provider and tracking_data.get("carrier"): - carrier = tracking_data.get("carrier", {}) - tracking_provider = carrier.get("code") or carrier.get("name") - - # Parse shipment number and carrier - shipment_number = shipment_data.get("number") # e.g., H74683403433 - shipping_carrier = None - shipment_data_obj = shipment_data.get("data") or {} - if shipment_data_obj: - # Carrier is determined by __typename (Greco, Colissimo, XpressLogistics) - shipping_carrier = shipment_data_obj.get("__typename", "").lower() - - # ===================================================================== - # PHASE 2: All validation passed - now create database records - # ===================================================================== - - # Find or create customer (inactive) - customer = self.find_or_create_customer( - db=db, - vendor_id=vendor_id, - email=customer_email, - first_name=ship_first_name, - last_name=ship_last_name, - is_active=False, - ) - - # Create order - order = Order( - vendor_id=vendor_id, - customer_id=customer.id, - order_number=order_number, - channel="letzshop", - # External references - external_order_id=order_data.get("id"), - external_shipment_id=shipment_data.get("id"), - external_order_number=letzshop_order_number, - external_data=shipment_data, - # Status - status=status, - # Financials (in cents) - total_amount_cents=total_amount_cents, - currency="EUR", - # Customer snapshot - customer_first_name=ship_first_name, - customer_last_name=ship_last_name, - customer_email=customer_email, - customer_locale=order_data.get("locale"), - # Shipping address snapshot - ship_first_name=ship_first_name, - ship_last_name=ship_last_name, - ship_company=ship_address.get("company"), - ship_address_line_1=ship_address.get("streetName", "") or "", - ship_address_line_2=ship_address.get("streetNumber"), - ship_city=ship_address.get("city", "") or "", - ship_postal_code=ship_address.get("postalCode", "") or "", - ship_country_iso=ship_country.get("iso", "") or "", - # Billing address snapshot - bill_first_name=bill_address.get("firstName", "") or ship_first_name, - bill_last_name=bill_address.get("lastName", "") or ship_last_name, - bill_company=bill_address.get("company"), - bill_address_line_1=bill_address.get("streetName", "") or "", - bill_address_line_2=bill_address.get("streetNumber"), - bill_city=bill_address.get("city", "") or "", - bill_postal_code=bill_address.get("postalCode", "") or "", - bill_country_iso=bill_country.get("iso", "") or "", - # Dates - order_date=order_date, - confirmed_at=datetime.now(UTC) if status == "processing" else None, - cancelled_at=datetime.now(UTC) if status == "cancelled" else None, - # Tracking (if available from Letzshop) - tracking_number=tracking_number, - tracking_provider=tracking_provider, - # Shipment info - shipment_number=shipment_number, - shipping_carrier=shipping_carrier, - ) - - db.add(order) - db.flush() - - # Create order items from inventory units - exceptions_created = 0 - for unit in inventory_units: - variant = unit.get("variant", {}) or {} - product_info = variant.get("product", {}) or {} - trade_id = variant.get("tradeId", {}) or {} - product_name_dict = product_info.get("name", {}) or {} - - gtin = trade_id.get("number") - gtin_type = trade_id.get("parser") - - # Get product from map, or use placeholder if not found - product = products_by_gtin.get(gtin) - needs_product_match = False - - if not product: - # Use placeholder for unmatched items - product = placeholder - needs_product_match = True - - # Get product name from marketplace data - product_name = ( - product_name_dict.get("en") - or product_name_dict.get("fr") - or product_name_dict.get("de") - or str(product_name_dict) - ) - - # Get price (convert to cents) - unit_price_cents = 0 - price_str = variant.get("price", "0") - try: - price_euros = float(str(price_str).split()[0]) - unit_price_cents = euros_to_cents(price_euros) - except (ValueError, IndexError): - pass - - # Map item state - item_state = unit.get("state") # unconfirmed, confirmed_available, etc. - - order_item = OrderItem( - order_id=order.id, - product_id=product.id if product else None, - product_name=product_name, - product_sku=variant.get("sku"), - gtin=gtin, - gtin_type=gtin_type, - quantity=1, # Letzshop uses individual inventory units - unit_price_cents=unit_price_cents, - total_price_cents=unit_price_cents, # qty=1 so same as unit - external_item_id=unit.get("id"), - external_variant_id=variant.get("id"), - item_state=item_state, - needs_product_match=needs_product_match, - ) - db.add(order_item) - db.flush() # Get order_item.id for exception creation - - # Create exception record for unmatched items - if needs_product_match: - order_item_exception_service.create_exception( - db=db, - order_item=order_item, - vendor_id=vendor_id, - original_gtin=gtin, - original_product_name=product_name, - original_sku=variant.get("sku"), - exception_type="product_not_found", - ) - exceptions_created += 1 - - db.flush() - db.refresh(order) - - if exceptions_created > 0: - logger.info( - f"Created {exceptions_created} order item exception(s) for " - f"order {order.order_number}" - ) - - # Increment order count for subscription tracking - subscription_service.increment_order_count(db, vendor_id) - - logger.info( - f"Letzshop order {order.order_number} created for vendor {vendor_id}, " - f"status: {status}, items: {len(inventory_units)}" - ) - - return order - - # ========================================================================= - # Order Retrieval - # ========================================================================= - - def get_order(self, db: Session, vendor_id: int, order_id: int) -> Order: - """Get order by ID within vendor scope.""" - order = ( - db.query(Order) - .filter(and_(Order.id == order_id, Order.vendor_id == vendor_id)) - .first() - ) - - if not order: - raise OrderNotFoundException(str(order_id)) - - return order - - def get_order_by_external_shipment_id( - self, - db: Session, - vendor_id: int, - shipment_id: str, - ) -> Order | None: - """Get order by external shipment ID (for Letzshop).""" - return ( - db.query(Order) - .filter( - and_( - Order.vendor_id == vendor_id, - Order.external_shipment_id == shipment_id, - ) - ) - .first() - ) - - def get_vendor_orders( - self, - db: Session, - vendor_id: int, - skip: int = 0, - limit: int = 50, - status: str | None = None, - channel: str | None = None, - search: str | None = None, - customer_id: int | None = None, - ) -> tuple[list[Order], int]: - """ - Get orders for vendor with filtering. - - Args: - db: Database session - vendor_id: Vendor ID - skip: Pagination offset - limit: Pagination limit - status: Filter by status - channel: Filter by channel (direct, letzshop) - search: Search by order number, customer name, or email - customer_id: Filter by customer ID - - Returns: - Tuple of (orders, total_count) - """ - query = db.query(Order).filter(Order.vendor_id == vendor_id) - - if status: - query = query.filter(Order.status == status) - - if channel: - query = query.filter(Order.channel == channel) - - if customer_id: - query = query.filter(Order.customer_id == customer_id) - - if search: - search_term = f"%{search}%" - query = query.filter( - or_( - Order.order_number.ilike(search_term), - Order.external_order_number.ilike(search_term), - Order.customer_email.ilike(search_term), - Order.customer_first_name.ilike(search_term), - Order.customer_last_name.ilike(search_term), - ) - ) - - # Order by most recent first - query = query.order_by(Order.order_date.desc()) - - total = query.count() - orders = query.offset(skip).limit(limit).all() - - return orders, total - - def get_customer_orders( - self, - db: Session, - vendor_id: int, - customer_id: int, - skip: int = 0, - limit: int = 50, - ) -> tuple[list[Order], int]: - """ - Get orders for a specific customer. - - Used by shop frontend for customer order history. - - Args: - db: Database session - vendor_id: Vendor ID - customer_id: Customer ID - skip: Pagination offset - limit: Pagination limit - - Returns: - Tuple of (orders, total_count) - """ - return self.get_vendor_orders( - db=db, - vendor_id=vendor_id, - skip=skip, - limit=limit, - customer_id=customer_id, - ) - - def get_order_stats(self, db: Session, vendor_id: int) -> dict[str, int]: - """ - Get order counts by status for a vendor. - - Returns: - Dict with counts for each status. - """ - status_counts = ( - db.query(Order.status, func.count(Order.id).label("count")) - .filter(Order.vendor_id == vendor_id) - .group_by(Order.status) - .all() - ) - - stats = { - "pending": 0, - "processing": 0, - "partially_shipped": 0, - "shipped": 0, - "delivered": 0, - "cancelled": 0, - "refunded": 0, - "total": 0, - } - - for status, count in status_counts: - if status in stats: - stats[status] = count - stats["total"] += count - - # Also count by channel - channel_counts = ( - db.query(Order.channel, func.count(Order.id).label("count")) - .filter(Order.vendor_id == vendor_id) - .group_by(Order.channel) - .all() - ) - - for channel, count in channel_counts: - stats[f"{channel}_orders"] = count - - return stats - - # ========================================================================= - # Order Updates - # ========================================================================= - - def update_order_status( - self, - db: Session, - vendor_id: int, - order_id: int, - order_update: OrderUpdate, - ) -> Order: - """ - Update order status and tracking information. - - This method now includes automatic inventory management: - - processing: Reserves inventory for order items - - shipped: Fulfills (deducts) inventory - - cancelled: Releases reserved inventory - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - order_update: Update data - - Returns: - Updated Order object - """ - order = self.get_order(db, vendor_id, order_id) - - now = datetime.now(UTC) - old_status = order.status - - if order_update.status: - order.status = order_update.status - - # Update timestamps based on status - if order_update.status == "processing" and not order.confirmed_at: - order.confirmed_at = now - elif order_update.status == "partially_shipped": - # partially_shipped doesn't set shipped_at yet - pass - elif order_update.status == "shipped" and not order.shipped_at: - order.shipped_at = now - elif order_update.status == "delivered" and not order.delivered_at: - order.delivered_at = now - elif order_update.status == "cancelled" and not order.cancelled_at: - order.cancelled_at = now - - # Handle inventory operations based on status change - try: - inventory_result = order_inventory_service.handle_status_change( - db=db, - vendor_id=vendor_id, - order_id=order_id, - old_status=old_status, - new_status=order_update.status, - ) - if inventory_result: - logger.info( - f"Order {order.order_number} inventory update: " - f"{inventory_result.get('reserved_count', 0)} reserved, " - f"{inventory_result.get('fulfilled_count', 0)} fulfilled, " - f"{inventory_result.get('released_count', 0)} released" - ) - except Exception as e: - # Log inventory errors but don't fail the status update - # Inventory can be adjusted manually if needed - logger.warning( - f"Order {order.order_number} inventory operation failed: {e}" - ) - - if order_update.tracking_number: - order.tracking_number = order_update.tracking_number - - if order_update.tracking_provider: - order.tracking_provider = order_update.tracking_provider - - if order_update.internal_notes: - order.internal_notes = order_update.internal_notes - - order.updated_at = now - db.flush() - db.refresh(order) - - logger.info(f"Order {order.order_number} updated: status={order.status}") - - return order - - def set_order_tracking( - self, - db: Session, - vendor_id: int, - order_id: int, - tracking_number: str, - tracking_provider: str, - ) -> Order: - """ - Set tracking information and mark as shipped. - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - tracking_number: Tracking number - tracking_provider: Shipping provider - - Returns: - Updated Order object - """ - order = self.get_order(db, vendor_id, order_id) - - now = datetime.now(UTC) - order.tracking_number = tracking_number - order.tracking_provider = tracking_provider - order.status = "shipped" - order.shipped_at = now - order.updated_at = now - - db.flush() - db.refresh(order) - - logger.info( - f"Order {order.order_number} shipped: " - f"{tracking_provider} {tracking_number}" - ) - - return order - - def update_item_state( - self, - db: Session, - vendor_id: int, - order_id: int, - item_id: int, - state: str, - ) -> OrderItem: - """ - Update the state of an order item (for marketplace confirmation). - - Args: - db: Database session - vendor_id: Vendor ID - order_id: Order ID - item_id: Item ID - state: New state (confirmed_available, confirmed_unavailable) - - Returns: - Updated OrderItem object - """ - order = self.get_order(db, vendor_id, order_id) - - item = ( - db.query(OrderItem) - .filter( - and_( - OrderItem.id == item_id, - OrderItem.order_id == order.id, - ) - ) - .first() - ) - - if not item: - raise ValidationException(f"Order item {item_id} not found") - - item.item_state = state - item.updated_at = datetime.now(UTC) - - # Check if all items are processed - all_items = db.query(OrderItem).filter(OrderItem.order_id == order.id).all() - all_confirmed = all( - i.item_state in ("confirmed_available", "confirmed_unavailable") - for i in all_items - ) - - if all_confirmed: - has_available = any( - i.item_state == "confirmed_available" for i in all_items - ) - all_unavailable = all( - i.item_state == "confirmed_unavailable" for i in all_items - ) - - now = datetime.now(UTC) - if all_unavailable: - order.status = "cancelled" - order.cancelled_at = now - elif has_available: - order.status = "processing" - order.confirmed_at = now - - order.updated_at = now - - db.flush() - db.refresh(item) - - return item - - # ========================================================================= - # Admin Methods (cross-vendor) - # ========================================================================= - - def get_all_orders_admin( - self, - db: Session, - skip: int = 0, - limit: int = 50, - vendor_id: int | None = None, - status: str | None = None, - channel: str | None = None, - search: str | None = None, - ) -> tuple[list[dict], int]: - """ - Get orders across all vendors for admin. - - Args: - db: Database session - skip: Pagination offset - limit: Pagination limit - vendor_id: Filter by vendor - status: Filter by status - channel: Filter by channel - search: Search by order number or customer - - Returns: - Tuple of (orders with vendor info, total_count) - """ - query = db.query(Order).join(Vendor) - - if vendor_id: - query = query.filter(Order.vendor_id == vendor_id) - - if status: - query = query.filter(Order.status == status) - - if channel: - query = query.filter(Order.channel == channel) - - if search: - search_term = f"%{search}%" - query = query.filter( - or_( - Order.order_number.ilike(search_term), - Order.external_order_number.ilike(search_term), - Order.customer_email.ilike(search_term), - Order.customer_first_name.ilike(search_term), - Order.customer_last_name.ilike(search_term), - ) - ) - - query = query.order_by(Order.order_date.desc()) - - total = query.count() - orders = query.offset(skip).limit(limit).all() - - result = [] - for order in orders: - item_count = len(order.items) if order.items else 0 - - result.append( - { - "id": order.id, - "vendor_id": order.vendor_id, - "vendor_name": order.vendor.name if order.vendor else None, - "vendor_code": order.vendor.vendor_code if order.vendor else None, - "customer_id": order.customer_id, - "customer_full_name": order.customer_full_name, - "customer_email": order.customer_email, - "order_number": order.order_number, - "channel": order.channel, - "status": order.status, - "external_order_number": order.external_order_number, - "external_shipment_id": order.external_shipment_id, - "subtotal": order.subtotal, - "tax_amount": order.tax_amount, - "shipping_amount": order.shipping_amount, - "discount_amount": order.discount_amount, - "total_amount": order.total_amount, - "currency": order.currency, - "ship_country_iso": order.ship_country_iso, - "tracking_number": order.tracking_number, - "tracking_provider": order.tracking_provider, - "item_count": item_count, - "order_date": order.order_date, - "confirmed_at": order.confirmed_at, - "shipped_at": order.shipped_at, - "delivered_at": order.delivered_at, - "cancelled_at": order.cancelled_at, - "created_at": order.created_at, - "updated_at": order.updated_at, - } - ) - - return result, total - - def get_order_stats_admin(self, db: Session) -> dict: - """Get platform-wide order statistics.""" - # Get status counts - status_counts = ( - db.query(Order.status, func.count(Order.id)) - .group_by(Order.status) - .all() - ) - - stats = { - "total_orders": 0, - "pending_orders": 0, - "processing_orders": 0, - "partially_shipped_orders": 0, - "shipped_orders": 0, - "delivered_orders": 0, - "cancelled_orders": 0, - "refunded_orders": 0, - "total_revenue": 0.0, - "direct_orders": 0, - "letzshop_orders": 0, - "vendors_with_orders": 0, - } - - for status, count in status_counts: - stats["total_orders"] += count - key = f"{status}_orders" - if key in stats: - stats[key] = count - - # Get channel counts - channel_counts = ( - db.query(Order.channel, func.count(Order.id)) - .group_by(Order.channel) - .all() - ) - - for channel, count in channel_counts: - key = f"{channel}_orders" - if key in stats: - stats[key] = count - - # Get total revenue (from delivered orders) - convert cents to euros - revenue_cents = ( - db.query(func.sum(Order.total_amount_cents)) - .filter(Order.status == "delivered") - .scalar() - ) - stats["total_revenue"] = cents_to_euros(revenue_cents) if revenue_cents else 0.0 - - # Count vendors with orders - vendors_count = ( - db.query(func.count(func.distinct(Order.vendor_id))).scalar() or 0 - ) - stats["vendors_with_orders"] = vendors_count - - return stats - - def get_order_by_id_admin(self, db: Session, order_id: int) -> Order: - """Get order by ID without vendor scope (admin only).""" - order = db.query(Order).filter(Order.id == order_id).first() - - if not order: - raise OrderNotFoundException(str(order_id)) - - return order - - def get_vendors_with_orders_admin(self, db: Session) -> list[dict]: - """Get list of vendors that have orders (admin only).""" - from models.database.vendor import Vendor - - # Query vendors with order counts - results = ( - db.query( - Vendor.id, - Vendor.name, - Vendor.vendor_code, - func.count(Order.id).label("order_count"), - ) - .join(Order, Order.vendor_id == Vendor.id) - .group_by(Vendor.id, Vendor.name, Vendor.vendor_code) - .order_by(func.count(Order.id).desc()) - .all() - ) - - return [ - { - "id": row.id, - "name": row.name, - "vendor_code": row.vendor_code, - "order_count": row.order_count, - } - for row in results - ] - - def mark_as_shipped_admin( - self, - db: Session, - order_id: int, - tracking_number: str | None = None, - tracking_url: str | None = None, - shipping_carrier: str | None = None, - ) -> Order: - """ - Mark an order as shipped with optional tracking info (admin only). - - Args: - db: Database session - order_id: Order ID - tracking_number: Optional tracking number - tracking_url: Optional full tracking URL - shipping_carrier: Optional carrier code (greco, colissimo, etc.) - - Returns: - Updated order - """ - order = db.query(Order).filter(Order.id == order_id).first() - - if not order: - raise OrderNotFoundException(str(order_id)) - - order.status = "shipped" - order.shipped_at = datetime.now(UTC) - order.updated_at = datetime.now(UTC) - - if tracking_number: - order.tracking_number = tracking_number - if tracking_url: - order.tracking_url = tracking_url - if shipping_carrier: - order.shipping_carrier = shipping_carrier - - logger.info( - f"Order {order.order_number} marked as shipped. " - f"Tracking: {tracking_number or 'N/A'}, Carrier: {shipping_carrier or 'N/A'}" - ) - - return order - - def get_shipping_label_info_admin( - self, - db: Session, - order_id: int, - ) -> dict[str, Any]: - """ - Get shipping label information for an order (admin only). - - Returns shipment number, carrier, and generated label URL - based on carrier settings. - """ - from app.services.admin_settings_service import admin_settings_service - - order = db.query(Order).filter(Order.id == order_id).first() - - if not order: - raise OrderNotFoundException(str(order_id)) - - label_url = None - carrier = order.shipping_carrier - - # Generate label URL based on carrier - if order.shipment_number and carrier: - # Get carrier label URL prefix from settings - setting_key = f"carrier_{carrier}_label_url" - prefix = admin_settings_service.get_setting_value(db, setting_key) - - if prefix: - label_url = prefix + order.shipment_number - - return { - "shipment_number": order.shipment_number, - "shipping_carrier": carrier, - "label_url": label_url, - "tracking_number": order.tracking_number, - "tracking_url": order.tracking_url, - } - - -# Create service instance -order_service = OrderService() +__all__ = [ + "order_service", + "OrderService", +] diff --git a/app/services/test_runner_service.py b/app/services/test_runner_service.py index 573299d2..bafc0d6d 100644 --- a/app/services/test_runner_service.py +++ b/app/services/test_runner_service.py @@ -1,507 +1,23 @@ +# app/services/test_runner_service.py """ -Test Runner Service -Service for running pytest and storing results +LEGACY LOCATION - Re-exports from module for backwards compatibility. + +The canonical implementation is now in: + app/modules/dev_tools/services/test_runner_service.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.dev_tools.services import test_runner_service """ -import json -import logging -import re -import subprocess -import tempfile -from datetime import UTC, datetime -from pathlib import Path - -from sqlalchemy import desc, func -from sqlalchemy.orm import Session - -from models.database.test_run import TestCollection, TestResult, TestRun - -logger = logging.getLogger(__name__) - - -class TestRunnerService: - """Service for managing pytest test runs""" - - def __init__(self): - self.project_root = Path(__file__).parent.parent.parent - - def create_test_run( - self, - db: Session, - test_path: str = "tests", - triggered_by: str = "manual", - ) -> TestRun: - """Create a test run record without executing tests""" - test_run = TestRun( - timestamp=datetime.now(UTC), - triggered_by=triggered_by, - test_path=test_path, - status="running", - git_commit_hash=self._get_git_commit(), - git_branch=self._get_git_branch(), - ) - db.add(test_run) - db.flush() - return test_run - - def run_tests( - self, - db: Session, - test_path: str = "tests", - triggered_by: str = "manual", - extra_args: list[str] | None = None, - ) -> TestRun: - """ - Run pytest synchronously and store results in database - - Args: - db: Database session - test_path: Path to tests (relative to project root) - triggered_by: Who triggered the run - extra_args: Additional pytest arguments - - Returns: - TestRun object with results - """ - test_run = self.create_test_run(db, test_path, triggered_by) - self._execute_tests(db, test_run, test_path, extra_args) - return test_run - - def _execute_tests( - self, - db: Session, - test_run: TestRun, - test_path: str, - extra_args: list[str] | None, - ) -> None: - """Execute pytest and update the test run record""" - try: - # Build pytest command with JSON output - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: - json_report_path = f.name - - pytest_args = [ - "python", - "-m", - "pytest", - test_path, - "--json-report", - f"--json-report-file={json_report_path}", - "-v", - "--tb=short", - ] - - if extra_args: - pytest_args.extend(extra_args) - - test_run.pytest_args = " ".join(pytest_args) - - # Run pytest - start_time = datetime.now(UTC) - result = subprocess.run( - pytest_args, - cwd=str(self.project_root), - capture_output=True, - text=True, - timeout=600, # 10 minute timeout - ) - end_time = datetime.now(UTC) - - test_run.duration_seconds = (end_time - start_time).total_seconds() - - # Parse JSON report - try: - with open(json_report_path) as f: - report = json.load(f) - - self._process_json_report(db, test_run, report) - except (FileNotFoundError, json.JSONDecodeError) as e: - # Fallback to parsing stdout if JSON report failed - logger.warning(f"JSON report unavailable ({e}), parsing stdout") - self._parse_pytest_output(test_run, result.stdout, result.stderr) - finally: - # Clean up temp file - try: - Path(json_report_path).unlink() - except Exception: - pass - - # Set final status - if test_run.failed > 0 or test_run.errors > 0: - test_run.status = "failed" - else: - test_run.status = "passed" - - except subprocess.TimeoutExpired: - test_run.status = "error" - logger.error("Pytest run timed out") - except Exception as e: - test_run.status = "error" - logger.error(f"Error running tests: {e}") - - def _process_json_report(self, db: Session, test_run: TestRun, report: dict): - """Process pytest-json-report output""" - summary = report.get("summary", {}) - - test_run.total_tests = summary.get("total", 0) - test_run.passed = summary.get("passed", 0) - test_run.failed = summary.get("failed", 0) - test_run.errors = summary.get("error", 0) - test_run.skipped = summary.get("skipped", 0) - test_run.xfailed = summary.get("xfailed", 0) - test_run.xpassed = summary.get("xpassed", 0) - - # Process individual test results - tests = report.get("tests", []) - for test in tests: - node_id = test.get("nodeid", "") - outcome = test.get("outcome", "unknown") - - # Parse node_id to get file, class, function - test_file, test_class, test_name = self._parse_node_id(node_id) - - # Get failure details - error_message = None - traceback = None - if outcome in ("failed", "error"): - call_info = test.get("call", {}) - if "longrepr" in call_info: - traceback = call_info["longrepr"] - # Extract error message from traceback - if isinstance(traceback, str): - lines = traceback.strip().split("\n") - if lines: - error_message = lines[-1][:500] # Last line, limited length - - test_result = TestResult( - run_id=test_run.id, - node_id=node_id, - test_name=test_name, - test_file=test_file, - test_class=test_class, - outcome=outcome, - duration_seconds=test.get("duration", 0.0), - error_message=error_message, - traceback=traceback, - markers=test.get("keywords", []), - ) - db.add(test_result) - - def _parse_node_id(self, node_id: str) -> tuple[str, str | None, str]: - """Parse pytest node_id into file, class, function""" - # Format: tests/unit/test_foo.py::TestClass::test_method - # or: tests/unit/test_foo.py::test_function - parts = node_id.split("::") - - test_file = parts[0] if parts else "" - test_class = None - test_name = parts[-1] if parts else "" - - if len(parts) == 3: - test_class = parts[1] - elif len(parts) == 2: - # Could be Class::method or file::function - if parts[1].startswith("Test"): - test_class = parts[1] - test_name = parts[1] - - # Handle parametrized tests - if "[" in test_name: - test_name = test_name.split("[")[0] - - return test_file, test_class, test_name - - def _parse_pytest_output(self, test_run: TestRun, stdout: str, stderr: str): - """Fallback parser for pytest text output""" - # Parse summary line like: "10 passed, 2 failed, 1 skipped" - summary_pattern = r"(\d+)\s+(passed|failed|error|skipped|xfailed|xpassed)" - - for match in re.finditer(summary_pattern, stdout): - count = int(match.group(1)) - status = match.group(2) - - if status == "passed": - test_run.passed = count - elif status == "failed": - test_run.failed = count - elif status == "error": - test_run.errors = count - elif status == "skipped": - test_run.skipped = count - elif status == "xfailed": - test_run.xfailed = count - elif status == "xpassed": - test_run.xpassed = count - - test_run.total_tests = ( - test_run.passed - + test_run.failed - + test_run.errors - + test_run.skipped - + test_run.xfailed - + test_run.xpassed - ) - - def _get_git_commit(self) -> str | None: - """Get current git commit hash""" - try: - result = subprocess.run( - ["git", "rev-parse", "HEAD"], - cwd=str(self.project_root), - capture_output=True, - text=True, - timeout=5, - ) - return result.stdout.strip()[:40] if result.returncode == 0 else None - except: - return None - - def _get_git_branch(self) -> str | None: - """Get current git branch""" - try: - result = subprocess.run( - ["git", "rev-parse", "--abbrev-ref", "HEAD"], - cwd=str(self.project_root), - capture_output=True, - text=True, - timeout=5, - ) - return result.stdout.strip() if result.returncode == 0 else None - except: - return None - - def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]: - """Get recent test run history""" - return db.query(TestRun).order_by(desc(TestRun.timestamp)).limit(limit).all() - - def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None: - """Get a specific test run with results""" - return db.query(TestRun).filter(TestRun.id == run_id).first() - - def get_failed_tests(self, db: Session, run_id: int) -> list[TestResult]: - """Get failed tests from a run""" - return ( - db.query(TestResult) - .filter( - TestResult.run_id == run_id, TestResult.outcome.in_(["failed", "error"]) - ) - .all() - ) - - def get_run_results( - self, db: Session, run_id: int, outcome: str | None = None - ) -> list[TestResult]: - """Get test results for a specific run, optionally filtered by outcome""" - query = db.query(TestResult).filter(TestResult.run_id == run_id) - - if outcome: - query = query.filter(TestResult.outcome == outcome) - - return query.all() - - def get_dashboard_stats(self, db: Session) -> dict: - """Get statistics for the testing dashboard""" - # Get latest run - latest_run = ( - db.query(TestRun) - .filter(TestRun.status != "running") - .order_by(desc(TestRun.timestamp)) - .first() - ) - - # Get test collection info (or calculate from latest run) - collection = ( - db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first() - ) - - # Get trend data (last 10 runs) - trend_runs = ( - db.query(TestRun) - .filter(TestRun.status != "running") - .order_by(desc(TestRun.timestamp)) - .limit(10) - .all() - ) - - # Calculate stats by category from latest run - by_category = {} - if latest_run: - results = ( - db.query(TestResult).filter(TestResult.run_id == latest_run.id).all() - ) - for result in results: - # Categorize by test path - if "unit" in result.test_file: - category = "Unit Tests" - elif "integration" in result.test_file: - category = "Integration Tests" - elif "performance" in result.test_file: - category = "Performance Tests" - elif "system" in result.test_file: - category = "System Tests" - else: - category = "Other" - - if category not in by_category: - by_category[category] = {"total": 0, "passed": 0, "failed": 0} - by_category[category]["total"] += 1 - if result.outcome == "passed": - by_category[category]["passed"] += 1 - elif result.outcome in ("failed", "error"): - by_category[category]["failed"] += 1 - - # Get top failing tests (across recent runs) - top_failing = ( - db.query( - TestResult.test_name, - TestResult.test_file, - func.count(TestResult.id).label("failure_count"), - ) - .filter(TestResult.outcome.in_(["failed", "error"])) - .group_by(TestResult.test_name, TestResult.test_file) - .order_by(desc("failure_count")) - .limit(10) - .all() - ) - - return { - # Current run stats - "total_tests": latest_run.total_tests if latest_run else 0, - "passed": latest_run.passed if latest_run else 0, - "failed": latest_run.failed if latest_run else 0, - "errors": latest_run.errors if latest_run else 0, - "skipped": latest_run.skipped if latest_run else 0, - "pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0, - "duration_seconds": round(latest_run.duration_seconds, 2) - if latest_run - else 0, - "coverage_percent": latest_run.coverage_percent if latest_run else None, - "last_run": latest_run.timestamp.isoformat() if latest_run else None, - "last_run_status": latest_run.status if latest_run else None, - # Collection stats - "total_test_files": collection.total_files if collection else 0, - "collected_tests": collection.total_tests if collection else 0, - "unit_tests": collection.unit_tests if collection else 0, - "integration_tests": collection.integration_tests if collection else 0, - "performance_tests": collection.performance_tests if collection else 0, - "system_tests": collection.system_tests if collection else 0, - "last_collected": collection.collected_at.isoformat() - if collection - else None, - # Trend data - "trend": [ - { - "timestamp": run.timestamp.isoformat(), - "total": run.total_tests, - "passed": run.passed, - "failed": run.failed, - "pass_rate": round(run.pass_rate, 1), - "duration": round(run.duration_seconds, 1), - } - for run in reversed(trend_runs) - ], - # By category - "by_category": by_category, - # Top failing tests - "top_failing": [ - { - "test_name": t.test_name, - "test_file": t.test_file, - "failure_count": t.failure_count, - } - for t in top_failing - ], - } - - def collect_tests(self, db: Session) -> TestCollection: - """Collect test information without running tests""" - collection = TestCollection( - collected_at=datetime.now(UTC), - ) - - try: - # Run pytest --collect-only with JSON report - with tempfile.NamedTemporaryFile( - mode="w", suffix=".json", delete=False - ) as f: - json_report_path = f.name - - result = subprocess.run( - [ - "python", - "-m", - "pytest", - "--collect-only", - "--json-report", - f"--json-report-file={json_report_path}", - "tests", - ], - cwd=str(self.project_root), - capture_output=True, - text=True, - timeout=120, - ) - - # Parse JSON report - json_path = Path(json_report_path) - if json_path.exists(): - with open(json_path) as f: - report = json.load(f) - - # Get total from summary - collection.total_tests = report.get("summary", {}).get("collected", 0) - - # Parse collectors to get test files and counts - test_files = {} - for collector in report.get("collectors", []): - for item in collector.get("result", []): - if item.get("type") == "Function": - node_id = item.get("nodeid", "") - if "::" in node_id: - file_path = node_id.split("::")[0] - if file_path not in test_files: - test_files[file_path] = 0 - test_files[file_path] += 1 - - # Count files and categorize - for file_path, count in test_files.items(): - collection.total_files += 1 - - if "/unit/" in file_path or file_path.startswith("tests/unit"): - collection.unit_tests += count - elif "/integration/" in file_path or file_path.startswith( - "tests/integration" - ): - collection.integration_tests += count - elif "/performance/" in file_path or file_path.startswith( - "tests/performance" - ): - collection.performance_tests += count - elif "/system/" in file_path or file_path.startswith( - "tests/system" - ): - collection.system_tests += count - - collection.test_files = [ - {"file": f, "count": c} - for f, c in sorted(test_files.items(), key=lambda x: -x[1]) - ] - - # Cleanup - json_path.unlink(missing_ok=True) - - logger.info( - f"Collected {collection.total_tests} tests from {collection.total_files} files" - ) - - except Exception as e: - logger.error(f"Error collecting tests: {e}", exc_info=True) - - db.add(collection) - return collection - - -# Singleton instance -test_runner_service = TestRunnerService() +from app.modules.dev_tools.services.test_runner_service import ( + test_runner_service, + TestRunnerService, +) + +__all__ = [ + "test_runner_service", + "TestRunnerService", +] diff --git a/models/database/customer.py b/models/database/customer.py index 4583ddd4..6986b857 100644 --- a/models/database/customer.py +++ b/models/database/customer.py @@ -1,82 +1,23 @@ -from sqlalchemy import ( - JSON, - Boolean, - Column, - DateTime, - ForeignKey, - Integer, - Numeric, - String, +# models/database/customer.py +""" +LEGACY LOCATION - Re-exports from module for backwards compatibility. + +The canonical implementation is now in: + app/modules/customers/models/customer.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.customers.models import Customer, CustomerAddress +""" + +from app.modules.customers.models.customer import ( + Customer, + CustomerAddress, ) -from sqlalchemy.orm import relationship -from app.core.database import Base - -from .base import TimestampMixin - - -class Customer(Base, TimestampMixin): - __tablename__ = "customers" - - id = Column(Integer, primary_key=True, index=True) - vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False) - email = Column( - String(255), nullable=False, index=True - ) # Unique within vendor scope - hashed_password = Column(String(255), nullable=False) - first_name = Column(String(100)) - last_name = Column(String(100)) - phone = Column(String(50)) - customer_number = Column( - String(100), nullable=False, index=True - ) # Vendor-specific ID - preferences = Column(JSON, default=dict) - marketing_consent = Column(Boolean, default=False) - last_order_date = Column(DateTime) - total_orders = Column(Integer, default=0) - total_spent = Column(Numeric(10, 2), default=0) - is_active = Column(Boolean, default=True, nullable=False) - - # Language preference (NULL = use vendor storefront_language default) - # Supported: en, fr, de, lb - preferred_language = Column(String(5), nullable=True) - - # Relationships - vendor = relationship("Vendor", back_populates="customers") - addresses = relationship("CustomerAddress", back_populates="customer") - orders = relationship("Order", back_populates="customer") - - def __repr__(self): - return f"" - - @property - def full_name(self): - if self.first_name and self.last_name: - return f"{self.first_name} {self.last_name}" - return self.email - - -class CustomerAddress(Base, TimestampMixin): - __tablename__ = "customer_addresses" - - id = Column(Integer, primary_key=True, index=True) - vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False) - customer_id = Column(Integer, ForeignKey("customers.id"), nullable=False) - address_type = Column(String(50), nullable=False) # 'billing', 'shipping' - first_name = Column(String(100), nullable=False) - last_name = Column(String(100), nullable=False) - company = Column(String(200)) - address_line_1 = Column(String(255), nullable=False) - address_line_2 = Column(String(255)) - city = Column(String(100), nullable=False) - postal_code = Column(String(20), nullable=False) - country_name = Column(String(100), nullable=False) - country_iso = Column(String(5), nullable=False) - is_default = Column(Boolean, default=False) - - # Relationships - vendor = relationship("Vendor") - customer = relationship("Customer", back_populates="addresses") - - def __repr__(self): - return f"" +__all__ = [ + "Customer", + "CustomerAddress", +] diff --git a/models/database/inventory.py b/models/database/inventory.py index 69ec0c6e..7d83a08c 100644 --- a/models/database/inventory.py +++ b/models/database/inventory.py @@ -1,61 +1,19 @@ # models/database/inventory.py """ -Inventory model for tracking stock at warehouse/bin locations. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Each entry represents a quantity of a product at a specific bin location -within a warehouse. Products can be scattered across multiple bins. +The canonical implementation is now in: + app/modules/inventory/models/inventory.py -Example: - Warehouse: "strassen" - Bin: "SA-10-02" - Product: GTIN 4007817144145 - Quantity: 3 +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.inventory.models import Inventory """ -from sqlalchemy import Column, ForeignKey, Index, Integer, String, UniqueConstraint -from sqlalchemy.orm import relationship +from app.modules.inventory.models.inventory import Inventory -from app.core.database import Base -from models.database.base import TimestampMixin - - -class Inventory(Base, TimestampMixin): - __tablename__ = "inventory" - - id = Column(Integer, primary_key=True, index=True) - product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True) - vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) - - # Location: warehouse + bin - warehouse = Column(String, nullable=False, default="strassen", index=True) - bin_location = Column(String, nullable=False, index=True) # e.g., "SA-10-02" - - # Legacy field - kept for backward compatibility, will be removed - location = Column(String, index=True) - - quantity = Column(Integer, nullable=False, default=0) - reserved_quantity = Column(Integer, default=0) - - # Keep GTIN for reference/reporting (matches Product.gtin) - gtin = Column(String, index=True) - - # Relationships - product = relationship("Product", back_populates="inventory_entries") - vendor = relationship("Vendor") - - # Constraints - __table_args__ = ( - UniqueConstraint( - "product_id", "warehouse", "bin_location", name="uq_inventory_product_warehouse_bin" - ), - Index("idx_inventory_vendor_product", "vendor_id", "product_id"), - Index("idx_inventory_warehouse_bin", "warehouse", "bin_location"), - ) - - def __repr__(self): - return f"" - - @property - def available_quantity(self): - """Calculate available quantity (total - reserved).""" - return max(0, self.quantity - self.reserved_quantity) +__all__ = [ + "Inventory", +] diff --git a/models/database/inventory_transaction.py b/models/database/inventory_transaction.py index 49c6ba08..725c0a5a 100644 --- a/models/database/inventory_transaction.py +++ b/models/database/inventory_transaction.py @@ -1,170 +1,23 @@ # models/database/inventory_transaction.py """ -Inventory Transaction Model - Audit trail for all stock movements. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -This model tracks every change to inventory quantities, providing: -- Complete audit trail for compliance and debugging -- Order-linked transactions for traceability -- Support for different transaction types (reserve, fulfill, adjust, etc.) +The canonical implementation is now in: + app/modules/inventory/models/inventory_transaction.py -All stock movements should create a transaction record. +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.inventory.models import InventoryTransaction, TransactionType """ -from datetime import UTC, datetime -from enum import Enum - -from sqlalchemy import ( - Column, - DateTime, - Enum as SQLEnum, - ForeignKey, - Index, - Integer, - String, - Text, +from app.modules.inventory.models.inventory_transaction import ( + InventoryTransaction, + TransactionType, ) -from sqlalchemy.orm import relationship -from app.core.database import Base - - -class TransactionType(str, Enum): - """Types of inventory transactions.""" - - # Order-related - RESERVE = "reserve" # Stock reserved for order - FULFILL = "fulfill" # Reserved stock consumed (shipped) - RELEASE = "release" # Reserved stock released (cancelled) - - # Manual adjustments - ADJUST = "adjust" # Manual adjustment (+/-) - SET = "set" # Set to exact quantity - - # Imports - IMPORT = "import" # Initial import/sync - - # Returns - RETURN = "return" # Stock returned from customer - - -class InventoryTransaction(Base): - """ - Audit log for inventory movements. - - Every change to inventory quantity creates a transaction record, - enabling complete traceability of stock levels over time. - """ - - __tablename__ = "inventory_transactions" - - id = Column(Integer, primary_key=True, index=True) - - # Core references - vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) - product_id = Column(Integer, ForeignKey("products.id"), nullable=False, index=True) - inventory_id = Column( - Integer, ForeignKey("inventory.id"), nullable=True, index=True - ) - - # Transaction details - transaction_type = Column( - SQLEnum(TransactionType), nullable=False, index=True - ) - quantity_change = Column(Integer, nullable=False) # Positive = add, negative = remove - - # Quantities after transaction (snapshot) - quantity_after = Column(Integer, nullable=False) - reserved_after = Column(Integer, nullable=False, default=0) - - # Location context - location = Column(String, nullable=True) - warehouse = Column(String, nullable=True) - - # Order reference (for order-related transactions) - order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True) - order_number = Column(String, nullable=True) - - # Audit fields - reason = Column(Text, nullable=True) # Human-readable reason - created_by = Column(String, nullable=True) # User/system that created - created_at = Column( - DateTime(timezone=True), - default=lambda: datetime.now(UTC), - nullable=False, - index=True, - ) - - # Relationships - vendor = relationship("Vendor") - product = relationship("Product") - inventory = relationship("Inventory") - order = relationship("Order") - - # Indexes for common queries - __table_args__ = ( - Index("idx_inv_tx_vendor_product", "vendor_id", "product_id"), - Index("idx_inv_tx_vendor_created", "vendor_id", "created_at"), - Index("idx_inv_tx_order", "order_id"), - Index("idx_inv_tx_type_created", "transaction_type", "created_at"), - ) - - def __repr__(self) -> str: - return ( - f"" - ) - - @classmethod - def create_transaction( - cls, - vendor_id: int, - product_id: int, - transaction_type: TransactionType, - quantity_change: int, - quantity_after: int, - reserved_after: int = 0, - inventory_id: int | None = None, - location: str | None = None, - warehouse: str | None = None, - order_id: int | None = None, - order_number: str | None = None, - reason: str | None = None, - created_by: str | None = None, - ) -> "InventoryTransaction": - """ - Factory method to create a transaction record. - - Args: - vendor_id: Vendor ID - product_id: Product ID - transaction_type: Type of transaction - quantity_change: Change in quantity (positive = add, negative = remove) - quantity_after: Total quantity after this transaction - reserved_after: Reserved quantity after this transaction - inventory_id: Optional inventory record ID - location: Optional location - warehouse: Optional warehouse - order_id: Optional order ID (for order-related transactions) - order_number: Optional order number for display - reason: Optional human-readable reason - created_by: Optional user/system identifier - - Returns: - InventoryTransaction instance (not yet added to session) - """ - return cls( - vendor_id=vendor_id, - product_id=product_id, - inventory_id=inventory_id, - transaction_type=transaction_type, - quantity_change=quantity_change, - quantity_after=quantity_after, - reserved_after=reserved_after, - location=location, - warehouse=warehouse, - order_id=order_id, - order_number=order_number, - reason=reason, - created_by=created_by, - ) +__all__ = [ + "InventoryTransaction", + "TransactionType", +] diff --git a/models/database/invoice.py b/models/database/invoice.py index e0435617..1979c70c 100644 --- a/models/database/invoice.py +++ b/models/database/invoice.py @@ -1,215 +1,27 @@ # models/database/invoice.py """ -Invoice database models for the OMS. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Provides models for: -- VendorInvoiceSettings: Per-vendor invoice configuration (company details, VAT, numbering) -- Invoice: Invoice records with snapshots of seller/buyer details +The canonical implementation is now in: + app/modules/orders/models/invoice.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.orders.models import Invoice, InvoiceStatus, VATRegime, VendorInvoiceSettings """ -import enum - -from sqlalchemy import ( - Boolean, - Column, - DateTime, - ForeignKey, - Index, - Integer, - Numeric, - String, - Text, +from app.modules.orders.models.invoice import ( + Invoice, + InvoiceStatus, + VATRegime, + VendorInvoiceSettings, ) -from sqlalchemy.dialects.sqlite import JSON -from sqlalchemy.orm import relationship -from app.core.database import Base -from models.database.base import TimestampMixin - - -class VendorInvoiceSettings(Base, TimestampMixin): - """ - Per-vendor invoice configuration. - - Stores company details, VAT number, invoice numbering preferences, - and payment information for invoice generation. - - One-to-one relationship with Vendor. - """ - - __tablename__ = "vendor_invoice_settings" - - id = Column(Integer, primary_key=True, index=True) - vendor_id = Column( - Integer, ForeignKey("vendors.id"), unique=True, nullable=False, index=True - ) - - # Legal company details for invoice header - company_name = Column(String(255), nullable=False) # Legal name for invoices - company_address = Column(String(255), nullable=True) # Street address - company_city = Column(String(100), nullable=True) - company_postal_code = Column(String(20), nullable=True) - company_country = Column(String(2), nullable=False, default="LU") # ISO country code - - # VAT information - vat_number = Column(String(50), nullable=True) # e.g., "LU12345678" - is_vat_registered = Column(Boolean, default=True, nullable=False) - - # OSS (One-Stop-Shop) for EU VAT - is_oss_registered = Column(Boolean, default=False, nullable=False) - oss_registration_country = Column(String(2), nullable=True) # ISO country code - - # Invoice numbering - invoice_prefix = Column(String(20), default="INV", nullable=False) - invoice_next_number = Column(Integer, default=1, nullable=False) - invoice_number_padding = Column(Integer, default=5, nullable=False) # e.g., INV00001 - - # Payment information - payment_terms = Column(Text, nullable=True) # e.g., "Payment due within 30 days" - bank_name = Column(String(255), nullable=True) - bank_iban = Column(String(50), nullable=True) - bank_bic = Column(String(20), nullable=True) - - # Invoice footer - footer_text = Column(Text, nullable=True) # Custom footer text - - # Default VAT rate for Luxembourg invoices (17% standard) - default_vat_rate = Column(Numeric(5, 2), default=17.00, nullable=False) - - # Relationships - vendor = relationship("Vendor", back_populates="invoice_settings") - - def __repr__(self): - return f"" - - def get_next_invoice_number(self) -> str: - """Generate the next invoice number and increment counter.""" - number = str(self.invoice_next_number).zfill(self.invoice_number_padding) - return f"{self.invoice_prefix}{number}" - - -class InvoiceStatus(str, enum.Enum): - """Invoice status enumeration.""" - - DRAFT = "draft" - ISSUED = "issued" - PAID = "paid" - CANCELLED = "cancelled" - - -class VATRegime(str, enum.Enum): - """VAT regime for invoice calculation.""" - - DOMESTIC = "domestic" # Same country as seller - OSS = "oss" # EU cross-border with OSS registration - REVERSE_CHARGE = "reverse_charge" # B2B with valid VAT number - ORIGIN = "origin" # Cross-border without OSS (use origin VAT) - EXEMPT = "exempt" # VAT exempt - - -class Invoice(Base, TimestampMixin): - """ - Invoice record with snapshots of seller/buyer details. - - Stores complete invoice data including snapshots of seller and buyer - details at time of creation for audit purposes. - """ - - __tablename__ = "invoices" - - id = Column(Integer, primary_key=True, index=True) - vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) - order_id = Column(Integer, ForeignKey("orders.id"), nullable=True, index=True) - - # Invoice identification - invoice_number = Column(String(50), nullable=False) - invoice_date = Column(DateTime(timezone=True), nullable=False) - - # Status - status = Column(String(20), default=InvoiceStatus.DRAFT.value, nullable=False) - - # Seller details snapshot (captured at invoice creation) - seller_details = Column(JSON, nullable=False) - # Structure: { - # "company_name": str, - # "address": str, - # "city": str, - # "postal_code": str, - # "country": str, - # "vat_number": str | None - # } - - # Buyer details snapshot (captured at invoice creation) - buyer_details = Column(JSON, nullable=False) - # Structure: { - # "name": str, - # "email": str, - # "address": str, - # "city": str, - # "postal_code": str, - # "country": str, - # "vat_number": str | None (for B2B) - # } - - # Line items snapshot - line_items = Column(JSON, nullable=False) - # Structure: [{ - # "description": str, - # "quantity": int, - # "unit_price_cents": int, - # "total_cents": int, - # "sku": str | None, - # "ean": str | None - # }] - - # VAT information - vat_regime = Column(String(20), default=VATRegime.DOMESTIC.value, nullable=False) - destination_country = Column(String(2), nullable=True) # For OSS invoices - vat_rate = Column(Numeric(5, 2), nullable=False) # e.g., 17.00 for 17% - vat_rate_label = Column(String(50), nullable=True) # e.g., "Luxembourg Standard VAT" - - # Amounts (stored in cents for precision) - currency = Column(String(3), default="EUR", nullable=False) - subtotal_cents = Column(Integer, nullable=False) # Before VAT - vat_amount_cents = Column(Integer, nullable=False) # VAT amount - total_cents = Column(Integer, nullable=False) # After VAT - - # Payment information - payment_terms = Column(Text, nullable=True) - bank_details = Column(JSON, nullable=True) # IBAN, BIC snapshot - footer_text = Column(Text, nullable=True) - - # PDF storage - pdf_generated_at = Column(DateTime(timezone=True), nullable=True) - pdf_path = Column(String(500), nullable=True) # Path to stored PDF - - # Notes - notes = Column(Text, nullable=True) # Internal notes - - # Relationships - vendor = relationship("Vendor", back_populates="invoices") - order = relationship("Order", back_populates="invoices") - - __table_args__ = ( - Index("idx_invoice_vendor_number", "vendor_id", "invoice_number", unique=True), - Index("idx_invoice_vendor_date", "vendor_id", "invoice_date"), - Index("idx_invoice_status", "vendor_id", "status"), - ) - - def __repr__(self): - return f"" - - @property - def subtotal(self) -> float: - """Get subtotal in EUR.""" - return self.subtotal_cents / 100 - - @property - def vat_amount(self) -> float: - """Get VAT amount in EUR.""" - return self.vat_amount_cents / 100 - - @property - def total(self) -> float: - """Get total in EUR.""" - return self.total_cents / 100 +__all__ = [ + "Invoice", + "InvoiceStatus", + "VATRegime", + "VendorInvoiceSettings", +] diff --git a/models/database/message.py b/models/database/message.py index 2ed2fecc..212c0381 100644 --- a/models/database/message.py +++ b/models/database/message.py @@ -1,272 +1,31 @@ # models/database/message.py """ -Messaging system database models. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Supports three communication channels: -- Admin <-> Vendor -- Vendor <-> Customer -- Admin <-> Customer +The canonical implementation is now in: + app/modules/messaging/models/message.py -Multi-tenant isolation is enforced via vendor_id for conversations -involving customers. +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.messaging.models import message """ -import enum -from datetime import datetime - -from sqlalchemy import ( - Boolean, - Column, - DateTime, - Enum, - ForeignKey, - Index, - Integer, - String, - Text, - UniqueConstraint, +from app.modules.messaging.models.message import ( + Conversation, + ConversationParticipant, + ConversationType, + Message, + MessageAttachment, + ParticipantType, ) -from sqlalchemy.orm import relationship -from app.core.database import Base -from models.database.base import TimestampMixin - - -class ConversationType(str, enum.Enum): - """Defines the three supported conversation channels.""" - - ADMIN_VENDOR = "admin_vendor" - VENDOR_CUSTOMER = "vendor_customer" - ADMIN_CUSTOMER = "admin_customer" - - -class ParticipantType(str, enum.Enum): - """Type of participant in a conversation.""" - - ADMIN = "admin" # User with role="admin" - VENDOR = "vendor" # User with role="vendor" (via VendorUser) - CUSTOMER = "customer" # Customer model - - -def _enum_values(enum_class): - """Extract enum values for SQLAlchemy Enum column.""" - return [e.value for e in enum_class] - - -class Conversation(Base, TimestampMixin): - """ - Represents a threaded conversation between participants. - - Multi-tenancy: vendor_id is required for vendor_customer and admin_customer - conversations to ensure customer data isolation. - """ - - __tablename__ = "conversations" - - id = Column(Integer, primary_key=True, index=True) - - # Conversation type determines participant structure - conversation_type = Column( - Enum(ConversationType, values_callable=_enum_values), - nullable=False, - index=True, - ) - - # Subject line for the conversation thread - subject = Column(String(500), nullable=False) - - # For vendor_customer and admin_customer conversations - # Required for multi-tenant data isolation - vendor_id = Column( - Integer, - ForeignKey("vendors.id"), - nullable=True, - index=True, - ) - - # Status flags - is_closed = Column(Boolean, default=False, nullable=False) - closed_at = Column(DateTime, nullable=True) - closed_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True) - closed_by_id = Column(Integer, nullable=True) - - # Last activity tracking for sorting - last_message_at = Column(DateTime, nullable=True, index=True) - message_count = Column(Integer, default=0, nullable=False) - - # Relationships - vendor = relationship("Vendor", foreign_keys=[vendor_id]) - participants = relationship( - "ConversationParticipant", - back_populates="conversation", - cascade="all, delete-orphan", - ) - messages = relationship( - "Message", - back_populates="conversation", - cascade="all, delete-orphan", - order_by="Message.created_at", - ) - - # Indexes for common queries - __table_args__ = ( - Index("ix_conversations_type_vendor", "conversation_type", "vendor_id"), - ) - - def __repr__(self) -> str: - return ( - f"" - ) - - -class ConversationParticipant(Base, TimestampMixin): - """ - Links participants (users or customers) to conversations. - - Polymorphic relationship: - - participant_type="admin" or "vendor": references users.id - - participant_type="customer": references customers.id - """ - - __tablename__ = "conversation_participants" - - id = Column(Integer, primary_key=True, index=True) - conversation_id = Column( - Integer, - ForeignKey("conversations.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - - # Polymorphic participant reference - participant_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False) - participant_id = Column(Integer, nullable=False, index=True) - - # For vendor participants, track which vendor they represent - vendor_id = Column( - Integer, - ForeignKey("vendors.id"), - nullable=True, - ) - - # Unread tracking per participant - unread_count = Column(Integer, default=0, nullable=False) - last_read_at = Column(DateTime, nullable=True) - - # Notification preferences for this conversation - email_notifications = Column(Boolean, default=True, nullable=False) - muted = Column(Boolean, default=False, nullable=False) - - # Relationships - conversation = relationship("Conversation", back_populates="participants") - vendor = relationship("Vendor", foreign_keys=[vendor_id]) - - __table_args__ = ( - UniqueConstraint( - "conversation_id", - "participant_type", - "participant_id", - name="uq_conversation_participant", - ), - Index( - "ix_participant_lookup", - "participant_type", - "participant_id", - ), - ) - - def __repr__(self) -> str: - return ( - f"" - ) - - -class Message(Base, TimestampMixin): - """ - Individual message within a conversation thread. - - Sender polymorphism follows same pattern as participant. - """ - - __tablename__ = "messages" - - id = Column(Integer, primary_key=True, index=True) - conversation_id = Column( - Integer, - ForeignKey("conversations.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - - # Polymorphic sender reference - sender_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=False) - sender_id = Column(Integer, nullable=False, index=True) - - # Message content - content = Column(Text, nullable=False) - - # System messages (e.g., "conversation closed") - is_system_message = Column(Boolean, default=False, nullable=False) - - # Soft delete for moderation - is_deleted = Column(Boolean, default=False, nullable=False) - deleted_at = Column(DateTime, nullable=True) - deleted_by_type = Column(Enum(ParticipantType, values_callable=_enum_values), nullable=True) - deleted_by_id = Column(Integer, nullable=True) - - # Relationships - conversation = relationship("Conversation", back_populates="messages") - attachments = relationship( - "MessageAttachment", - back_populates="message", - cascade="all, delete-orphan", - ) - - __table_args__ = ( - Index("ix_messages_conversation_created", "conversation_id", "created_at"), - ) - - def __repr__(self) -> str: - return ( - f"" - ) - - -class MessageAttachment(Base, TimestampMixin): - """ - File attachments for messages. - - Files are stored in platform storage (local/S3) with references here. - """ - - __tablename__ = "message_attachments" - - id = Column(Integer, primary_key=True, index=True) - message_id = Column( - Integer, - ForeignKey("messages.id", ondelete="CASCADE"), - nullable=False, - index=True, - ) - - # File metadata - filename = Column(String(255), nullable=False) - original_filename = Column(String(255), nullable=False) - file_path = Column(String(1000), nullable=False) # Storage path - file_size = Column(Integer, nullable=False) # Size in bytes - mime_type = Column(String(100), nullable=False) - - # For image attachments - is_image = Column(Boolean, default=False, nullable=False) - image_width = Column(Integer, nullable=True) - image_height = Column(Integer, nullable=True) - thumbnail_path = Column(String(1000), nullable=True) - - # Relationships - message = relationship("Message", back_populates="attachments") - - def __repr__(self) -> str: - return f"" +__all__ = [ + "Conversation", + "ConversationParticipant", + "ConversationType", + "Message", + "MessageAttachment", + "ParticipantType", +] diff --git a/models/database/order.py b/models/database/order.py index b8c66eba..ee1b2b30 100644 --- a/models/database/order.py +++ b/models/database/order.py @@ -1,406 +1,20 @@ # models/database/order.py """ -Unified Order model for all sales channels. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Supports: -- Direct orders (from vendor's own storefront) -- Marketplace orders (Letzshop, etc.) +The canonical implementation is now in: + app/modules/orders/models/order.py -Design principles: -- Customer/address data is snapshotted at order time (preserves history) -- customer_id FK links to Customer record (may be inactive for marketplace imports) -- channel field distinguishes order source -- external_* fields store marketplace-specific references +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: -Money values are stored as integer cents (e.g., €105.91 = 10591). -See docs/architecture/money-handling.md for details. + from app.modules.orders.models import Order, OrderItem """ -from sqlalchemy import ( - Boolean, - Column, - DateTime, - ForeignKey, - Index, - Integer, - Numeric, - String, - Text, -) -from typing import TYPE_CHECKING +from app.modules.orders.models.order import Order, OrderItem -if TYPE_CHECKING: - from models.database.order_item_exception import OrderItemException -from sqlalchemy.dialects.sqlite import JSON -from sqlalchemy.orm import relationship - -from app.core.database import Base -from app.utils.money import cents_to_euros, euros_to_cents -from models.database.base import TimestampMixin - - -class Order(Base, TimestampMixin): - """ - Unified order model for all sales channels. - - Stores orders from direct sales and marketplaces (Letzshop, etc.) - with snapshotted customer and address data. - - All monetary amounts are stored as integer cents for precision. - """ - - __tablename__ = "orders" - - id = Column(Integer, primary_key=True, index=True) - vendor_id = Column(Integer, ForeignKey("vendors.id"), nullable=False, index=True) - customer_id = Column( - Integer, ForeignKey("customers.id"), nullable=False, index=True - ) - order_number = Column(String(100), nullable=False, unique=True, index=True) - - # === Channel/Source === - channel = Column( - String(50), default="direct", nullable=False, index=True - ) # direct, letzshop - - # External references (for marketplace orders) - external_order_id = Column( - String(100), nullable=True, index=True - ) # Marketplace order ID - external_shipment_id = Column( - String(100), nullable=True, index=True - ) # Marketplace shipment ID - external_order_number = Column(String(100), nullable=True) # Marketplace order # - external_data = Column(JSON, nullable=True) # Raw marketplace data for debugging - - # === Status === - # pending: awaiting confirmation - # processing: confirmed, being prepared - # shipped: shipped with tracking - # delivered: delivered to customer - # cancelled: order cancelled/declined - # refunded: order refunded - status = Column(String(50), nullable=False, default="pending", index=True) - - # === Financials (stored as integer cents) === - subtotal_cents = Column(Integer, nullable=True) # May not be available from marketplace - tax_amount_cents = Column(Integer, nullable=True) - shipping_amount_cents = Column(Integer, nullable=True) - discount_amount_cents = Column(Integer, nullable=True) - total_amount_cents = Column(Integer, nullable=False) - currency = Column(String(10), default="EUR") - - # === VAT Information === - # VAT regime: domestic, oss, reverse_charge, origin, exempt - vat_regime = Column(String(20), nullable=True) - # VAT rate as percentage (e.g., 17.00 for 17%) - vat_rate = Column(Numeric(5, 2), nullable=True) - # Human-readable VAT label (e.g., "Luxembourg VAT 17%") - vat_rate_label = Column(String(100), nullable=True) - # Destination country for cross-border sales (ISO code) - vat_destination_country = Column(String(2), nullable=True) - - # === Customer Snapshot (preserved at order time) === - customer_first_name = Column(String(100), nullable=False) - customer_last_name = Column(String(100), nullable=False) - customer_email = Column(String(255), nullable=False) - customer_phone = Column(String(50), nullable=True) - customer_locale = Column(String(10), nullable=True) # en, fr, de, lb - - # === Shipping Address Snapshot === - ship_first_name = Column(String(100), nullable=False) - ship_last_name = Column(String(100), nullable=False) - ship_company = Column(String(200), nullable=True) - ship_address_line_1 = Column(String(255), nullable=False) - ship_address_line_2 = Column(String(255), nullable=True) - ship_city = Column(String(100), nullable=False) - ship_postal_code = Column(String(20), nullable=False) - ship_country_iso = Column(String(5), nullable=False) - - # === Billing Address Snapshot === - bill_first_name = Column(String(100), nullable=False) - bill_last_name = Column(String(100), nullable=False) - bill_company = Column(String(200), nullable=True) - bill_address_line_1 = Column(String(255), nullable=False) - bill_address_line_2 = Column(String(255), nullable=True) - bill_city = Column(String(100), nullable=False) - bill_postal_code = Column(String(20), nullable=False) - bill_country_iso = Column(String(5), nullable=False) - - # === Tracking === - shipping_method = Column(String(100), nullable=True) - tracking_number = Column(String(100), nullable=True) - tracking_provider = Column(String(100), nullable=True) - tracking_url = Column(String(500), nullable=True) # Full tracking URL - shipment_number = Column(String(100), nullable=True) # Carrier shipment number (e.g., H74683403433) - shipping_carrier = Column(String(50), nullable=True) # Carrier code (greco, colissimo, etc.) - - # === Notes === - customer_notes = Column(Text, nullable=True) - internal_notes = Column(Text, nullable=True) - - # === Timestamps === - order_date = Column( - DateTime(timezone=True), nullable=False - ) # When customer placed order - confirmed_at = Column(DateTime(timezone=True), nullable=True) - shipped_at = Column(DateTime(timezone=True), nullable=True) - delivered_at = Column(DateTime(timezone=True), nullable=True) - cancelled_at = Column(DateTime(timezone=True), nullable=True) - - # === Relationships === - vendor = relationship("Vendor") - customer = relationship("Customer", back_populates="orders") - items = relationship( - "OrderItem", back_populates="order", cascade="all, delete-orphan" - ) - invoices = relationship( - "Invoice", back_populates="order", cascade="all, delete-orphan" - ) - - # Composite indexes for common queries - __table_args__ = ( - Index("idx_order_vendor_status", "vendor_id", "status"), - Index("idx_order_vendor_channel", "vendor_id", "channel"), - Index("idx_order_vendor_date", "vendor_id", "order_date"), - ) - - def __repr__(self): - return f"" - - # === PRICE PROPERTIES (Euro convenience accessors) === - - @property - def subtotal(self) -> float | None: - """Get subtotal in euros.""" - if self.subtotal_cents is not None: - return cents_to_euros(self.subtotal_cents) - return None - - @subtotal.setter - def subtotal(self, value: float | None): - """Set subtotal from euros.""" - self.subtotal_cents = euros_to_cents(value) if value is not None else None - - @property - def tax_amount(self) -> float | None: - """Get tax amount in euros.""" - if self.tax_amount_cents is not None: - return cents_to_euros(self.tax_amount_cents) - return None - - @tax_amount.setter - def tax_amount(self, value: float | None): - """Set tax amount from euros.""" - self.tax_amount_cents = euros_to_cents(value) if value is not None else None - - @property - def shipping_amount(self) -> float | None: - """Get shipping amount in euros.""" - if self.shipping_amount_cents is not None: - return cents_to_euros(self.shipping_amount_cents) - return None - - @shipping_amount.setter - def shipping_amount(self, value: float | None): - """Set shipping amount from euros.""" - self.shipping_amount_cents = euros_to_cents(value) if value is not None else None - - @property - def discount_amount(self) -> float | None: - """Get discount amount in euros.""" - if self.discount_amount_cents is not None: - return cents_to_euros(self.discount_amount_cents) - return None - - @discount_amount.setter - def discount_amount(self, value: float | None): - """Set discount amount from euros.""" - self.discount_amount_cents = euros_to_cents(value) if value is not None else None - - @property - def total_amount(self) -> float: - """Get total amount in euros.""" - return cents_to_euros(self.total_amount_cents) - - @total_amount.setter - def total_amount(self, value: float): - """Set total amount from euros.""" - self.total_amount_cents = euros_to_cents(value) - - # === NAME PROPERTIES === - - @property - def customer_full_name(self) -> str: - """Customer full name from snapshot.""" - return f"{self.customer_first_name} {self.customer_last_name}".strip() - - @property - def ship_full_name(self) -> str: - """Shipping address full name.""" - return f"{self.ship_first_name} {self.ship_last_name}".strip() - - @property - def bill_full_name(self) -> str: - """Billing address full name.""" - return f"{self.bill_first_name} {self.bill_last_name}".strip() - - @property - def is_marketplace_order(self) -> bool: - """Check if this is a marketplace order.""" - return self.channel != "direct" - - @property - def is_fully_shipped(self) -> bool: - """Check if all items are fully shipped.""" - if not self.items: - return False - return all(item.is_fully_shipped for item in self.items) - - @property - def is_partially_shipped(self) -> bool: - """Check if some items are shipped but not all.""" - if not self.items: - return False - has_shipped = any(item.shipped_quantity > 0 for item in self.items) - all_shipped = all(item.is_fully_shipped for item in self.items) - return has_shipped and not all_shipped - - @property - def shipped_item_count(self) -> int: - """Count of fully shipped items.""" - return sum(1 for item in self.items if item.is_fully_shipped) - - @property - def total_shipped_units(self) -> int: - """Total quantity shipped across all items.""" - return sum(item.shipped_quantity for item in self.items) - - @property - def total_ordered_units(self) -> int: - """Total quantity ordered across all items.""" - return sum(item.quantity for item in self.items) - - -class OrderItem(Base, TimestampMixin): - """ - Individual items in an order. - - Stores product snapshot at time of order plus external references - for marketplace items. - - All monetary amounts are stored as integer cents for precision. - """ - - __tablename__ = "order_items" - - id = Column(Integer, primary_key=True, index=True) - order_id = Column(Integer, ForeignKey("orders.id"), nullable=False, index=True) - product_id = Column(Integer, ForeignKey("products.id"), nullable=False) - - # === Product Snapshot (preserved at order time) === - product_name = Column(String(255), nullable=False) - product_sku = Column(String(100), nullable=True) - gtin = Column(String(50), nullable=True) # EAN/UPC/ISBN etc. - gtin_type = Column(String(20), nullable=True) # ean13, upc, isbn, etc. - - # === Pricing (stored as integer cents) === - quantity = Column(Integer, nullable=False) - unit_price_cents = Column(Integer, nullable=False) - total_price_cents = Column(Integer, nullable=False) - - # === External References (for marketplace items) === - external_item_id = Column(String(100), nullable=True) # e.g., Letzshop inventory unit ID - external_variant_id = Column(String(100), nullable=True) # e.g., Letzshop variant ID - - # === Item State (for marketplace confirmation flow) === - # confirmed_available: item confirmed and available - # confirmed_unavailable: item confirmed but not available (declined) - item_state = Column(String(50), nullable=True) - - # === Inventory Tracking === - inventory_reserved = Column(Boolean, default=False) - inventory_fulfilled = Column(Boolean, default=False) - - # === Shipment Tracking === - shipped_quantity = Column(Integer, default=0, nullable=False) # Units shipped so far - - # === Exception Tracking === - # True if product was not found by GTIN during import (linked to placeholder) - needs_product_match = Column(Boolean, default=False, index=True) - - # === Relationships === - order = relationship("Order", back_populates="items") - product = relationship("Product") - exception = relationship( - "OrderItemException", - back_populates="order_item", - uselist=False, - cascade="all, delete-orphan", - ) - - def __repr__(self): - return f"" - - # === PRICE PROPERTIES (Euro convenience accessors) === - - @property - def unit_price(self) -> float: - """Get unit price in euros.""" - return cents_to_euros(self.unit_price_cents) - - @unit_price.setter - def unit_price(self, value: float): - """Set unit price from euros.""" - self.unit_price_cents = euros_to_cents(value) - - @property - def total_price(self) -> float: - """Get total price in euros.""" - return cents_to_euros(self.total_price_cents) - - @total_price.setter - def total_price(self, value: float): - """Set total price from euros.""" - self.total_price_cents = euros_to_cents(value) - - # === STATUS PROPERTIES === - - @property - def is_confirmed(self) -> bool: - """Check if item has been confirmed (available or unavailable).""" - return self.item_state in ("confirmed_available", "confirmed_unavailable") - - @property - def is_available(self) -> bool: - """Check if item is confirmed as available.""" - return self.item_state == "confirmed_available" - - @property - def is_declined(self) -> bool: - """Check if item was declined (unavailable).""" - return self.item_state == "confirmed_unavailable" - - @property - def has_unresolved_exception(self) -> bool: - """Check if item has an unresolved exception blocking confirmation.""" - if not self.exception: - return False - return self.exception.blocks_confirmation - - # === SHIPMENT PROPERTIES === - - @property - def remaining_quantity(self) -> int: - """Quantity not yet shipped.""" - return max(0, self.quantity - self.shipped_quantity) - - @property - def is_fully_shipped(self) -> bool: - """Check if all units have been shipped.""" - return self.shipped_quantity >= self.quantity - - @property - def is_partially_shipped(self) -> bool: - """Check if some but not all units have been shipped.""" - return 0 < self.shipped_quantity < self.quantity +__all__ = [ + "Order", + "OrderItem", +] diff --git a/models/database/order_item_exception.py b/models/database/order_item_exception.py index 07832357..4a2fc0cb 100644 --- a/models/database/order_item_exception.py +++ b/models/database/order_item_exception.py @@ -1,117 +1,19 @@ # models/database/order_item_exception.py """ -Order Item Exception model for tracking unmatched products during marketplace imports. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -When a marketplace order contains a GTIN that doesn't match any product in the -vendor's catalog, the order is still imported but the item is linked to a -placeholder product and an exception is recorded here for resolution. +The canonical implementation is now in: + app/modules/orders/models/order_item_exception.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.orders.models import OrderItemException """ -from sqlalchemy import ( - Column, - DateTime, - ForeignKey, - Index, - Integer, - String, - Text, -) -from sqlalchemy.orm import relationship +from app.modules.orders.models.order_item_exception import OrderItemException -from app.core.database import Base -from models.database.base import TimestampMixin - - -class OrderItemException(Base, TimestampMixin): - """ - Tracks unmatched order items requiring admin/vendor resolution. - - When a marketplace order is imported and a product cannot be found by GTIN, - the order item is linked to a placeholder product and this exception record - is created. The order cannot be confirmed until all exceptions are resolved. - """ - - __tablename__ = "order_item_exceptions" - - id = Column(Integer, primary_key=True, index=True) - - # Link to the order item (one-to-one) - order_item_id = Column( - Integer, - ForeignKey("order_items.id", ondelete="CASCADE"), - nullable=False, - unique=True, - ) - - # Vendor ID for efficient querying (denormalized from order) - vendor_id = Column( - Integer, ForeignKey("vendors.id"), nullable=False, index=True - ) - - # Original data from marketplace (preserved for matching) - original_gtin = Column(String(50), nullable=True, index=True) - original_product_name = Column(String(500), nullable=True) - original_sku = Column(String(100), nullable=True) - - # Exception classification - # product_not_found: GTIN not in vendor catalog - # gtin_mismatch: GTIN format issue - # duplicate_gtin: Multiple products with same GTIN - exception_type = Column( - String(50), nullable=False, default="product_not_found" - ) - - # Resolution status - # pending: Awaiting resolution - # resolved: Product has been assigned - # ignored: Marked as ignored (still blocks confirmation) - status = Column(String(50), nullable=False, default="pending", index=True) - - # Resolution details (populated when resolved) - resolved_product_id = Column( - Integer, ForeignKey("products.id"), nullable=True - ) - resolved_at = Column(DateTime(timezone=True), nullable=True) - resolved_by = Column(Integer, ForeignKey("users.id"), nullable=True) - resolution_notes = Column(Text, nullable=True) - - # Relationships - order_item = relationship("OrderItem", back_populates="exception") - vendor = relationship("Vendor") - resolved_product = relationship("Product") - resolver = relationship("User") - - # Composite indexes for common queries - __table_args__ = ( - Index("idx_exception_vendor_status", "vendor_id", "status"), - Index("idx_exception_gtin", "vendor_id", "original_gtin"), - ) - - def __repr__(self): - return ( - f"" - ) - - @property - def is_pending(self) -> bool: - """Check if exception is pending resolution.""" - return self.status == "pending" - - @property - def is_resolved(self) -> bool: - """Check if exception has been resolved.""" - return self.status == "resolved" - - @property - def is_ignored(self) -> bool: - """Check if exception has been ignored.""" - return self.status == "ignored" - - @property - def blocks_confirmation(self) -> bool: - """Check if this exception blocks order confirmation.""" - # Both pending and ignored exceptions block confirmation - return self.status in ("pending", "ignored") +__all__ = [ + "OrderItemException", +] diff --git a/models/database/password_reset_token.py b/models/database/password_reset_token.py index b46cef5c..48e38285 100644 --- a/models/database/password_reset_token.py +++ b/models/database/password_reset_token.py @@ -1,85 +1,19 @@ -import hashlib -import secrets -from datetime import datetime, timedelta +# models/database/password_reset_token.py +""" +LEGACY LOCATION - Re-exports from module for backwards compatibility. -from sqlalchemy import Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import Session, relationship +The canonical implementation is now in: + app/modules/customers/models/password_reset_token.py -from app.core.database import Base +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + from app.modules.customers.models import PasswordResetToken +""" -class PasswordResetToken(Base): - """Password reset token for customer accounts. +from app.modules.customers.models.password_reset_token import PasswordResetToken - Security: - - Tokens are stored as SHA256 hashes, not plaintext - - Tokens expire after 1 hour - - Only one active token per customer (old tokens invalidated on new request) - """ - - __tablename__ = "password_reset_tokens" - - # Token expiry in hours - TOKEN_EXPIRY_HOURS = 1 - - id = Column(Integer, primary_key=True, index=True) - customer_id = Column(Integer, ForeignKey("customers.id", ondelete="CASCADE"), nullable=False) - token_hash = Column(String(64), nullable=False, index=True) - expires_at = Column(DateTime, nullable=False) - used_at = Column(DateTime, nullable=True) - created_at = Column(DateTime, default=datetime.utcnow, nullable=False) - - # Relationships - customer = relationship("Customer") - - def __repr__(self): - return f"" - - @staticmethod - def hash_token(token: str) -> str: - """Hash a token using SHA256.""" - return hashlib.sha256(token.encode()).hexdigest() - - @classmethod - def create_for_customer(cls, db: Session, customer_id: int) -> str: - """Create a new password reset token for a customer. - - Invalidates any existing tokens for the customer. - Returns the plaintext token (to be sent via email). - """ - # Invalidate existing tokens for this customer - db.query(cls).filter( - cls.customer_id == customer_id, - cls.used_at.is_(None), - ).delete() - - # Generate new token - plaintext_token = secrets.token_urlsafe(32) - token_hash = cls.hash_token(plaintext_token) - - # Create token record - token = cls( - customer_id=customer_id, - token_hash=token_hash, - expires_at=datetime.utcnow() + timedelta(hours=cls.TOKEN_EXPIRY_HOURS), - ) - db.add(token) - db.flush() - - return plaintext_token - - @classmethod - def find_valid_token(cls, db: Session, plaintext_token: str) -> "PasswordResetToken | None": - """Find a valid (not expired, not used) token.""" - token_hash = cls.hash_token(plaintext_token) - - return db.query(cls).filter( - cls.token_hash == token_hash, - cls.expires_at > datetime.utcnow(), - cls.used_at.is_(None), - ).first() - - def mark_used(self, db: Session) -> None: - """Mark this token as used.""" - self.used_at = datetime.utcnow() - db.flush() +__all__ = [ + "PasswordResetToken", +] diff --git a/models/schema/customer.py b/models/schema/customer.py index dc1f4a7f..34c9a954 100644 --- a/models/schema/customer.py +++ b/models/schema/customer.py @@ -1,333 +1,69 @@ # models/schema/customer.py """ -Pydantic schema for customer-related operations. +LEGACY LOCATION - Re-exports from module for backwards compatibility. + +The canonical implementation is now in: + app/modules/customers/schemas/customer.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.customers.schemas import CustomerRegister, CustomerResponse """ -from datetime import datetime -from decimal import Decimal - -from pydantic import BaseModel, EmailStr, Field, field_validator - -# ============================================================================ -# Customer Registration & Authentication -# ============================================================================ - - -class CustomerRegister(BaseModel): - """Schema for customer registration.""" - - email: EmailStr = Field(..., description="Customer email address") - password: str = Field( - ..., min_length=8, description="Password (minimum 8 characters)" - ) - first_name: str = Field(..., min_length=1, max_length=100) - last_name: str = Field(..., min_length=1, max_length=100) - phone: str | None = Field(None, max_length=50) - marketing_consent: bool = Field(default=False) - preferred_language: str | None = Field( - None, description="Preferred language (en, fr, de, lb)" - ) - - @field_validator("email") - @classmethod - def email_lowercase(cls, v: str) -> str: - """Convert email to lowercase.""" - return v.lower() - - @field_validator("password") - @classmethod - def password_strength(cls, v: str) -> str: - """Validate password strength.""" - if len(v) < 8: - raise ValueError("Password must be at least 8 characters") - if not any(char.isdigit() for char in v): - raise ValueError("Password must contain at least one digit") - if not any(char.isalpha() for char in v): - raise ValueError("Password must contain at least one letter") - return v - - -class CustomerUpdate(BaseModel): - """Schema for updating customer profile.""" - - email: EmailStr | None = None - first_name: str | None = Field(None, min_length=1, max_length=100) - last_name: str | None = Field(None, min_length=1, max_length=100) - phone: str | None = Field(None, max_length=50) - marketing_consent: bool | None = None - preferred_language: str | None = Field( - None, description="Preferred language (en, fr, de, lb)" - ) - - @field_validator("email") - @classmethod - def email_lowercase(cls, v: str | None) -> str | None: - """Convert email to lowercase.""" - return v.lower() if v else None - - -class CustomerPasswordChange(BaseModel): - """Schema for customer password change.""" - - current_password: str = Field(..., description="Current password") - new_password: str = Field( - ..., min_length=8, description="New password (minimum 8 characters)" - ) - confirm_password: str = Field(..., description="Confirm new password") - - @field_validator("new_password") - @classmethod - def password_strength(cls, v: str) -> str: - """Validate password strength.""" - if len(v) < 8: - raise ValueError("Password must be at least 8 characters") - if not any(char.isdigit() for char in v): - raise ValueError("Password must contain at least one digit") - if not any(char.isalpha() for char in v): - raise ValueError("Password must contain at least one letter") - return v - - -# ============================================================================ -# Customer Response -# ============================================================================ - - -class CustomerResponse(BaseModel): - """Schema for customer response (excludes password).""" - - id: int - vendor_id: int - email: str - first_name: str | None - last_name: str | None - phone: str | None - customer_number: str - marketing_consent: bool - preferred_language: str | None - last_order_date: datetime | None - total_orders: int - total_spent: Decimal - is_active: bool - created_at: datetime - updated_at: datetime - - model_config = {"from_attributes": True} - - @property - def full_name(self) -> str: - """Get customer full name.""" - if self.first_name and self.last_name: - return f"{self.first_name} {self.last_name}" - return self.email - - -class CustomerListResponse(BaseModel): - """Schema for paginated customer list.""" - - customers: list[CustomerResponse] - total: int - page: int - per_page: int - total_pages: int - - -# ============================================================================ -# Customer Address -# ============================================================================ - - -class CustomerAddressCreate(BaseModel): - """Schema for creating customer address.""" - - address_type: str = Field(..., pattern="^(billing|shipping)$") - first_name: str = Field(..., min_length=1, max_length=100) - last_name: str = Field(..., min_length=1, max_length=100) - company: str | None = Field(None, max_length=200) - address_line_1: str = Field(..., min_length=1, max_length=255) - address_line_2: str | None = Field(None, max_length=255) - city: str = Field(..., min_length=1, max_length=100) - postal_code: str = Field(..., min_length=1, max_length=20) - country_name: str = Field(..., min_length=2, max_length=100) - country_iso: str = Field(..., min_length=2, max_length=5) - is_default: bool = Field(default=False) - - -class CustomerAddressUpdate(BaseModel): - """Schema for updating customer address.""" - - address_type: str | None = Field(None, pattern="^(billing|shipping)$") - first_name: str | None = Field(None, min_length=1, max_length=100) - last_name: str | None = Field(None, min_length=1, max_length=100) - company: str | None = Field(None, max_length=200) - address_line_1: str | None = Field(None, min_length=1, max_length=255) - address_line_2: str | None = Field(None, max_length=255) - city: str | None = Field(None, min_length=1, max_length=100) - postal_code: str | None = Field(None, min_length=1, max_length=20) - country_name: str | None = Field(None, min_length=2, max_length=100) - country_iso: str | None = Field(None, min_length=2, max_length=5) - is_default: bool | None = None - - -class CustomerAddressResponse(BaseModel): - """Schema for customer address response.""" - - id: int - vendor_id: int - customer_id: int - address_type: str - first_name: str - last_name: str - company: str | None - address_line_1: str - address_line_2: str | None - city: str - postal_code: str - country_name: str - country_iso: str - is_default: bool - created_at: datetime - updated_at: datetime - - model_config = {"from_attributes": True} - - -class CustomerAddressListResponse(BaseModel): - """Schema for customer address list response.""" - - addresses: list[CustomerAddressResponse] - total: int - - -# ============================================================================ -# Customer Preferences -# ============================================================================ - - -class CustomerPreferencesUpdate(BaseModel): - """Schema for updating customer preferences.""" - - marketing_consent: bool | None = None - preferred_language: str | None = Field( - None, description="Preferred language (en, fr, de, lb)" - ) - currency: str | None = Field(None, max_length=3) - notification_preferences: dict[str, bool] | None = None - - -# ============================================================================ -# Vendor Customer Management Response Schemas -# ============================================================================ - - -class CustomerMessageResponse(BaseModel): - """Simple message response for customer operations.""" - - message: str - - -class VendorCustomerListResponse(BaseModel): - """Schema for vendor customer list with skip/limit pagination.""" - - customers: list[CustomerResponse] = [] - total: int = 0 - skip: int = 0 - limit: int = 100 - message: str | None = None - - -class CustomerDetailResponse(BaseModel): - """Detailed customer response for vendor management.""" - - id: int | None = None - vendor_id: int | None = None - email: str | None = None - first_name: str | None = None - last_name: str | None = None - phone: str | None = None - customer_number: str | None = None - marketing_consent: bool | None = None - preferred_language: str | None = None - last_order_date: datetime | None = None - total_orders: int | None = None - total_spent: Decimal | None = None - is_active: bool | None = None - created_at: datetime | None = None - updated_at: datetime | None = None - message: str | None = None - - model_config = {"from_attributes": True} - - -class CustomerOrderInfo(BaseModel): - """Basic order info for customer order history.""" - - id: int - order_number: str - status: str - total: Decimal - created_at: datetime - - -class CustomerOrdersResponse(BaseModel): - """Response for customer order history.""" - - orders: list[CustomerOrderInfo] = [] - total: int = 0 - message: str | None = None - - -class CustomerStatisticsResponse(BaseModel): - """Response for customer statistics.""" - - total: int = 0 - active: int = 0 - inactive: int = 0 - with_orders: int = 0 - total_spent: float = 0.0 - total_orders: int = 0 - avg_order_value: float = 0.0 - - -# ============================================================================ -# Admin Customer Management Response Schemas -# ============================================================================ - - -class AdminCustomerItem(BaseModel): - """Admin customer list item with vendor info.""" - - id: int - vendor_id: int - email: str - first_name: str | None = None - last_name: str | None = None - phone: str | None = None - customer_number: str - marketing_consent: bool = False - preferred_language: str | None = None - last_order_date: datetime | None = None - total_orders: int = 0 - total_spent: float = 0.0 - is_active: bool = True - created_at: datetime - updated_at: datetime - vendor_name: str | None = None - vendor_code: str | None = None - - model_config = {"from_attributes": True} - - -class CustomerListResponse(BaseModel): - """Admin paginated customer list with skip/limit.""" - - customers: list[AdminCustomerItem] = [] - total: int = 0 - skip: int = 0 - limit: int = 20 - - -class CustomerDetailResponse(AdminCustomerItem): - """Detailed customer response for admin.""" - - pass +from app.modules.customers.schemas.customer import ( + # Registration & Authentication + CustomerRegister, + CustomerUpdate, + CustomerPasswordChange, + # Customer Response + CustomerResponse, + CustomerListResponse, + # Address + CustomerAddressCreate, + CustomerAddressUpdate, + CustomerAddressResponse, + CustomerAddressListResponse, + # Preferences + CustomerPreferencesUpdate, + # Vendor Management + CustomerMessageResponse, + VendorCustomerListResponse, + CustomerDetailResponse, + CustomerOrderInfo, + CustomerOrdersResponse, + CustomerStatisticsResponse, + # Admin Management + AdminCustomerItem, + AdminCustomerListResponse, + AdminCustomerDetailResponse, +) + +__all__ = [ + # Registration & Authentication + "CustomerRegister", + "CustomerUpdate", + "CustomerPasswordChange", + # Customer Response + "CustomerResponse", + "CustomerListResponse", + # Address + "CustomerAddressCreate", + "CustomerAddressUpdate", + "CustomerAddressResponse", + "CustomerAddressListResponse", + # Preferences + "CustomerPreferencesUpdate", + # Vendor Management + "CustomerMessageResponse", + "VendorCustomerListResponse", + "CustomerDetailResponse", + "CustomerOrderInfo", + "CustomerOrdersResponse", + "CustomerStatisticsResponse", + # Admin Management + "AdminCustomerItem", + "AdminCustomerListResponse", + "AdminCustomerDetailResponse", +] diff --git a/models/schema/inventory.py b/models/schema/inventory.py index 007c64d5..78c46c64 100644 --- a/models/schema/inventory.py +++ b/models/schema/inventory.py @@ -1,294 +1,85 @@ # models/schema/inventory.py -from datetime import datetime - -from pydantic import BaseModel, ConfigDict, Field - - -class InventoryBase(BaseModel): - product_id: int = Field(..., description="Product ID in vendor catalog") - location: str = Field(..., description="Storage location") - - -class InventoryCreate(InventoryBase): - """Set exact inventory quantity (replaces existing).""" - - quantity: int = Field(..., description="Exact inventory quantity", ge=0) - - -class InventoryAdjust(InventoryBase): - """Add or remove inventory quantity.""" - - quantity: int = Field( - ..., description="Quantity to add (positive) or remove (negative)" - ) - - -class InventoryUpdate(BaseModel): - """Update inventory fields.""" - - quantity: int | None = Field(None, ge=0) - reserved_quantity: int | None = Field(None, ge=0) - location: str | None = None - - -class InventoryReserve(BaseModel): - """Reserve inventory for orders.""" - - product_id: int - location: str - quantity: int = Field(..., gt=0) - - -class InventoryResponse(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: int - product_id: int - vendor_id: int - location: str - quantity: int - reserved_quantity: int - gtin: str | None - created_at: datetime - updated_at: datetime - - @property - def available_quantity(self): - return max(0, self.quantity - self.reserved_quantity) - - -class InventoryLocationResponse(BaseModel): - location: str - quantity: int - reserved_quantity: int - available_quantity: int - - -class ProductInventorySummary(BaseModel): - """Inventory summary for a product.""" - - product_id: int - vendor_id: int - product_sku: str | None - product_title: str - total_quantity: int - total_reserved: int - total_available: int - locations: list[InventoryLocationResponse] - - -class InventoryListResponse(BaseModel): - inventories: list[InventoryResponse] - total: int - skip: int - limit: int - - -class InventoryMessageResponse(BaseModel): - """Simple message response for inventory operations.""" - - message: str - - -class InventorySummaryResponse(BaseModel): - """Inventory summary response for marketplace product service.""" - - gtin: str - total_quantity: int - locations: list[InventoryLocationResponse] - - -# ============================================================================ -# Admin Inventory Schemas -# ============================================================================ - - -class AdminInventoryCreate(BaseModel): - """Admin version of inventory create - requires explicit vendor_id.""" - - vendor_id: int = Field(..., description="Target vendor ID") - product_id: int = Field(..., description="Product ID in vendor catalog") - location: str = Field(..., description="Storage location") - quantity: int = Field(..., description="Exact inventory quantity", ge=0) - - -class AdminInventoryAdjust(BaseModel): - """Admin version of inventory adjust - requires explicit vendor_id.""" - - vendor_id: int = Field(..., description="Target vendor ID") - product_id: int = Field(..., description="Product ID in vendor catalog") - location: str = Field(..., description="Storage location") - quantity: int = Field( - ..., description="Quantity to add (positive) or remove (negative)" - ) - reason: str | None = Field(None, description="Reason for adjustment") - - -class AdminInventoryItem(BaseModel): - """Inventory item with vendor info for admin list view.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - product_id: int - vendor_id: int - vendor_name: str | None = None - vendor_code: str | None = None - product_title: str | None = None - product_sku: str | None = None - location: str - quantity: int - reserved_quantity: int - available_quantity: int - gtin: str | None = None - created_at: datetime - updated_at: datetime - - -class AdminInventoryListResponse(BaseModel): - """Cross-vendor inventory list for admin.""" - - inventories: list[AdminInventoryItem] - total: int - skip: int - limit: int - vendor_filter: int | None = None - location_filter: str | None = None - - -class AdminInventoryStats(BaseModel): - """Inventory statistics for admin dashboard.""" - - total_entries: int - total_quantity: int - total_reserved: int - total_available: int - low_stock_count: int - vendors_with_inventory: int - unique_locations: int - - -class AdminLowStockItem(BaseModel): - """Low stock item for admin alerts.""" - - id: int - product_id: int - vendor_id: int - vendor_name: str | None = None - product_title: str | None = None - location: str - quantity: int - reserved_quantity: int - available_quantity: int - - -class AdminVendorWithInventory(BaseModel): - """Vendor with inventory entries.""" - - id: int - name: str - vendor_code: str - - -class AdminVendorsWithInventoryResponse(BaseModel): - """Response for vendors with inventory list.""" - - vendors: list[AdminVendorWithInventory] - - -class AdminInventoryLocationsResponse(BaseModel): - """Response for unique inventory locations.""" - - locations: list[str] - - -# ============================================================================ -# Inventory Transaction Schemas -# ============================================================================ - - -class InventoryTransactionResponse(BaseModel): - """Single inventory transaction record.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - vendor_id: int - product_id: int - inventory_id: int | None = None - transaction_type: str - quantity_change: int - quantity_after: int - reserved_after: int - location: str | None = None - warehouse: str | None = None - order_id: int | None = None - order_number: str | None = None - reason: str | None = None - created_by: str | None = None - created_at: datetime - - -class InventoryTransactionWithProduct(InventoryTransactionResponse): - """Transaction with product details for list views.""" - - product_title: str | None = None - product_sku: str | None = None - - -class InventoryTransactionListResponse(BaseModel): - """Paginated list of inventory transactions.""" - - transactions: list[InventoryTransactionWithProduct] - total: int - skip: int - limit: int - - -class ProductTransactionHistoryResponse(BaseModel): - """Transaction history for a specific product.""" - - product_id: int - product_title: str | None = None - product_sku: str | None = None - current_quantity: int - current_reserved: int - transactions: list[InventoryTransactionResponse] - total: int - - -class OrderTransactionHistoryResponse(BaseModel): - """Transaction history for a specific order.""" - - order_id: int - order_number: str - transactions: list[InventoryTransactionWithProduct] - - -# ============================================================================ -# Admin Inventory Transaction Schemas -# ============================================================================ - - -class AdminInventoryTransactionItem(InventoryTransactionWithProduct): - """Transaction with vendor details for admin views.""" - - vendor_name: str | None = None - vendor_code: str | None = None - - -class AdminInventoryTransactionListResponse(BaseModel): - """Paginated list of transactions for admin.""" - - transactions: list[AdminInventoryTransactionItem] - total: int - skip: int - limit: int - - -class AdminTransactionStatsResponse(BaseModel): - """Transaction statistics for admin dashboard.""" - - total_transactions: int - transactions_today: int - by_type: dict[str, int] +""" +LEGACY LOCATION - Re-exports from module for backwards compatibility. + +The canonical implementation is now in: + app/modules/inventory/schemas/inventory.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.inventory.schemas import InventoryCreate, InventoryResponse +""" + +from app.modules.inventory.schemas.inventory import ( + # Base schemas + InventoryBase, + InventoryCreate, + InventoryAdjust, + InventoryUpdate, + InventoryReserve, + # Response schemas + InventoryResponse, + InventoryLocationResponse, + ProductInventorySummary, + InventoryListResponse, + InventoryMessageResponse, + InventorySummaryResponse, + # Admin schemas + AdminInventoryCreate, + AdminInventoryAdjust, + AdminInventoryItem, + AdminInventoryListResponse, + AdminInventoryStats, + AdminLowStockItem, + AdminVendorWithInventory, + AdminVendorsWithInventoryResponse, + AdminInventoryLocationsResponse, + # Transaction schemas + InventoryTransactionResponse, + InventoryTransactionWithProduct, + InventoryTransactionListResponse, + ProductTransactionHistoryResponse, + OrderTransactionHistoryResponse, + # Admin transaction schemas + AdminInventoryTransactionItem, + AdminInventoryTransactionListResponse, + AdminTransactionStatsResponse, +) + +__all__ = [ + # Base schemas + "InventoryBase", + "InventoryCreate", + "InventoryAdjust", + "InventoryUpdate", + "InventoryReserve", + # Response schemas + "InventoryResponse", + "InventoryLocationResponse", + "ProductInventorySummary", + "InventoryListResponse", + "InventoryMessageResponse", + "InventorySummaryResponse", + # Admin schemas + "AdminInventoryCreate", + "AdminInventoryAdjust", + "AdminInventoryItem", + "AdminInventoryListResponse", + "AdminInventoryStats", + "AdminLowStockItem", + "AdminVendorWithInventory", + "AdminVendorsWithInventoryResponse", + "AdminInventoryLocationsResponse", + # Transaction schemas + "InventoryTransactionResponse", + "InventoryTransactionWithProduct", + "InventoryTransactionListResponse", + "ProductTransactionHistoryResponse", + "OrderTransactionHistoryResponse", + # Admin transaction schemas + "AdminInventoryTransactionItem", + "AdminInventoryTransactionListResponse", + "AdminTransactionStatsResponse", +] diff --git a/models/schema/invoice.py b/models/schema/invoice.py index 768fc49d..27de27b3 100644 --- a/models/schema/invoice.py +++ b/models/schema/invoice.py @@ -1,310 +1,61 @@ # models/schema/invoice.py """ -Pydantic schemas for invoice operations. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Supports invoice settings management and invoice generation. +The canonical implementation is now in: + app/modules/orders/schemas/invoice.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.orders.schemas import invoice """ -from datetime import datetime -from decimal import Decimal - -from pydantic import BaseModel, ConfigDict, Field - -# ============================================================================ -# Invoice Settings Schemas -# ============================================================================ - - -class VendorInvoiceSettingsCreate(BaseModel): - """Schema for creating vendor invoice settings.""" - - company_name: str = Field(..., min_length=1, max_length=255) - company_address: str | None = Field(None, max_length=255) - company_city: str | None = Field(None, max_length=100) - company_postal_code: str | None = Field(None, max_length=20) - company_country: str = Field(default="LU", min_length=2, max_length=2) - - vat_number: str | None = Field(None, max_length=50) - is_vat_registered: bool = True - - is_oss_registered: bool = False - oss_registration_country: str | None = Field(None, min_length=2, max_length=2) - - invoice_prefix: str = Field(default="INV", max_length=20) - invoice_number_padding: int = Field(default=5, ge=1, le=10) - - payment_terms: str | None = None - bank_name: str | None = Field(None, max_length=255) - bank_iban: str | None = Field(None, max_length=50) - bank_bic: str | None = Field(None, max_length=20) - - footer_text: str | None = None - default_vat_rate: Decimal = Field(default=Decimal("17.00"), ge=0, le=100) - - -class VendorInvoiceSettingsUpdate(BaseModel): - """Schema for updating vendor invoice settings.""" - - company_name: str | None = Field(None, min_length=1, max_length=255) - company_address: str | None = Field(None, max_length=255) - company_city: str | None = Field(None, max_length=100) - company_postal_code: str | None = Field(None, max_length=20) - company_country: str | None = Field(None, min_length=2, max_length=2) - - vat_number: str | None = None - is_vat_registered: bool | None = None - - is_oss_registered: bool | None = None - oss_registration_country: str | None = None - - invoice_prefix: str | None = Field(None, max_length=20) - invoice_number_padding: int | None = Field(None, ge=1, le=10) - - payment_terms: str | None = None - bank_name: str | None = Field(None, max_length=255) - bank_iban: str | None = Field(None, max_length=50) - bank_bic: str | None = Field(None, max_length=20) - - footer_text: str | None = None - default_vat_rate: Decimal | None = Field(None, ge=0, le=100) - - -class VendorInvoiceSettingsResponse(BaseModel): - """Schema for vendor invoice settings response.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - vendor_id: int - - company_name: str - company_address: str | None - company_city: str | None - company_postal_code: str | None - company_country: str - - vat_number: str | None - is_vat_registered: bool - - is_oss_registered: bool - oss_registration_country: str | None - - invoice_prefix: str - invoice_next_number: int - invoice_number_padding: int - - payment_terms: str | None - bank_name: str | None - bank_iban: str | None - bank_bic: str | None - - footer_text: str | None - default_vat_rate: Decimal - - created_at: datetime - updated_at: datetime - - -# ============================================================================ -# Invoice Line Item Schemas -# ============================================================================ - - -class InvoiceLineItem(BaseModel): - """Schema for invoice line item.""" - - description: str - quantity: int = Field(..., ge=1) - unit_price_cents: int - total_cents: int - sku: str | None = None - ean: str | None = None - - -class InvoiceLineItemResponse(BaseModel): - """Schema for invoice line item in response.""" - - description: str - quantity: int - unit_price_cents: int - total_cents: int - sku: str | None = None - ean: str | None = None - - @property - def unit_price(self) -> float: - return self.unit_price_cents / 100 - - @property - def total(self) -> float: - return self.total_cents / 100 - - -# ============================================================================ -# Invoice Address Schemas -# ============================================================================ - - -class InvoiceSellerDetails(BaseModel): - """Seller details for invoice.""" - - company_name: str - address: str | None = None - city: str | None = None - postal_code: str | None = None - country: str - vat_number: str | None = None - - -class InvoiceBuyerDetails(BaseModel): - """Buyer details for invoice.""" - - name: str - email: str | None = None - address: str | None = None - city: str | None = None - postal_code: str | None = None - country: str - vat_number: str | None = None # For B2B - - -# ============================================================================ -# Invoice Schemas -# ============================================================================ - - -class InvoiceCreate(BaseModel): - """Schema for creating an invoice from an order.""" - - order_id: int - notes: str | None = None - - -class InvoiceManualCreate(BaseModel): - """Schema for creating a manual invoice (without order).""" - - buyer_details: InvoiceBuyerDetails - line_items: list[InvoiceLineItem] - notes: str | None = None - payment_terms: str | None = None - - -class InvoiceResponse(BaseModel): - """Schema for invoice response.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - vendor_id: int - order_id: int | None - - invoice_number: str - invoice_date: datetime - status: str - - seller_details: dict - buyer_details: dict - line_items: list[dict] - - vat_regime: str - destination_country: str | None - vat_rate: Decimal - vat_rate_label: str | None - - currency: str - subtotal_cents: int - vat_amount_cents: int - total_cents: int - - payment_terms: str | None - bank_details: dict | None - footer_text: str | None - - pdf_generated_at: datetime | None - pdf_path: str | None - - notes: str | None - - created_at: datetime - updated_at: datetime - - @property - def subtotal(self) -> float: - return self.subtotal_cents / 100 - - @property - def vat_amount(self) -> float: - return self.vat_amount_cents / 100 - - @property - def total(self) -> float: - return self.total_cents / 100 - - -class InvoiceListResponse(BaseModel): - """Schema for invoice list response (summary).""" - - model_config = ConfigDict(from_attributes=True) - - id: int - invoice_number: str - invoice_date: datetime - status: str - currency: str - total_cents: int - order_id: int | None - - # Buyer name for display - buyer_name: str | None = None - - @property - def total(self) -> float: - return self.total_cents / 100 - - -class InvoiceStatusUpdate(BaseModel): - """Schema for updating invoice status.""" - - status: str = Field(..., pattern="^(draft|issued|paid|cancelled)$") - - -# ============================================================================ -# Paginated Response -# ============================================================================ - - -class InvoiceListPaginatedResponse(BaseModel): - """Paginated invoice list response.""" - - items: list[InvoiceListResponse] - total: int - page: int - per_page: int - pages: int - - -# ============================================================================ -# PDF Response -# ============================================================================ - - -class InvoicePDFGeneratedResponse(BaseModel): - """Response for PDF generation.""" - - pdf_path: str - message: str = "PDF generated successfully" - - -class InvoiceStatsResponse(BaseModel): - """Invoice statistics response.""" - - total_invoices: int - total_revenue_cents: int - draft_count: int - issued_count: int - paid_count: int - cancelled_count: int - - @property - def total_revenue(self) -> float: - return self.total_revenue_cents / 100 +from app.modules.orders.schemas.invoice import ( + # Invoice settings schemas + VendorInvoiceSettingsCreate, + VendorInvoiceSettingsUpdate, + VendorInvoiceSettingsResponse, + # Line item schemas + InvoiceLineItem, + InvoiceLineItemResponse, + # Address schemas + InvoiceSellerDetails, + InvoiceBuyerDetails, + # Invoice CRUD schemas + InvoiceCreate, + InvoiceManualCreate, + InvoiceResponse, + InvoiceListResponse, + InvoiceStatusUpdate, + # Pagination + InvoiceListPaginatedResponse, + # PDF + InvoicePDFGeneratedResponse, + InvoiceStatsResponse, +) + +__all__ = [ + # Invoice settings schemas + "VendorInvoiceSettingsCreate", + "VendorInvoiceSettingsUpdate", + "VendorInvoiceSettingsResponse", + # Line item schemas + "InvoiceLineItem", + "InvoiceLineItemResponse", + # Address schemas + "InvoiceSellerDetails", + "InvoiceBuyerDetails", + # Invoice CRUD schemas + "InvoiceCreate", + "InvoiceManualCreate", + "InvoiceResponse", + "InvoiceListResponse", + "InvoiceStatusUpdate", + # Pagination + "InvoiceListPaginatedResponse", + # PDF + "InvoicePDFGeneratedResponse", + "InvoiceStatsResponse", +] diff --git a/models/schema/message.py b/models/schema/message.py index 7f0682e2..4db01eab 100644 --- a/models/schema/message.py +++ b/models/schema/message.py @@ -1,308 +1,83 @@ # models/schema/message.py """ -Pydantic schemas for the messaging system. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Supports three communication channels: -- Admin <-> Vendor -- Vendor <-> Customer -- Admin <-> Customer +The canonical implementation is now in: + app/modules/messaging/schemas/message.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.messaging.schemas import message """ -from datetime import datetime - -from pydantic import BaseModel, ConfigDict, Field - -from models.database.message import ConversationType, ParticipantType - - -# ============================================================================ -# Attachment Schemas -# ============================================================================ - - -class AttachmentResponse(BaseModel): - """Schema for message attachment in responses.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - filename: str - original_filename: str - file_size: int - mime_type: str - is_image: bool - image_width: int | None = None - image_height: int | None = None - download_url: str | None = None - thumbnail_url: str | None = None - - @property - def file_size_display(self) -> str: - """Human-readable file size.""" - if self.file_size < 1024: - return f"{self.file_size} B" - elif self.file_size < 1024 * 1024: - return f"{self.file_size / 1024:.1f} KB" - else: - return f"{self.file_size / 1024 / 1024:.1f} MB" - - -# ============================================================================ -# Message Schemas -# ============================================================================ - - -class MessageCreate(BaseModel): - """Schema for sending a new message.""" - - content: str = Field(..., min_length=1, max_length=10000) - - -class MessageResponse(BaseModel): - """Schema for a single message in responses.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - conversation_id: int - sender_type: ParticipantType - sender_id: int - content: str - is_system_message: bool - is_deleted: bool - created_at: datetime - - # Enriched sender info (populated by API) - sender_name: str | None = None - sender_email: str | None = None - - # Attachments - attachments: list[AttachmentResponse] = [] - - -# ============================================================================ -# Participant Schemas -# ============================================================================ - - -class ParticipantInfo(BaseModel): - """Schema for participant information.""" - - id: int - type: ParticipantType - name: str - email: str | None = None - avatar_url: str | None = None - - -class ParticipantResponse(BaseModel): - """Schema for conversation participant in responses.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - participant_type: ParticipantType - participant_id: int - unread_count: int - last_read_at: datetime | None - email_notifications: bool - muted: bool - - # Enriched info (populated by API) - participant_info: ParticipantInfo | None = None - - -# ============================================================================ -# Conversation Schemas -# ============================================================================ - - -class ConversationCreate(BaseModel): - """Schema for creating a new conversation.""" - - conversation_type: ConversationType - subject: str = Field(..., min_length=1, max_length=500) - recipient_type: ParticipantType - recipient_id: int - vendor_id: int | None = None - initial_message: str | None = Field(None, min_length=1, max_length=10000) - - -class ConversationSummary(BaseModel): - """Schema for conversation in list views.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - conversation_type: ConversationType - subject: str - vendor_id: int | None = None - is_closed: bool - closed_at: datetime | None - last_message_at: datetime | None - message_count: int - created_at: datetime - - # Unread count for current user (from participant) - unread_count: int = 0 - - # Other participant info (enriched by API) - other_participant: ParticipantInfo | None = None - - # Last message preview - last_message_preview: str | None = None - - -class ConversationDetailResponse(BaseModel): - """Schema for full conversation detail with messages.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - conversation_type: ConversationType - subject: str - vendor_id: int | None = None - is_closed: bool - closed_at: datetime | None - closed_by_type: ParticipantType | None = None - closed_by_id: int | None = None - last_message_at: datetime | None - message_count: int - created_at: datetime - updated_at: datetime - - # Participants with enriched info - participants: list[ParticipantResponse] = [] - - # Messages ordered by created_at - messages: list[MessageResponse] = [] - - # Current user's unread count - unread_count: int = 0 - - # Vendor info if applicable - vendor_name: str | None = None - - -class ConversationListResponse(BaseModel): - """Schema for paginated conversation list.""" - - conversations: list[ConversationSummary] - total: int - total_unread: int - skip: int - limit: int - - -# ============================================================================ -# Unread Count Schemas -# ============================================================================ - - -class UnreadCountResponse(BaseModel): - """Schema for unread message count (for header badge).""" - - total_unread: int - - -# ============================================================================ -# Notification Preferences Schemas -# ============================================================================ - - -class NotificationPreferencesUpdate(BaseModel): - """Schema for updating notification preferences.""" - - email_notifications: bool | None = None - muted: bool | None = None - - -# ============================================================================ -# Conversation Action Schemas -# ============================================================================ - - -class CloseConversationResponse(BaseModel): - """Response after closing a conversation.""" - - success: bool - message: str - conversation_id: int - - -class ReopenConversationResponse(BaseModel): - """Response after reopening a conversation.""" - - success: bool - message: str - conversation_id: int - - -class MarkReadResponse(BaseModel): - """Response after marking conversation as read.""" - - success: bool - conversation_id: int - unread_count: int - - -# ============================================================================ -# Recipient Selection Schemas (for compose modal) -# ============================================================================ - - -class RecipientOption(BaseModel): - """Schema for a selectable recipient in compose modal.""" - - id: int - type: ParticipantType - name: str - email: str | None = None - vendor_id: int | None = None # For vendor users - vendor_name: str | None = None - - -class RecipientListResponse(BaseModel): - """Schema for list of available recipients.""" - - recipients: list[RecipientOption] - total: int - - -# ============================================================================ -# Admin-specific Schemas -# ============================================================================ - - -class AdminConversationSummary(ConversationSummary): - """Extended conversation summary with vendor info for admin views.""" - - vendor_name: str | None = None - vendor_code: str | None = None - - -class AdminConversationListResponse(BaseModel): - """Schema for admin conversation list with vendor info.""" - - conversations: list[AdminConversationSummary] - total: int - total_unread: int - skip: int - limit: int - - -class AdminMessageStats(BaseModel): - """Messaging statistics for admin dashboard.""" - - total_conversations: int = 0 - open_conversations: int = 0 - closed_conversations: int = 0 - total_messages: int = 0 - - # By type - admin_vendor_conversations: int = 0 - vendor_customer_conversations: int = 0 - admin_customer_conversations: int = 0 - - # Unread - unread_admin: int = 0 +from app.modules.messaging.schemas.message import ( + # Attachment schemas + AttachmentResponse, + # Message schemas + MessageCreate, + MessageResponse, + # Participant schemas + ParticipantInfo, + ParticipantResponse, + # Conversation schemas + ConversationCreate, + ConversationSummary, + ConversationDetailResponse, + ConversationListResponse, + ConversationResponse, + # Unread count + UnreadCountResponse, + # Notification preferences + NotificationPreferencesUpdate, + # Conversation actions + CloseConversationResponse, + ReopenConversationResponse, + MarkReadResponse, + # Recipient selection + RecipientOption, + RecipientListResponse, + # Admin schemas + AdminConversationSummary, + AdminConversationListResponse, + AdminMessageStats, +) + +# Re-export enums from models for backward compatibility +from app.modules.messaging.models.message import ConversationType, ParticipantType + +__all__ = [ + # Attachment schemas + "AttachmentResponse", + # Message schemas + "MessageCreate", + "MessageResponse", + # Participant schemas + "ParticipantInfo", + "ParticipantResponse", + # Conversation schemas + "ConversationCreate", + "ConversationSummary", + "ConversationDetailResponse", + "ConversationListResponse", + "ConversationResponse", + # Unread count + "UnreadCountResponse", + # Notification preferences + "NotificationPreferencesUpdate", + # Conversation actions + "CloseConversationResponse", + "ReopenConversationResponse", + "MarkReadResponse", + # Recipient selection + "RecipientOption", + "RecipientListResponse", + # Admin schemas + "AdminConversationSummary", + "AdminConversationListResponse", + "AdminMessageStats", + # Enums + "ConversationType", + "ParticipantType", +] diff --git a/models/schema/notification.py b/models/schema/notification.py index 8a920006..0b118bfd 100644 --- a/models/schema/notification.py +++ b/models/schema/notification.py @@ -1,152 +1,53 @@ # models/schema/notification.py """ -Notification Pydantic schemas for API validation and responses. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -This module provides schemas for: -- Vendor notifications (list, read, delete) -- Notification settings management -- Notification email templates -- Unread counts and statistics +The canonical implementation is now in: + app/modules/messaging/schemas/notification.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.messaging.schemas import notification """ -from datetime import datetime -from typing import Any +from app.modules.messaging.schemas.notification import ( + # Response schemas + MessageResponse, + UnreadCountResponse, + # Notification schemas + NotificationResponse, + NotificationListResponse, + # Settings schemas + NotificationSettingsResponse, + NotificationSettingsUpdate, + # Template schemas + NotificationTemplateResponse, + NotificationTemplateListResponse, + NotificationTemplateUpdate, + # Test notification + TestNotificationRequest, + # Alert statistics + AlertStatisticsResponse, +) -from pydantic import BaseModel, Field - -# ============================================================================ -# SHARED RESPONSE SCHEMAS -# ============================================================================ - - -class MessageResponse(BaseModel): - """Generic message response for simple operations.""" - - message: str - - -class UnreadCountResponse(BaseModel): - """Response for unread notification count.""" - - unread_count: int - message: str | None = None - - -# ============================================================================ -# NOTIFICATION SCHEMAS -# ============================================================================ - - -class NotificationResponse(BaseModel): - """Single notification response.""" - - id: int - type: str - title: str - message: str - is_read: bool - read_at: datetime | None = None - priority: str = "normal" - action_url: str | None = None - metadata: dict[str, Any] | None = None - created_at: datetime - - model_config = {"from_attributes": True} - - -class NotificationListResponse(BaseModel): - """Paginated list of notifications.""" - - notifications: list[NotificationResponse] = [] - total: int = 0 - unread_count: int = 0 - message: str | None = None - - -# ============================================================================ -# NOTIFICATION SETTINGS SCHEMAS -# ============================================================================ - - -class NotificationSettingsResponse(BaseModel): - """Notification preferences response.""" - - email_notifications: bool = True - in_app_notifications: bool = True - notification_types: dict[str, bool] = Field(default_factory=dict) - message: str | None = None - - -class NotificationSettingsUpdate(BaseModel): - """Request model for updating notification settings.""" - - email_notifications: bool | None = None - in_app_notifications: bool | None = None - notification_types: dict[str, bool] | None = None - - -# ============================================================================ -# NOTIFICATION TEMPLATE SCHEMAS -# ============================================================================ - - -class NotificationTemplateResponse(BaseModel): - """Single notification template response.""" - - id: int - name: str - type: str - subject: str - body_html: str | None = None - body_text: str | None = None - variables: list[str] = Field(default_factory=list) - is_active: bool = True - created_at: datetime - updated_at: datetime | None = None - - model_config = {"from_attributes": True} - - -class NotificationTemplateListResponse(BaseModel): - """List of notification templates.""" - - templates: list[NotificationTemplateResponse] = [] - message: str | None = None - - -class NotificationTemplateUpdate(BaseModel): - """Request model for updating notification template.""" - - subject: str | None = Field(None, max_length=200) - body_html: str | None = None - body_text: str | None = None - is_active: bool | None = None - - -# ============================================================================ -# TEST NOTIFICATION SCHEMA -# ============================================================================ - - -class TestNotificationRequest(BaseModel): - """Request model for sending test notification.""" - - template_id: int | None = Field(None, description="Template to use") - email: str | None = Field(None, description="Override recipient email") - notification_type: str = Field( - default="test", description="Type of notification to send" - ) - - -# ============================================================================ -# ADMIN ALERT STATISTICS SCHEMA -# ============================================================================ - - -class AlertStatisticsResponse(BaseModel): - """Response for alert statistics.""" - - total_alerts: int = 0 - active_alerts: int = 0 - critical_alerts: int = 0 - resolved_today: int = 0 +__all__ = [ + # Response schemas + "MessageResponse", + "UnreadCountResponse", + # Notification schemas + "NotificationResponse", + "NotificationListResponse", + # Settings schemas + "NotificationSettingsResponse", + "NotificationSettingsUpdate", + # Template schemas + "NotificationTemplateResponse", + "NotificationTemplateListResponse", + "NotificationTemplateUpdate", + # Test notification + "TestNotificationRequest", + # Alert statistics + "AlertStatisticsResponse", +] diff --git a/models/schema/order.py b/models/schema/order.py index 59ecfd9b..ca299a07 100644 --- a/models/schema/order.py +++ b/models/schema/order.py @@ -1,584 +1,89 @@ # models/schema/order.py """ -Pydantic schemas for unified order operations. +LEGACY LOCATION - Re-exports from module for backwards compatibility. -Supports both direct orders and marketplace orders (Letzshop, etc.) -with snapshotted customer and address data. +The canonical implementation is now in: + app/modules/orders/schemas/order.py + +This file exists to maintain backwards compatibility with code that +imports from the old location. All new code should import directly +from the module: + + from app.modules.orders.schemas import order """ -from datetime import datetime - -from pydantic import BaseModel, ConfigDict, Field - -# ============================================================================ -# Address Snapshot Schemas -# ============================================================================ - - -class AddressSnapshot(BaseModel): - """Address snapshot for order creation.""" - - first_name: str = Field(..., min_length=1, max_length=100) - last_name: str = Field(..., min_length=1, max_length=100) - company: str | None = Field(None, max_length=200) - address_line_1: str = Field(..., min_length=1, max_length=255) - address_line_2: str | None = Field(None, max_length=255) - city: str = Field(..., min_length=1, max_length=100) - postal_code: str = Field(..., min_length=1, max_length=20) - country_iso: str = Field(..., min_length=2, max_length=5) - - -class AddressSnapshotResponse(BaseModel): - """Address snapshot in order response.""" - - first_name: str - last_name: str - company: str | None - address_line_1: str - address_line_2: str | None - city: str - postal_code: str - country_iso: str - - @property - def full_name(self) -> str: - return f"{self.first_name} {self.last_name}".strip() - - -# ============================================================================ -# Order Item Schemas -# ============================================================================ - - -class OrderItemCreate(BaseModel): - """Schema for creating an order item.""" - - product_id: int - quantity: int = Field(..., ge=1) - - -class OrderItemExceptionBrief(BaseModel): - """Brief exception info for embedding in order item responses.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - original_gtin: str | None - original_product_name: str | None - exception_type: str - status: str - resolved_product_id: int | None - - -class OrderItemResponse(BaseModel): - """Schema for order item response.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - order_id: int - product_id: int - product_name: str - product_sku: str | None - gtin: str | None - gtin_type: str | None - quantity: int - unit_price: float - total_price: float - - # External references (for marketplace items) - external_item_id: str | None = None - external_variant_id: str | None = None - - # Item state (for marketplace confirmation flow) - item_state: str | None = None - - # Inventory tracking - inventory_reserved: bool - inventory_fulfilled: bool - - # Exception tracking - needs_product_match: bool = False - exception: OrderItemExceptionBrief | None = None - - created_at: datetime - updated_at: datetime - - @property - def is_confirmed(self) -> bool: - """Check if item has been confirmed (available or unavailable).""" - return self.item_state in ("confirmed_available", "confirmed_unavailable") - - @property - def is_available(self) -> bool: - """Check if item is confirmed as available.""" - return self.item_state == "confirmed_available" - - @property - def is_declined(self) -> bool: - """Check if item was declined (unavailable).""" - return self.item_state == "confirmed_unavailable" - - @property - def has_unresolved_exception(self) -> bool: - """Check if item has an unresolved exception blocking confirmation.""" - if not self.exception: - return False - return self.exception.status in ("pending", "ignored") - - -# ============================================================================ -# Customer Snapshot Schemas -# ============================================================================ - - -class CustomerSnapshot(BaseModel): - """Customer snapshot for order creation.""" - - first_name: str = Field(..., min_length=1, max_length=100) - last_name: str = Field(..., min_length=1, max_length=100) - email: str = Field(..., max_length=255) - phone: str | None = Field(None, max_length=50) - locale: str | None = Field(None, max_length=10) - - -class CustomerSnapshotResponse(BaseModel): - """Customer snapshot in order response.""" - - first_name: str - last_name: str - email: str - phone: str | None - locale: str | None - - @property - def full_name(self) -> str: - return f"{self.first_name} {self.last_name}".strip() - - -# ============================================================================ -# Order Create/Update Schemas -# ============================================================================ - - -class OrderCreate(BaseModel): - """Schema for creating an order (direct channel).""" - - customer_id: int | None = None # Optional for guest checkout - items: list[OrderItemCreate] = Field(..., min_length=1) - - # Customer info snapshot - customer: CustomerSnapshot - - # Addresses (snapshots) - shipping_address: AddressSnapshot - billing_address: AddressSnapshot | None = None # Use shipping if not provided - - # Optional fields - shipping_method: str | None = None - customer_notes: str | None = Field(None, max_length=1000) - - # Cart/session info - session_id: str | None = None - - -class OrderUpdate(BaseModel): - """Schema for updating order status.""" - - status: str | None = Field( - None, pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$" - ) - tracking_number: str | None = None - tracking_provider: str | None = None - internal_notes: str | None = None - - -class OrderTrackingUpdate(BaseModel): - """Schema for setting tracking information.""" - - tracking_number: str = Field(..., min_length=1, max_length=100) - tracking_provider: str = Field(..., min_length=1, max_length=100) - - -class OrderItemStateUpdate(BaseModel): - """Schema for updating item state (marketplace confirmation).""" - - item_id: int - state: str = Field(..., pattern="^(confirmed_available|confirmed_unavailable)$") - - -# ============================================================================ -# Order Response Schemas -# ============================================================================ - - -class OrderResponse(BaseModel): - """Schema for order response.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - vendor_id: int - customer_id: int - order_number: str - - # Channel/Source - channel: str - external_order_id: str | None = None - external_shipment_id: str | None = None - external_order_number: str | None = None - - # Status - status: str - - # Financial - subtotal: float | None - tax_amount: float | None - shipping_amount: float | None - discount_amount: float | None - total_amount: float - currency: str - - # VAT information - vat_regime: str | None = None - vat_rate: float | None = None - vat_rate_label: str | None = None - vat_destination_country: str | None = None - - # Customer snapshot - customer_first_name: str - customer_last_name: str - customer_email: str - customer_phone: str | None - customer_locale: str | None - - # Shipping address snapshot - ship_first_name: str - ship_last_name: str - ship_company: str | None - ship_address_line_1: str - ship_address_line_2: str | None - ship_city: str - ship_postal_code: str - ship_country_iso: str - - # Billing address snapshot - bill_first_name: str - bill_last_name: str - bill_company: str | None - bill_address_line_1: str - bill_address_line_2: str | None - bill_city: str - bill_postal_code: str - bill_country_iso: str - - # Tracking - shipping_method: str | None - tracking_number: str | None - tracking_provider: str | None - tracking_url: str | None = None - shipment_number: str | None = None - shipping_carrier: str | None = None - - # Notes - customer_notes: str | None - internal_notes: str | None - - # Timestamps - order_date: datetime - confirmed_at: datetime | None - shipped_at: datetime | None - delivered_at: datetime | None - cancelled_at: datetime | None - created_at: datetime - updated_at: datetime - - @property - def customer_full_name(self) -> str: - return f"{self.customer_first_name} {self.customer_last_name}".strip() - - @property - def ship_full_name(self) -> str: - return f"{self.ship_first_name} {self.ship_last_name}".strip() - - @property - def is_marketplace_order(self) -> bool: - return self.channel != "direct" - - -class OrderDetailResponse(OrderResponse): - """Schema for detailed order response with items.""" - - items: list[OrderItemResponse] = [] - - # Vendor info (enriched by API) - vendor_name: str | None = None - vendor_code: str | None = None - - -class OrderListResponse(BaseModel): - """Schema for paginated order list.""" - - orders: list[OrderResponse] - total: int - skip: int - limit: int - - -# ============================================================================ -# Order List Item (Simplified for list views) -# ============================================================================ - - -class OrderListItem(BaseModel): - """Simplified order item for list views.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - vendor_id: int - order_number: str - channel: str - status: str - - # External references - external_order_number: str | None = None - - # Customer - customer_full_name: str - customer_email: str - - # Financial - total_amount: float - currency: str - - # Shipping - ship_country_iso: str - - # Tracking - tracking_number: str | None - tracking_provider: str | None - tracking_url: str | None = None - shipment_number: str | None = None - shipping_carrier: str | None = None - - # Item count - item_count: int = 0 - - # Timestamps - order_date: datetime - confirmed_at: datetime | None - shipped_at: datetime | None - - -# ============================================================================ -# Admin Order Schemas -# ============================================================================ - - -class AdminOrderItem(BaseModel): - """Order item with vendor info for admin list view.""" - - model_config = ConfigDict(from_attributes=True) - - id: int - vendor_id: int - vendor_name: str | None = None - vendor_code: str | None = None - customer_id: int - order_number: str - channel: str - status: str - - # External references - external_order_number: str | None = None - external_shipment_id: str | None = None - - # Customer snapshot - customer_full_name: str - customer_email: str - - # Financial - subtotal: float | None - tax_amount: float | None - shipping_amount: float | None - discount_amount: float | None - total_amount: float - currency: str - - # VAT information - vat_regime: str | None = None - vat_rate: float | None = None - vat_rate_label: str | None = None - vat_destination_country: str | None = None - - # Shipping - ship_country_iso: str - tracking_number: str | None - tracking_provider: str | None - tracking_url: str | None = None - shipment_number: str | None = None - shipping_carrier: str | None = None - - # Item count - item_count: int = 0 - - # Timestamps - order_date: datetime - confirmed_at: datetime | None - shipped_at: datetime | None - delivered_at: datetime | None - cancelled_at: datetime | None - created_at: datetime - updated_at: datetime - - -class AdminOrderListResponse(BaseModel): - """Cross-vendor order list for admin.""" - - orders: list[AdminOrderItem] - total: int - skip: int - limit: int - - -class AdminOrderStats(BaseModel): - """Order statistics for admin dashboard.""" - - total_orders: int = 0 - pending_orders: int = 0 - processing_orders: int = 0 - shipped_orders: int = 0 - delivered_orders: int = 0 - cancelled_orders: int = 0 - refunded_orders: int = 0 - total_revenue: float = 0.0 - - # By channel - direct_orders: int = 0 - letzshop_orders: int = 0 - - # Vendors - vendors_with_orders: int = 0 - - -class AdminOrderStatusUpdate(BaseModel): - """Admin version of status update with reason.""" - - status: str = Field( - ..., pattern="^(pending|processing|shipped|delivered|cancelled|refunded)$" - ) - tracking_number: str | None = None - tracking_provider: str | None = None - reason: str | None = Field(None, description="Reason for status change") - - -class AdminVendorWithOrders(BaseModel): - """Vendor with order count.""" - - id: int - name: str - vendor_code: str - order_count: int = 0 - - -class AdminVendorsWithOrdersResponse(BaseModel): - """Response for vendors with orders list.""" - - vendors: list[AdminVendorWithOrders] - - -# ============================================================================ -# Letzshop-specific Schemas -# ============================================================================ - - -class LetzshopOrderImport(BaseModel): - """Schema for importing a Letzshop order from shipment data.""" - - shipment_id: str - order_id: str - order_number: str - order_date: datetime - - # Customer - customer_email: str - customer_locale: str | None = None - - # Shipping address - ship_first_name: str - ship_last_name: str - ship_company: str | None = None - ship_address_line_1: str - ship_address_line_2: str | None = None - ship_city: str - ship_postal_code: str - ship_country_iso: str - - # Billing address - bill_first_name: str - bill_last_name: str - bill_company: str | None = None - bill_address_line_1: str - bill_address_line_2: str | None = None - bill_city: str - bill_postal_code: str - bill_country_iso: str - - # Totals - total_amount: float - currency: str = "EUR" - - # State - letzshop_state: str # unconfirmed, confirmed, declined - - # Items - inventory_units: list[dict] - - # Raw data - raw_data: dict | None = None - - -class LetzshopShippingInfo(BaseModel): - """Shipping info retrieved from Letzshop.""" - - tracking_number: str - tracking_provider: str - shipment_id: str - - -class LetzshopOrderConfirmItem(BaseModel): - """Schema for confirming/declining a single item.""" - - item_id: int - external_item_id: str - action: str = Field(..., pattern="^(confirm|decline)$") - - -class LetzshopOrderConfirmRequest(BaseModel): - """Schema for confirming/declining order items.""" - - items: list[LetzshopOrderConfirmItem] - - -# ============================================================================ -# Mark as Shipped Schemas -# ============================================================================ - - -class MarkAsShippedRequest(BaseModel): - """Schema for marking an order as shipped with tracking info.""" - - tracking_number: str | None = Field(None, max_length=100) - tracking_url: str | None = Field(None, max_length=500) - shipping_carrier: str | None = Field(None, max_length=50) - - -class ShippingLabelInfo(BaseModel): - """Shipping label information for an order.""" - - shipment_number: str | None = None - shipping_carrier: str | None = None - label_url: str | None = None - tracking_number: str | None = None - tracking_url: str | None = None +from app.modules.orders.schemas.order import ( + # Address schemas + AddressSnapshot, + AddressSnapshotResponse, + # Order item schemas + OrderItemCreate, + OrderItemExceptionBrief, + OrderItemResponse, + # Customer schemas + CustomerSnapshot, + CustomerSnapshotResponse, + # Order CRUD schemas + OrderCreate, + OrderUpdate, + OrderTrackingUpdate, + OrderItemStateUpdate, + # Order response schemas + OrderResponse, + OrderDetailResponse, + OrderListResponse, + OrderListItem, + # Admin schemas + AdminOrderItem, + AdminOrderListResponse, + AdminOrderStats, + AdminOrderStatusUpdate, + AdminVendorWithOrders, + AdminVendorsWithOrdersResponse, + # Letzshop schemas + LetzshopOrderImport, + LetzshopShippingInfo, + LetzshopOrderConfirmItem, + LetzshopOrderConfirmRequest, + # Shipping schemas + MarkAsShippedRequest, + ShippingLabelInfo, +) + +__all__ = [ + # Address schemas + "AddressSnapshot", + "AddressSnapshotResponse", + # Order item schemas + "OrderItemCreate", + "OrderItemExceptionBrief", + "OrderItemResponse", + # Customer schemas + "CustomerSnapshot", + "CustomerSnapshotResponse", + # Order CRUD schemas + "OrderCreate", + "OrderUpdate", + "OrderTrackingUpdate", + "OrderItemStateUpdate", + # Order response schemas + "OrderResponse", + "OrderDetailResponse", + "OrderListResponse", + "OrderListItem", + # Admin schemas + "AdminOrderItem", + "AdminOrderListResponse", + "AdminOrderStats", + "AdminOrderStatusUpdate", + "AdminVendorWithOrders", + "AdminVendorsWithOrdersResponse", + # Letzshop schemas + "LetzshopOrderImport", + "LetzshopShippingInfo", + "LetzshopOrderConfirmItem", + "LetzshopOrderConfirmRequest", + # Shipping schemas + "MarkAsShippedRequest", + "ShippingLabelInfo", +]