refactor: migrate modules from re-exports to canonical implementations

Move actual code implementations into module directories:
- orders: 5 services, 4 models, order/invoice schemas
- inventory: 3 services, 2 models, 30+ schemas
- customers: 3 services, 2 models, customer schemas
- messaging: 3 services, 2 models, message/notification schemas
- monitoring: background_tasks_service
- marketplace: 5+ services including letzshop submodule
- dev_tools: code_quality_service, test_runner_service
- billing: billing_service
- contracts: definition.py

Legacy files in app/services/, models/database/, models/schema/
now re-export from canonical module locations for backwards
compatibility. Architecture validator passes with 0 errors.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-29 21:28:56 +01:00
parent b5a803cde8
commit de83875d0a
99 changed files with 19413 additions and 15357 deletions

View File

@@ -1,242 +1,23 @@
# app/services/admin_customer_service.py
"""
Admin customer management service.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Handles customer operations for admin users across all vendors.
The canonical implementation is now in:
app/modules/customers/services/admin_customer_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.customers.services import admin_customer_service
"""
import logging
from typing import Any
from app.modules.customers.services.admin_customer_service import (
admin_customer_service,
AdminCustomerService,
)
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.exceptions.customer import CustomerNotFoundException
from models.database.customer import Customer
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class AdminCustomerService:
"""Service for admin-level customer management across vendors."""
def list_customers(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
is_active: bool | None = None,
skip: int = 0,
limit: int = 20,
) -> tuple[list[dict[str, Any]], int]:
"""
Get paginated list of customers across all vendors.
Args:
db: Database session
vendor_id: Optional vendor ID filter
search: Search by email, name, or customer number
is_active: Filter by active status
skip: Number of records to skip
limit: Maximum records to return
Returns:
Tuple of (customers list, total count)
"""
# Build query
query = db.query(Customer).join(Vendor, Customer.vendor_id == Vendor.id)
# Apply filters
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
if search:
search_term = f"%{search}%"
query = query.filter(
(Customer.email.ilike(search_term))
| (Customer.first_name.ilike(search_term))
| (Customer.last_name.ilike(search_term))
| (Customer.customer_number.ilike(search_term))
)
if is_active is not None:
query = query.filter(Customer.is_active == is_active)
# Get total count
total = query.count()
# Get paginated results with vendor info
customers = (
query.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
.order_by(Customer.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# Format response
result = []
for row in customers:
customer = row[0]
vendor_name = row[1]
vendor_code = row[2]
customer_dict = {
"id": customer.id,
"vendor_id": customer.vendor_id,
"email": customer.email,
"first_name": customer.first_name,
"last_name": customer.last_name,
"phone": customer.phone,
"customer_number": customer.customer_number,
"marketing_consent": customer.marketing_consent,
"preferred_language": customer.preferred_language,
"last_order_date": customer.last_order_date,
"total_orders": customer.total_orders,
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
"is_active": customer.is_active,
"created_at": customer.created_at,
"updated_at": customer.updated_at,
"vendor_name": vendor_name,
"vendor_code": vendor_code,
}
result.append(customer_dict)
return result, total
def get_customer_stats(
self,
db: Session,
vendor_id: int | None = None,
) -> dict[str, Any]:
"""
Get customer statistics.
Args:
db: Database session
vendor_id: Optional vendor ID filter
Returns:
Dict with customer statistics
"""
query = db.query(Customer)
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
total = query.count()
active = query.filter(Customer.is_active == True).count() # noqa: E712
inactive = query.filter(Customer.is_active == False).count() # noqa: E712
with_orders = query.filter(Customer.total_orders > 0).count()
# Total spent across all customers
total_spent_result = query.with_entities(func.sum(Customer.total_spent)).scalar()
total_spent = float(total_spent_result) if total_spent_result else 0
# Average order value
total_orders_result = query.with_entities(func.sum(Customer.total_orders)).scalar()
total_orders = int(total_orders_result) if total_orders_result else 0
avg_order_value = total_spent / total_orders if total_orders > 0 else 0
return {
"total": total,
"active": active,
"inactive": inactive,
"with_orders": with_orders,
"total_spent": total_spent,
"total_orders": total_orders,
"avg_order_value": round(avg_order_value, 2),
}
def get_customer(
self,
db: Session,
customer_id: int,
) -> dict[str, Any]:
"""
Get customer details by ID.
Args:
db: Database session
customer_id: Customer ID
Returns:
Customer dict with vendor info
Raises:
CustomerNotFoundException: If customer not found
"""
result = (
db.query(Customer)
.join(Vendor, Customer.vendor_id == Vendor.id)
.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
.filter(Customer.id == customer_id)
.first()
)
if not result:
raise CustomerNotFoundException(str(customer_id))
customer = result[0]
return {
"id": customer.id,
"vendor_id": customer.vendor_id,
"email": customer.email,
"first_name": customer.first_name,
"last_name": customer.last_name,
"phone": customer.phone,
"customer_number": customer.customer_number,
"marketing_consent": customer.marketing_consent,
"preferred_language": customer.preferred_language,
"last_order_date": customer.last_order_date,
"total_orders": customer.total_orders,
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
"is_active": customer.is_active,
"created_at": customer.created_at,
"updated_at": customer.updated_at,
"vendor_name": result[1],
"vendor_code": result[2],
}
def toggle_customer_status(
self,
db: Session,
customer_id: int,
admin_email: str,
) -> dict[str, Any]:
"""
Toggle customer active status.
Args:
db: Database session
customer_id: Customer ID
admin_email: Admin user email for logging
Returns:
Dict with customer ID, new status, and message
Raises:
CustomerNotFoundException: If customer not found
"""
customer = db.query(Customer).filter(Customer.id == customer_id).first()
if not customer:
raise CustomerNotFoundException(str(customer_id))
customer.is_active = not customer.is_active
db.flush()
db.refresh(customer)
status = "activated" if customer.is_active else "deactivated"
logger.info(f"Customer {customer.email} {status} by admin {admin_email}")
return {
"id": customer.id,
"is_active": customer.is_active,
"message": f"Customer {status} successfully",
}
# Singleton instance
admin_customer_service = AdminCustomerService()
__all__ = [
"admin_customer_service",
"AdminCustomerService",
]

View File

@@ -1,701 +1,37 @@
# app/services/admin_notification_service.py
"""
Admin notification service.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Provides functionality for:
- Creating and managing admin notifications
- Managing platform alerts
- Notification statistics and queries
The canonical implementation is now in:
app/modules/messaging/services/admin_notification_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.messaging.services import admin_notification_service
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from sqlalchemy import and_, case, func, or_
from sqlalchemy.orm import Session
from models.database.admin import AdminNotification, PlatformAlert
from models.schema.admin import AdminNotificationCreate, PlatformAlertCreate
logger = logging.getLogger(__name__)
# ============================================================================
# NOTIFICATION TYPES
# ============================================================================
class NotificationType:
"""Notification type constants."""
SYSTEM_ALERT = "system_alert"
IMPORT_FAILURE = "import_failure"
EXPORT_FAILURE = "export_failure"
ORDER_SYNC_FAILURE = "order_sync_failure"
VENDOR_ISSUE = "vendor_issue"
CUSTOMER_MESSAGE = "customer_message"
VENDOR_MESSAGE = "vendor_message"
SECURITY_ALERT = "security_alert"
PERFORMANCE_ALERT = "performance_alert"
ORDER_EXCEPTION = "order_exception"
CRITICAL_ERROR = "critical_error"
class Priority:
"""Priority level constants."""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
class AlertType:
"""Platform alert type constants."""
SECURITY = "security"
PERFORMANCE = "performance"
CAPACITY = "capacity"
INTEGRATION = "integration"
DATABASE = "database"
SYSTEM = "system"
class Severity:
"""Alert severity constants."""
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
# ============================================================================
# ADMIN NOTIFICATION SERVICE
# ============================================================================
class AdminNotificationService:
"""Service for managing admin notifications."""
def create_notification(
self,
db: Session,
notification_type: str,
title: str,
message: str,
priority: str = Priority.NORMAL,
action_required: bool = False,
action_url: str | None = None,
metadata: dict[str, Any] | None = None,
) -> AdminNotification:
"""
Create a new admin notification.
Args:
db: Database session
notification_type: Type of notification
title: Notification title
message: Notification message
priority: Priority level (low, normal, high, critical)
action_required: Whether action is required
action_url: URL to relevant admin page
metadata: Additional contextual data
Returns:
Created AdminNotification
"""
notification = AdminNotification(
type=notification_type,
title=title,
message=message,
priority=priority,
action_required=action_required,
action_url=action_url,
notification_metadata=metadata,
)
db.add(notification)
db.flush()
logger.info(
f"Created notification: {notification_type} - {title} (priority: {priority})"
)
return notification
def create_from_schema(
self,
db: Session,
data: AdminNotificationCreate,
) -> AdminNotification:
"""Create notification from Pydantic schema."""
return self.create_notification(
db=db,
notification_type=data.type,
title=data.title,
message=data.message,
priority=data.priority,
action_required=data.action_required,
action_url=data.action_url,
metadata=data.metadata,
)
def get_notifications(
self,
db: Session,
priority: str | None = None,
is_read: bool | None = None,
notification_type: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[AdminNotification], int, int]:
"""
Get paginated admin notifications.
Returns:
Tuple of (notifications, total_count, unread_count)
"""
query = db.query(AdminNotification)
# Apply filters
if priority:
query = query.filter(AdminNotification.priority == priority)
if is_read is not None:
query = query.filter(AdminNotification.is_read == is_read)
if notification_type:
query = query.filter(AdminNotification.type == notification_type)
# Get counts
total = query.count()
unread_count = (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.count()
)
# Get paginated results ordered by priority and date
priority_order = case(
(AdminNotification.priority == "critical", 1),
(AdminNotification.priority == "high", 2),
(AdminNotification.priority == "normal", 3),
(AdminNotification.priority == "low", 4),
else_=5,
)
notifications = (
query.order_by(
AdminNotification.is_read, # Unread first
priority_order,
AdminNotification.created_at.desc(),
)
.offset(skip)
.limit(limit)
.all()
)
return notifications, total, unread_count
def get_unread_count(self, db: Session) -> int:
"""Get count of unread notifications."""
return (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.count()
)
def get_recent_notifications(
self,
db: Session,
limit: int = 5,
) -> list[AdminNotification]:
"""Get recent unread notifications for header dropdown."""
priority_order = case(
(AdminNotification.priority == "critical", 1),
(AdminNotification.priority == "high", 2),
(AdminNotification.priority == "normal", 3),
(AdminNotification.priority == "low", 4),
else_=5,
)
return (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.order_by(priority_order, AdminNotification.created_at.desc())
.limit(limit)
.all()
)
def mark_as_read(
self,
db: Session,
notification_id: int,
user_id: int,
) -> AdminNotification | None:
"""Mark a notification as read."""
notification = (
db.query(AdminNotification)
.filter(AdminNotification.id == notification_id)
.first()
)
if notification and not notification.is_read:
notification.is_read = True
notification.read_at = datetime.utcnow()
notification.read_by_user_id = user_id
db.flush()
return notification
def mark_all_as_read(
self,
db: Session,
user_id: int,
) -> int:
"""Mark all unread notifications as read. Returns count of updated."""
now = datetime.utcnow()
count = (
db.query(AdminNotification)
.filter(AdminNotification.is_read == False) # noqa: E712
.update(
{
AdminNotification.is_read: True,
AdminNotification.read_at: now,
AdminNotification.read_by_user_id: user_id,
}
)
)
db.flush()
return count
def delete_old_notifications(
self,
db: Session,
days: int = 30,
) -> int:
"""Delete notifications older than specified days."""
cutoff = datetime.utcnow() - timedelta(days=days)
count = (
db.query(AdminNotification)
.filter(
and_(
AdminNotification.is_read == True, # noqa: E712
AdminNotification.created_at < cutoff,
)
)
.delete()
)
db.flush()
return count
def delete_notification(
self,
db: Session,
notification_id: int,
) -> bool:
"""
Delete a notification by ID.
Returns:
True if notification was deleted, False if not found.
"""
notification = (
db.query(AdminNotification)
.filter(AdminNotification.id == notification_id)
.first()
)
if notification:
db.delete(notification)
db.flush()
logger.info(f"Deleted notification {notification_id}")
return True
return False
# =========================================================================
# CONVENIENCE METHODS FOR CREATING SPECIFIC NOTIFICATION TYPES
# =========================================================================
def notify_import_failure(
self,
db: Session,
vendor_name: str,
job_id: int,
error_message: str,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for import job failure."""
return self.create_notification(
db=db,
notification_type=NotificationType.IMPORT_FAILURE,
title=f"Import Failed: {vendor_name}",
message=error_message,
priority=Priority.HIGH,
action_required=True,
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
if vendor_id
else "/admin/marketplace",
metadata={"vendor_name": vendor_name, "job_id": job_id, "vendor_id": vendor_id},
)
def notify_order_sync_failure(
self,
db: Session,
vendor_name: str,
error_message: str,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for order sync failure."""
return self.create_notification(
db=db,
notification_type=NotificationType.ORDER_SYNC_FAILURE,
title=f"Order Sync Failed: {vendor_name}",
message=error_message,
priority=Priority.HIGH,
action_required=True,
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
if vendor_id
else "/admin/marketplace/letzshop",
metadata={"vendor_name": vendor_name, "vendor_id": vendor_id},
)
def notify_order_exception(
self,
db: Session,
vendor_name: str,
order_number: str,
exception_count: int,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for order item exceptions."""
return self.create_notification(
db=db,
notification_type=NotificationType.ORDER_EXCEPTION,
title=f"Order Exception: {order_number}",
message=f"{exception_count} item(s) need attention for order {order_number} ({vendor_name})",
priority=Priority.NORMAL,
action_required=True,
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=exceptions"
if vendor_id
else "/admin/marketplace/letzshop",
metadata={
"vendor_name": vendor_name,
"order_number": order_number,
"exception_count": exception_count,
"vendor_id": vendor_id,
},
)
def notify_critical_error(
self,
db: Session,
error_type: str,
error_message: str,
details: dict[str, Any] | None = None,
) -> AdminNotification:
"""Create notification for critical application errors."""
return self.create_notification(
db=db,
notification_type=NotificationType.CRITICAL_ERROR,
title=f"Critical Error: {error_type}",
message=error_message,
priority=Priority.CRITICAL,
action_required=True,
action_url="/admin/logs",
metadata=details,
)
def notify_vendor_issue(
self,
db: Session,
vendor_name: str,
issue_type: str,
message: str,
vendor_id: int | None = None,
) -> AdminNotification:
"""Create notification for vendor-related issues."""
return self.create_notification(
db=db,
notification_type=NotificationType.VENDOR_ISSUE,
title=f"Vendor Issue: {vendor_name}",
message=message,
priority=Priority.HIGH,
action_required=True,
action_url=f"/admin/vendors/{vendor_id}" if vendor_id else "/admin/vendors",
metadata={
"vendor_name": vendor_name,
"issue_type": issue_type,
"vendor_id": vendor_id,
},
)
def notify_security_alert(
self,
db: Session,
title: str,
message: str,
details: dict[str, Any] | None = None,
) -> AdminNotification:
"""Create notification for security-related alerts."""
return self.create_notification(
db=db,
notification_type=NotificationType.SECURITY_ALERT,
title=title,
message=message,
priority=Priority.CRITICAL,
action_required=True,
action_url="/admin/audit",
metadata=details,
)
# ============================================================================
# PLATFORM ALERT SERVICE
# ============================================================================
class PlatformAlertService:
"""Service for managing platform-wide alerts."""
def create_alert(
self,
db: Session,
alert_type: str,
severity: str,
title: str,
description: str | None = None,
affected_vendors: list[int] | None = None,
affected_systems: list[str] | None = None,
auto_generated: bool = True,
) -> PlatformAlert:
"""Create a new platform alert."""
now = datetime.utcnow()
alert = PlatformAlert(
alert_type=alert_type,
severity=severity,
title=title,
description=description,
affected_vendors=affected_vendors,
affected_systems=affected_systems,
auto_generated=auto_generated,
first_occurred_at=now,
last_occurred_at=now,
)
db.add(alert)
db.flush()
logger.info(f"Created platform alert: {alert_type} - {title} ({severity})")
return alert
def create_from_schema(
self,
db: Session,
data: PlatformAlertCreate,
) -> PlatformAlert:
"""Create alert from Pydantic schema."""
return self.create_alert(
db=db,
alert_type=data.alert_type,
severity=data.severity,
title=data.title,
description=data.description,
affected_vendors=data.affected_vendors,
affected_systems=data.affected_systems,
auto_generated=data.auto_generated,
)
def get_alerts(
self,
db: Session,
severity: str | None = None,
alert_type: str | None = None,
is_resolved: bool | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[PlatformAlert], int, int, int]:
"""
Get paginated platform alerts.
Returns:
Tuple of (alerts, total_count, active_count, critical_count)
"""
query = db.query(PlatformAlert)
# Apply filters
if severity:
query = query.filter(PlatformAlert.severity == severity)
if alert_type:
query = query.filter(PlatformAlert.alert_type == alert_type)
if is_resolved is not None:
query = query.filter(PlatformAlert.is_resolved == is_resolved)
# Get counts
total = query.count()
active_count = (
db.query(PlatformAlert)
.filter(PlatformAlert.is_resolved == False) # noqa: E712
.count()
)
critical_count = (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.is_resolved == False, # noqa: E712
PlatformAlert.severity == Severity.CRITICAL,
)
)
.count()
)
# Get paginated results
severity_order = case(
(PlatformAlert.severity == "critical", 1),
(PlatformAlert.severity == "error", 2),
(PlatformAlert.severity == "warning", 3),
(PlatformAlert.severity == "info", 4),
else_=5,
)
alerts = (
query.order_by(
PlatformAlert.is_resolved, # Unresolved first
severity_order,
PlatformAlert.last_occurred_at.desc(),
)
.offset(skip)
.limit(limit)
.all()
)
return alerts, total, active_count, critical_count
def resolve_alert(
self,
db: Session,
alert_id: int,
user_id: int,
resolution_notes: str | None = None,
) -> PlatformAlert | None:
"""Resolve a platform alert."""
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
if alert and not alert.is_resolved:
alert.is_resolved = True
alert.resolved_at = datetime.utcnow()
alert.resolved_by_user_id = user_id
alert.resolution_notes = resolution_notes
db.flush()
logger.info(f"Resolved platform alert {alert_id}")
return alert
def get_statistics(self, db: Session) -> dict[str, int]:
"""Get alert statistics."""
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
total = db.query(PlatformAlert).count()
active = (
db.query(PlatformAlert)
.filter(PlatformAlert.is_resolved == False) # noqa: E712
.count()
)
critical = (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.is_resolved == False, # noqa: E712
PlatformAlert.severity == Severity.CRITICAL,
)
)
.count()
)
resolved_today = (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.is_resolved == True, # noqa: E712
PlatformAlert.resolved_at >= today_start,
)
)
.count()
)
return {
"total_alerts": total,
"active_alerts": active,
"critical_alerts": critical,
"resolved_today": resolved_today,
}
def increment_occurrence(
self,
db: Session,
alert_id: int,
) -> PlatformAlert | None:
"""Increment occurrence count for repeated alert."""
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
if alert:
alert.occurrence_count += 1
alert.last_occurred_at = datetime.utcnow()
db.flush()
return alert
def find_similar_active_alert(
self,
db: Session,
alert_type: str,
title: str,
) -> PlatformAlert | None:
"""Find an active alert with same type and title."""
return (
db.query(PlatformAlert)
.filter(
and_(
PlatformAlert.alert_type == alert_type,
PlatformAlert.title == title,
PlatformAlert.is_resolved == False, # noqa: E712
)
)
.first()
)
def create_or_increment_alert(
self,
db: Session,
alert_type: str,
severity: str,
title: str,
description: str | None = None,
affected_vendors: list[int] | None = None,
affected_systems: list[str] | None = None,
) -> PlatformAlert:
"""Create alert or increment occurrence if similar exists."""
existing = self.find_similar_active_alert(db, alert_type, title)
if existing:
self.increment_occurrence(db, existing.id)
return existing
return self.create_alert(
db=db,
alert_type=alert_type,
severity=severity,
title=title,
description=description,
affected_vendors=affected_vendors,
affected_systems=affected_systems,
)
# Singleton instances
admin_notification_service = AdminNotificationService()
platform_alert_service = PlatformAlertService()
from app.modules.messaging.services.admin_notification_service import (
admin_notification_service,
AdminNotificationService,
platform_alert_service,
PlatformAlertService,
# Constants
NotificationType,
Priority,
AlertType,
Severity,
)
__all__ = [
"admin_notification_service",
"AdminNotificationService",
"platform_alert_service",
"PlatformAlertService",
# Constants
"NotificationType",
"Priority",
"AlertType",
"Severity",
]

View File

@@ -1,194 +1,23 @@
# app/services/background_tasks_service.py
"""
Background Tasks Service
Service for monitoring background tasks across the system
LEGACY LOCATION - Re-exports from module for backwards compatibility.
The canonical implementation is now in:
app/modules/monitoring/services/background_tasks_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.monitoring.services import background_tasks_service
"""
from datetime import UTC, datetime
from app.modules.monitoring.services.background_tasks_service import (
background_tasks_service,
BackgroundTasksService,
)
from sqlalchemy import case, desc, func
from sqlalchemy.orm import Session
from models.database.architecture_scan import ArchitectureScan
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.test_run import TestRun
class BackgroundTasksService:
"""Service for monitoring background tasks"""
def get_import_jobs(
self, db: Session, status: str | None = None, limit: int = 50
) -> list[MarketplaceImportJob]:
"""Get import jobs with optional status filter"""
query = db.query(MarketplaceImportJob)
if status:
query = query.filter(MarketplaceImportJob.status == status)
return query.order_by(desc(MarketplaceImportJob.created_at)).limit(limit).all()
def get_test_runs(
self, db: Session, status: str | None = None, limit: int = 50
) -> list[TestRun]:
"""Get test runs with optional status filter"""
query = db.query(TestRun)
if status:
query = query.filter(TestRun.status == status)
return query.order_by(desc(TestRun.timestamp)).limit(limit).all()
def get_running_imports(self, db: Session) -> list[MarketplaceImportJob]:
"""Get currently running import jobs"""
return (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.status == "processing")
.all()
)
def get_running_test_runs(self, db: Session) -> list[TestRun]:
"""Get currently running test runs"""
# noqa: SVC-005 - Platform-level, TestRuns not vendor-scoped
return db.query(TestRun).filter(TestRun.status == "running").all()
def get_import_stats(self, db: Session) -> dict:
"""Get import job statistics"""
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(MarketplaceImportJob.id).label("total"),
func.sum(
case((MarketplaceImportJob.status == "processing", 1), else_=0)
).label("running"),
func.sum(
case(
(
MarketplaceImportJob.status.in_(
["completed", "completed_with_errors"]
),
1,
),
else_=0,
)
).label("completed"),
func.sum(
case((MarketplaceImportJob.status == "failed", 1), else_=0)
).label("failed"),
).first()
today_count = (
db.query(func.count(MarketplaceImportJob.id))
.filter(MarketplaceImportJob.created_at >= today_start)
.scalar()
or 0
)
return {
"total": stats.total or 0,
"running": stats.running or 0,
"completed": stats.completed or 0,
"failed": stats.failed or 0,
"today": today_count,
}
def get_test_run_stats(self, db: Session) -> dict:
"""Get test run statistics"""
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(TestRun.id).label("total"),
func.sum(case((TestRun.status == "running", 1), else_=0)).label(
"running"
),
func.sum(case((TestRun.status == "passed", 1), else_=0)).label(
"completed"
),
func.sum(
case((TestRun.status.in_(["failed", "error"]), 1), else_=0)
).label("failed"),
func.avg(TestRun.duration_seconds).label("avg_duration"),
).first()
today_count = (
db.query(func.count(TestRun.id))
.filter(TestRun.timestamp >= today_start)
.scalar()
or 0
)
return {
"total": stats.total or 0,
"running": stats.running or 0,
"completed": stats.completed or 0,
"failed": stats.failed or 0,
"today": today_count,
"avg_duration": round(stats.avg_duration or 0, 1),
}
def get_code_quality_scans(
self, db: Session, status: str | None = None, limit: int = 50
) -> list[ArchitectureScan]:
"""Get code quality scans with optional status filter"""
query = db.query(ArchitectureScan)
if status:
query = query.filter(ArchitectureScan.status == status)
return query.order_by(desc(ArchitectureScan.timestamp)).limit(limit).all()
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
"""Get currently running code quality scans"""
return (
db.query(ArchitectureScan)
.filter(ArchitectureScan.status.in_(["pending", "running"]))
.all()
)
def get_scan_stats(self, db: Session) -> dict:
"""Get code quality scan statistics"""
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(ArchitectureScan.id).label("total"),
func.sum(
case(
(ArchitectureScan.status.in_(["pending", "running"]), 1), else_=0
)
).label("running"),
func.sum(
case(
(
ArchitectureScan.status.in_(
["completed", "completed_with_warnings"]
),
1,
),
else_=0,
)
).label("completed"),
func.sum(
case((ArchitectureScan.status == "failed", 1), else_=0)
).label("failed"),
func.avg(ArchitectureScan.duration_seconds).label("avg_duration"),
).first()
today_count = (
db.query(func.count(ArchitectureScan.id))
.filter(ArchitectureScan.timestamp >= today_start)
.scalar()
or 0
)
return {
"total": stats.total or 0,
"running": stats.running or 0,
"completed": stats.completed or 0,
"failed": stats.failed or 0,
"today": today_count,
"avg_duration": round(stats.avg_duration or 0, 1),
}
# Singleton instance
background_tasks_service = BackgroundTasksService()
__all__ = [
"background_tasks_service",
"BackgroundTasksService",
]

View File

@@ -1,588 +1,35 @@
# app/services/billing_service.py
"""
Billing service for subscription and payment operations.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Provides:
- Subscription status and usage queries
- Tier management
- Invoice history
- Add-on management
The canonical implementation is now in:
app/modules/billing/services/billing_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.billing.services import billing_service
"""
import logging
from datetime import datetime
from sqlalchemy.orm import Session
from app.services.stripe_service import stripe_service
from app.services.subscription_service import subscription_service
from models.database.subscription import (
AddOnProduct,
BillingHistory,
SubscriptionTier,
VendorAddOn,
VendorSubscription,
from app.modules.billing.services.billing_service import (
BillingService,
billing_service,
BillingServiceError,
PaymentSystemNotConfiguredError,
TierNotFoundError,
StripePriceNotConfiguredError,
NoActiveSubscriptionError,
SubscriptionNotCancelledError,
)
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class BillingServiceError(Exception):
"""Base exception for billing service errors."""
pass
class PaymentSystemNotConfiguredError(BillingServiceError):
"""Raised when Stripe is not configured."""
def __init__(self):
super().__init__("Payment system not configured")
class TierNotFoundError(BillingServiceError):
"""Raised when a tier is not found."""
def __init__(self, tier_code: str):
self.tier_code = tier_code
super().__init__(f"Tier '{tier_code}' not found")
class StripePriceNotConfiguredError(BillingServiceError):
"""Raised when Stripe price is not configured for a tier."""
def __init__(self, tier_code: str):
self.tier_code = tier_code
super().__init__(f"Stripe price not configured for tier '{tier_code}'")
class NoActiveSubscriptionError(BillingServiceError):
"""Raised when no active subscription exists."""
def __init__(self):
super().__init__("No active subscription found")
class SubscriptionNotCancelledError(BillingServiceError):
"""Raised when trying to reactivate a non-cancelled subscription."""
def __init__(self):
super().__init__("Subscription is not cancelled")
class BillingService:
"""Service for billing operations."""
def get_subscription_with_tier(
self, db: Session, vendor_id: int
) -> tuple[VendorSubscription, SubscriptionTier | None]:
"""
Get subscription and its tier info.
Returns:
Tuple of (subscription, tier) where tier may be None
"""
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
tier = (
db.query(SubscriptionTier)
.filter(SubscriptionTier.code == subscription.tier)
.first()
)
return subscription, tier
def get_available_tiers(
self, db: Session, current_tier: str
) -> tuple[list[dict], dict[str, int]]:
"""
Get all available tiers with upgrade/downgrade flags.
Returns:
Tuple of (tier_list, tier_order_map)
"""
tiers = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.is_active == True, # noqa: E712
SubscriptionTier.is_public == True, # noqa: E712
)
.order_by(SubscriptionTier.display_order)
.all()
)
tier_order = {t.code: t.display_order for t in tiers}
current_order = tier_order.get(current_tier, 0)
tier_list = []
for tier in tiers:
tier_list.append({
"code": tier.code,
"name": tier.name,
"description": tier.description,
"price_monthly_cents": tier.price_monthly_cents,
"price_annual_cents": tier.price_annual_cents,
"orders_per_month": tier.orders_per_month,
"products_limit": tier.products_limit,
"team_members": tier.team_members,
"features": tier.features or [],
"is_current": tier.code == current_tier,
"can_upgrade": tier.display_order > current_order,
"can_downgrade": tier.display_order < current_order,
})
return tier_list, tier_order
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
"""
Get a tier by its code.
Raises:
TierNotFoundError: If tier doesn't exist
"""
tier = (
db.query(SubscriptionTier)
.filter(
SubscriptionTier.code == tier_code,
SubscriptionTier.is_active == True, # noqa: E712
)
.first()
)
if not tier:
raise TierNotFoundError(tier_code)
return tier
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
"""
Get vendor by ID.
Raises:
VendorNotFoundException from app.exceptions
"""
from app.exceptions import VendorNotFoundException
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
return vendor
def create_checkout_session(
self,
db: Session,
vendor_id: int,
tier_code: str,
is_annual: bool,
success_url: str,
cancel_url: str,
) -> dict:
"""
Create a Stripe checkout session.
Returns:
Dict with checkout_url and session_id
Raises:
PaymentSystemNotConfiguredError: If Stripe not configured
TierNotFoundError: If tier doesn't exist
StripePriceNotConfiguredError: If price not configured
"""
if not stripe_service.is_configured:
raise PaymentSystemNotConfiguredError()
vendor = self.get_vendor(db, vendor_id)
tier = self.get_tier_by_code(db, tier_code)
price_id = (
tier.stripe_price_annual_id
if is_annual and tier.stripe_price_annual_id
else tier.stripe_price_monthly_id
)
if not price_id:
raise StripePriceNotConfiguredError(tier_code)
# Check if this is a new subscription (for trial)
existing_sub = subscription_service.get_subscription(db, vendor_id)
trial_days = None
if not existing_sub or not existing_sub.stripe_subscription_id:
from app.core.config import settings
trial_days = settings.stripe_trial_days
session = stripe_service.create_checkout_session(
db=db,
vendor=vendor,
price_id=price_id,
success_url=success_url,
cancel_url=cancel_url,
trial_days=trial_days,
)
# Update subscription with tier info
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
subscription.tier = tier_code
subscription.is_annual = is_annual
return {
"checkout_url": session.url,
"session_id": session.id,
}
def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict:
"""
Create a Stripe customer portal session.
Returns:
Dict with portal_url
Raises:
PaymentSystemNotConfiguredError: If Stripe not configured
NoActiveSubscriptionError: If no subscription with customer ID
"""
if not stripe_service.is_configured:
raise PaymentSystemNotConfiguredError()
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_customer_id:
raise NoActiveSubscriptionError()
session = stripe_service.create_portal_session(
customer_id=subscription.stripe_customer_id,
return_url=return_url,
)
return {"portal_url": session.url}
def get_invoices(
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20
) -> tuple[list[BillingHistory], int]:
"""
Get invoice history for a vendor.
Returns:
Tuple of (invoices, total_count)
"""
query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id)
total = query.count()
invoices = (
query.order_by(BillingHistory.invoice_date.desc())
.offset(skip)
.limit(limit)
.all()
)
return invoices, total
def get_available_addons(
self, db: Session, category: str | None = None
) -> list[AddOnProduct]:
"""Get available add-on products."""
query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712
if category:
query = query.filter(AddOnProduct.category == category)
return query.order_by(AddOnProduct.display_order).all()
def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]:
"""Get vendor's purchased add-ons."""
return (
db.query(VendorAddOn)
.filter(VendorAddOn.vendor_id == vendor_id)
.all()
)
def cancel_subscription(
self, db: Session, vendor_id: int, reason: str | None, immediately: bool
) -> dict:
"""
Cancel a subscription.
Returns:
Dict with message and effective_date
Raises:
NoActiveSubscriptionError: If no subscription to cancel
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_subscription_id:
raise NoActiveSubscriptionError()
if stripe_service.is_configured:
stripe_service.cancel_subscription(
subscription_id=subscription.stripe_subscription_id,
immediately=immediately,
cancellation_reason=reason,
)
subscription.cancelled_at = datetime.utcnow()
subscription.cancellation_reason = reason
effective_date = (
datetime.utcnow().isoformat()
if immediately
else subscription.period_end.isoformat()
if subscription.period_end
else datetime.utcnow().isoformat()
)
return {
"message": "Subscription cancelled successfully",
"effective_date": effective_date,
}
def reactivate_subscription(self, db: Session, vendor_id: int) -> dict:
"""
Reactivate a cancelled subscription.
Returns:
Dict with success message
Raises:
NoActiveSubscriptionError: If no subscription
SubscriptionNotCancelledError: If not cancelled
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_subscription_id:
raise NoActiveSubscriptionError()
if not subscription.cancelled_at:
raise SubscriptionNotCancelledError()
if stripe_service.is_configured:
stripe_service.reactivate_subscription(subscription.stripe_subscription_id)
subscription.cancelled_at = None
subscription.cancellation_reason = None
return {"message": "Subscription reactivated successfully"}
def get_upcoming_invoice(self, db: Session, vendor_id: int) -> dict:
"""
Get upcoming invoice preview.
Returns:
Dict with amount_due_cents, currency, next_payment_date, line_items
Raises:
NoActiveSubscriptionError: If no subscription with customer ID
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_customer_id:
raise NoActiveSubscriptionError()
if not stripe_service.is_configured:
# Return empty preview if Stripe not configured
return {
"amount_due_cents": 0,
"currency": "EUR",
"next_payment_date": None,
"line_items": [],
}
invoice = stripe_service.get_upcoming_invoice(subscription.stripe_customer_id)
if not invoice:
return {
"amount_due_cents": 0,
"currency": "EUR",
"next_payment_date": None,
"line_items": [],
}
line_items = []
if invoice.lines and invoice.lines.data:
for line in invoice.lines.data:
line_items.append({
"description": line.description or "",
"amount_cents": line.amount,
"quantity": line.quantity or 1,
})
return {
"amount_due_cents": invoice.amount_due,
"currency": invoice.currency.upper(),
"next_payment_date": datetime.fromtimestamp(invoice.next_payment_attempt).isoformat()
if invoice.next_payment_attempt
else None,
"line_items": line_items,
}
def change_tier(
self,
db: Session,
vendor_id: int,
new_tier_code: str,
is_annual: bool,
) -> dict:
"""
Change subscription tier (upgrade/downgrade).
Returns:
Dict with message, new_tier, effective_immediately
Raises:
TierNotFoundError: If tier doesn't exist
NoActiveSubscriptionError: If no subscription
StripePriceNotConfiguredError: If price not configured
"""
subscription = subscription_service.get_subscription(db, vendor_id)
if not subscription or not subscription.stripe_subscription_id:
raise NoActiveSubscriptionError()
tier = self.get_tier_by_code(db, new_tier_code)
price_id = (
tier.stripe_price_annual_id
if is_annual and tier.stripe_price_annual_id
else tier.stripe_price_monthly_id
)
if not price_id:
raise StripePriceNotConfiguredError(new_tier_code)
# Update in Stripe
if stripe_service.is_configured:
stripe_service.update_subscription(
subscription_id=subscription.stripe_subscription_id,
new_price_id=price_id,
)
# Update local subscription
old_tier = subscription.tier
subscription.tier = new_tier_code
subscription.tier_id = tier.id
subscription.is_annual = is_annual
subscription.updated_at = datetime.utcnow()
is_upgrade = self._is_upgrade(db, old_tier, new_tier_code)
return {
"message": f"Subscription {'upgraded' if is_upgrade else 'changed'} to {tier.name}",
"new_tier": new_tier_code,
"effective_immediately": True,
}
def _is_upgrade(self, db: Session, old_tier: str, new_tier: str) -> bool:
"""Check if tier change is an upgrade."""
old = db.query(SubscriptionTier).filter(SubscriptionTier.code == old_tier).first()
new = db.query(SubscriptionTier).filter(SubscriptionTier.code == new_tier).first()
if not old or not new:
return False
return new.display_order > old.display_order
def purchase_addon(
self,
db: Session,
vendor_id: int,
addon_code: str,
domain_name: str | None,
quantity: int,
success_url: str,
cancel_url: str,
) -> dict:
"""
Create checkout session for add-on purchase.
Returns:
Dict with checkout_url and session_id
Raises:
PaymentSystemNotConfiguredError: If Stripe not configured
AddonNotFoundError: If addon doesn't exist
"""
if not stripe_service.is_configured:
raise PaymentSystemNotConfiguredError()
addon = (
db.query(AddOnProduct)
.filter(
AddOnProduct.code == addon_code,
AddOnProduct.is_active == True, # noqa: E712
)
.first()
)
if not addon:
raise BillingServiceError(f"Add-on '{addon_code}' not found")
if not addon.stripe_price_id:
raise BillingServiceError(f"Stripe price not configured for add-on '{addon_code}'")
vendor = self.get_vendor(db, vendor_id)
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
# Create checkout session for add-on
session = stripe_service.create_checkout_session(
db=db,
vendor=vendor,
price_id=addon.stripe_price_id,
success_url=success_url,
cancel_url=cancel_url,
quantity=quantity,
metadata={
"addon_code": addon_code,
"domain_name": domain_name or "",
},
)
return {
"checkout_url": session.url,
"session_id": session.id,
}
def cancel_addon(self, db: Session, vendor_id: int, addon_id: int) -> dict:
"""
Cancel a purchased add-on.
Returns:
Dict with message and addon_code
Raises:
BillingServiceError: If addon not found or not owned by vendor
"""
vendor_addon = (
db.query(VendorAddOn)
.filter(
VendorAddOn.id == addon_id,
VendorAddOn.vendor_id == vendor_id,
)
.first()
)
if not vendor_addon:
raise BillingServiceError("Add-on not found")
addon_code = vendor_addon.addon_product.code
# Cancel in Stripe if applicable
if stripe_service.is_configured and vendor_addon.stripe_subscription_item_id:
try:
stripe_service.cancel_subscription_item(vendor_addon.stripe_subscription_item_id)
except Exception as e:
logger.warning(f"Failed to cancel addon in Stripe: {e}")
# Mark as cancelled
vendor_addon.status = "cancelled"
vendor_addon.cancelled_at = datetime.utcnow()
return {
"message": "Add-on cancelled successfully",
"addon_code": addon_code,
}
# Create service instance
billing_service = BillingService()
__all__ = [
"BillingService",
"billing_service",
"BillingServiceError",
"PaymentSystemNotConfiguredError",
"TierNotFoundError",
"StripePriceNotConfiguredError",
"NoActiveSubscriptionError",
"SubscriptionNotCancelledError",
]

View File

@@ -1,820 +1,35 @@
# app/services/code_quality_service.py
"""
Code Quality Service
Business logic for managing code quality scans and violations
Supports multiple validator types: architecture, security, performance
LEGACY LOCATION - Re-exports from module for backwards compatibility.
The canonical implementation is now in:
app/modules/dev_tools/services/code_quality_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.dev_tools.services import code_quality_service
"""
import json
import logging
import subprocess
from datetime import datetime
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from app.exceptions import (
ScanParseException,
ScanTimeoutException,
ViolationNotFoundException,
)
from models.database.architecture_scan import (
ArchitectureScan,
ArchitectureViolation,
ViolationAssignment,
ViolationComment,
from app.modules.dev_tools.services.code_quality_service import (
code_quality_service,
CodeQualityService,
VALIDATOR_ARCHITECTURE,
VALIDATOR_SECURITY,
VALIDATOR_PERFORMANCE,
VALID_VALIDATOR_TYPES,
VALIDATOR_SCRIPTS,
VALIDATOR_NAMES,
)
logger = logging.getLogger(__name__)
# Validator type constants
VALIDATOR_ARCHITECTURE = "architecture"
VALIDATOR_SECURITY = "security"
VALIDATOR_PERFORMANCE = "performance"
VALID_VALIDATOR_TYPES = [VALIDATOR_ARCHITECTURE, VALIDATOR_SECURITY, VALIDATOR_PERFORMANCE]
# Map validator types to their scripts
VALIDATOR_SCRIPTS = {
VALIDATOR_ARCHITECTURE: "scripts/validate_architecture.py",
VALIDATOR_SECURITY: "scripts/validate_security.py",
VALIDATOR_PERFORMANCE: "scripts/validate_performance.py",
}
# Human-readable names
VALIDATOR_NAMES = {
VALIDATOR_ARCHITECTURE: "Architecture",
VALIDATOR_SECURITY: "Security",
VALIDATOR_PERFORMANCE: "Performance",
}
class CodeQualityService:
"""Service for managing code quality scans and violations"""
def run_scan(
self,
db: Session,
triggered_by: str = "manual",
validator_type: str = VALIDATOR_ARCHITECTURE,
) -> ArchitectureScan:
"""
Run a code quality validator and store results in database
Args:
db: Database session
triggered_by: Who/what triggered the scan ('manual', 'scheduled', 'ci/cd')
validator_type: Type of validator ('architecture', 'security', 'performance')
Returns:
ArchitectureScan object with results
Raises:
ValueError: If validator_type is invalid
ScanTimeoutException: If validator times out
ScanParseException: If validator output cannot be parsed
"""
if validator_type not in VALID_VALIDATOR_TYPES:
raise ValueError(
f"Invalid validator type: {validator_type}. "
f"Must be one of: {VALID_VALIDATOR_TYPES}"
)
script_path = VALIDATOR_SCRIPTS[validator_type]
validator_name = VALIDATOR_NAMES[validator_type]
logger.info(
f"Starting {validator_name} scan (triggered by: {triggered_by})"
)
# Get git commit hash
git_commit = self._get_git_commit_hash()
# Run validator with JSON output
start_time = datetime.now()
try:
result = subprocess.run(
["python", script_path, "--json"],
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
)
except subprocess.TimeoutExpired:
logger.error(f"{validator_name} scan timed out after 5 minutes")
raise ScanTimeoutException(timeout_seconds=300)
duration = (datetime.now() - start_time).total_seconds()
# Parse JSON output (get only the JSON part, skip progress messages)
try:
# Find the JSON part in stdout
lines = result.stdout.strip().split("\n")
json_start = -1
for i, line in enumerate(lines):
if line.strip().startswith("{"):
json_start = i
break
if json_start == -1:
raise ValueError("No JSON output found")
json_output = "\n".join(lines[json_start:])
data = json.loads(json_output)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse {validator_name} validator output: {e}")
logger.error(f"Stdout: {result.stdout}")
logger.error(f"Stderr: {result.stderr}")
raise ScanParseException(reason=str(e))
# Create scan record
scan = ArchitectureScan(
timestamp=datetime.now(),
validator_type=validator_type,
total_files=data.get("files_checked", 0),
total_violations=data.get("total_violations", 0),
errors=data.get("errors", 0),
warnings=data.get("warnings", 0),
duration_seconds=duration,
triggered_by=triggered_by,
git_commit_hash=git_commit,
)
db.add(scan)
db.flush() # Get scan.id
# Create violation records
violations_data = data.get("violations", [])
logger.info(f"Creating {len(violations_data)} {validator_name} violation records")
for v in violations_data:
violation = ArchitectureViolation(
scan_id=scan.id,
validator_type=validator_type,
rule_id=v["rule_id"],
rule_name=v["rule_name"],
severity=v["severity"],
file_path=v["file_path"],
line_number=v["line_number"],
message=v["message"],
context=v.get("context", ""),
suggestion=v.get("suggestion", ""),
status="open",
)
db.add(violation)
db.flush()
db.refresh(scan)
logger.info(
f"{validator_name} scan completed: {scan.total_violations} violations found"
)
return scan
def run_all_scans(
self, db: Session, triggered_by: str = "manual"
) -> list[ArchitectureScan]:
"""
Run all validators and return list of scans
Args:
db: Database session
triggered_by: Who/what triggered the scan
Returns:
List of ArchitectureScan objects (one per validator)
"""
results = []
for validator_type in VALID_VALIDATOR_TYPES:
try:
scan = self.run_scan(db, triggered_by, validator_type)
results.append(scan)
except Exception as e:
logger.error(f"Failed to run {validator_type} scan: {e}")
# Continue with other validators even if one fails
return results
def get_latest_scan(
self, db: Session, validator_type: str = None
) -> ArchitectureScan | None:
"""
Get the most recent scan
Args:
db: Database session
validator_type: Optional filter by validator type
Returns:
Most recent ArchitectureScan or None
"""
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
if validator_type:
query = query.filter(ArchitectureScan.validator_type == validator_type)
return query.first()
def get_latest_scans_by_type(self, db: Session) -> dict[str, ArchitectureScan]:
"""
Get the most recent scan for each validator type
Returns:
Dictionary mapping validator_type to its latest scan
"""
result = {}
for vtype in VALID_VALIDATOR_TYPES:
scan = self.get_latest_scan(db, validator_type=vtype)
if scan:
result[vtype] = scan
return result
def get_scan_by_id(self, db: Session, scan_id: int) -> ArchitectureScan | None:
"""Get scan by ID"""
return db.query(ArchitectureScan).filter(ArchitectureScan.id == scan_id).first()
def create_pending_scan(
self, db: Session, validator_type: str, triggered_by: str
) -> ArchitectureScan:
"""
Create a new scan record with pending status.
Args:
db: Database session
validator_type: Type of validator (architecture, security, performance)
triggered_by: Who triggered the scan (e.g., "manual:username")
Returns:
The created ArchitectureScan record with ID populated
"""
scan = ArchitectureScan(
timestamp=datetime.now(UTC),
validator_type=validator_type,
status="pending",
triggered_by=triggered_by,
)
db.add(scan)
db.flush() # Get scan.id
return scan
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
"""
Get all currently running scans (pending or running status).
Returns:
List of scans with status 'pending' or 'running', newest first
"""
return (
db.query(ArchitectureScan)
.filter(ArchitectureScan.status.in_(["pending", "running"]))
.order_by(ArchitectureScan.timestamp.desc())
.all()
)
def get_scan_history(
self, db: Session, limit: int = 30, validator_type: str = None
) -> list[ArchitectureScan]:
"""
Get scan history for trend graphs
Args:
db: Database session
limit: Maximum number of scans to return
validator_type: Optional filter by validator type
Returns:
List of ArchitectureScan objects, newest first
"""
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
if validator_type:
query = query.filter(ArchitectureScan.validator_type == validator_type)
return query.limit(limit).all()
def get_violations(
self,
db: Session,
scan_id: int = None,
validator_type: str = None,
severity: str = None,
status: str = None,
rule_id: str = None,
file_path: str = None,
limit: int = 100,
offset: int = 0,
) -> tuple[list[ArchitectureViolation], int]:
"""
Get violations with filtering and pagination
Args:
db: Database session
scan_id: Filter by scan ID (if None, use latest scan(s))
validator_type: Filter by validator type
severity: Filter by severity ('error', 'warning')
status: Filter by status ('open', 'assigned', 'resolved', etc.)
rule_id: Filter by rule ID
file_path: Filter by file path (partial match)
limit: Page size
offset: Page offset
Returns:
Tuple of (violations list, total count)
"""
# Build query
query = db.query(ArchitectureViolation)
# If scan_id specified, filter by it
if scan_id is not None:
query = query.filter(ArchitectureViolation.scan_id == scan_id)
else:
# If no scan_id, get violations from latest scan(s)
if validator_type:
# Get latest scan for specific validator type
latest_scan = self.get_latest_scan(db, validator_type)
if not latest_scan:
return [], 0
query = query.filter(ArchitectureViolation.scan_id == latest_scan.id)
else:
# Get violations from latest scans of all types
latest_scans = self.get_latest_scans_by_type(db)
if not latest_scans:
return [], 0
scan_ids = [s.id for s in latest_scans.values()]
query = query.filter(ArchitectureViolation.scan_id.in_(scan_ids))
# Apply validator_type filter if specified (for scan_id queries)
if validator_type and scan_id is not None:
query = query.filter(ArchitectureViolation.validator_type == validator_type)
# Apply other filters
if severity:
query = query.filter(ArchitectureViolation.severity == severity)
if status:
query = query.filter(ArchitectureViolation.status == status)
if rule_id:
query = query.filter(ArchitectureViolation.rule_id == rule_id)
if file_path:
query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
# Get total count
total = query.count()
# Get page of results
violations = (
query.order_by(
ArchitectureViolation.severity.desc(),
ArchitectureViolation.validator_type,
ArchitectureViolation.file_path,
)
.limit(limit)
.offset(offset)
.all()
)
return violations, total
def get_violation_by_id(
self, db: Session, violation_id: int
) -> ArchitectureViolation | None:
"""Get single violation with details"""
return (
db.query(ArchitectureViolation)
.filter(ArchitectureViolation.id == violation_id)
.first()
)
def assign_violation(
self,
db: Session,
violation_id: int,
user_id: int,
assigned_by: int,
due_date: datetime = None,
priority: str = "medium",
) -> ViolationAssignment:
"""
Assign violation to a developer
Args:
db: Database session
violation_id: Violation ID
user_id: User to assign to
assigned_by: User who is assigning
due_date: Due date (optional)
priority: Priority level ('low', 'medium', 'high', 'critical')
Returns:
ViolationAssignment object
"""
# Update violation status
violation = self.get_violation_by_id(db, violation_id)
if violation:
violation.status = "assigned"
violation.assigned_to = user_id
# Create assignment record
assignment = ViolationAssignment(
violation_id=violation_id,
user_id=user_id,
assigned_by=assigned_by,
due_date=due_date,
priority=priority,
)
db.add(assignment)
db.flush()
logger.info(f"Violation {violation_id} assigned to user {user_id}")
return assignment
def resolve_violation(
self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
) -> ArchitectureViolation:
"""
Mark violation as resolved
Args:
db: Database session
violation_id: Violation ID
resolved_by: User who resolved it
resolution_note: Note about resolution
Returns:
Updated ArchitectureViolation object
"""
violation = self.get_violation_by_id(db, violation_id)
if not violation:
raise ViolationNotFoundException(violation_id)
violation.status = "resolved"
violation.resolved_at = datetime.now()
violation.resolved_by = resolved_by
violation.resolution_note = resolution_note
db.flush()
logger.info(f"Violation {violation_id} resolved by user {resolved_by}")
return violation
def ignore_violation(
self, db: Session, violation_id: int, ignored_by: int, reason: str
) -> ArchitectureViolation:
"""
Mark violation as ignored/won't fix
Args:
db: Database session
violation_id: Violation ID
ignored_by: User who ignored it
reason: Reason for ignoring
Returns:
Updated ArchitectureViolation object
"""
violation = self.get_violation_by_id(db, violation_id)
if not violation:
raise ViolationNotFoundException(violation_id)
violation.status = "ignored"
violation.resolved_at = datetime.now()
violation.resolved_by = ignored_by
violation.resolution_note = f"Ignored: {reason}"
db.flush()
logger.info(f"Violation {violation_id} ignored by user {ignored_by}")
return violation
def add_comment(
self, db: Session, violation_id: int, user_id: int, comment: str
) -> ViolationComment:
"""
Add comment to violation
Args:
db: Database session
violation_id: Violation ID
user_id: User posting comment
comment: Comment text
Returns:
ViolationComment object
"""
comment_obj = ViolationComment(
violation_id=violation_id, user_id=user_id, comment=comment
)
db.add(comment_obj)
db.flush()
logger.info(f"Comment added to violation {violation_id} by user {user_id}")
return comment_obj
def get_dashboard_stats(
self, db: Session, validator_type: str = None
) -> dict:
"""
Get statistics for dashboard
Args:
db: Database session
validator_type: Optional filter by validator type. If None, returns combined stats.
Returns:
Dictionary with various statistics including per-validator breakdown
"""
# Get latest scans by type
latest_scans = self.get_latest_scans_by_type(db)
if not latest_scans:
return self._empty_dashboard_stats()
# If specific validator type requested
if validator_type and validator_type in latest_scans:
scan = latest_scans[validator_type]
return self._get_stats_for_scan(db, scan, validator_type)
# Combined stats across all validators
return self._get_combined_stats(db, latest_scans)
def _empty_dashboard_stats(self) -> dict:
"""Return empty dashboard stats structure"""
return {
"total_violations": 0,
"errors": 0,
"warnings": 0,
"info": 0,
"open": 0,
"assigned": 0,
"resolved": 0,
"ignored": 0,
"technical_debt_score": 100,
"trend": [],
"by_severity": {},
"by_rule": {},
"by_module": {},
"top_files": [],
"last_scan": None,
"by_validator": {},
}
def _get_stats_for_scan(
self, db: Session, scan: ArchitectureScan, validator_type: str
) -> dict:
"""Get stats for a single scan/validator type"""
# Get violation counts by status
status_counts = (
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.status)
.all()
)
status_dict = {status: count for status, count in status_counts}
# Get violations by severity
severity_counts = (
db.query(
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.severity)
.all()
)
by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule
rule_counts = (
db.query(
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.rule_id)
.all()
)
by_rule = {
rule: count
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
}
# Get top violating files
file_counts = (
db.query(
ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label("count"),
)
.filter(ArchitectureViolation.scan_id == scan.id)
.group_by(ArchitectureViolation.file_path)
.order_by(desc("count"))
.limit(10)
.all()
)
top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module
by_module = self._get_violations_by_module(db, scan.id)
# Get trend for this validator type
trend_scans = self.get_scan_history(db, limit=7, validator_type=validator_type)
trend = [
{
"timestamp": s.timestamp.isoformat(),
"violations": s.total_violations,
"errors": s.errors,
"warnings": s.warnings,
}
for s in reversed(trend_scans)
]
return {
"total_violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
"info": by_severity.get("info", 0),
"open": status_dict.get("open", 0),
"assigned": status_dict.get("assigned", 0),
"resolved": status_dict.get("resolved", 0),
"ignored": status_dict.get("ignored", 0),
"technical_debt_score": self._calculate_score(scan.errors, scan.warnings),
"trend": trend,
"by_severity": by_severity,
"by_rule": by_rule,
"by_module": by_module,
"top_files": top_files,
"last_scan": scan.timestamp.isoformat(),
"validator_type": validator_type,
"by_validator": {
validator_type: {
"total_violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
"last_scan": scan.timestamp.isoformat(),
}
},
}
def _get_combined_stats(
self, db: Session, latest_scans: dict[str, ArchitectureScan]
) -> dict:
"""Get combined stats across all validators"""
# Aggregate totals
total_violations = sum(s.total_violations for s in latest_scans.values())
total_errors = sum(s.errors for s in latest_scans.values())
total_warnings = sum(s.warnings for s in latest_scans.values())
# Get all scan IDs
scan_ids = [s.id for s in latest_scans.values()]
# Get violation counts by status
status_counts = (
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.status)
.all()
)
status_dict = {status: count for status, count in status_counts}
# Get violations by severity
severity_counts = (
db.query(
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.severity)
.all()
)
by_severity = {sev: count for sev, count in severity_counts}
# Get violations by rule (across all validators)
rule_counts = (
db.query(
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
)
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.rule_id)
.all()
)
by_rule = {
rule: count
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
}
# Get top violating files
file_counts = (
db.query(
ArchitectureViolation.file_path,
func.count(ArchitectureViolation.id).label("count"),
)
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
.group_by(ArchitectureViolation.file_path)
.order_by(desc("count"))
.limit(10)
.all()
)
top_files = [{"file": file, "count": count} for file, count in file_counts]
# Get violations by module
by_module = {}
for scan_id in scan_ids:
module_counts = self._get_violations_by_module(db, scan_id)
for module, count in module_counts.items():
by_module[module] = by_module.get(module, 0) + count
by_module = dict(
sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
)
# Per-validator breakdown
by_validator = {}
for vtype, scan in latest_scans.items():
by_validator[vtype] = {
"total_violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
"last_scan": scan.timestamp.isoformat(),
}
# Get most recent scan timestamp
most_recent = max(latest_scans.values(), key=lambda s: s.timestamp)
return {
"total_violations": total_violations,
"errors": total_errors,
"warnings": total_warnings,
"info": by_severity.get("info", 0),
"open": status_dict.get("open", 0),
"assigned": status_dict.get("assigned", 0),
"resolved": status_dict.get("resolved", 0),
"ignored": status_dict.get("ignored", 0),
"technical_debt_score": self._calculate_score(total_errors, total_warnings),
"trend": [], # Combined trend would need special handling
"by_severity": by_severity,
"by_rule": by_rule,
"by_module": by_module,
"top_files": top_files,
"last_scan": most_recent.timestamp.isoformat(),
"by_validator": by_validator,
}
def _get_violations_by_module(self, db: Session, scan_id: int) -> dict[str, int]:
"""Extract module from file paths and count violations"""
by_module = {}
violations = (
db.query(ArchitectureViolation.file_path)
.filter(ArchitectureViolation.scan_id == scan_id)
.all()
)
for v in violations:
path_parts = v.file_path.split("/")
if len(path_parts) >= 2:
module = "/".join(path_parts[:2])
else:
module = path_parts[0]
by_module[module] = by_module.get(module, 0) + 1
return dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10])
def _calculate_score(self, errors: int, warnings: int) -> int:
"""Calculate technical debt score (0-100)"""
score = 100 - (errors * 0.5 + warnings * 0.05)
return max(0, min(100, int(score)))
def calculate_technical_debt_score(
self, db: Session, scan_id: int = None, validator_type: str = None
) -> int:
"""
Calculate technical debt score (0-100)
Formula: 100 - (errors * 0.5 + warnings * 0.05)
Capped at 0 minimum
Args:
db: Database session
scan_id: Scan ID (if None, use latest)
validator_type: Filter by validator type
Returns:
Score from 0-100
"""
if scan_id is None:
latest_scan = self.get_latest_scan(db, validator_type)
if not latest_scan:
return 100
scan_id = latest_scan.id
scan = self.get_scan_by_id(db, scan_id)
if not scan:
return 100
return self._calculate_score(scan.errors, scan.warnings)
def _get_git_commit_hash(self) -> str | None:
"""Get current git commit hash"""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
return result.stdout.strip()[:40]
except Exception:
pass
return None
# Singleton instance
code_quality_service = CodeQualityService()
__all__ = [
"code_quality_service",
"CodeQualityService",
"VALIDATOR_ARCHITECTURE",
"VALIDATOR_SECURITY",
"VALIDATOR_PERFORMANCE",
"VALID_VALIDATOR_TYPES",
"VALIDATOR_SCRIPTS",
"VALIDATOR_NAMES",
]

View File

@@ -1,314 +1,23 @@
# app/services/customer_address_service.py
"""
Customer Address Service
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Business logic for managing customer addresses with vendor isolation.
The canonical implementation is now in:
app/modules/customers/services/customer_address_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.customers.services import customer_address_service
"""
import logging
from sqlalchemy.orm import Session
from app.exceptions import (
AddressLimitExceededException,
AddressNotFoundException,
from app.modules.customers.services.customer_address_service import (
customer_address_service,
CustomerAddressService,
)
from models.database.customer import CustomerAddress
from models.schema.customer import CustomerAddressCreate, CustomerAddressUpdate
logger = logging.getLogger(__name__)
class CustomerAddressService:
"""Service for managing customer addresses with vendor isolation."""
MAX_ADDRESSES_PER_CUSTOMER = 10
def list_addresses(
self, db: Session, vendor_id: int, customer_id: int
) -> list[CustomerAddress]:
"""
Get all addresses for a customer.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
Returns:
List of customer addresses
"""
return (
db.query(CustomerAddress)
.filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
)
.order_by(CustomerAddress.is_default.desc(), CustomerAddress.created_at.desc())
.all()
)
def get_address(
self, db: Session, vendor_id: int, customer_id: int, address_id: int
) -> CustomerAddress:
"""
Get a specific address with ownership validation.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
Returns:
Customer address
Raises:
AddressNotFoundException: If address not found or doesn't belong to customer
"""
address = (
db.query(CustomerAddress)
.filter(
CustomerAddress.id == address_id,
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
)
.first()
)
if not address:
raise AddressNotFoundException(address_id)
return address
def get_default_address(
self, db: Session, vendor_id: int, customer_id: int, address_type: str
) -> CustomerAddress | None:
"""
Get the default address for a specific type.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_type: 'shipping' or 'billing'
Returns:
Default address or None if not set
"""
return (
db.query(CustomerAddress)
.filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
CustomerAddress.address_type == address_type,
CustomerAddress.is_default == True, # noqa: E712
)
.first()
)
def create_address(
self,
db: Session,
vendor_id: int,
customer_id: int,
address_data: CustomerAddressCreate,
) -> CustomerAddress:
"""
Create a new address for a customer.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_data: Address creation data
Returns:
Created customer address
Raises:
AddressLimitExceededException: If max addresses reached
"""
# Check address limit
current_count = (
db.query(CustomerAddress)
.filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
)
.count()
)
if current_count >= self.MAX_ADDRESSES_PER_CUSTOMER:
raise AddressLimitExceededException(self.MAX_ADDRESSES_PER_CUSTOMER)
# If setting as default, clear other defaults of same type
if address_data.is_default:
self._clear_other_defaults(
db, vendor_id, customer_id, address_data.address_type
)
# Create the address
address = CustomerAddress(
vendor_id=vendor_id,
customer_id=customer_id,
address_type=address_data.address_type,
first_name=address_data.first_name,
last_name=address_data.last_name,
company=address_data.company,
address_line_1=address_data.address_line_1,
address_line_2=address_data.address_line_2,
city=address_data.city,
postal_code=address_data.postal_code,
country_name=address_data.country_name,
country_iso=address_data.country_iso,
is_default=address_data.is_default,
)
db.add(address)
db.flush()
logger.info(
f"Created address {address.id} for customer {customer_id} "
f"(type={address_data.address_type}, default={address_data.is_default})"
)
return address
def update_address(
self,
db: Session,
vendor_id: int,
customer_id: int,
address_id: int,
address_data: CustomerAddressUpdate,
) -> CustomerAddress:
"""
Update an existing address.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
address_data: Address update data
Returns:
Updated customer address
Raises:
AddressNotFoundException: If address not found
"""
address = self.get_address(db, vendor_id, customer_id, address_id)
# Update only provided fields
update_data = address_data.model_dump(exclude_unset=True)
# Handle default flag - clear others if setting to default
if update_data.get("is_default") is True:
# Use updated type if provided, otherwise current type
address_type = update_data.get("address_type", address.address_type)
self._clear_other_defaults(
db, vendor_id, customer_id, address_type, exclude_id=address_id
)
for field, value in update_data.items():
setattr(address, field, value)
db.flush()
logger.info(f"Updated address {address_id} for customer {customer_id}")
return address
def delete_address(
self, db: Session, vendor_id: int, customer_id: int, address_id: int
) -> None:
"""
Delete an address.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
Raises:
AddressNotFoundException: If address not found
"""
address = self.get_address(db, vendor_id, customer_id, address_id)
db.delete(address)
db.flush()
logger.info(f"Deleted address {address_id} for customer {customer_id}")
def set_default(
self, db: Session, vendor_id: int, customer_id: int, address_id: int
) -> CustomerAddress:
"""
Set an address as the default for its type.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_id: Address ID
Returns:
Updated customer address
Raises:
AddressNotFoundException: If address not found
"""
address = self.get_address(db, vendor_id, customer_id, address_id)
# Clear other defaults of same type
self._clear_other_defaults(
db, vendor_id, customer_id, address.address_type, exclude_id=address_id
)
# Set this one as default
address.is_default = True
db.flush()
logger.info(
f"Set address {address_id} as default {address.address_type} "
f"for customer {customer_id}"
)
return address
def _clear_other_defaults(
self,
db: Session,
vendor_id: int,
customer_id: int,
address_type: str,
exclude_id: int | None = None,
) -> None:
"""
Clear the default flag on other addresses of the same type.
Args:
db: Database session
vendor_id: Vendor ID for isolation
customer_id: Customer ID
address_type: 'shipping' or 'billing'
exclude_id: Address ID to exclude from clearing
"""
query = db.query(CustomerAddress).filter(
CustomerAddress.vendor_id == vendor_id,
CustomerAddress.customer_id == customer_id,
CustomerAddress.address_type == address_type,
CustomerAddress.is_default == True, # noqa: E712
)
if exclude_id:
query = query.filter(CustomerAddress.id != exclude_id)
query.update({"is_default": False}, synchronize_session=False)
# Singleton instance
customer_address_service = CustomerAddressService()
__all__ = [
"customer_address_service",
"CustomerAddressService",
]

View File

@@ -1,664 +1,23 @@
# app/services/customer_service.py
"""
Customer management service.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Handles customer registration, authentication, and profile management
with complete vendor isolation.
The canonical implementation is now in:
app/modules/customers/services/customer_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.customers.services import customer_service
"""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy import and_
from sqlalchemy.orm import Session
from app.exceptions.customer import (
CustomerNotActiveException,
CustomerNotFoundException,
CustomerValidationException,
DuplicateCustomerEmailException,
InvalidCustomerCredentialsException,
InvalidPasswordResetTokenException,
PasswordTooShortException,
from app.modules.customers.services.customer_service import (
customer_service,
CustomerService,
)
from app.exceptions.vendor import VendorNotActiveException, VendorNotFoundException
from app.services.auth_service import AuthService
from models.database.customer import Customer
from models.database.password_reset_token import PasswordResetToken
from models.database.vendor import Vendor
from models.schema.auth import UserLogin
from models.schema.customer import CustomerRegister, CustomerUpdate
logger = logging.getLogger(__name__)
class CustomerService:
"""Service for managing vendor-scoped customers."""
def __init__(self):
self.auth_service = AuthService()
def register_customer(
self, db: Session, vendor_id: int, customer_data: CustomerRegister
) -> Customer:
"""
Register a new customer for a specific vendor.
Args:
db: Database session
vendor_id: Vendor ID
customer_data: Customer registration data
Returns:
Customer: Created customer object
Raises:
VendorNotFoundException: If vendor doesn't exist
VendorNotActiveException: If vendor is not active
DuplicateCustomerEmailException: If email already exists for this vendor
CustomerValidationException: If customer data is invalid
"""
# Verify vendor exists and is active
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
if not vendor.is_active:
raise VendorNotActiveException(vendor.vendor_code)
# Check if email already exists for this vendor
existing_customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == customer_data.email.lower(),
)
)
.first()
)
if existing_customer:
raise DuplicateCustomerEmailException(
customer_data.email, vendor.vendor_code
)
# Generate unique customer number for this vendor
customer_number = self._generate_customer_number(
db, vendor_id, vendor.vendor_code
)
# Hash password
hashed_password = self.auth_service.hash_password(customer_data.password)
# Create customer
customer = Customer(
vendor_id=vendor_id,
email=customer_data.email.lower(),
hashed_password=hashed_password,
first_name=customer_data.first_name,
last_name=customer_data.last_name,
phone=customer_data.phone,
customer_number=customer_number,
marketing_consent=(
customer_data.marketing_consent
if hasattr(customer_data, "marketing_consent")
else False
),
is_active=True,
)
try:
db.add(customer)
db.flush()
db.refresh(customer)
logger.info(
f"Customer registered successfully: {customer.email} "
f"(ID: {customer.id}, Number: {customer.customer_number}) "
f"for vendor {vendor.vendor_code}"
)
return customer
except Exception as e:
logger.error(f"Error registering customer: {str(e)}")
raise CustomerValidationException(
message="Failed to register customer", details={"error": str(e)}
)
def login_customer(
self, db: Session, vendor_id: int, credentials: UserLogin
) -> dict[str, Any]:
"""
Authenticate customer and generate JWT token.
Args:
db: Database session
vendor_id: Vendor ID
credentials: Login credentials
Returns:
Dict containing customer and token data
Raises:
VendorNotFoundException: If vendor doesn't exist
InvalidCustomerCredentialsException: If credentials are invalid
CustomerNotActiveException: If customer account is inactive
"""
# Verify vendor exists
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
# Find customer by email (vendor-scoped)
customer = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == credentials.email_or_username.lower(),
)
)
.first()
)
if not customer:
raise InvalidCustomerCredentialsException()
# Verify password using auth_manager directly
if not self.auth_service.auth_manager.verify_password(
credentials.password, customer.hashed_password
):
raise InvalidCustomerCredentialsException()
# Check if customer is active
if not customer.is_active:
raise CustomerNotActiveException(customer.email)
# Generate JWT token with customer context
# Use auth_manager directly since Customer is not a User model
from datetime import datetime
from jose import jwt
auth_manager = self.auth_service.auth_manager
expires_delta = timedelta(minutes=auth_manager.token_expire_minutes)
expire = datetime.now(UTC) + expires_delta
payload = {
"sub": str(customer.id),
"email": customer.email,
"vendor_id": vendor_id,
"type": "customer",
"exp": expire,
"iat": datetime.now(UTC),
}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
token_data = {
"access_token": token,
"token_type": "bearer",
"expires_in": auth_manager.token_expire_minutes * 60,
}
logger.info(
f"Customer login successful: {customer.email} "
f"for vendor {vendor.vendor_code}"
)
return {"customer": customer, "token_data": token_data}
def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
"""
Get customer by ID with vendor isolation.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Customer: Customer object
Raises:
CustomerNotFoundException: If customer not found
"""
customer = (
db.query(Customer)
.filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
.first()
)
if not customer:
raise CustomerNotFoundException(str(customer_id))
return customer
def get_customer_by_email(
self, db: Session, vendor_id: int, email: str
) -> Customer | None:
"""
Get customer by email (vendor-scoped).
Args:
db: Database session
vendor_id: Vendor ID
email: Customer email
Returns:
Optional[Customer]: Customer object or None
"""
return (
db.query(Customer)
.filter(
and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
)
.first()
)
def get_vendor_customers(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
search: str | None = None,
is_active: bool | None = None,
) -> tuple[list[Customer], int]:
"""
Get all customers for a vendor with filtering and pagination.
Args:
db: Database session
vendor_id: Vendor ID
skip: Pagination offset
limit: Pagination limit
search: Search in name/email
is_active: Filter by active status
Returns:
Tuple of (customers, total_count)
"""
from sqlalchemy import or_
query = db.query(Customer).filter(Customer.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
or_(
Customer.email.ilike(search_pattern),
Customer.first_name.ilike(search_pattern),
Customer.last_name.ilike(search_pattern),
Customer.customer_number.ilike(search_pattern),
)
)
if is_active is not None:
query = query.filter(Customer.is_active == is_active)
# Order by most recent first
query = query.order_by(Customer.created_at.desc())
total = query.count()
customers = query.offset(skip).limit(limit).all()
return customers, total
def get_customer_orders(
self,
db: Session,
vendor_id: int,
customer_id: int,
skip: int = 0,
limit: int = 50,
) -> tuple[list, int]:
"""
Get orders for a specific customer.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
skip: Pagination offset
limit: Pagination limit
Returns:
Tuple of (orders, total_count)
Raises:
CustomerNotFoundException: If customer not found
"""
from models.database.order import Order
# Verify customer belongs to vendor
self.get_customer(db, vendor_id, customer_id)
# Get customer orders
query = (
db.query(Order)
.filter(
Order.customer_id == customer_id,
Order.vendor_id == vendor_id,
)
.order_by(Order.created_at.desc())
)
total = query.count()
orders = query.offset(skip).limit(limit).all()
return orders, total
def get_customer_statistics(
self, db: Session, vendor_id: int, customer_id: int
) -> dict:
"""
Get detailed statistics for a customer.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Dict with customer statistics
"""
from sqlalchemy import func
from models.database.order import Order
customer = self.get_customer(db, vendor_id, customer_id)
# Get order statistics
order_stats = (
db.query(
func.count(Order.id).label("total_orders"),
func.sum(Order.total_cents).label("total_spent_cents"),
func.avg(Order.total_cents).label("avg_order_cents"),
func.max(Order.created_at).label("last_order_date"),
)
.filter(Order.customer_id == customer_id)
.first()
)
total_orders = order_stats.total_orders or 0
total_spent_cents = order_stats.total_spent_cents or 0
avg_order_cents = order_stats.avg_order_cents or 0
return {
"customer_id": customer_id,
"total_orders": total_orders,
"total_spent": total_spent_cents / 100, # Convert to euros
"average_order_value": avg_order_cents / 100 if avg_order_cents else 0.0,
"last_order_date": order_stats.last_order_date,
"member_since": customer.created_at,
"is_active": customer.is_active,
}
def toggle_customer_status(
self, db: Session, vendor_id: int, customer_id: int
) -> Customer:
"""
Toggle customer active status.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Customer: Updated customer
"""
customer = self.get_customer(db, vendor_id, customer_id)
customer.is_active = not customer.is_active
db.flush()
db.refresh(customer)
action = "activated" if customer.is_active else "deactivated"
logger.info(f"Customer {action}: {customer.email} (ID: {customer.id})")
return customer
def update_customer(
self,
db: Session,
vendor_id: int,
customer_id: int,
customer_data: CustomerUpdate,
) -> Customer:
"""
Update customer profile.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
customer_data: Updated customer data
Returns:
Customer: Updated customer object
Raises:
CustomerNotFoundException: If customer not found
CustomerValidationException: If update data is invalid
"""
customer = self.get_customer(db, vendor_id, customer_id)
# Update fields
update_data = customer_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "email" and value:
# Check if new email already exists for this vendor
existing = (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.email == value.lower(),
Customer.id != customer_id,
)
)
.first()
)
if existing:
raise DuplicateCustomerEmailException(value, "vendor")
setattr(customer, field, value.lower())
elif hasattr(customer, field):
setattr(customer, field, value)
try:
db.flush()
db.refresh(customer)
logger.info(f"Customer updated: {customer.email} (ID: {customer.id})")
return customer
except Exception as e:
logger.error(f"Error updating customer: {str(e)}")
raise CustomerValidationException(
message="Failed to update customer", details={"error": str(e)}
)
def deactivate_customer(
self, db: Session, vendor_id: int, customer_id: int
) -> Customer:
"""
Deactivate customer account.
Args:
db: Database session
vendor_id: Vendor ID
customer_id: Customer ID
Returns:
Customer: Deactivated customer object
Raises:
CustomerNotFoundException: If customer not found
"""
customer = self.get_customer(db, vendor_id, customer_id)
customer.is_active = False
db.flush()
db.refresh(customer)
logger.info(f"Customer deactivated: {customer.email} (ID: {customer.id})")
return customer
def update_customer_stats(
self, db: Session, customer_id: int, order_total: float
) -> None:
"""
Update customer statistics after order.
Args:
db: Database session
customer_id: Customer ID
order_total: Order total amount
"""
customer = db.query(Customer).filter(Customer.id == customer_id).first()
if customer:
customer.total_orders += 1
customer.total_spent += order_total
customer.last_order_date = datetime.utcnow()
logger.debug(f"Updated stats for customer {customer.email}")
def _generate_customer_number(
self, db: Session, vendor_id: int, vendor_code: str
) -> str:
"""
Generate unique customer number for vendor.
Format: {VENDOR_CODE}-CUST-{SEQUENCE}
Example: VENDORA-CUST-00001
Args:
db: Database session
vendor_id: Vendor ID
vendor_code: Vendor code
Returns:
str: Unique customer number
"""
# Get count of customers for this vendor
count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
# Generate number with padding
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
# Ensure uniqueness (in case of deletions)
while (
db.query(Customer)
.filter(
and_(
Customer.vendor_id == vendor_id,
Customer.customer_number == customer_number,
)
)
.first()
):
count += 1
sequence = str(count + 1).zfill(5)
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
return customer_number
def get_customer_for_password_reset(
self, db: Session, vendor_id: int, email: str
) -> Customer | None:
"""
Get active customer by email for password reset.
Args:
db: Database session
vendor_id: Vendor ID
email: Customer email
Returns:
Customer if found and active, None otherwise
"""
return (
db.query(Customer)
.filter(
Customer.vendor_id == vendor_id,
Customer.email == email.lower(),
Customer.is_active == True, # noqa: E712
)
.first()
)
def validate_and_reset_password(
self,
db: Session,
vendor_id: int,
reset_token: str,
new_password: str,
) -> Customer:
"""
Validate reset token and update customer password.
Args:
db: Database session
vendor_id: Vendor ID
reset_token: Password reset token from email
new_password: New password
Returns:
Customer: Updated customer
Raises:
PasswordTooShortException: If password too short
InvalidPasswordResetTokenException: If token invalid/expired
CustomerNotActiveException: If customer not active
"""
# Validate password length
if len(new_password) < 8:
raise PasswordTooShortException(min_length=8)
# Find valid token
token_record = PasswordResetToken.find_valid_token(db, reset_token)
if not token_record:
raise InvalidPasswordResetTokenException()
# Get the customer and verify they belong to this vendor
customer = (
db.query(Customer)
.filter(Customer.id == token_record.customer_id)
.first()
)
if not customer or customer.vendor_id != vendor_id:
raise InvalidPasswordResetTokenException()
if not customer.is_active:
raise CustomerNotActiveException(customer.email)
# Hash the new password and update customer
hashed_password = self.auth_service.hash_password(new_password)
customer.hashed_password = hashed_password
# Mark token as used
token_record.mark_used(db)
logger.info(f"Password reset completed for customer {customer.id}")
return customer
# Singleton instance
customer_service = CustomerService()
__all__ = [
"customer_service",
"CustomerService",
]

View File

@@ -1,250 +1,25 @@
# app/services/inventory_import_service.py
"""
Inventory import service for bulk importing stock from TSV/CSV files.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Supports two formats:
1. One row per unit (quantity = count of rows):
BIN EAN PRODUCT
SA-10-02 0810050910101 Product Name
SA-10-02 0810050910101 Product Name (2nd unit)
The canonical implementation is now in:
app/modules/inventory/services/inventory_import_service.py
2. With explicit quantity column:
BIN EAN PRODUCT QUANTITY
SA-10-02 0810050910101 Product Name 12
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
Products are matched by GTIN/EAN to existing vendor products.
from app.modules.inventory.services import inventory_import_service
"""
import csv
import io
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from app.modules.inventory.services.inventory_import_service import (
inventory_import_service,
InventoryImportService,
ImportResult,
)
from sqlalchemy.orm import Session
from models.database.inventory import Inventory
from models.database.product import Product
logger = logging.getLogger(__name__)
@dataclass
class ImportResult:
"""Result of an inventory import operation."""
success: bool = True
total_rows: int = 0
entries_created: int = 0
entries_updated: int = 0
quantity_imported: int = 0
unmatched_gtins: list = field(default_factory=list)
errors: list = field(default_factory=list)
class InventoryImportService:
"""Service for importing inventory from TSV/CSV files."""
def import_from_text(
self,
db: Session,
content: str,
vendor_id: int,
warehouse: str = "strassen",
delimiter: str = "\t",
clear_existing: bool = False,
) -> ImportResult:
"""
Import inventory from TSV/CSV text content.
Args:
db: Database session
content: TSV/CSV content as string
vendor_id: Vendor ID for inventory
warehouse: Warehouse name (default: "strassen")
delimiter: Column delimiter (default: tab)
clear_existing: If True, clear existing inventory before import
Returns:
ImportResult with summary and errors
"""
result = ImportResult()
try:
# Parse CSV/TSV
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
# Normalize headers (case-insensitive, strip whitespace)
if reader.fieldnames:
reader.fieldnames = [h.strip().upper() for h in reader.fieldnames]
# Validate required columns
required = {"BIN", "EAN"}
if not reader.fieldnames or not required.issubset(set(reader.fieldnames)):
result.success = False
result.errors.append(
f"Missing required columns. Found: {reader.fieldnames}, Required: {required}"
)
return result
has_quantity = "QUANTITY" in reader.fieldnames
# Group entries by (EAN, BIN)
# Key: (ean, bin) -> quantity
inventory_data: dict[tuple[str, str], int] = defaultdict(int)
product_names: dict[str, str] = {} # EAN -> product name (for logging)
for row in reader:
result.total_rows += 1
ean = row.get("EAN", "").strip()
bin_loc = row.get("BIN", "").strip()
product_name = row.get("PRODUCT", "").strip()
if not ean or not bin_loc:
result.errors.append(f"Row {result.total_rows}: Missing EAN or BIN")
continue
# Get quantity
if has_quantity:
try:
qty = int(row.get("QUANTITY", "1").strip())
except ValueError:
result.errors.append(
f"Row {result.total_rows}: Invalid quantity '{row.get('QUANTITY')}'"
)
continue
else:
qty = 1 # Each row = 1 unit
inventory_data[(ean, bin_loc)] += qty
if product_name:
product_names[ean] = product_name
# Clear existing inventory if requested
if clear_existing:
db.query(Inventory).filter(
Inventory.vendor_id == vendor_id,
Inventory.warehouse == warehouse,
).delete()
db.flush()
# Build EAN to Product mapping for this vendor
products = (
db.query(Product)
.filter(
Product.vendor_id == vendor_id,
Product.gtin.isnot(None),
)
.all()
)
ean_to_product: dict[str, Product] = {p.gtin: p for p in products if p.gtin}
# Track unmatched GTINs
unmatched: dict[str, int] = {} # EAN -> total quantity
# Process inventory entries
for (ean, bin_loc), quantity in inventory_data.items():
product = ean_to_product.get(ean)
if not product:
# Track unmatched
if ean not in unmatched:
unmatched[ean] = 0
unmatched[ean] += quantity
continue
# Upsert inventory entry
existing = (
db.query(Inventory)
.filter(
Inventory.product_id == product.id,
Inventory.warehouse == warehouse,
Inventory.bin_location == bin_loc,
)
.first()
)
if existing:
existing.quantity = quantity
existing.gtin = ean
result.entries_updated += 1
else:
inv = Inventory(
product_id=product.id,
vendor_id=vendor_id,
warehouse=warehouse,
bin_location=bin_loc,
location=bin_loc, # Legacy field
quantity=quantity,
gtin=ean,
)
db.add(inv)
result.entries_created += 1
result.quantity_imported += quantity
db.flush()
# Format unmatched GTINs for result
for ean, qty in unmatched.items():
product_name = product_names.get(ean, "Unknown")
result.unmatched_gtins.append(
{"gtin": ean, "quantity": qty, "product_name": product_name}
)
if result.unmatched_gtins:
logger.warning(
f"Import had {len(result.unmatched_gtins)} unmatched GTINs"
)
except Exception as e:
logger.exception("Inventory import failed")
result.success = False
result.errors.append(str(e))
return result
def import_from_file(
self,
db: Session,
file_path: str,
vendor_id: int,
warehouse: str = "strassen",
clear_existing: bool = False,
) -> ImportResult:
"""
Import inventory from a TSV/CSV file.
Args:
db: Database session
file_path: Path to TSV/CSV file
vendor_id: Vendor ID for inventory
warehouse: Warehouse name
clear_existing: If True, clear existing inventory before import
Returns:
ImportResult with summary and errors
"""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
except Exception as e:
return ImportResult(success=False, errors=[f"Failed to read file: {e}"])
# Detect delimiter
first_line = content.split("\n")[0] if content else ""
delimiter = "\t" if "\t" in first_line else ","
return self.import_from_text(
db=db,
content=content,
vendor_id=vendor_id,
warehouse=warehouse,
delimiter=delimiter,
clear_existing=clear_existing,
)
# Singleton instance
inventory_import_service = InventoryImportService()
__all__ = [
"inventory_import_service",
"InventoryImportService",
"ImportResult",
]

View File

@@ -1,949 +1,23 @@
# app/services/inventory_service.py
import logging
from datetime import UTC, datetime
"""
LEGACY LOCATION - Re-exports from module for backwards compatibility.
from sqlalchemy import func
from sqlalchemy.orm import Session
The canonical implementation is now in:
app/modules/inventory/services/inventory_service.py
from app.exceptions import (
InsufficientInventoryException,
InvalidQuantityException,
InventoryNotFoundException,
InventoryValidationException,
ProductNotFoundException,
ValidationException,
VendorNotFoundException,
)
from models.database.inventory import Inventory
from models.database.product import Product
from models.database.vendor import Vendor
from models.schema.inventory import (
AdminInventoryItem,
AdminInventoryListResponse,
AdminInventoryLocationsResponse,
AdminInventoryStats,
AdminLowStockItem,
AdminVendorsWithInventoryResponse,
AdminVendorWithInventory,
InventoryAdjust,
InventoryCreate,
InventoryLocationResponse,
InventoryReserve,
InventoryUpdate,
ProductInventorySummary,
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.inventory.services import inventory_service
"""
from app.modules.inventory.services.inventory_service import (
inventory_service,
InventoryService,
)
logger = logging.getLogger(__name__)
class InventoryService:
"""Service for inventory operations with vendor isolation."""
def set_inventory(
self, db: Session, vendor_id: int, inventory_data: InventoryCreate
) -> Inventory:
"""
Set exact inventory quantity for a product at a location (replaces existing).
Args:
db: Database session
vendor_id: Vendor ID (from middleware)
inventory_data: Inventory data
Returns:
Inventory object
"""
try:
# Validate product belongs to vendor
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
# Validate location
location = self._validate_location(inventory_data.location)
# Validate quantity
self._validate_quantity(inventory_data.quantity, allow_zero=True)
# Check if inventory entry exists
existing = self._get_inventory_entry(
db, inventory_data.product_id, location
)
if existing:
old_qty = existing.quantity
existing.quantity = inventory_data.quantity
existing.updated_at = datetime.now(UTC)
db.flush()
db.refresh(existing)
logger.info(
f"Set inventory for product {inventory_data.product_id} at {location}: "
f"{old_qty}{inventory_data.quantity}"
)
return existing
# Create new inventory entry
new_inventory = Inventory(
product_id=inventory_data.product_id,
vendor_id=vendor_id,
warehouse="strassen", # Default warehouse
bin_location=location, # Use location as bin location
location=location, # Keep for backward compatibility
quantity=inventory_data.quantity,
gtin=product.marketplace_product.gtin, # Optional reference
)
db.add(new_inventory)
db.flush()
db.refresh(new_inventory)
logger.info(
f"Created inventory for product {inventory_data.product_id} at {location}: "
f"{inventory_data.quantity}"
)
return new_inventory
except (
ProductNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error setting inventory: {str(e)}")
raise ValidationException("Failed to set inventory")
def adjust_inventory(
self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
) -> Inventory:
"""
Adjust inventory by adding or removing quantity.
Positive quantity = add, negative = remove.
Args:
db: Database session
vendor_id: Vendor ID
inventory_data: Adjustment data
Returns:
Updated Inventory object
"""
try:
# Validate product belongs to vendor
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
# Validate location
location = self._validate_location(inventory_data.location)
# Check if inventory exists
existing = self._get_inventory_entry(
db, inventory_data.product_id, location
)
if not existing:
# Create new if adding, error if removing
if inventory_data.quantity < 0:
raise InventoryNotFoundException(
f"No inventory found for product {inventory_data.product_id} at {location}"
)
# Create with positive quantity
new_inventory = Inventory(
product_id=inventory_data.product_id,
vendor_id=vendor_id,
warehouse="strassen", # Default warehouse
bin_location=location, # Use location as bin location
location=location, # Keep for backward compatibility
quantity=inventory_data.quantity,
gtin=product.marketplace_product.gtin,
)
db.add(new_inventory)
db.flush()
db.refresh(new_inventory)
logger.info(
f"Created inventory for product {inventory_data.product_id} at {location}: "
f"+{inventory_data.quantity}"
)
return new_inventory
# Adjust existing inventory
old_qty = existing.quantity
new_qty = old_qty + inventory_data.quantity
# Validate resulting quantity
if new_qty < 0:
raise InsufficientInventoryException(
f"Insufficient inventory. Available: {old_qty}, "
f"Requested removal: {abs(inventory_data.quantity)}"
)
existing.quantity = new_qty
existing.updated_at = datetime.now(UTC)
db.flush()
db.refresh(existing)
logger.info(
f"Adjusted inventory for product {inventory_data.product_id} at {location}: "
f"{old_qty} {'+' if inventory_data.quantity >= 0 else ''}{inventory_data.quantity} = {new_qty}"
)
return existing
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error adjusting inventory: {str(e)}")
raise ValidationException("Failed to adjust inventory")
def reserve_inventory(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Reserve inventory for an order (increases reserved_quantity).
Args:
db: Database session
vendor_id: Vendor ID
reserve_data: Reservation data
Returns:
Updated Inventory object
"""
try:
# Validate product
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
# Validate location and quantity
location = self._validate_location(reserve_data.location)
self._validate_quantity(reserve_data.quantity, allow_zero=False)
# Get inventory entry
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
if not inventory:
raise InventoryNotFoundException(
f"No inventory found for product {reserve_data.product_id} at {location}"
)
# Check available quantity
available = inventory.quantity - inventory.reserved_quantity
if available < reserve_data.quantity:
raise InsufficientInventoryException(
f"Insufficient available inventory. Available: {available}, "
f"Requested: {reserve_data.quantity}"
)
# Reserve inventory
inventory.reserved_quantity += reserve_data.quantity
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(
f"Reserved {reserve_data.quantity} units for product {reserve_data.product_id} "
f"at {location}"
)
return inventory
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error reserving inventory: {str(e)}")
raise ValidationException("Failed to reserve inventory")
def release_reservation(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Release reserved inventory (decreases reserved_quantity).
Args:
db: Database session
vendor_id: Vendor ID
reserve_data: Reservation data
Returns:
Updated Inventory object
"""
try:
# Validate product
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
location = self._validate_location(reserve_data.location)
self._validate_quantity(reserve_data.quantity, allow_zero=False)
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
if not inventory:
raise InventoryNotFoundException(
f"No inventory found for product {reserve_data.product_id} at {location}"
)
# Validate reserved quantity
if inventory.reserved_quantity < reserve_data.quantity:
logger.warning(
f"Attempting to release more than reserved. Reserved: {inventory.reserved_quantity}, "
f"Requested: {reserve_data.quantity}"
)
inventory.reserved_quantity = 0
else:
inventory.reserved_quantity -= reserve_data.quantity
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(
f"Released {reserve_data.quantity} units for product {reserve_data.product_id} "
f"at {location}"
)
return inventory
except (
ProductNotFoundException,
InventoryNotFoundException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error releasing reservation: {str(e)}")
raise ValidationException("Failed to release reservation")
def fulfill_reservation(
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
) -> Inventory:
"""
Fulfill a reservation (decreases both quantity and reserved_quantity).
Use when order is shipped/completed.
Args:
db: Database session
vendor_id: Vendor ID
reserve_data: Reservation data
Returns:
Updated Inventory object
"""
try:
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
location = self._validate_location(reserve_data.location)
self._validate_quantity(reserve_data.quantity, allow_zero=False)
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
if not inventory:
raise InventoryNotFoundException(
f"No inventory found for product {reserve_data.product_id} at {location}"
)
# Validate quantities
if inventory.quantity < reserve_data.quantity:
raise InsufficientInventoryException(
f"Insufficient inventory. Available: {inventory.quantity}, "
f"Requested: {reserve_data.quantity}"
)
if inventory.reserved_quantity < reserve_data.quantity:
logger.warning(
f"Fulfilling more than reserved. Reserved: {inventory.reserved_quantity}, "
f"Fulfilling: {reserve_data.quantity}"
)
# Fulfill (remove from both quantity and reserved)
inventory.quantity -= reserve_data.quantity
inventory.reserved_quantity = max(
0, inventory.reserved_quantity - reserve_data.quantity
)
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(
f"Fulfilled {reserve_data.quantity} units for product {reserve_data.product_id} "
f"at {location}"
)
return inventory
except (
ProductNotFoundException,
InventoryNotFoundException,
InsufficientInventoryException,
InvalidQuantityException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error fulfilling reservation: {str(e)}")
raise ValidationException("Failed to fulfill reservation")
def get_product_inventory(
self, db: Session, vendor_id: int, product_id: int
) -> ProductInventorySummary:
"""
Get inventory summary for a product across all locations.
Args:
db: Database session
vendor_id: Vendor ID
product_id: Product ID
Returns:
ProductInventorySummary
"""
try:
product = self._get_vendor_product(db, vendor_id, product_id)
inventory_entries = (
db.query(Inventory).filter(Inventory.product_id == product_id).all()
)
if not inventory_entries:
return ProductInventorySummary(
product_id=product_id,
vendor_id=vendor_id,
product_sku=product.vendor_sku,
product_title=product.marketplace_product.get_title() or "",
total_quantity=0,
total_reserved=0,
total_available=0,
locations=[],
)
total_qty = sum(inv.quantity for inv in inventory_entries)
total_reserved = sum(inv.reserved_quantity for inv in inventory_entries)
total_available = sum(inv.available_quantity for inv in inventory_entries)
locations = [
InventoryLocationResponse(
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
)
for inv in inventory_entries
]
return ProductInventorySummary(
product_id=product_id,
vendor_id=vendor_id,
product_sku=product.vendor_sku,
product_title=product.marketplace_product.get_title() or "",
total_quantity=total_qty,
total_reserved=total_reserved,
total_available=total_available,
locations=locations,
)
except ProductNotFoundException:
raise
except Exception as e:
logger.error(f"Error getting product inventory: {str(e)}")
raise ValidationException("Failed to retrieve product inventory")
def get_vendor_inventory(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 100,
location: str | None = None,
low_stock_threshold: int | None = None,
) -> list[Inventory]:
"""
Get all inventory for a vendor with filtering.
Args:
db: Database session
vendor_id: Vendor ID
skip: Pagination offset
limit: Pagination limit
location: Filter by location
low_stock_threshold: Filter items below threshold
Returns:
List of Inventory objects
"""
try:
query = db.query(Inventory).filter(Inventory.vendor_id == vendor_id)
if location:
query = query.filter(Inventory.location.ilike(f"%{location}%"))
if low_stock_threshold is not None:
query = query.filter(Inventory.quantity <= low_stock_threshold)
return query.offset(skip).limit(limit).all()
except Exception as e:
logger.error(f"Error getting vendor inventory: {str(e)}")
raise ValidationException("Failed to retrieve vendor inventory")
def update_inventory(
self,
db: Session,
vendor_id: int,
inventory_id: int,
inventory_update: InventoryUpdate,
) -> Inventory:
"""Update inventory entry."""
try:
inventory = self._get_inventory_by_id(db, inventory_id)
# Verify ownership
if inventory.vendor_id != vendor_id:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
# Update fields
if inventory_update.quantity is not None:
self._validate_quantity(inventory_update.quantity, allow_zero=True)
inventory.quantity = inventory_update.quantity
if inventory_update.reserved_quantity is not None:
self._validate_quantity(
inventory_update.reserved_quantity, allow_zero=True
)
inventory.reserved_quantity = inventory_update.reserved_quantity
if inventory_update.location:
inventory.location = self._validate_location(inventory_update.location)
inventory.updated_at = datetime.now(UTC)
db.flush()
db.refresh(inventory)
logger.info(f"Updated inventory {inventory_id}")
return inventory
except (
InventoryNotFoundException,
InvalidQuantityException,
InventoryValidationException,
):
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"Error updating inventory: {str(e)}")
raise ValidationException("Failed to update inventory")
def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
"""Delete inventory entry."""
try:
inventory = self._get_inventory_by_id(db, inventory_id)
# Verify ownership
if inventory.vendor_id != vendor_id:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
db.delete(inventory)
db.flush()
logger.info(f"Deleted inventory {inventory_id}")
return True
except InventoryNotFoundException:
raise
except Exception as e:
db.rollback()
logger.error(f"Error deleting inventory: {str(e)}")
raise ValidationException("Failed to delete inventory")
# =========================================================================
# Admin Methods (cross-vendor operations)
# =========================================================================
def get_all_inventory_admin(
self,
db: Session,
skip: int = 0,
limit: int = 50,
vendor_id: int | None = None,
location: str | None = None,
low_stock: int | None = None,
search: str | None = None,
) -> AdminInventoryListResponse:
"""
Get inventory across all vendors with filtering (admin only).
Args:
db: Database session
skip: Pagination offset
limit: Pagination limit
vendor_id: Filter by vendor
location: Filter by location
low_stock: Filter items below threshold
search: Search by product title or SKU
Returns:
AdminInventoryListResponse
"""
query = db.query(Inventory).join(Product).join(Vendor)
# Apply filters
if vendor_id is not None:
query = query.filter(Inventory.vendor_id == vendor_id)
if location:
query = query.filter(Inventory.location.ilike(f"%{location}%"))
if low_stock is not None:
query = query.filter(Inventory.quantity <= low_stock)
if search:
from models.database.marketplace_product import MarketplaceProduct
from models.database.marketplace_product_translation import (
MarketplaceProductTranslation,
)
query = (
query.join(MarketplaceProduct)
.outerjoin(MarketplaceProductTranslation)
.filter(
(MarketplaceProductTranslation.title.ilike(f"%{search}%"))
| (Product.vendor_sku.ilike(f"%{search}%"))
)
)
# Get total count before pagination
total = query.count()
# Apply pagination
inventories = query.offset(skip).limit(limit).all()
# Build response with vendor/product info
items = []
for inv in inventories:
product = inv.product
vendor = inv.vendor
title = None
if product and product.marketplace_product:
title = product.marketplace_product.get_title()
items.append(
AdminInventoryItem(
id=inv.id,
product_id=inv.product_id,
vendor_id=inv.vendor_id,
vendor_name=vendor.name if vendor else None,
vendor_code=vendor.vendor_code if vendor else None,
product_title=title,
product_sku=product.vendor_sku if product else None,
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
gtin=inv.gtin,
created_at=inv.created_at,
updated_at=inv.updated_at,
)
)
return AdminInventoryListResponse(
inventories=items,
total=total,
skip=skip,
limit=limit,
vendor_filter=vendor_id,
location_filter=location,
)
def get_inventory_stats_admin(self, db: Session) -> AdminInventoryStats:
"""Get platform-wide inventory statistics (admin only)."""
# Total entries
total_entries = db.query(func.count(Inventory.id)).scalar() or 0
# Aggregate quantities
totals = db.query(
func.sum(Inventory.quantity).label("total_qty"),
func.sum(Inventory.reserved_quantity).label("total_reserved"),
).first()
total_quantity = totals.total_qty or 0
total_reserved = totals.total_reserved or 0
total_available = total_quantity - total_reserved
# Low stock count (default threshold: 10)
low_stock_count = (
db.query(func.count(Inventory.id))
.filter(Inventory.quantity <= 10)
.scalar()
or 0
)
# Vendors with inventory
vendors_with_inventory = (
db.query(func.count(func.distinct(Inventory.vendor_id))).scalar() or 0
)
# Unique locations
unique_locations = (
db.query(func.count(func.distinct(Inventory.location))).scalar() or 0
)
return AdminInventoryStats(
total_entries=total_entries,
total_quantity=total_quantity,
total_reserved=total_reserved,
total_available=total_available,
low_stock_count=low_stock_count,
vendors_with_inventory=vendors_with_inventory,
unique_locations=unique_locations,
)
def get_low_stock_items_admin(
self,
db: Session,
threshold: int = 10,
vendor_id: int | None = None,
limit: int = 50,
) -> list[AdminLowStockItem]:
"""Get items with low stock levels (admin only)."""
query = (
db.query(Inventory)
.join(Product)
.join(Vendor)
.filter(Inventory.quantity <= threshold)
)
if vendor_id is not None:
query = query.filter(Inventory.vendor_id == vendor_id)
# Order by quantity ascending (most critical first)
query = query.order_by(Inventory.quantity.asc())
inventories = query.limit(limit).all()
items = []
for inv in inventories:
product = inv.product
vendor = inv.vendor
title = None
if product and product.marketplace_product:
title = product.marketplace_product.get_title()
items.append(
AdminLowStockItem(
id=inv.id,
product_id=inv.product_id,
vendor_id=inv.vendor_id,
vendor_name=vendor.name if vendor else None,
product_title=title,
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
)
)
return items
def get_vendors_with_inventory_admin(
self, db: Session
) -> AdminVendorsWithInventoryResponse:
"""Get list of vendors that have inventory entries (admin only)."""
# noqa: SVC-005 - Admin function, intentionally cross-vendor
# Use subquery to avoid DISTINCT on JSON columns (PostgreSQL can't compare JSON)
vendor_ids_subquery = (
db.query(Inventory.vendor_id)
.distinct()
.subquery()
)
vendors = (
db.query(Vendor)
.filter(Vendor.id.in_(db.query(vendor_ids_subquery.c.vendor_id)))
.order_by(Vendor.name)
.all()
)
return AdminVendorsWithInventoryResponse(
vendors=[
AdminVendorWithInventory(
id=v.id, name=v.name, vendor_code=v.vendor_code
)
for v in vendors
]
)
def get_inventory_locations_admin(
self, db: Session, vendor_id: int | None = None
) -> AdminInventoryLocationsResponse:
"""Get list of unique inventory locations (admin only)."""
query = db.query(func.distinct(Inventory.location))
if vendor_id is not None:
query = query.filter(Inventory.vendor_id == vendor_id)
locations = [loc[0] for loc in query.all()]
return AdminInventoryLocationsResponse(locations=sorted(locations))
def get_vendor_inventory_admin(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 50,
location: str | None = None,
low_stock: int | None = None,
) -> AdminInventoryListResponse:
"""Get inventory for a specific vendor (admin only)."""
# Verify vendor exists
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
# Use the existing method
inventories = self.get_vendor_inventory(
db=db,
vendor_id=vendor_id,
skip=skip,
limit=limit,
location=location,
low_stock_threshold=low_stock,
)
# Build response with product info
items = []
for inv in inventories:
product = inv.product
title = None
if product and product.marketplace_product:
title = product.marketplace_product.get_title()
items.append(
AdminInventoryItem(
id=inv.id,
product_id=inv.product_id,
vendor_id=inv.vendor_id,
vendor_name=vendor.name,
vendor_code=vendor.vendor_code,
product_title=title,
product_sku=product.vendor_sku if product else None,
location=inv.location,
quantity=inv.quantity,
reserved_quantity=inv.reserved_quantity,
available_quantity=inv.available_quantity,
gtin=inv.gtin,
created_at=inv.created_at,
updated_at=inv.updated_at,
)
)
# Get total count for pagination
total_query = db.query(func.count(Inventory.id)).filter(
Inventory.vendor_id == vendor_id
)
if location:
total_query = total_query.filter(Inventory.location.ilike(f"%{location}%"))
if low_stock is not None:
total_query = total_query.filter(Inventory.quantity <= low_stock)
total = total_query.scalar() or 0
return AdminInventoryListResponse(
inventories=items,
total=total,
skip=skip,
limit=limit,
vendor_filter=vendor_id,
location_filter=location,
)
def get_product_inventory_admin(
self, db: Session, product_id: int
) -> ProductInventorySummary:
"""Get inventory summary for a product (admin only - no vendor check)."""
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise ProductNotFoundException(f"Product {product_id} not found")
# Use existing method with the product's vendor_id
return self.get_product_inventory(db, product.vendor_id, product_id)
def verify_vendor_exists(self, db: Session, vendor_id: int) -> Vendor:
"""Verify vendor exists and return it."""
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
return vendor
def get_inventory_by_id_admin(self, db: Session, inventory_id: int) -> Inventory:
"""Get inventory by ID (admin only - returns inventory with vendor_id)."""
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
if not inventory:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
return inventory
# =========================================================================
# Private helper methods
# =========================================================================
def _get_vendor_product(
self, db: Session, vendor_id: int, product_id: int
) -> Product:
"""Get product and verify it belongs to vendor."""
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(
f"Product {product_id} not found in your catalog"
)
return product
def _get_inventory_entry(
self, db: Session, product_id: int, location: str
) -> Inventory | None:
"""Get inventory entry by product and location."""
return (
db.query(Inventory)
.filter(Inventory.product_id == product_id, Inventory.location == location)
.first()
)
def _get_inventory_by_id(self, db: Session, inventory_id: int) -> Inventory:
"""Get inventory by ID or raise exception."""
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
if not inventory:
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
return inventory
def _validate_location(self, location: str) -> str:
"""Validate and normalize location."""
if not location or not location.strip():
raise InventoryValidationException("Location is required")
return location.strip().upper()
def _validate_quantity(self, quantity: int, allow_zero: bool = True) -> None:
"""Validate quantity value."""
if quantity is None:
raise InvalidQuantityException("Quantity is required")
if not isinstance(quantity, int):
raise InvalidQuantityException("Quantity must be an integer")
if quantity < 0:
raise InvalidQuantityException("Quantity cannot be negative")
if not allow_zero and quantity == 0:
raise InvalidQuantityException("Quantity must be positive")
# Create service instance
inventory_service = InventoryService()
__all__ = [
"inventory_service",
"InventoryService",
]

View File

@@ -1,431 +1,23 @@
# app/services/inventory_transaction_service.py
"""
Inventory Transaction Service.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Provides query operations for inventory transaction history.
All transaction WRITES are handled by OrderInventoryService.
This service handles transaction READS for reporting and auditing.
The canonical implementation is now in:
app/modules/inventory/services/inventory_transaction_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.inventory.services import inventory_transaction_service
"""
import logging
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.modules.inventory.services.inventory_transaction_service import (
inventory_transaction_service,
InventoryTransactionService,
)
from app.exceptions import OrderNotFoundException, ProductNotFoundException
from models.database.inventory import Inventory
from models.database.inventory_transaction import InventoryTransaction
from models.database.order import Order
from models.database.product import Product
logger = logging.getLogger(__name__)
class InventoryTransactionService:
"""Service for querying inventory transaction history."""
def get_vendor_transactions(
self,
db: Session,
vendor_id: int,
skip: int = 0,
limit: int = 50,
product_id: int | None = None,
transaction_type: str | None = None,
) -> tuple[list[dict], int]:
"""
Get inventory transactions for a vendor with optional filters.
Args:
db: Database session
vendor_id: Vendor ID
skip: Pagination offset
limit: Pagination limit
product_id: Optional product filter
transaction_type: Optional transaction type filter
Returns:
Tuple of (transactions with product details, total count)
"""
# Build query
query = db.query(InventoryTransaction).filter(
InventoryTransaction.vendor_id == vendor_id
)
# Apply filters
if product_id:
query = query.filter(InventoryTransaction.product_id == product_id)
if transaction_type:
query = query.filter(
InventoryTransaction.transaction_type == transaction_type
)
# Get total count
total = query.count()
# Get transactions with pagination (newest first)
transactions = (
query.order_by(InventoryTransaction.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# Build result with product details
result = []
for tx in transactions:
product = db.query(Product).filter(Product.id == tx.product_id).first()
product_title = None
product_sku = None
if product:
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
result.append(
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
"product_title": product_title,
"product_sku": product_sku,
}
)
return result, total
def get_product_history(
self,
db: Session,
vendor_id: int,
product_id: int,
limit: int = 50,
) -> dict:
"""
Get transaction history for a specific product.
Args:
db: Database session
vendor_id: Vendor ID
product_id: Product ID
limit: Max transactions to return
Returns:
Dict with product info, current inventory, and transactions
Raises:
ProductNotFoundException: If product not found or doesn't belong to vendor
"""
# Get product details
product = (
db.query(Product)
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
.first()
)
if not product:
raise ProductNotFoundException(
f"Product {product_id} not found in vendor catalog"
)
product_title = None
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
# Get current inventory
inventory = (
db.query(Inventory)
.filter(Inventory.product_id == product_id, Inventory.vendor_id == vendor_id)
.first()
)
current_quantity = inventory.quantity if inventory else 0
current_reserved = inventory.reserved_quantity if inventory else 0
# Get transactions
transactions = (
db.query(InventoryTransaction)
.filter(
InventoryTransaction.vendor_id == vendor_id,
InventoryTransaction.product_id == product_id,
)
.order_by(InventoryTransaction.created_at.desc())
.limit(limit)
.all()
)
total = (
db.query(func.count(InventoryTransaction.id))
.filter(
InventoryTransaction.vendor_id == vendor_id,
InventoryTransaction.product_id == product_id,
)
.scalar()
or 0
)
return {
"product_id": product_id,
"product_title": product_title,
"product_sku": product_sku,
"current_quantity": current_quantity,
"current_reserved": current_reserved,
"transactions": [
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
}
for tx in transactions
],
"total": total,
}
def get_order_history(
self,
db: Session,
vendor_id: int,
order_id: int,
) -> dict:
"""
Get all inventory transactions for a specific order.
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
Returns:
Dict with order info and transactions
Raises:
OrderNotFoundException: If order not found or doesn't belong to vendor
"""
# Verify order belongs to vendor
order = (
db.query(Order)
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
.first()
)
if not order:
raise OrderNotFoundException(f"Order {order_id} not found")
# Get transactions for this order
transactions = (
db.query(InventoryTransaction)
.filter(InventoryTransaction.order_id == order_id)
.order_by(InventoryTransaction.created_at.asc())
.all()
)
# Build result with product details
result = []
for tx in transactions:
product = db.query(Product).filter(Product.id == tx.product_id).first()
product_title = None
product_sku = None
if product:
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
result.append(
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
"product_title": product_title,
"product_sku": product_sku,
}
)
return {
"order_id": order_id,
"order_number": order.order_number,
"transactions": result,
}
# =========================================================================
# Admin Methods (cross-vendor operations)
# =========================================================================
def get_all_transactions_admin(
self,
db: Session,
skip: int = 0,
limit: int = 50,
vendor_id: int | None = None,
product_id: int | None = None,
transaction_type: str | None = None,
order_id: int | None = None,
) -> tuple[list[dict], int]:
"""
Get inventory transactions across all vendors (admin only).
Args:
db: Database session
skip: Pagination offset
limit: Pagination limit
vendor_id: Optional vendor filter
product_id: Optional product filter
transaction_type: Optional transaction type filter
order_id: Optional order filter
Returns:
Tuple of (transactions with details, total count)
"""
from models.database.vendor import Vendor
# Build query
query = db.query(InventoryTransaction)
# Apply filters
if vendor_id:
query = query.filter(InventoryTransaction.vendor_id == vendor_id)
if product_id:
query = query.filter(InventoryTransaction.product_id == product_id)
if transaction_type:
query = query.filter(
InventoryTransaction.transaction_type == transaction_type
)
if order_id:
query = query.filter(InventoryTransaction.order_id == order_id)
# Get total count
total = query.count()
# Get transactions with pagination (newest first)
transactions = (
query.order_by(InventoryTransaction.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# Build result with vendor and product details
result = []
for tx in transactions:
vendor = db.query(Vendor).filter(Vendor.id == tx.vendor_id).first()
product = db.query(Product).filter(Product.id == tx.product_id).first()
product_title = None
product_sku = None
if product:
product_sku = product.vendor_sku
if product.marketplace_product:
product_title = product.marketplace_product.get_title()
result.append(
{
"id": tx.id,
"vendor_id": tx.vendor_id,
"vendor_name": vendor.name if vendor else None,
"vendor_code": vendor.vendor_code if vendor else None,
"product_id": tx.product_id,
"inventory_id": tx.inventory_id,
"transaction_type": (
tx.transaction_type.value if tx.transaction_type else None
),
"quantity_change": tx.quantity_change,
"quantity_after": tx.quantity_after,
"reserved_after": tx.reserved_after,
"location": tx.location,
"warehouse": tx.warehouse,
"order_id": tx.order_id,
"order_number": tx.order_number,
"reason": tx.reason,
"created_by": tx.created_by,
"created_at": tx.created_at,
"product_title": product_title,
"product_sku": product_sku,
}
)
return result, total
def get_transaction_stats_admin(self, db: Session) -> dict:
"""
Get transaction statistics across the platform (admin only).
Returns:
Dict with transaction counts by type
"""
from sqlalchemy import func as sql_func
# Count by transaction type
type_counts = (
db.query(
InventoryTransaction.transaction_type,
sql_func.count(InventoryTransaction.id).label("count"),
)
.group_by(InventoryTransaction.transaction_type)
.all()
)
# Total transactions
total = db.query(sql_func.count(InventoryTransaction.id)).scalar() or 0
# Transactions today
from datetime import UTC, datetime, timedelta
today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
today_count = (
db.query(sql_func.count(InventoryTransaction.id))
.filter(InventoryTransaction.created_at >= today_start)
.scalar()
or 0
)
return {
"total_transactions": total,
"transactions_today": today_count,
"by_type": {tc.transaction_type.value: tc.count for tc in type_counts},
}
# Create service instance
inventory_transaction_service = InventoryTransactionService()
__all__ = [
"inventory_transaction_service",
"InventoryTransactionService",
]

View File

@@ -1,164 +1,23 @@
# app/services/invoice_pdf_service.py
"""
Invoice PDF generation service using WeasyPrint.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint.
Stores generated PDFs in the configured storage location.
The canonical implementation is now in:
app/modules/orders/services/invoice_pdf_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.orders.services import invoice_pdf_service
"""
import logging
import os
from datetime import UTC, datetime
from pathlib import Path
from app.modules.orders.services.invoice_pdf_service import (
invoice_pdf_service,
InvoicePDFService,
)
from jinja2 import Environment, FileSystemLoader
from sqlalchemy.orm import Session
from app.core.config import settings
from models.database.invoice import Invoice
logger = logging.getLogger(__name__)
# Template directory
TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices"
# PDF storage directory (relative to project root)
PDF_STORAGE_DIR = Path("storage") / "invoices"
class InvoicePDFService:
"""Service for generating invoice PDFs."""
def __init__(self):
"""Initialize the PDF service with Jinja2 environment."""
self.env = Environment(
loader=FileSystemLoader(str(TEMPLATE_DIR)),
autoescape=True,
)
def _ensure_storage_dir(self, vendor_id: int) -> Path:
"""Ensure the storage directory exists for a vendor."""
storage_path = PDF_STORAGE_DIR / str(vendor_id)
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def _get_pdf_filename(self, invoice: Invoice) -> str:
"""Generate PDF filename for an invoice."""
# Sanitize invoice number for filename
safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-")
return f"{safe_number}.pdf"
def generate_pdf(
self,
db: Session,
invoice: Invoice,
force_regenerate: bool = False,
) -> str:
"""
Generate PDF for an invoice.
Args:
db: Database session
invoice: Invoice to generate PDF for
force_regenerate: If True, regenerate even if PDF already exists
Returns:
Path to the generated PDF file
"""
# Check if PDF already exists
if invoice.pdf_path and not force_regenerate:
if Path(invoice.pdf_path).exists():
logger.debug(f"PDF already exists for invoice {invoice.invoice_number}")
return invoice.pdf_path
# Ensure storage directory exists
storage_dir = self._ensure_storage_dir(invoice.vendor_id)
pdf_filename = self._get_pdf_filename(invoice)
pdf_path = storage_dir / pdf_filename
# Render HTML template
html_content = self._render_html(invoice)
# Generate PDF using WeasyPrint
try:
from weasyprint import HTML
html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR))
html_doc.write_pdf(str(pdf_path))
logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}")
except ImportError:
logger.error("WeasyPrint not installed. Install with: pip install weasyprint")
raise RuntimeError("WeasyPrint not installed")
except Exception as e:
logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}")
raise
# Update invoice record with PDF path and timestamp
invoice.pdf_path = str(pdf_path)
invoice.pdf_generated_at = datetime.now(UTC)
db.flush()
return str(pdf_path)
def _render_html(self, invoice: Invoice) -> str:
"""Render the invoice HTML template."""
template = self.env.get_template("invoice.html")
# Prepare template context
context = {
"invoice": invoice,
"seller": invoice.seller_details,
"buyer": invoice.buyer_details,
"line_items": invoice.line_items,
"bank_details": invoice.bank_details,
"payment_terms": invoice.payment_terms,
"footer_text": invoice.footer_text,
"now": datetime.now(UTC),
}
return template.render(**context)
def get_pdf_path(self, invoice: Invoice) -> str | None:
"""Get the PDF path for an invoice if it exists."""
if invoice.pdf_path and Path(invoice.pdf_path).exists():
return invoice.pdf_path
return None
def delete_pdf(self, invoice: Invoice, db: Session) -> bool:
"""
Delete the PDF file for an invoice.
Args:
invoice: Invoice whose PDF to delete
db: Database session
Returns:
True if deleted, False if not found
"""
if not invoice.pdf_path:
return False
pdf_path = Path(invoice.pdf_path)
if pdf_path.exists():
try:
pdf_path.unlink()
logger.info(f"Deleted PDF for invoice {invoice.invoice_number}")
except Exception as e:
logger.error(f"Failed to delete PDF {pdf_path}: {e}")
return False
# Clear PDF fields
invoice.pdf_path = None
invoice.pdf_generated_at = None
db.flush()
return True
def regenerate_pdf(self, db: Session, invoice: Invoice) -> str:
"""Force regenerate PDF for an invoice."""
return self.generate_pdf(db, invoice, force_regenerate=True)
# Singleton instance
invoice_pdf_service = InvoicePDFService()
__all__ = [
"invoice_pdf_service",
"InvoicePDFService",
]

View File

@@ -1,675 +1,23 @@
# app/services/invoice_service.py
"""
Invoice service for generating and managing invoices.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Handles:
- Vendor invoice settings management
- Invoice generation from orders
- VAT calculation (Luxembourg, EU, B2B reverse charge)
- Invoice number sequencing
- PDF generation (via separate module)
The canonical implementation is now in:
app/modules/orders/services/invoice_service.py
VAT Logic:
- Luxembourg domestic: 17% (standard), 8% (reduced), 3% (super-reduced), 14% (intermediate)
- EU cross-border B2C with OSS: Use destination country VAT rate
- EU cross-border B2C without OSS: Use Luxembourg VAT rate (origin principle)
- EU B2B with valid VAT number: Reverse charge (0% VAT)
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.orders.services import invoice_service
"""
import logging
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from app.exceptions import (
ValidationException,
)
from app.exceptions.invoice import (
InvoiceNotFoundException,
InvoicePDFGenerationException,
InvoicePDFNotFoundException,
InvoiceSettingsAlreadyExistException,
InvoiceSettingsNotFoundException,
InvoiceValidationException,
OrderNotFoundException,
)
from models.database.invoice import (
Invoice,
InvoiceStatus,
VATRegime,
VendorInvoiceSettings,
)
from models.database.order import Order
from models.database.vendor import Vendor
from models.schema.invoice import (
InvoiceBuyerDetails,
InvoiceCreate,
InvoiceLineItem,
InvoiceManualCreate,
InvoiceSellerDetails,
VendorInvoiceSettingsCreate,
VendorInvoiceSettingsUpdate,
from app.modules.orders.services.invoice_service import (
invoice_service,
InvoiceService,
)
logger = logging.getLogger(__name__)
# EU VAT rates by country code (2024 standard rates)
EU_VAT_RATES: dict[str, Decimal] = {
"AT": Decimal("20.00"), # Austria
"BE": Decimal("21.00"), # Belgium
"BG": Decimal("20.00"), # Bulgaria
"HR": Decimal("25.00"), # Croatia
"CY": Decimal("19.00"), # Cyprus
"CZ": Decimal("21.00"), # Czech Republic
"DK": Decimal("25.00"), # Denmark
"EE": Decimal("22.00"), # Estonia
"FI": Decimal("24.00"), # Finland
"FR": Decimal("20.00"), # France
"DE": Decimal("19.00"), # Germany
"GR": Decimal("24.00"), # Greece
"HU": Decimal("27.00"), # Hungary
"IE": Decimal("23.00"), # Ireland
"IT": Decimal("22.00"), # Italy
"LV": Decimal("21.00"), # Latvia
"LT": Decimal("21.00"), # Lithuania
"LU": Decimal("17.00"), # Luxembourg (standard)
"MT": Decimal("18.00"), # Malta
"NL": Decimal("21.00"), # Netherlands
"PL": Decimal("23.00"), # Poland
"PT": Decimal("23.00"), # Portugal
"RO": Decimal("19.00"), # Romania
"SK": Decimal("20.00"), # Slovakia
"SI": Decimal("22.00"), # Slovenia
"ES": Decimal("21.00"), # Spain
"SE": Decimal("25.00"), # Sweden
}
# Luxembourg specific VAT rates
LU_VAT_RATES = {
"standard": Decimal("17.00"),
"intermediate": Decimal("14.00"),
"reduced": Decimal("8.00"),
"super_reduced": Decimal("3.00"),
}
class InvoiceService:
"""Service for invoice operations."""
# =========================================================================
# VAT Calculation
# =========================================================================
def get_vat_rate_for_country(self, country_iso: str) -> Decimal:
"""Get standard VAT rate for EU country."""
return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00"))
def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str:
"""Get human-readable VAT rate label."""
country_names = {
"AT": "Austria",
"BE": "Belgium",
"BG": "Bulgaria",
"HR": "Croatia",
"CY": "Cyprus",
"CZ": "Czech Republic",
"DK": "Denmark",
"EE": "Estonia",
"FI": "Finland",
"FR": "France",
"DE": "Germany",
"GR": "Greece",
"HU": "Hungary",
"IE": "Ireland",
"IT": "Italy",
"LV": "Latvia",
"LT": "Lithuania",
"LU": "Luxembourg",
"MT": "Malta",
"NL": "Netherlands",
"PL": "Poland",
"PT": "Portugal",
"RO": "Romania",
"SK": "Slovakia",
"SI": "Slovenia",
"ES": "Spain",
"SE": "Sweden",
}
country_name = country_names.get(country_iso.upper(), country_iso)
return f"{country_name} VAT {vat_rate}%"
def determine_vat_regime(
self,
seller_country: str,
buyer_country: str,
buyer_vat_number: str | None,
seller_oss_registered: bool,
) -> tuple[VATRegime, Decimal, str | None]:
"""
Determine VAT regime and rate for invoice.
Returns: (regime, vat_rate, destination_country)
"""
seller_country = seller_country.upper()
buyer_country = buyer_country.upper()
# Same country = domestic VAT
if seller_country == buyer_country:
vat_rate = self.get_vat_rate_for_country(seller_country)
return VATRegime.DOMESTIC, vat_rate, None
# Different EU countries
if buyer_country in EU_VAT_RATES:
# B2B with valid VAT number = reverse charge
if buyer_vat_number:
return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country
# B2C cross-border
if seller_oss_registered:
# OSS: use destination country VAT
vat_rate = self.get_vat_rate_for_country(buyer_country)
return VATRegime.OSS, vat_rate, buyer_country
else:
# No OSS: use origin country VAT
vat_rate = self.get_vat_rate_for_country(seller_country)
return VATRegime.ORIGIN, vat_rate, buyer_country
# Non-EU = VAT exempt (export)
return VATRegime.EXEMPT, Decimal("0.00"), buyer_country
# =========================================================================
# Invoice Settings Management
# =========================================================================
def get_settings(
self, db: Session, vendor_id: int
) -> VendorInvoiceSettings | None:
"""Get vendor invoice settings."""
return (
db.query(VendorInvoiceSettings)
.filter(VendorInvoiceSettings.vendor_id == vendor_id)
.first()
)
def get_settings_or_raise(
self, db: Session, vendor_id: int
) -> VendorInvoiceSettings:
"""Get vendor invoice settings or raise exception."""
settings = self.get_settings(db, vendor_id)
if not settings:
raise InvoiceSettingsNotFoundException(vendor_id)
return settings
def create_settings(
self,
db: Session,
vendor_id: int,
data: VendorInvoiceSettingsCreate,
) -> VendorInvoiceSettings:
"""Create vendor invoice settings."""
# Check if settings already exist
existing = self.get_settings(db, vendor_id)
if existing:
raise ValidationException(
"Invoice settings already exist for this vendor"
)
settings = VendorInvoiceSettings(
vendor_id=vendor_id,
**data.model_dump(),
)
db.add(settings)
db.flush()
db.refresh(settings)
logger.info(f"Created invoice settings for vendor {vendor_id}")
return settings
def update_settings(
self,
db: Session,
vendor_id: int,
data: VendorInvoiceSettingsUpdate,
) -> VendorInvoiceSettings:
"""Update vendor invoice settings."""
settings = self.get_settings_or_raise(db, vendor_id)
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(settings, key, value)
settings.updated_at = datetime.now(UTC)
db.flush()
db.refresh(settings)
logger.info(f"Updated invoice settings for vendor {vendor_id}")
return settings
def create_settings_from_vendor(
self,
db: Session,
vendor: Vendor,
) -> VendorInvoiceSettings:
"""
Create invoice settings from vendor/company info.
Used for initial setup based on existing vendor data.
"""
company = vendor.company
settings = VendorInvoiceSettings(
vendor_id=vendor.id,
company_name=company.legal_name if company else vendor.name,
company_address=vendor.effective_business_address,
company_city=None, # Would need to parse from address
company_postal_code=None,
company_country="LU",
vat_number=vendor.effective_tax_number,
is_vat_registered=bool(vendor.effective_tax_number),
)
db.add(settings)
db.flush()
db.refresh(settings)
logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}")
return settings
# =========================================================================
# Invoice Number Generation
# =========================================================================
def _get_next_invoice_number(
self, db: Session, settings: VendorInvoiceSettings
) -> str:
"""Generate next invoice number and increment counter."""
number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding)
invoice_number = f"{settings.invoice_prefix}{number}"
# Increment counter
settings.invoice_next_number += 1
db.flush()
return invoice_number
# =========================================================================
# Invoice Creation
# =========================================================================
def create_invoice_from_order(
self,
db: Session,
vendor_id: int,
order_id: int,
notes: str | None = None,
) -> Invoice:
"""
Create an invoice from an order.
Captures snapshots of seller/buyer details and calculates VAT.
"""
# Get invoice settings
settings = self.get_settings_or_raise(db, vendor_id)
# Get order
order = (
db.query(Order)
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
.first()
)
if not order:
raise OrderNotFoundException(f"Order {order_id} not found")
# Check for existing invoice
existing = (
db.query(Invoice)
.filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id))
.first()
)
if existing:
raise ValidationException(f"Invoice already exists for order {order_id}")
# Determine VAT regime
buyer_country = order.bill_country_iso
vat_regime, vat_rate, destination_country = self.determine_vat_regime(
seller_country=settings.company_country,
buyer_country=buyer_country,
buyer_vat_number=None, # TODO: Add B2B VAT number support
seller_oss_registered=settings.is_oss_registered,
)
# Build seller details snapshot
seller_details = {
"company_name": settings.company_name,
"address": settings.company_address,
"city": settings.company_city,
"postal_code": settings.company_postal_code,
"country": settings.company_country,
"vat_number": settings.vat_number,
}
# Build buyer details snapshot
buyer_details = {
"name": f"{order.bill_first_name} {order.bill_last_name}".strip(),
"email": order.customer_email,
"address": order.bill_address_line_1,
"city": order.bill_city,
"postal_code": order.bill_postal_code,
"country": order.bill_country_iso,
"vat_number": None, # TODO: B2B support
}
if order.bill_company:
buyer_details["company"] = order.bill_company
# Build line items from order items
line_items = []
for item in order.items:
line_items.append({
"description": item.product_name,
"quantity": item.quantity,
"unit_price_cents": item.unit_price_cents,
"total_cents": item.total_price_cents,
"sku": item.product_sku,
"ean": item.gtin,
})
# Calculate amounts
subtotal_cents = sum(item["total_cents"] for item in line_items)
# Calculate VAT
if vat_rate > 0:
vat_amount_cents = int(
subtotal_cents * float(vat_rate) / 100
)
else:
vat_amount_cents = 0
total_cents = subtotal_cents + vat_amount_cents
# Get VAT label
vat_rate_label = None
if vat_rate > 0:
if destination_country:
vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate)
else:
vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate)
# Generate invoice number
invoice_number = self._get_next_invoice_number(db, settings)
# Create invoice
invoice = Invoice(
vendor_id=vendor_id,
order_id=order_id,
invoice_number=invoice_number,
invoice_date=datetime.now(UTC),
status=InvoiceStatus.DRAFT.value,
seller_details=seller_details,
buyer_details=buyer_details,
line_items=line_items,
vat_regime=vat_regime.value,
destination_country=destination_country,
vat_rate=vat_rate,
vat_rate_label=vat_rate_label,
currency=order.currency,
subtotal_cents=subtotal_cents,
vat_amount_cents=vat_amount_cents,
total_cents=total_cents,
payment_terms=settings.payment_terms,
bank_details={
"bank_name": settings.bank_name,
"iban": settings.bank_iban,
"bic": settings.bank_bic,
} if settings.bank_iban else None,
footer_text=settings.footer_text,
notes=notes,
)
db.add(invoice)
db.flush()
db.refresh(invoice)
logger.info(
f"Created invoice {invoice_number} for order {order_id} "
f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})"
)
return invoice
# =========================================================================
# Invoice Retrieval
# =========================================================================
def get_invoice(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice | None:
"""Get invoice by ID."""
return (
db.query(Invoice)
.filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id))
.first()
)
def get_invoice_or_raise(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Get invoice by ID or raise exception."""
invoice = self.get_invoice(db, vendor_id, invoice_id)
if not invoice:
raise InvoiceNotFoundException(invoice_id)
return invoice
def get_invoice_by_number(
self, db: Session, vendor_id: int, invoice_number: str
) -> Invoice | None:
"""Get invoice by invoice number."""
return (
db.query(Invoice)
.filter(
and_(
Invoice.invoice_number == invoice_number,
Invoice.vendor_id == vendor_id,
)
)
.first()
)
def get_invoice_by_order_id(
self, db: Session, vendor_id: int, order_id: int
) -> Invoice | None:
"""Get invoice by order ID."""
return (
db.query(Invoice)
.filter(
and_(
Invoice.order_id == order_id,
Invoice.vendor_id == vendor_id,
)
)
.first()
)
def list_invoices(
self,
db: Session,
vendor_id: int,
status: str | None = None,
page: int = 1,
per_page: int = 20,
) -> tuple[list[Invoice], int]:
"""
List invoices for vendor with pagination.
Returns: (invoices, total_count)
"""
query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id)
if status:
query = query.filter(Invoice.status == status)
# Get total count
total = query.count()
# Apply pagination and order
invoices = (
query.order_by(Invoice.invoice_date.desc())
.offset((page - 1) * per_page)
.limit(per_page)
.all()
)
return invoices, total
# =========================================================================
# Invoice Status Management
# =========================================================================
def update_status(
self,
db: Session,
vendor_id: int,
invoice_id: int,
new_status: str,
) -> Invoice:
"""Update invoice status."""
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
# Validate status transition
valid_statuses = [s.value for s in InvoiceStatus]
if new_status not in valid_statuses:
raise ValidationException(f"Invalid status: {new_status}")
# Cannot change cancelled invoices
if invoice.status == InvoiceStatus.CANCELLED.value:
raise ValidationException("Cannot change status of cancelled invoice")
invoice.status = new_status
invoice.updated_at = datetime.now(UTC)
db.flush()
db.refresh(invoice)
logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}")
return invoice
def mark_as_issued(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Mark invoice as issued."""
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value)
def mark_as_paid(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Mark invoice as paid."""
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value)
def cancel_invoice(
self, db: Session, vendor_id: int, invoice_id: int
) -> Invoice:
"""Cancel invoice."""
return self.update_status(
db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value
)
# =========================================================================
# Statistics
# =========================================================================
def get_invoice_stats(
self, db: Session, vendor_id: int
) -> dict[str, Any]:
"""Get invoice statistics for vendor."""
total_count = (
db.query(func.count(Invoice.id))
.filter(Invoice.vendor_id == vendor_id)
.scalar()
or 0
)
total_revenue = (
db.query(func.sum(Invoice.total_cents))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status.in_([
InvoiceStatus.ISSUED.value,
InvoiceStatus.PAID.value,
]),
)
)
.scalar()
or 0
)
draft_count = (
db.query(func.count(Invoice.id))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status == InvoiceStatus.DRAFT.value,
)
)
.scalar()
or 0
)
paid_count = (
db.query(func.count(Invoice.id))
.filter(
and_(
Invoice.vendor_id == vendor_id,
Invoice.status == InvoiceStatus.PAID.value,
)
)
.scalar()
or 0
)
return {
"total_invoices": total_count,
"total_revenue_cents": total_revenue,
"total_revenue": total_revenue / 100 if total_revenue else 0,
"draft_count": draft_count,
"paid_count": paid_count,
}
# =========================================================================
# PDF Generation
# =========================================================================
def generate_pdf(
self,
db: Session,
vendor_id: int,
invoice_id: int,
force_regenerate: bool = False,
) -> str:
"""
Generate PDF for an invoice.
Returns path to the generated PDF.
"""
from app.services.invoice_pdf_service import invoice_pdf_service
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate)
def get_pdf_path(
self,
db: Session,
vendor_id: int,
invoice_id: int,
) -> str | None:
"""Get PDF path for an invoice if it exists."""
from app.services.invoice_pdf_service import invoice_pdf_service
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
return invoice_pdf_service.get_pdf_path(invoice)
# Singleton instance
invoice_service = InvoiceService()
__all__ = [
"invoice_service",
"InvoiceService",
]

View File

@@ -1,33 +1,33 @@
# app/services/letzshop/__init__.py
"""
Letzshop marketplace integration services.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Provides:
- GraphQL client for API communication
- Credential management service
- Order import service
- Fulfillment sync service
- Vendor directory sync service
The canonical implementation is now in:
app/modules/marketplace/services/letzshop/
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.marketplace.services.letzshop import LetzshopClient
"""
from .client_service import (
LetzshopAPIError,
LetzshopAuthError,
from app.modules.marketplace.services.letzshop import (
# Client
LetzshopClient,
LetzshopClientError,
LetzshopAuthError,
LetzshopAPIError,
LetzshopConnectionError,
)
from .credentials_service import (
# Credentials
LetzshopCredentialsService,
CredentialsError,
CredentialsNotFoundError,
LetzshopCredentialsService,
)
from .order_service import (
# Order Service
LetzshopOrderService,
OrderNotFoundError,
VendorNotFoundError,
)
from .vendor_sync_service import (
# Vendor Sync Service
LetzshopVendorSyncService,
get_vendor_sync_service,
)

View File

@@ -1,338 +1,25 @@
# app/services/letzshop_export_service.py
"""
Service for exporting products to Letzshop CSV format.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Generates Google Shopping compatible CSV files for Letzshop marketplace.
The canonical implementation is now in:
app/modules/marketplace/services/letzshop_export_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.marketplace.services import letzshop_export_service
"""
import csv
import io
import logging
from datetime import UTC, datetime
from app.modules.marketplace.services.letzshop_export_service import (
LetzshopExportService,
letzshop_export_service,
LETZSHOP_CSV_COLUMNS,
)
from sqlalchemy.orm import Session, joinedload
from models.database.letzshop import LetzshopSyncLog
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
logger = logging.getLogger(__name__)
# Letzshop CSV columns in order
LETZSHOP_CSV_COLUMNS = [
"id",
"title",
"description",
"link",
"image_link",
"additional_image_link",
"availability",
"price",
"sale_price",
"brand",
"gtin",
"mpn",
"google_product_category",
"product_type",
"condition",
"adult",
"multipack",
"is_bundle",
"age_group",
"color",
"gender",
"material",
"pattern",
"size",
"size_type",
"size_system",
"item_group_id",
"custom_label_0",
"custom_label_1",
"custom_label_2",
"custom_label_3",
"custom_label_4",
"identifier_exists",
"unit_pricing_measure",
"unit_pricing_base_measure",
"shipping",
"atalanda:tax_rate",
"atalanda:quantity",
"atalanda:boost_sort",
"atalanda:delivery_method",
__all__ = [
"LetzshopExportService",
"letzshop_export_service",
"LETZSHOP_CSV_COLUMNS",
]
class LetzshopExportService:
"""Service for exporting products to Letzshop CSV format."""
def __init__(self, default_tax_rate: float = 17.0):
"""
Initialize the export service.
Args:
default_tax_rate: Default VAT rate for Luxembourg (17%)
"""
self.default_tax_rate = default_tax_rate
def export_vendor_products(
self,
db: Session,
vendor_id: int,
language: str = "en",
include_inactive: bool = False,
) -> str:
"""
Export all products for a vendor in Letzshop CSV format.
Args:
db: Database session
vendor_id: Vendor ID to export products for
language: Language for title/description (en, fr, de)
include_inactive: Whether to include inactive products
Returns:
CSV string content
"""
# Query products for this vendor with their marketplace product data
query = (
db.query(Product)
.filter(Product.vendor_id == vendor_id)
.options(
joinedload(Product.marketplace_product).joinedload(
MarketplaceProduct.translations
)
)
)
if not include_inactive:
query = query.filter(Product.is_active == True)
products = query.all()
logger.info(
f"Exporting {len(products)} products for vendor {vendor_id} in {language}"
)
return self._generate_csv(products, language)
def export_marketplace_products(
self,
db: Session,
marketplace: str = "Letzshop",
language: str = "en",
limit: int | None = None,
) -> str:
"""
Export marketplace products directly (admin use).
Args:
db: Database session
marketplace: Filter by marketplace source
language: Language for title/description
limit: Optional limit on number of products
Returns:
CSV string content
"""
query = (
db.query(MarketplaceProduct)
.filter(MarketplaceProduct.is_active == True)
.options(joinedload(MarketplaceProduct.translations))
)
if marketplace:
query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if limit:
query = query.limit(limit)
products = query.all()
logger.info(
f"Exporting {len(products)} marketplace products for {marketplace} in {language}"
)
return self._generate_csv_from_marketplace_products(products, language)
def _generate_csv(self, products: list[Product], language: str) -> str:
"""Generate CSV from vendor Product objects."""
output = io.StringIO()
writer = csv.DictWriter(
output,
fieldnames=LETZSHOP_CSV_COLUMNS,
delimiter="\t",
quoting=csv.QUOTE_MINIMAL,
)
writer.writeheader()
for product in products:
if product.marketplace_product:
row = self._product_to_row(product, language)
writer.writerow(row)
return output.getvalue()
def _generate_csv_from_marketplace_products(
self, products: list[MarketplaceProduct], language: str
) -> str:
"""Generate CSV from MarketplaceProduct objects directly."""
output = io.StringIO()
writer = csv.DictWriter(
output,
fieldnames=LETZSHOP_CSV_COLUMNS,
delimiter="\t",
quoting=csv.QUOTE_MINIMAL,
)
writer.writeheader()
for mp in products:
row = self._marketplace_product_to_row(mp, language)
writer.writerow(row)
return output.getvalue()
def _product_to_row(self, product: Product, language: str) -> dict:
"""Convert a Product (with MarketplaceProduct) to a CSV row."""
mp = product.marketplace_product
return self._marketplace_product_to_row(
mp, language, vendor_sku=product.vendor_sku
)
def _marketplace_product_to_row(
self,
mp: MarketplaceProduct,
language: str,
vendor_sku: str | None = None,
) -> dict:
"""Convert a MarketplaceProduct to a CSV row dict."""
# Get localized title and description
title = mp.get_title(language) or ""
description = mp.get_description(language) or ""
# Format price with currency
price = ""
if mp.price_numeric:
price = f"{mp.price_numeric:.2f} {mp.currency or 'EUR'}"
elif mp.price:
price = mp.price
# Format sale price
sale_price = ""
if mp.sale_price_numeric:
sale_price = f"{mp.sale_price_numeric:.2f} {mp.currency or 'EUR'}"
elif mp.sale_price:
sale_price = mp.sale_price
# Additional images - join with comma if multiple
additional_images = ""
if mp.additional_images:
additional_images = ",".join(mp.additional_images)
elif mp.additional_image_link:
additional_images = mp.additional_image_link
# Determine identifier_exists
identifier_exists = mp.identifier_exists
if not identifier_exists:
identifier_exists = "yes" if (mp.gtin or mp.mpn) else "no"
return {
"id": vendor_sku or mp.marketplace_product_id,
"title": title,
"description": description,
"link": mp.link or mp.source_url or "",
"image_link": mp.image_link or "",
"additional_image_link": additional_images,
"availability": mp.availability or "in stock",
"price": price,
"sale_price": sale_price,
"brand": mp.brand or "",
"gtin": mp.gtin or "",
"mpn": mp.mpn or "",
"google_product_category": mp.google_product_category or "",
"product_type": mp.product_type_raw or "",
"condition": mp.condition or "new",
"adult": mp.adult or "no",
"multipack": str(mp.multipack) if mp.multipack else "",
"is_bundle": mp.is_bundle or "no",
"age_group": mp.age_group or "",
"color": mp.color or "",
"gender": mp.gender or "",
"material": mp.material or "",
"pattern": mp.pattern or "",
"size": mp.size or "",
"size_type": mp.size_type or "",
"size_system": mp.size_system or "",
"item_group_id": mp.item_group_id or "",
"custom_label_0": mp.custom_label_0 or "",
"custom_label_1": mp.custom_label_1 or "",
"custom_label_2": mp.custom_label_2 or "",
"custom_label_3": mp.custom_label_3 or "",
"custom_label_4": mp.custom_label_4 or "",
"identifier_exists": identifier_exists,
"unit_pricing_measure": mp.unit_pricing_measure or "",
"unit_pricing_base_measure": mp.unit_pricing_base_measure or "",
"shipping": mp.shipping or "",
"atalanda:tax_rate": str(self.default_tax_rate),
"atalanda:quantity": "", # Would need inventory data
"atalanda:boost_sort": "",
"atalanda:delivery_method": "",
}
def log_export(
self,
db: Session,
vendor_id: int,
started_at: datetime,
completed_at: datetime,
files_processed: int,
files_succeeded: int,
files_failed: int,
products_exported: int,
triggered_by: str,
error_details: dict | None = None,
) -> LetzshopSyncLog:
"""
Log an export operation to the sync log.
Args:
db: Database session
vendor_id: Vendor ID
started_at: When the export started
completed_at: When the export completed
files_processed: Number of language files to export (e.g., 3)
files_succeeded: Number of files successfully exported
files_failed: Number of files that failed
products_exported: Total products in the export
triggered_by: Who triggered the export (e.g., "admin:123")
error_details: Optional error details if any failures
Returns:
Created LetzshopSyncLog entry
"""
sync_log = LetzshopSyncLog(
vendor_id=vendor_id,
operation_type="product_export",
direction="outbound",
status="completed" if files_failed == 0 else "partial",
records_processed=files_processed,
records_succeeded=files_succeeded,
records_failed=files_failed,
started_at=started_at,
completed_at=completed_at,
duration_seconds=int((completed_at - started_at).total_seconds()),
triggered_by=triggered_by,
error_details={
"products_exported": products_exported,
**(error_details or {}),
} if products_exported or error_details else None,
)
db.add(sync_log)
db.flush()
return sync_log
# Singleton instance
letzshop_export_service = LetzshopExportService()

View File

@@ -1,334 +1,23 @@
# app/services/marketplace_import_job_service.py
import logging
"""
LEGACY LOCATION - Re-exports from module for backwards compatibility.
from sqlalchemy.orm import Session
The canonical implementation is now in:
app/modules/marketplace/services/marketplace_import_job_service.py
from app.exceptions import (
ImportJobNotFoundException,
ImportJobNotOwnedException,
ValidationException,
)
from models.database.marketplace_import_job import (
MarketplaceImportError,
MarketplaceImportJob,
)
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (
AdminMarketplaceImportJobResponse,
MarketplaceImportJobRequest,
MarketplaceImportJobResponse,
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.marketplace.services import marketplace_import_job_service
"""
from app.modules.marketplace.services.marketplace_import_job_service import (
MarketplaceImportJobService,
marketplace_import_job_service,
)
logger = logging.getLogger(__name__)
class MarketplaceImportJobService:
"""Service class for Marketplace operations."""
def create_import_job(
self,
db: Session,
request: MarketplaceImportJobRequest,
vendor: Vendor, # CHANGED: Vendor object from middleware
user: User,
) -> MarketplaceImportJob:
"""
Create a new marketplace import job.
Args:
db: Database session
request: Import request data
vendor: Vendor object (from middleware)
user: User creating the job
Returns:
Created MarketplaceImportJob object
"""
try:
# Create marketplace import job record
import_job = MarketplaceImportJob(
status="pending",
source_url=request.source_url,
marketplace=request.marketplace,
language=request.language,
vendor_id=vendor.id,
user_id=user.id,
)
db.add(import_job)
db.flush()
db.refresh(import_job)
logger.info(
f"Created marketplace import job {import_job.id}: "
f"{request.marketplace} -> {vendor.name} (code: {vendor.vendor_code}) "
f"by user {user.username}"
)
return import_job
except Exception as e:
logger.error(f"Error creating import job: {str(e)}")
raise ValidationException("Failed to create import job")
def get_import_job_by_id(
self, db: Session, job_id: int, user: User
) -> MarketplaceImportJob:
"""Get a marketplace import job by ID with access control."""
try:
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ImportJobNotFoundException(job_id)
# Users can only see their own jobs, admins can see all
if user.role != "admin" and job.user_id != user.id:
raise ImportJobNotOwnedException(job_id, user.id)
return job
except (ImportJobNotFoundException, ImportJobNotOwnedException):
raise
except Exception as e:
logger.error(f"Error getting import job {job_id}: {str(e)}")
raise ValidationException("Failed to retrieve import job")
def get_import_job_for_vendor(
self, db: Session, job_id: int, vendor_id: int
) -> MarketplaceImportJob:
"""
Get a marketplace import job by ID with vendor access control.
Validates that the job belongs to the specified vendor.
Args:
db: Database session
job_id: Import job ID
vendor_id: Vendor ID from token (to verify ownership)
Raises:
ImportJobNotFoundException: If job not found
UnauthorizedVendorAccessException: If job doesn't belong to vendor
"""
from app.exceptions import UnauthorizedVendorAccessException
try:
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ImportJobNotFoundException(job_id)
# Verify job belongs to vendor (service layer validation)
if job.vendor_id != vendor_id:
raise UnauthorizedVendorAccessException(
vendor_code=str(vendor_id),
user_id=0, # Not user-specific, but vendor mismatch
)
return job
except (ImportJobNotFoundException, UnauthorizedVendorAccessException):
raise
except Exception as e:
logger.error(
f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}"
)
raise ValidationException("Failed to retrieve import job")
def get_import_jobs(
self,
db: Session,
vendor: Vendor, # ADDED: Vendor filter
user: User,
marketplace: str | None = None,
skip: int = 0,
limit: int = 50,
) -> list[MarketplaceImportJob]:
"""Get marketplace import jobs for a specific vendor."""
try:
query = db.query(MarketplaceImportJob).filter(
MarketplaceImportJob.vendor_id == vendor.id
)
# Users can only see their own jobs, admins can see all vendor jobs
if user.role != "admin":
query = query.filter(MarketplaceImportJob.user_id == user.id)
# Apply marketplace filter
if marketplace:
query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
# Order by creation date (newest first) and apply pagination
jobs = (
query.order_by(
MarketplaceImportJob.created_at.desc(),
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
)
.offset(skip)
.limit(limit)
.all()
)
return jobs
except Exception as e:
logger.error(f"Error getting import jobs: {str(e)}")
raise ValidationException("Failed to retrieve import jobs")
def convert_to_response_model(
self, job: MarketplaceImportJob
) -> MarketplaceImportJobResponse:
"""Convert database model to API response model."""
return MarketplaceImportJobResponse(
job_id=job.id,
status=job.status,
marketplace=job.marketplace,
language=job.language,
vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None,
vendor_name=job.vendor.name if job.vendor else None,
source_url=job.source_url,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
total_processed=job.total_processed or 0,
error_count=job.error_count or 0,
error_message=job.error_message,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
)
def convert_to_admin_response_model(
self, job: MarketplaceImportJob
) -> AdminMarketplaceImportJobResponse:
"""Convert database model to admin API response model with extra fields."""
return AdminMarketplaceImportJobResponse(
id=job.id,
job_id=job.id,
status=job.status,
marketplace=job.marketplace,
language=job.language,
vendor_id=job.vendor_id,
vendor_code=job.vendor.vendor_code if job.vendor else None,
vendor_name=job.vendor.name if job.vendor else None,
source_url=job.source_url,
imported=job.imported_count or 0,
updated=job.updated_count or 0,
total_processed=job.total_processed or 0,
error_count=job.error_count or 0,
error_message=job.error_message,
error_details=[],
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
created_by_name=job.user.username if job.user else None,
)
def get_all_import_jobs_paginated(
self,
db: Session,
marketplace: str | None = None,
status: str | None = None,
page: int = 1,
limit: int = 100,
) -> tuple[list[MarketplaceImportJob], int]:
"""Get all marketplace import jobs with pagination (for admin)."""
try:
query = db.query(MarketplaceImportJob)
if marketplace:
query = query.filter(
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
)
if status:
query = query.filter(MarketplaceImportJob.status == status)
total = query.count()
skip = (page - 1) * limit
jobs = (
query.order_by(
MarketplaceImportJob.created_at.desc(),
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
)
.offset(skip)
.limit(limit)
.all()
)
return jobs, total
except Exception as e:
logger.error(f"Error getting all import jobs: {str(e)}")
raise ValidationException("Failed to retrieve import jobs")
def get_import_job_by_id_admin(
self, db: Session, job_id: int
) -> MarketplaceImportJob:
"""Get a marketplace import job by ID (admin - no access control)."""
job = (
db.query(MarketplaceImportJob)
.filter(MarketplaceImportJob.id == job_id)
.first()
)
if not job:
raise ImportJobNotFoundException(job_id)
return job
def get_import_job_errors(
self,
db: Session,
job_id: int,
error_type: str | None = None,
page: int = 1,
limit: int = 50,
) -> tuple[list[MarketplaceImportError], int]:
"""
Get import errors for a specific job with pagination.
Args:
db: Database session
job_id: Import job ID
error_type: Optional filter by error type
page: Page number (1-indexed)
limit: Number of items per page
Returns:
Tuple of (list of errors, total count)
"""
try:
query = db.query(MarketplaceImportError).filter(
MarketplaceImportError.import_job_id == job_id
)
if error_type:
query = query.filter(MarketplaceImportError.error_type == error_type)
total = query.count()
offset = (page - 1) * limit
errors = (
query.order_by(MarketplaceImportError.row_number)
.offset(offset)
.limit(limit)
.all()
)
return errors, total
except Exception as e:
logger.error(f"Error getting import job errors for job {job_id}: {str(e)}")
raise ValidationException("Failed to retrieve import errors")
marketplace_import_job_service = MarketplaceImportJobService()
__all__ = [
"MarketplaceImportJobService",
"marketplace_import_job_service",
]

File diff suppressed because it is too large Load Diff

View File

@@ -1,225 +1,23 @@
# app/services/message_attachment_service.py
"""
Attachment handling service for messaging system.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Handles file upload, validation, storage, and retrieval.
The canonical implementation is now in:
app/modules/messaging/services/message_attachment_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.messaging.services import message_attachment_service
"""
import logging
import os
import uuid
from datetime import datetime
from pathlib import Path
from app.modules.messaging.services.message_attachment_service import (
message_attachment_service,
MessageAttachmentService,
)
from fastapi import UploadFile
from sqlalchemy.orm import Session
from app.services.admin_settings_service import admin_settings_service
logger = logging.getLogger(__name__)
# Allowed MIME types for attachments
ALLOWED_MIME_TYPES = {
# Images
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
# Documents
"application/pdf",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
# Archives
"application/zip",
# Text
"text/plain",
"text/csv",
}
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
# Default max file size in MB
DEFAULT_MAX_FILE_SIZE_MB = 10
class MessageAttachmentService:
"""Service for handling message attachments."""
def __init__(self, storage_base: str = "uploads/messages"):
self.storage_base = storage_base
def get_max_file_size_bytes(self, db: Session) -> int:
"""Get maximum file size from platform settings."""
max_mb = admin_settings_service.get_setting_value(
db,
"message_attachment_max_size_mb",
default=DEFAULT_MAX_FILE_SIZE_MB,
)
try:
max_mb = int(max_mb)
except (TypeError, ValueError):
max_mb = DEFAULT_MAX_FILE_SIZE_MB
return max_mb * 1024 * 1024 # Convert to bytes
def validate_file_type(self, mime_type: str) -> bool:
"""Check if file type is allowed."""
return mime_type in ALLOWED_MIME_TYPES
def is_image(self, mime_type: str) -> bool:
"""Check if file is an image."""
return mime_type in IMAGE_MIME_TYPES
async def validate_and_store(
self,
db: Session,
file: UploadFile,
conversation_id: int,
) -> dict:
"""
Validate and store an uploaded file.
Returns dict with file metadata for MessageAttachment creation.
Raises:
ValueError: If file type or size is invalid
"""
# Validate MIME type
content_type = file.content_type or "application/octet-stream"
if not self.validate_file_type(content_type):
raise ValueError(
f"File type '{content_type}' not allowed. "
"Allowed types: images (JPEG, PNG, GIF, WebP), "
"PDF, Office documents, ZIP, text files."
)
# Read file content
content = await file.read()
file_size = len(content)
# Validate file size
max_size = self.get_max_file_size_bytes(db)
if file_size > max_size:
raise ValueError(
f"File size {file_size / 1024 / 1024:.1f}MB exceeds "
f"maximum allowed size of {max_size / 1024 / 1024:.1f}MB"
)
# Generate unique filename
original_filename = file.filename or "attachment"
ext = Path(original_filename).suffix.lower()
unique_filename = f"{uuid.uuid4().hex}{ext}"
# Create storage path: uploads/messages/YYYY/MM/conversation_id/filename
now = datetime.utcnow()
relative_path = os.path.join(
self.storage_base,
str(now.year),
f"{now.month:02d}",
str(conversation_id),
)
# Ensure directory exists
os.makedirs(relative_path, exist_ok=True)
# Full file path
file_path = os.path.join(relative_path, unique_filename)
# Write file
with open(file_path, "wb") as f:
f.write(content)
# Prepare metadata
is_image = self.is_image(content_type)
metadata = {
"filename": unique_filename,
"original_filename": original_filename,
"file_path": file_path,
"file_size": file_size,
"mime_type": content_type,
"is_image": is_image,
}
# Generate thumbnail for images
if is_image:
thumbnail_data = self._create_thumbnail(content, file_path)
metadata.update(thumbnail_data)
logger.info(
f"Stored attachment {unique_filename} for conversation {conversation_id} "
f"({file_size} bytes, type: {content_type})"
)
return metadata
def _create_thumbnail(self, content: bytes, original_path: str) -> dict:
"""Create thumbnail for image attachments."""
try:
from PIL import Image
import io
img = Image.open(io.BytesIO(content))
width, height = img.size
# Create thumbnail
img.thumbnail((200, 200))
thumb_path = original_path.replace(".", "_thumb.")
img.save(thumb_path)
return {
"image_width": width,
"image_height": height,
"thumbnail_path": thumb_path,
}
except ImportError:
logger.warning("PIL not installed, skipping thumbnail generation")
return {}
except Exception as e:
logger.error(f"Failed to create thumbnail: {e}")
return {}
def delete_attachment(
self, file_path: str, thumbnail_path: str | None = None
) -> bool:
"""Delete attachment files from storage."""
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Deleted attachment file: {file_path}")
if thumbnail_path and os.path.exists(thumbnail_path):
os.remove(thumbnail_path)
logger.info(f"Deleted thumbnail: {thumbnail_path}")
return True
except Exception as e:
logger.error(f"Failed to delete attachment {file_path}: {e}")
return False
def get_download_url(self, file_path: str) -> str:
"""
Get download URL for an attachment.
For local storage, returns a relative path that can be served
by the static file handler or a dedicated download endpoint.
"""
# Convert local path to URL path
# Assumes files are served from /static/uploads or similar
return f"/static/{file_path}"
def get_file_content(self, file_path: str) -> bytes | None:
"""Read file content from storage."""
try:
if os.path.exists(file_path):
with open(file_path, "rb") as f:
return f.read()
return None
except Exception as e:
logger.error(f"Failed to read file {file_path}: {e}")
return None
# Singleton instance
message_attachment_service = MessageAttachmentService()
__all__ = [
"message_attachment_service",
"MessageAttachmentService",
]

View File

@@ -1,684 +1,23 @@
# app/services/messaging_service.py
"""
Messaging service for conversation and message management.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
Provides functionality for:
- Creating conversations between different participant types
- Sending messages with attachments
- Managing read status and unread counts
- Conversation listing with filters
- Multi-tenant data isolation
The canonical implementation is now in:
app/modules/messaging/services/messaging_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.messaging.services import messaging_service
"""
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload
from models.database.customer import Customer
from models.database.message import (
Conversation,
ConversationParticipant,
ConversationType,
Message,
MessageAttachment,
ParticipantType,
from app.modules.messaging.services.messaging_service import (
messaging_service,
MessagingService,
)
from models.database.user import User
logger = logging.getLogger(__name__)
class MessagingService:
"""Service for managing conversations and messages."""
# =========================================================================
# CONVERSATION MANAGEMENT
# =========================================================================
def create_conversation(
self,
db: Session,
conversation_type: ConversationType,
subject: str,
initiator_type: ParticipantType,
initiator_id: int,
recipient_type: ParticipantType,
recipient_id: int,
vendor_id: int | None = None,
initial_message: str | None = None,
) -> Conversation:
"""
Create a new conversation between two participants.
Args:
db: Database session
conversation_type: Type of conversation channel
subject: Conversation subject line
initiator_type: Type of initiating participant
initiator_id: ID of initiating participant
recipient_type: Type of receiving participant
recipient_id: ID of receiving participant
vendor_id: Required for vendor_customer/admin_customer types
initial_message: Optional first message content
Returns:
Created Conversation object
"""
# Validate vendor_id requirement
if conversation_type in [
ConversationType.VENDOR_CUSTOMER,
ConversationType.ADMIN_CUSTOMER,
]:
if not vendor_id:
raise ValueError(
f"vendor_id required for {conversation_type.value} conversations"
)
# Create conversation
conversation = Conversation(
conversation_type=conversation_type,
subject=subject,
vendor_id=vendor_id,
)
db.add(conversation)
db.flush()
# Add participants
initiator_vendor_id = (
vendor_id if initiator_type == ParticipantType.VENDOR else None
)
recipient_vendor_id = (
vendor_id if recipient_type == ParticipantType.VENDOR else None
)
initiator = ConversationParticipant(
conversation_id=conversation.id,
participant_type=initiator_type,
participant_id=initiator_id,
vendor_id=initiator_vendor_id,
unread_count=0, # Initiator has read their own message
)
recipient = ConversationParticipant(
conversation_id=conversation.id,
participant_type=recipient_type,
participant_id=recipient_id,
vendor_id=recipient_vendor_id,
unread_count=1 if initial_message else 0,
)
db.add(initiator)
db.add(recipient)
db.flush()
# Add initial message if provided
if initial_message:
self.send_message(
db=db,
conversation_id=conversation.id,
sender_type=initiator_type,
sender_id=initiator_id,
content=initial_message,
_skip_unread_update=True, # Already set above
)
logger.info(
f"Created {conversation_type.value} conversation {conversation.id}: "
f"{initiator_type.value}:{initiator_id} -> {recipient_type.value}:{recipient_id}"
)
return conversation
def get_conversation(
self,
db: Session,
conversation_id: int,
participant_type: ParticipantType,
participant_id: int,
) -> Conversation | None:
"""
Get conversation if participant has access.
Validates that the requester is a participant.
"""
conversation = (
db.query(Conversation)
.options(
joinedload(Conversation.participants),
joinedload(Conversation.messages).joinedload(Message.attachments),
)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
# Verify participant access
has_access = any(
p.participant_type == participant_type
and p.participant_id == participant_id
for p in conversation.participants
)
if not has_access:
logger.warning(
f"Access denied to conversation {conversation_id} for "
f"{participant_type.value}:{participant_id}"
)
return None
return conversation
def list_conversations(
self,
db: Session,
participant_type: ParticipantType,
participant_id: int,
vendor_id: int | None = None,
conversation_type: ConversationType | None = None,
is_closed: bool | None = None,
skip: int = 0,
limit: int = 20,
) -> tuple[list[Conversation], int, int]:
"""
List conversations for a participant with filters.
Returns:
Tuple of (conversations, total_count, total_unread)
"""
# Base query: conversations where user is a participant
query = (
db.query(Conversation)
.join(ConversationParticipant)
.filter(
and_(
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
)
# Multi-tenant filter for vendor users
if participant_type == ParticipantType.VENDOR and vendor_id:
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
# Customer vendor isolation
if participant_type == ParticipantType.CUSTOMER and vendor_id:
query = query.filter(Conversation.vendor_id == vendor_id)
# Type filter
if conversation_type:
query = query.filter(Conversation.conversation_type == conversation_type)
# Status filter
if is_closed is not None:
query = query.filter(Conversation.is_closed == is_closed)
# Get total count
total = query.count()
# Get total unread across all conversations
unread_query = db.query(
func.sum(ConversationParticipant.unread_count)
).filter(
and_(
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
if participant_type == ParticipantType.VENDOR and vendor_id:
unread_query = unread_query.filter(
ConversationParticipant.vendor_id == vendor_id
)
total_unread = unread_query.scalar() or 0
# Get paginated results, ordered by last activity
conversations = (
query.options(joinedload(Conversation.participants))
.order_by(Conversation.last_message_at.desc().nullslast())
.offset(skip)
.limit(limit)
.all()
)
return conversations, total, total_unread
def close_conversation(
self,
db: Session,
conversation_id: int,
closer_type: ParticipantType,
closer_id: int,
) -> Conversation | None:
"""Close a conversation thread."""
conversation = self.get_conversation(
db, conversation_id, closer_type, closer_id
)
if not conversation:
return None
conversation.is_closed = True
conversation.closed_at = datetime.now(UTC)
conversation.closed_by_type = closer_type
conversation.closed_by_id = closer_id
# Add system message
self.send_message(
db=db,
conversation_id=conversation_id,
sender_type=closer_type,
sender_id=closer_id,
content="Conversation closed",
is_system_message=True,
)
db.flush()
return conversation
def reopen_conversation(
self,
db: Session,
conversation_id: int,
opener_type: ParticipantType,
opener_id: int,
) -> Conversation | None:
"""Reopen a closed conversation."""
conversation = self.get_conversation(
db, conversation_id, opener_type, opener_id
)
if not conversation:
return None
conversation.is_closed = False
conversation.closed_at = None
conversation.closed_by_type = None
conversation.closed_by_id = None
# Add system message
self.send_message(
db=db,
conversation_id=conversation_id,
sender_type=opener_type,
sender_id=opener_id,
content="Conversation reopened",
is_system_message=True,
)
db.flush()
return conversation
# =========================================================================
# MESSAGE MANAGEMENT
# =========================================================================
def send_message(
self,
db: Session,
conversation_id: int,
sender_type: ParticipantType,
sender_id: int,
content: str,
attachments: list[dict[str, Any]] | None = None,
is_system_message: bool = False,
_skip_unread_update: bool = False,
) -> Message:
"""
Send a message in a conversation.
Args:
db: Database session
conversation_id: Target conversation ID
sender_type: Type of sender
sender_id: ID of sender
content: Message text content
attachments: List of attachment dicts with file metadata
is_system_message: Whether this is a system-generated message
_skip_unread_update: Internal flag to skip unread increment
Returns:
Created Message object
"""
# Create message
message = Message(
conversation_id=conversation_id,
sender_type=sender_type,
sender_id=sender_id,
content=content,
is_system_message=is_system_message,
)
db.add(message)
db.flush()
# Add attachments if any
if attachments:
for att_data in attachments:
attachment = MessageAttachment(
message_id=message.id,
filename=att_data["filename"],
original_filename=att_data["original_filename"],
file_path=att_data["file_path"],
file_size=att_data["file_size"],
mime_type=att_data["mime_type"],
is_image=att_data.get("is_image", False),
image_width=att_data.get("image_width"),
image_height=att_data.get("image_height"),
thumbnail_path=att_data.get("thumbnail_path"),
)
db.add(attachment)
# Update conversation metadata
conversation = (
db.query(Conversation).filter(Conversation.id == conversation_id).first()
)
if conversation:
conversation.last_message_at = datetime.now(UTC)
conversation.message_count += 1
# Update unread counts for other participants
if not _skip_unread_update:
db.query(ConversationParticipant).filter(
and_(
ConversationParticipant.conversation_id == conversation_id,
or_(
ConversationParticipant.participant_type != sender_type,
ConversationParticipant.participant_id != sender_id,
),
)
).update(
{
ConversationParticipant.unread_count: ConversationParticipant.unread_count
+ 1
}
)
db.flush()
logger.info(
f"Message {message.id} sent in conversation {conversation_id} "
f"by {sender_type.value}:{sender_id}"
)
return message
def delete_message(
self,
db: Session,
message_id: int,
deleter_type: ParticipantType,
deleter_id: int,
) -> Message | None:
"""Soft delete a message (for moderation)."""
message = db.query(Message).filter(Message.id == message_id).first()
if not message:
return None
# Verify deleter has access to conversation
conversation = self.get_conversation(
db, message.conversation_id, deleter_type, deleter_id
)
if not conversation:
return None
message.is_deleted = True
message.deleted_at = datetime.now(UTC)
message.deleted_by_type = deleter_type
message.deleted_by_id = deleter_id
db.flush()
return message
def mark_conversation_read(
self,
db: Session,
conversation_id: int,
reader_type: ParticipantType,
reader_id: int,
) -> bool:
"""Mark all messages in conversation as read for participant."""
result = (
db.query(ConversationParticipant)
.filter(
and_(
ConversationParticipant.conversation_id == conversation_id,
ConversationParticipant.participant_type == reader_type,
ConversationParticipant.participant_id == reader_id,
)
)
.update(
{
ConversationParticipant.unread_count: 0,
ConversationParticipant.last_read_at: datetime.now(UTC),
}
)
)
db.flush()
return result > 0
def get_unread_count(
self,
db: Session,
participant_type: ParticipantType,
participant_id: int,
vendor_id: int | None = None,
) -> int:
"""Get total unread message count for a participant."""
query = db.query(func.sum(ConversationParticipant.unread_count)).filter(
and_(
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
if vendor_id:
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
return query.scalar() or 0
# =========================================================================
# PARTICIPANT HELPERS
# =========================================================================
def get_participant_info(
self,
db: Session,
participant_type: ParticipantType,
participant_id: int,
) -> dict[str, Any] | None:
"""Get display info for a participant (name, email, avatar)."""
if participant_type in [ParticipantType.ADMIN, ParticipantType.VENDOR]:
user = db.query(User).filter(User.id == participant_id).first()
if user:
return {
"id": user.id,
"type": participant_type.value,
"name": f"{user.first_name or ''} {user.last_name or ''}".strip()
or user.username,
"email": user.email,
"avatar_url": None, # Could add avatar support later
}
elif participant_type == ParticipantType.CUSTOMER:
customer = db.query(Customer).filter(Customer.id == participant_id).first()
if customer:
return {
"id": customer.id,
"type": participant_type.value,
"name": f"{customer.first_name or ''} {customer.last_name or ''}".strip()
or customer.email,
"email": customer.email,
"avatar_url": None,
}
return None
def get_other_participant(
self,
conversation: Conversation,
my_type: ParticipantType,
my_id: int,
) -> ConversationParticipant | None:
"""Get the other participant in a conversation."""
for p in conversation.participants:
if p.participant_type != my_type or p.participant_id != my_id:
return p
return None
# =========================================================================
# NOTIFICATION PREFERENCES
# =========================================================================
def update_notification_preferences(
self,
db: Session,
conversation_id: int,
participant_type: ParticipantType,
participant_id: int,
email_notifications: bool | None = None,
muted: bool | None = None,
) -> bool:
"""Update notification preferences for a participant in a conversation."""
updates = {}
if email_notifications is not None:
updates[ConversationParticipant.email_notifications] = email_notifications
if muted is not None:
updates[ConversationParticipant.muted] = muted
if not updates:
return False
result = (
db.query(ConversationParticipant)
.filter(
and_(
ConversationParticipant.conversation_id == conversation_id,
ConversationParticipant.participant_type == participant_type,
ConversationParticipant.participant_id == participant_id,
)
)
.update(updates)
)
db.flush()
return result > 0
# =========================================================================
# RECIPIENT QUERIES
# =========================================================================
def get_vendor_recipients(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[dict], int]:
"""
Get list of vendor users as potential recipients.
Args:
db: Database session
vendor_id: Optional vendor ID filter
search: Search term for name/email
skip: Pagination offset
limit: Max results
Returns:
Tuple of (recipients list, total count)
"""
from models.database.vendor import VendorUser
query = (
db.query(User, VendorUser)
.join(VendorUser, User.id == VendorUser.user_id)
.filter(User.is_active == True) # noqa: E712
)
if vendor_id:
query = query.filter(VendorUser.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(User.username.ilike(search_pattern))
| (User.email.ilike(search_pattern))
| (User.first_name.ilike(search_pattern))
| (User.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
recipients = []
for user, vendor_user in results:
name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username
recipients.append({
"id": user.id,
"type": ParticipantType.VENDOR,
"name": name,
"email": user.email,
"vendor_id": vendor_user.vendor_id,
"vendor_name": vendor_user.vendor.name if vendor_user.vendor else None,
})
return recipients, total
def get_customer_recipients(
self,
db: Session,
vendor_id: int | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[dict], int]:
"""
Get list of customers as potential recipients.
Args:
db: Database session
vendor_id: Optional vendor ID filter (required for vendor users)
search: Search term for name/email
skip: Pagination offset
limit: Max results
Returns:
Tuple of (recipients list, total count)
"""
query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712
if vendor_id:
query = query.filter(Customer.vendor_id == vendor_id)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(Customer.email.ilike(search_pattern))
| (Customer.first_name.ilike(search_pattern))
| (Customer.last_name.ilike(search_pattern))
)
total = query.count()
results = query.offset(skip).limit(limit).all()
recipients = []
for customer in results:
name = f"{customer.first_name or ''} {customer.last_name or ''}".strip()
recipients.append({
"id": customer.id,
"type": ParticipantType.CUSTOMER,
"name": name or customer.email,
"email": customer.email,
"vendor_id": customer.vendor_id,
})
return recipients, total
# Singleton instance
messaging_service = MessagingService()
__all__ = [
"messaging_service",
"MessagingService",
]

View File

@@ -1,735 +1,23 @@
# app/services/order_inventory_service.py
"""
Order-Inventory Integration Service.
LEGACY LOCATION - Re-exports from module for backwards compatibility.
This service orchestrates inventory operations for orders:
- Reserve inventory when orders are confirmed
- Fulfill (deduct) inventory when orders are shipped
- Release reservations when orders are cancelled
The canonical implementation is now in:
app/modules/orders/services/order_inventory_service.py
This is the critical link between the order and inventory systems
that ensures stock accuracy.
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
All operations are logged to the inventory_transactions table for audit trail.
from app.modules.orders.services import order_inventory_service
"""
import logging
from sqlalchemy.orm import Session
from app.exceptions import (
InsufficientInventoryException,
InventoryNotFoundException,
OrderNotFoundException,
ValidationException,
from app.modules.orders.services.order_inventory_service import (
order_inventory_service,
OrderInventoryService,
)
from app.services.inventory_service import inventory_service
from models.database.inventory import Inventory
from models.database.inventory_transaction import InventoryTransaction, TransactionType
from models.database.order import Order, OrderItem
from models.schema.inventory import InventoryReserve
logger = logging.getLogger(__name__)
# Default location for inventory operations
DEFAULT_LOCATION = "DEFAULT"
class OrderInventoryService:
"""
Orchestrate order and inventory operations together.
This service ensures that:
1. When orders are confirmed, inventory is reserved
2. When orders are shipped, inventory is fulfilled (deducted)
3. When orders are cancelled, reservations are released
Note: Letzshop orders with unmatched products (placeholder) skip
inventory operations for those items.
"""
def get_order_with_items(
self, db: Session, vendor_id: int, order_id: int
) -> Order:
"""Get order with items or raise OrderNotFoundException."""
order = (
db.query(Order)
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
.first()
)
if not order:
raise OrderNotFoundException(f"Order {order_id} not found")
return order
def _find_inventory_location(
self, db: Session, product_id: int, vendor_id: int
) -> str | None:
"""
Find the location with available inventory for a product.
Returns the first location with available quantity, or None if no
inventory exists.
"""
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == product_id,
Inventory.vendor_id == vendor_id,
Inventory.quantity > Inventory.reserved_quantity,
)
.first()
)
return inventory.location if inventory else None
def _is_placeholder_product(self, order_item: OrderItem) -> bool:
"""Check if the order item uses a placeholder product."""
if not order_item.product:
return True
# Check if it's the placeholder product (GTIN 0000000000000)
return order_item.product.gtin == "0000000000000"
def _log_transaction(
self,
db: Session,
vendor_id: int,
product_id: int,
inventory: Inventory,
transaction_type: TransactionType,
quantity_change: int,
order: Order,
reason: str | None = None,
) -> InventoryTransaction:
"""
Create an inventory transaction record for audit trail.
Args:
db: Database session
vendor_id: Vendor ID
product_id: Product ID
inventory: Inventory record after the operation
transaction_type: Type of transaction
quantity_change: Change in quantity (positive = add, negative = remove)
order: Order associated with this transaction
reason: Optional reason for the transaction
Returns:
Created InventoryTransaction
"""
transaction = InventoryTransaction.create_transaction(
vendor_id=vendor_id,
product_id=product_id,
inventory_id=inventory.id if inventory else None,
transaction_type=transaction_type,
quantity_change=quantity_change,
quantity_after=inventory.quantity if inventory else 0,
reserved_after=inventory.reserved_quantity if inventory else 0,
location=inventory.location if inventory else None,
warehouse=inventory.warehouse if inventory else None,
order_id=order.id,
order_number=order.order_number,
reason=reason,
created_by="system",
)
db.add(transaction)
return transaction
def reserve_for_order(
self,
db: Session,
vendor_id: int,
order_id: int,
skip_missing: bool = True,
) -> dict:
"""
Reserve inventory for all items in an order.
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
skip_missing: If True, skip items without inventory instead of failing
Returns:
Dict with reserved count and any skipped items
Raises:
InsufficientInventoryException: If skip_missing=False and inventory unavailable
"""
order = self.get_order_with_items(db, vendor_id, order_id)
reserved_count = 0
skipped_items = []
for item in order.items:
# Skip placeholder products
if self._is_placeholder_product(item):
skipped_items.append({
"item_id": item.id,
"reason": "placeholder_product",
})
continue
# Find inventory location
location = self._find_inventory_location(db, item.product_id, vendor_id)
if not location:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "no_inventory",
})
continue
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=location,
quantity=item.quantity,
)
updated_inventory = inventory_service.reserve_inventory(
db, vendor_id, reserve_data
)
reserved_count += 1
# Log transaction for audit trail
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.RESERVE,
quantity_change=0, # Reserve doesn't change quantity, only reserved_quantity
order=order,
reason=f"Reserved for order {order.order_number}",
)
logger.info(
f"Reserved {item.quantity} units of product {item.product_id} "
f"for order {order.order_number}"
)
except InsufficientInventoryException:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "insufficient_inventory",
})
else:
raise
logger.info(
f"Order {order.order_number}: reserved {reserved_count} items, "
f"skipped {len(skipped_items)}"
)
return {
"order_id": order_id,
"order_number": order.order_number,
"reserved_count": reserved_count,
"skipped_items": skipped_items,
}
def fulfill_order(
self,
db: Session,
vendor_id: int,
order_id: int,
skip_missing: bool = True,
) -> dict:
"""
Fulfill (deduct) inventory when an order is shipped.
This decreases both the total quantity and reserved quantity,
effectively consuming the reserved stock.
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
skip_missing: If True, skip items without inventory
Returns:
Dict with fulfilled count and any skipped items
"""
order = self.get_order_with_items(db, vendor_id, order_id)
fulfilled_count = 0
skipped_items = []
for item in order.items:
# Skip already fully shipped items
if item.is_fully_shipped:
continue
# Skip placeholder products
if self._is_placeholder_product(item):
skipped_items.append({
"item_id": item.id,
"reason": "placeholder_product",
})
continue
# Only fulfill remaining quantity
quantity_to_fulfill = item.remaining_quantity
# Find inventory location
location = self._find_inventory_location(db, item.product_id, vendor_id)
# Also check for inventory with reserved quantity
if not location:
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == item.product_id,
Inventory.vendor_id == vendor_id,
)
.first()
)
if inventory:
location = inventory.location
if not location:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "no_inventory",
})
continue
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=location,
quantity=quantity_to_fulfill,
)
updated_inventory = inventory_service.fulfill_reservation(
db, vendor_id, reserve_data
)
fulfilled_count += 1
# Update item shipped quantity
item.shipped_quantity = item.quantity
item.inventory_fulfilled = True
# Log transaction for audit trail
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.FULFILL,
quantity_change=-quantity_to_fulfill, # Negative because stock is consumed
order=order,
reason=f"Fulfilled for order {order.order_number}",
)
logger.info(
f"Fulfilled {quantity_to_fulfill} units of product {item.product_id} "
f"for order {order.order_number}"
)
except (InsufficientInventoryException, InventoryNotFoundException) as e:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": str(e),
})
else:
raise
logger.info(
f"Order {order.order_number}: fulfilled {fulfilled_count} items, "
f"skipped {len(skipped_items)}"
)
return {
"order_id": order_id,
"order_number": order.order_number,
"fulfilled_count": fulfilled_count,
"skipped_items": skipped_items,
}
def fulfill_item(
self,
db: Session,
vendor_id: int,
order_id: int,
item_id: int,
quantity: int | None = None,
skip_missing: bool = True,
) -> dict:
"""
Fulfill (deduct) inventory for a specific order item.
Supports partial fulfillment - ship some units now, rest later.
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
item_id: Order item ID
quantity: Quantity to ship (defaults to remaining quantity)
skip_missing: If True, skip if inventory not found
Returns:
Dict with fulfillment result
Raises:
ValidationException: If quantity exceeds remaining
"""
order = self.get_order_with_items(db, vendor_id, order_id)
# Find the item
item = None
for order_item in order.items:
if order_item.id == item_id:
item = order_item
break
if not item:
raise ValidationException(f"Item {item_id} not found in order {order_id}")
# Check if already fully shipped
if item.is_fully_shipped:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "Item already fully shipped",
}
# Default to remaining quantity
quantity_to_fulfill = quantity or item.remaining_quantity
# Validate quantity
if quantity_to_fulfill > item.remaining_quantity:
raise ValidationException(
f"Cannot ship {quantity_to_fulfill} units - only {item.remaining_quantity} remaining"
)
if quantity_to_fulfill <= 0:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "Nothing to fulfill",
}
# Skip placeholder products
if self._is_placeholder_product(item):
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "Placeholder product - skipped",
}
# Find inventory location
location = self._find_inventory_location(db, item.product_id, vendor_id)
if not location:
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == item.product_id,
Inventory.vendor_id == vendor_id,
)
.first()
)
if inventory:
location = inventory.location
if not location:
if skip_missing:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": "No inventory found",
}
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=location,
quantity=quantity_to_fulfill,
)
updated_inventory = inventory_service.fulfill_reservation(
db, vendor_id, reserve_data
)
# Update item shipped quantity
item.shipped_quantity += quantity_to_fulfill
# Mark as fulfilled only if fully shipped
if item.is_fully_shipped:
item.inventory_fulfilled = True
# Log transaction
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.FULFILL,
quantity_change=-quantity_to_fulfill,
order=order,
reason=f"Partial shipment for order {order.order_number}",
)
logger.info(
f"Fulfilled {quantity_to_fulfill} of {item.quantity} units "
f"for item {item_id} in order {order.order_number}"
)
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": quantity_to_fulfill,
"shipped_quantity": item.shipped_quantity,
"remaining_quantity": item.remaining_quantity,
"is_fully_shipped": item.is_fully_shipped,
}
except (InsufficientInventoryException, InventoryNotFoundException) as e:
if skip_missing:
return {
"order_id": order_id,
"item_id": item_id,
"fulfilled_quantity": 0,
"message": str(e),
}
else:
raise
def release_order_reservation(
self,
db: Session,
vendor_id: int,
order_id: int,
skip_missing: bool = True,
) -> dict:
"""
Release reserved inventory when an order is cancelled.
This decreases the reserved quantity, making the stock available again.
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
skip_missing: If True, skip items without inventory
Returns:
Dict with released count and any skipped items
"""
order = self.get_order_with_items(db, vendor_id, order_id)
released_count = 0
skipped_items = []
for item in order.items:
# Skip placeholder products
if self._is_placeholder_product(item):
skipped_items.append({
"item_id": item.id,
"reason": "placeholder_product",
})
continue
# Find inventory - look for any inventory for this product
inventory = (
db.query(Inventory)
.filter(
Inventory.product_id == item.product_id,
Inventory.vendor_id == vendor_id,
)
.first()
)
if not inventory:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": "no_inventory",
})
continue
else:
raise InventoryNotFoundException(
f"No inventory found for product {item.product_id}"
)
try:
reserve_data = InventoryReserve(
product_id=item.product_id,
location=inventory.location,
quantity=item.quantity,
)
updated_inventory = inventory_service.release_reservation(
db, vendor_id, reserve_data
)
released_count += 1
# Log transaction for audit trail
self._log_transaction(
db=db,
vendor_id=vendor_id,
product_id=item.product_id,
inventory=updated_inventory,
transaction_type=TransactionType.RELEASE,
quantity_change=0, # Release doesn't change quantity, only reserved_quantity
order=order,
reason=f"Released for cancelled order {order.order_number}",
)
logger.info(
f"Released {item.quantity} units of product {item.product_id} "
f"for cancelled order {order.order_number}"
)
except Exception as e:
if skip_missing:
skipped_items.append({
"item_id": item.id,
"product_id": item.product_id,
"reason": str(e),
})
else:
raise
logger.info(
f"Order {order.order_number}: released {released_count} items, "
f"skipped {len(skipped_items)}"
)
return {
"order_id": order_id,
"order_number": order.order_number,
"released_count": released_count,
"skipped_items": skipped_items,
}
def handle_status_change(
self,
db: Session,
vendor_id: int,
order_id: int,
old_status: str | None,
new_status: str,
) -> dict | None:
"""
Handle inventory operations based on order status changes.
Status transitions that trigger inventory operations:
- Any → processing: Reserve inventory (if not already reserved)
- processing → shipped: Fulfill inventory (deduct from stock)
- processing → partially_shipped: Partial fulfillment already done via fulfill_item
- Any → cancelled: Release reservations
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
old_status: Previous status (can be None for new orders)
new_status: New status
Returns:
Result of inventory operation, or None if no operation needed
"""
# Skip if status didn't change
if old_status == new_status:
return None
result = None
# Transitioning to processing - reserve inventory
if new_status == "processing":
result = self.reserve_for_order(db, vendor_id, order_id, skip_missing=True)
logger.info(f"Order {order_id} confirmed: inventory reserved")
# Transitioning to shipped - fulfill remaining inventory
elif new_status == "shipped":
result = self.fulfill_order(db, vendor_id, order_id, skip_missing=True)
logger.info(f"Order {order_id} shipped: inventory fulfilled")
# partially_shipped - no automatic fulfillment (handled via fulfill_item)
elif new_status == "partially_shipped":
logger.info(
f"Order {order_id} partially shipped: use fulfill_item for item-level fulfillment"
)
result = {"order_id": order_id, "status": "partially_shipped"}
# Transitioning to cancelled - release reservations
elif new_status == "cancelled":
# Only release if there was a previous status (order was in progress)
if old_status and old_status not in ("cancelled", "refunded"):
result = self.release_order_reservation(
db, vendor_id, order_id, skip_missing=True
)
logger.info(f"Order {order_id} cancelled: reservations released")
return result
def get_shipment_status(
self,
db: Session,
vendor_id: int,
order_id: int,
) -> dict:
"""
Get detailed shipment status for an order.
Returns item-level shipment status for partial shipment tracking.
Args:
db: Database session
vendor_id: Vendor ID
order_id: Order ID
Returns:
Dict with shipment status details
"""
order = self.get_order_with_items(db, vendor_id, order_id)
items = []
for item in order.items:
items.append({
"item_id": item.id,
"product_id": item.product_id,
"product_name": item.product_name,
"quantity": item.quantity,
"shipped_quantity": item.shipped_quantity,
"remaining_quantity": item.remaining_quantity,
"is_fully_shipped": item.is_fully_shipped,
"is_partially_shipped": item.is_partially_shipped,
})
return {
"order_id": order_id,
"order_number": order.order_number,
"order_status": order.status,
"is_fully_shipped": order.is_fully_shipped,
"is_partially_shipped": order.is_partially_shipped,
"shipped_item_count": order.shipped_item_count,
"total_item_count": len(order.items),
"total_shipped_units": order.total_shipped_units,
"total_ordered_units": order.total_ordered_units,
"items": items,
}
# Create service instance
order_inventory_service = OrderInventoryService()
__all__ = [
"order_inventory_service",
"OrderInventoryService",
]

View File

@@ -1,632 +1,23 @@
# app/services/order_item_exception_service.py
"""
Service for managing order item exceptions (unmatched products).
LEGACY LOCATION - Re-exports from module for backwards compatibility.
This service handles:
- Creating exceptions when products are not found during order import
- Resolving exceptions by assigning products
- Auto-matching when new products are imported
- Querying and statistics for exceptions
The canonical implementation is now in:
app/modules/orders/services/order_item_exception_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.orders.services import order_item_exception_service
"""
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload
from app.exceptions import (
ExceptionAlreadyResolvedException,
InvalidProductForExceptionException,
OrderItemExceptionNotFoundException,
ProductNotFoundException,
from app.modules.orders.services.order_item_exception_service import (
order_item_exception_service,
OrderItemExceptionService,
)
from models.database.order import Order, OrderItem
from models.database.order_item_exception import OrderItemException
from models.database.product import Product
logger = logging.getLogger(__name__)
class OrderItemExceptionService:
"""Service for order item exception CRUD and resolution workflow."""
# =========================================================================
# Exception Creation
# =========================================================================
def create_exception(
self,
db: Session,
order_item: OrderItem,
vendor_id: int,
original_gtin: str | None,
original_product_name: str | None,
original_sku: str | None,
exception_type: str = "product_not_found",
) -> OrderItemException:
"""
Create an exception record for an unmatched order item.
Args:
db: Database session
order_item: The order item that couldn't be matched
vendor_id: Vendor ID (denormalized for efficient queries)
original_gtin: Original GTIN from marketplace
original_product_name: Original product name from marketplace
original_sku: Original SKU from marketplace
exception_type: Type of exception (product_not_found, gtin_mismatch, etc.)
Returns:
Created OrderItemException
"""
exception = OrderItemException(
order_item_id=order_item.id,
vendor_id=vendor_id,
original_gtin=original_gtin,
original_product_name=original_product_name,
original_sku=original_sku,
exception_type=exception_type,
status="pending",
)
db.add(exception)
db.flush()
logger.info(
f"Created order item exception {exception.id} for order item "
f"{order_item.id}, GTIN: {original_gtin}"
)
return exception
# =========================================================================
# Exception Retrieval
# =========================================================================
def get_exception_by_id(
self,
db: Session,
exception_id: int,
vendor_id: int | None = None,
) -> OrderItemException:
"""
Get an exception by ID, optionally filtered by vendor.
Args:
db: Database session
exception_id: Exception ID
vendor_id: Optional vendor ID filter (for vendor-scoped access)
Returns:
OrderItemException
Raises:
OrderItemExceptionNotFoundException: If not found
"""
query = db.query(OrderItemException).filter(
OrderItemException.id == exception_id
)
if vendor_id is not None:
query = query.filter(OrderItemException.vendor_id == vendor_id)
exception = query.first()
if not exception:
raise OrderItemExceptionNotFoundException(exception_id)
return exception
def get_pending_exceptions(
self,
db: Session,
vendor_id: int | None = None,
status: str | None = None,
search: str | None = None,
skip: int = 0,
limit: int = 50,
) -> tuple[list[OrderItemException], int]:
"""
Get exceptions with pagination and filtering.
Args:
db: Database session
vendor_id: Optional vendor filter
status: Optional status filter (pending, resolved, ignored)
search: Optional search in GTIN, product name, or order number
skip: Pagination offset
limit: Pagination limit
Returns:
Tuple of (list of exceptions, total count)
"""
query = (
db.query(OrderItemException)
.join(OrderItem)
.join(Order)
.options(
joinedload(OrderItemException.order_item).joinedload(OrderItem.order)
)
)
# Apply filters
if vendor_id is not None:
query = query.filter(OrderItemException.vendor_id == vendor_id)
if status:
query = query.filter(OrderItemException.status == status)
if search:
search_pattern = f"%{search}%"
query = query.filter(
or_(
OrderItemException.original_gtin.ilike(search_pattern),
OrderItemException.original_product_name.ilike(search_pattern),
OrderItemException.original_sku.ilike(search_pattern),
Order.order_number.ilike(search_pattern),
)
)
# Get total count
total = query.count()
# Apply pagination and ordering
exceptions = (
query.order_by(OrderItemException.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return exceptions, total
def get_exceptions_for_order(
self,
db: Session,
order_id: int,
) -> list[OrderItemException]:
"""
Get all exceptions for items in an order.
Args:
db: Database session
order_id: Order ID
Returns:
List of exceptions for the order
"""
return (
db.query(OrderItemException)
.join(OrderItem)
.filter(OrderItem.order_id == order_id)
.all()
)
# =========================================================================
# Exception Statistics
# =========================================================================
def get_exception_stats(
self,
db: Session,
vendor_id: int | None = None,
) -> dict[str, int]:
"""
Get exception counts by status.
Args:
db: Database session
vendor_id: Optional vendor filter
Returns:
Dict with pending, resolved, ignored, total counts
"""
query = db.query(
OrderItemException.status,
func.count(OrderItemException.id).label("count"),
)
if vendor_id is not None:
query = query.filter(OrderItemException.vendor_id == vendor_id)
results = query.group_by(OrderItemException.status).all()
stats = {
"pending": 0,
"resolved": 0,
"ignored": 0,
"total": 0,
}
for status, count in results:
if status in stats:
stats[status] = count
stats["total"] += count
# Count orders with pending exceptions
orders_query = (
db.query(func.count(func.distinct(OrderItem.order_id)))
.join(OrderItemException)
.filter(OrderItemException.status == "pending")
)
if vendor_id is not None:
orders_query = orders_query.filter(
OrderItemException.vendor_id == vendor_id
)
stats["orders_with_exceptions"] = orders_query.scalar() or 0
return stats
# =========================================================================
# Exception Resolution
# =========================================================================
def resolve_exception(
self,
db: Session,
exception_id: int,
product_id: int,
resolved_by: int,
notes: str | None = None,
vendor_id: int | None = None,
) -> OrderItemException:
"""
Resolve an exception by assigning a product.
This updates:
- The exception record (status, resolved_product_id, etc.)
- The order item (product_id, needs_product_match)
Args:
db: Database session
exception_id: Exception ID to resolve
product_id: Product ID to assign
resolved_by: User ID who resolved
notes: Optional resolution notes
vendor_id: Optional vendor filter (for scoped access)
Returns:
Updated OrderItemException
Raises:
OrderItemExceptionNotFoundException: If exception not found
ExceptionAlreadyResolvedException: If already resolved
InvalidProductForExceptionException: If product is invalid
"""
exception = self.get_exception_by_id(db, exception_id, vendor_id)
if exception.status == "resolved":
raise ExceptionAlreadyResolvedException(exception_id)
# Validate product exists and belongs to vendor
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise ProductNotFoundException(product_id)
if product.vendor_id != exception.vendor_id:
raise InvalidProductForExceptionException(
product_id, "Product belongs to a different vendor"
)
if not product.is_active:
raise InvalidProductForExceptionException(
product_id, "Product is not active"
)
# Update exception
exception.status = "resolved"
exception.resolved_product_id = product_id
exception.resolved_at = datetime.now(UTC)
exception.resolved_by = resolved_by
exception.resolution_notes = notes
# Update order item
order_item = exception.order_item
order_item.product_id = product_id
order_item.needs_product_match = False
# Update product snapshot on order item
if product.marketplace_product:
order_item.product_name = product.marketplace_product.get_title("en")
order_item.product_sku = product.vendor_sku or order_item.product_sku
db.flush()
logger.info(
f"Resolved exception {exception_id} with product {product_id} "
f"by user {resolved_by}"
)
return exception
def ignore_exception(
self,
db: Session,
exception_id: int,
resolved_by: int,
notes: str,
vendor_id: int | None = None,
) -> OrderItemException:
"""
Mark an exception as ignored.
Note: Ignored exceptions still block order confirmation.
This is for tracking purposes (e.g., product will never be matched).
Args:
db: Database session
exception_id: Exception ID
resolved_by: User ID who ignored
notes: Reason for ignoring (required)
vendor_id: Optional vendor filter
Returns:
Updated OrderItemException
"""
exception = self.get_exception_by_id(db, exception_id, vendor_id)
if exception.status == "resolved":
raise ExceptionAlreadyResolvedException(exception_id)
exception.status = "ignored"
exception.resolved_at = datetime.now(UTC)
exception.resolved_by = resolved_by
exception.resolution_notes = notes
db.flush()
logger.info(
f"Ignored exception {exception_id} by user {resolved_by}: {notes}"
)
return exception
# =========================================================================
# Auto-Matching
# =========================================================================
def auto_match_by_gtin(
self,
db: Session,
vendor_id: int,
gtin: str,
product_id: int,
) -> list[OrderItemException]:
"""
Auto-resolve pending exceptions matching a GTIN.
Called after a product is imported with a GTIN.
Args:
db: Database session
vendor_id: Vendor ID
gtin: GTIN to match
product_id: Product ID to assign
Returns:
List of resolved exceptions
"""
if not gtin:
return []
# Find pending exceptions for this GTIN
pending = (
db.query(OrderItemException)
.filter(
and_(
OrderItemException.vendor_id == vendor_id,
OrderItemException.original_gtin == gtin,
OrderItemException.status == "pending",
)
)
.all()
)
if not pending:
return []
# Get product for snapshot update
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
logger.warning(f"Product {product_id} not found for auto-match")
return []
resolved = []
now = datetime.now(UTC)
for exception in pending:
exception.status = "resolved"
exception.resolved_product_id = product_id
exception.resolved_at = now
exception.resolution_notes = "Auto-matched during product import"
# Update order item
order_item = exception.order_item
order_item.product_id = product_id
order_item.needs_product_match = False
if product.marketplace_product:
order_item.product_name = product.marketplace_product.get_title("en")
resolved.append(exception)
if resolved:
db.flush()
logger.info(
f"Auto-matched {len(resolved)} exceptions for GTIN {gtin} "
f"with product {product_id}"
)
return resolved
def auto_match_batch(
self,
db: Session,
vendor_id: int,
gtin_to_product: dict[str, int],
) -> int:
"""
Batch auto-match multiple GTINs after bulk import.
Args:
db: Database session
vendor_id: Vendor ID
gtin_to_product: Dict mapping GTIN to product ID
Returns:
Total number of resolved exceptions
"""
if not gtin_to_product:
return 0
total_resolved = 0
for gtin, product_id in gtin_to_product.items():
resolved = self.auto_match_by_gtin(db, vendor_id, gtin, product_id)
total_resolved += len(resolved)
return total_resolved
# =========================================================================
# Confirmation Checks
# =========================================================================
def order_has_unresolved_exceptions(
self,
db: Session,
order_id: int,
) -> bool:
"""
Check if order has any unresolved exceptions.
An order cannot be confirmed if it has pending or ignored exceptions.
Args:
db: Database session
order_id: Order ID
Returns:
True if order has unresolved exceptions
"""
count = (
db.query(func.count(OrderItemException.id))
.join(OrderItem)
.filter(
and_(
OrderItem.order_id == order_id,
OrderItemException.status.in_(["pending", "ignored"]),
)
)
.scalar()
)
return count > 0
def get_unresolved_exception_count(
self,
db: Session,
order_id: int,
) -> int:
"""
Get count of unresolved exceptions for an order.
Args:
db: Database session
order_id: Order ID
Returns:
Count of unresolved exceptions
"""
return (
db.query(func.count(OrderItemException.id))
.join(OrderItem)
.filter(
and_(
OrderItem.order_id == order_id,
OrderItemException.status.in_(["pending", "ignored"]),
)
)
.scalar()
) or 0
# =========================================================================
# Bulk Operations
# =========================================================================
def bulk_resolve_by_gtin(
self,
db: Session,
vendor_id: int,
gtin: str,
product_id: int,
resolved_by: int,
notes: str | None = None,
) -> int:
"""
Bulk resolve all pending exceptions for a GTIN.
Args:
db: Database session
vendor_id: Vendor ID
gtin: GTIN to match
product_id: Product ID to assign
resolved_by: User ID who resolved
notes: Optional notes
Returns:
Number of resolved exceptions
"""
# Validate product
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise ProductNotFoundException(product_id)
if product.vendor_id != vendor_id:
raise InvalidProductForExceptionException(
product_id, "Product belongs to a different vendor"
)
# Find and resolve all pending exceptions for this GTIN
pending = (
db.query(OrderItemException)
.filter(
and_(
OrderItemException.vendor_id == vendor_id,
OrderItemException.original_gtin == gtin,
OrderItemException.status == "pending",
)
)
.all()
)
now = datetime.now(UTC)
resolution_notes = notes or f"Bulk resolved for GTIN {gtin}"
for exception in pending:
exception.status = "resolved"
exception.resolved_product_id = product_id
exception.resolved_at = now
exception.resolved_by = resolved_by
exception.resolution_notes = resolution_notes
# Update order item
order_item = exception.order_item
order_item.product_id = product_id
order_item.needs_product_match = False
if product.marketplace_product:
order_item.product_name = product.marketplace_product.get_title("en")
db.flush()
logger.info(
f"Bulk resolved {len(pending)} exceptions for GTIN {gtin} "
f"with product {product_id} by user {resolved_by}"
)
return len(pending)
# Global service instance
order_item_exception_service = OrderItemExceptionService()
__all__ = [
"order_item_exception_service",
"OrderItemExceptionService",
]

File diff suppressed because it is too large Load Diff

View File

@@ -1,507 +1,23 @@
# app/services/test_runner_service.py
"""
Test Runner Service
Service for running pytest and storing results
LEGACY LOCATION - Re-exports from module for backwards compatibility.
The canonical implementation is now in:
app/modules/dev_tools/services/test_runner_service.py
This file exists to maintain backwards compatibility with code that
imports from the old location. All new code should import directly
from the module:
from app.modules.dev_tools.services import test_runner_service
"""
import json
import logging
import re
import subprocess
import tempfile
from datetime import UTC, datetime
from pathlib import Path
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from models.database.test_run import TestCollection, TestResult, TestRun
logger = logging.getLogger(__name__)
class TestRunnerService:
"""Service for managing pytest test runs"""
def __init__(self):
self.project_root = Path(__file__).parent.parent.parent
def create_test_run(
self,
db: Session,
test_path: str = "tests",
triggered_by: str = "manual",
) -> TestRun:
"""Create a test run record without executing tests"""
test_run = TestRun(
timestamp=datetime.now(UTC),
triggered_by=triggered_by,
test_path=test_path,
status="running",
git_commit_hash=self._get_git_commit(),
git_branch=self._get_git_branch(),
)
db.add(test_run)
db.flush()
return test_run
def run_tests(
self,
db: Session,
test_path: str = "tests",
triggered_by: str = "manual",
extra_args: list[str] | None = None,
) -> TestRun:
"""
Run pytest synchronously and store results in database
Args:
db: Database session
test_path: Path to tests (relative to project root)
triggered_by: Who triggered the run
extra_args: Additional pytest arguments
Returns:
TestRun object with results
"""
test_run = self.create_test_run(db, test_path, triggered_by)
self._execute_tests(db, test_run, test_path, extra_args)
return test_run
def _execute_tests(
self,
db: Session,
test_run: TestRun,
test_path: str,
extra_args: list[str] | None,
) -> None:
"""Execute pytest and update the test run record"""
try:
# Build pytest command with JSON output
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json_report_path = f.name
pytest_args = [
"python",
"-m",
"pytest",
test_path,
"--json-report",
f"--json-report-file={json_report_path}",
"-v",
"--tb=short",
]
if extra_args:
pytest_args.extend(extra_args)
test_run.pytest_args = " ".join(pytest_args)
# Run pytest
start_time = datetime.now(UTC)
result = subprocess.run(
pytest_args,
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=600, # 10 minute timeout
)
end_time = datetime.now(UTC)
test_run.duration_seconds = (end_time - start_time).total_seconds()
# Parse JSON report
try:
with open(json_report_path) as f:
report = json.load(f)
self._process_json_report(db, test_run, report)
except (FileNotFoundError, json.JSONDecodeError) as e:
# Fallback to parsing stdout if JSON report failed
logger.warning(f"JSON report unavailable ({e}), parsing stdout")
self._parse_pytest_output(test_run, result.stdout, result.stderr)
finally:
# Clean up temp file
try:
Path(json_report_path).unlink()
except Exception:
pass
# Set final status
if test_run.failed > 0 or test_run.errors > 0:
test_run.status = "failed"
else:
test_run.status = "passed"
except subprocess.TimeoutExpired:
test_run.status = "error"
logger.error("Pytest run timed out")
except Exception as e:
test_run.status = "error"
logger.error(f"Error running tests: {e}")
def _process_json_report(self, db: Session, test_run: TestRun, report: dict):
"""Process pytest-json-report output"""
summary = report.get("summary", {})
test_run.total_tests = summary.get("total", 0)
test_run.passed = summary.get("passed", 0)
test_run.failed = summary.get("failed", 0)
test_run.errors = summary.get("error", 0)
test_run.skipped = summary.get("skipped", 0)
test_run.xfailed = summary.get("xfailed", 0)
test_run.xpassed = summary.get("xpassed", 0)
# Process individual test results
tests = report.get("tests", [])
for test in tests:
node_id = test.get("nodeid", "")
outcome = test.get("outcome", "unknown")
# Parse node_id to get file, class, function
test_file, test_class, test_name = self._parse_node_id(node_id)
# Get failure details
error_message = None
traceback = None
if outcome in ("failed", "error"):
call_info = test.get("call", {})
if "longrepr" in call_info:
traceback = call_info["longrepr"]
# Extract error message from traceback
if isinstance(traceback, str):
lines = traceback.strip().split("\n")
if lines:
error_message = lines[-1][:500] # Last line, limited length
test_result = TestResult(
run_id=test_run.id,
node_id=node_id,
test_name=test_name,
test_file=test_file,
test_class=test_class,
outcome=outcome,
duration_seconds=test.get("duration", 0.0),
error_message=error_message,
traceback=traceback,
markers=test.get("keywords", []),
)
db.add(test_result)
def _parse_node_id(self, node_id: str) -> tuple[str, str | None, str]:
"""Parse pytest node_id into file, class, function"""
# Format: tests/unit/test_foo.py::TestClass::test_method
# or: tests/unit/test_foo.py::test_function
parts = node_id.split("::")
test_file = parts[0] if parts else ""
test_class = None
test_name = parts[-1] if parts else ""
if len(parts) == 3:
test_class = parts[1]
elif len(parts) == 2:
# Could be Class::method or file::function
if parts[1].startswith("Test"):
test_class = parts[1]
test_name = parts[1]
# Handle parametrized tests
if "[" in test_name:
test_name = test_name.split("[")[0]
return test_file, test_class, test_name
def _parse_pytest_output(self, test_run: TestRun, stdout: str, stderr: str):
"""Fallback parser for pytest text output"""
# Parse summary line like: "10 passed, 2 failed, 1 skipped"
summary_pattern = r"(\d+)\s+(passed|failed|error|skipped|xfailed|xpassed)"
for match in re.finditer(summary_pattern, stdout):
count = int(match.group(1))
status = match.group(2)
if status == "passed":
test_run.passed = count
elif status == "failed":
test_run.failed = count
elif status == "error":
test_run.errors = count
elif status == "skipped":
test_run.skipped = count
elif status == "xfailed":
test_run.xfailed = count
elif status == "xpassed":
test_run.xpassed = count
test_run.total_tests = (
test_run.passed
+ test_run.failed
+ test_run.errors
+ test_run.skipped
+ test_run.xfailed
+ test_run.xpassed
)
def _get_git_commit(self) -> str | None:
"""Get current git commit hash"""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=5,
)
return result.stdout.strip()[:40] if result.returncode == 0 else None
except:
return None
def _get_git_branch(self) -> str | None:
"""Get current git branch"""
try:
result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=5,
)
return result.stdout.strip() if result.returncode == 0 else None
except:
return None
def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]:
"""Get recent test run history"""
return db.query(TestRun).order_by(desc(TestRun.timestamp)).limit(limit).all()
def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None:
"""Get a specific test run with results"""
return db.query(TestRun).filter(TestRun.id == run_id).first()
def get_failed_tests(self, db: Session, run_id: int) -> list[TestResult]:
"""Get failed tests from a run"""
return (
db.query(TestResult)
.filter(
TestResult.run_id == run_id, TestResult.outcome.in_(["failed", "error"])
)
.all()
)
def get_run_results(
self, db: Session, run_id: int, outcome: str | None = None
) -> list[TestResult]:
"""Get test results for a specific run, optionally filtered by outcome"""
query = db.query(TestResult).filter(TestResult.run_id == run_id)
if outcome:
query = query.filter(TestResult.outcome == outcome)
return query.all()
def get_dashboard_stats(self, db: Session) -> dict:
"""Get statistics for the testing dashboard"""
# Get latest run
latest_run = (
db.query(TestRun)
.filter(TestRun.status != "running")
.order_by(desc(TestRun.timestamp))
.first()
)
# Get test collection info (or calculate from latest run)
collection = (
db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first()
)
# Get trend data (last 10 runs)
trend_runs = (
db.query(TestRun)
.filter(TestRun.status != "running")
.order_by(desc(TestRun.timestamp))
.limit(10)
.all()
)
# Calculate stats by category from latest run
by_category = {}
if latest_run:
results = (
db.query(TestResult).filter(TestResult.run_id == latest_run.id).all()
)
for result in results:
# Categorize by test path
if "unit" in result.test_file:
category = "Unit Tests"
elif "integration" in result.test_file:
category = "Integration Tests"
elif "performance" in result.test_file:
category = "Performance Tests"
elif "system" in result.test_file:
category = "System Tests"
else:
category = "Other"
if category not in by_category:
by_category[category] = {"total": 0, "passed": 0, "failed": 0}
by_category[category]["total"] += 1
if result.outcome == "passed":
by_category[category]["passed"] += 1
elif result.outcome in ("failed", "error"):
by_category[category]["failed"] += 1
# Get top failing tests (across recent runs)
top_failing = (
db.query(
TestResult.test_name,
TestResult.test_file,
func.count(TestResult.id).label("failure_count"),
)
.filter(TestResult.outcome.in_(["failed", "error"]))
.group_by(TestResult.test_name, TestResult.test_file)
.order_by(desc("failure_count"))
.limit(10)
.all()
)
return {
# Current run stats
"total_tests": latest_run.total_tests if latest_run else 0,
"passed": latest_run.passed if latest_run else 0,
"failed": latest_run.failed if latest_run else 0,
"errors": latest_run.errors if latest_run else 0,
"skipped": latest_run.skipped if latest_run else 0,
"pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0,
"duration_seconds": round(latest_run.duration_seconds, 2)
if latest_run
else 0,
"coverage_percent": latest_run.coverage_percent if latest_run else None,
"last_run": latest_run.timestamp.isoformat() if latest_run else None,
"last_run_status": latest_run.status if latest_run else None,
# Collection stats
"total_test_files": collection.total_files if collection else 0,
"collected_tests": collection.total_tests if collection else 0,
"unit_tests": collection.unit_tests if collection else 0,
"integration_tests": collection.integration_tests if collection else 0,
"performance_tests": collection.performance_tests if collection else 0,
"system_tests": collection.system_tests if collection else 0,
"last_collected": collection.collected_at.isoformat()
if collection
else None,
# Trend data
"trend": [
{
"timestamp": run.timestamp.isoformat(),
"total": run.total_tests,
"passed": run.passed,
"failed": run.failed,
"pass_rate": round(run.pass_rate, 1),
"duration": round(run.duration_seconds, 1),
}
for run in reversed(trend_runs)
],
# By category
"by_category": by_category,
# Top failing tests
"top_failing": [
{
"test_name": t.test_name,
"test_file": t.test_file,
"failure_count": t.failure_count,
}
for t in top_failing
],
}
def collect_tests(self, db: Session) -> TestCollection:
"""Collect test information without running tests"""
collection = TestCollection(
collected_at=datetime.now(UTC),
)
try:
# Run pytest --collect-only with JSON report
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json_report_path = f.name
result = subprocess.run(
[
"python",
"-m",
"pytest",
"--collect-only",
"--json-report",
f"--json-report-file={json_report_path}",
"tests",
],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=120,
)
# Parse JSON report
json_path = Path(json_report_path)
if json_path.exists():
with open(json_path) as f:
report = json.load(f)
# Get total from summary
collection.total_tests = report.get("summary", {}).get("collected", 0)
# Parse collectors to get test files and counts
test_files = {}
for collector in report.get("collectors", []):
for item in collector.get("result", []):
if item.get("type") == "Function":
node_id = item.get("nodeid", "")
if "::" in node_id:
file_path = node_id.split("::")[0]
if file_path not in test_files:
test_files[file_path] = 0
test_files[file_path] += 1
# Count files and categorize
for file_path, count in test_files.items():
collection.total_files += 1
if "/unit/" in file_path or file_path.startswith("tests/unit"):
collection.unit_tests += count
elif "/integration/" in file_path or file_path.startswith(
"tests/integration"
):
collection.integration_tests += count
elif "/performance/" in file_path or file_path.startswith(
"tests/performance"
):
collection.performance_tests += count
elif "/system/" in file_path or file_path.startswith(
"tests/system"
):
collection.system_tests += count
collection.test_files = [
{"file": f, "count": c}
for f, c in sorted(test_files.items(), key=lambda x: -x[1])
]
# Cleanup
json_path.unlink(missing_ok=True)
logger.info(
f"Collected {collection.total_tests} tests from {collection.total_files} files"
)
except Exception as e:
logger.error(f"Error collecting tests: {e}", exc_info=True)
db.add(collection)
return collection
# Singleton instance
test_runner_service = TestRunnerService()
from app.modules.dev_tools.services.test_runner_service import (
test_runner_service,
TestRunnerService,
)
__all__ = [
"test_runner_service",
"TestRunnerService",
]