refactor: migrate modules from re-exports to canonical implementations
Move actual code implementations into module directories: - orders: 5 services, 4 models, order/invoice schemas - inventory: 3 services, 2 models, 30+ schemas - customers: 3 services, 2 models, customer schemas - messaging: 3 services, 2 models, message/notification schemas - monitoring: background_tasks_service - marketplace: 5+ services including letzshop submodule - dev_tools: code_quality_service, test_runner_service - billing: billing_service - contracts: definition.py Legacy files in app/services/, models/database/, models/schema/ now re-export from canonical module locations for backwards compatibility. Architecture validator passes with 0 errors. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,242 +1,23 @@
|
||||
# app/services/admin_customer_service.py
|
||||
"""
|
||||
Admin customer management service.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Handles customer operations for admin users across all vendors.
|
||||
The canonical implementation is now in:
|
||||
app/modules/customers/services/admin_customer_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.customers.services import admin_customer_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from app.modules.customers.services.admin_customer_service import (
|
||||
admin_customer_service,
|
||||
AdminCustomerService,
|
||||
)
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions.customer import CustomerNotFoundException
|
||||
from models.database.customer import Customer
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminCustomerService:
|
||||
"""Service for admin-level customer management across vendors."""
|
||||
|
||||
def list_customers(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
search: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get paginated list of customers across all vendors.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter
|
||||
search: Search by email, name, or customer number
|
||||
is_active: Filter by active status
|
||||
skip: Number of records to skip
|
||||
limit: Maximum records to return
|
||||
|
||||
Returns:
|
||||
Tuple of (customers list, total count)
|
||||
"""
|
||||
# Build query
|
||||
query = db.query(Customer).join(Vendor, Customer.vendor_id == Vendor.id)
|
||||
|
||||
# Apply filters
|
||||
if vendor_id:
|
||||
query = query.filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.filter(
|
||||
(Customer.email.ilike(search_term))
|
||||
| (Customer.first_name.ilike(search_term))
|
||||
| (Customer.last_name.ilike(search_term))
|
||||
| (Customer.customer_number.ilike(search_term))
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Customer.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get paginated results with vendor info
|
||||
customers = (
|
||||
query.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
|
||||
.order_by(Customer.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Format response
|
||||
result = []
|
||||
for row in customers:
|
||||
customer = row[0]
|
||||
vendor_name = row[1]
|
||||
vendor_code = row[2]
|
||||
|
||||
customer_dict = {
|
||||
"id": customer.id,
|
||||
"vendor_id": customer.vendor_id,
|
||||
"email": customer.email,
|
||||
"first_name": customer.first_name,
|
||||
"last_name": customer.last_name,
|
||||
"phone": customer.phone,
|
||||
"customer_number": customer.customer_number,
|
||||
"marketing_consent": customer.marketing_consent,
|
||||
"preferred_language": customer.preferred_language,
|
||||
"last_order_date": customer.last_order_date,
|
||||
"total_orders": customer.total_orders,
|
||||
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
|
||||
"is_active": customer.is_active,
|
||||
"created_at": customer.created_at,
|
||||
"updated_at": customer.updated_at,
|
||||
"vendor_name": vendor_name,
|
||||
"vendor_code": vendor_code,
|
||||
}
|
||||
result.append(customer_dict)
|
||||
|
||||
return result, total
|
||||
|
||||
def get_customer_stats(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get customer statistics.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter
|
||||
|
||||
Returns:
|
||||
Dict with customer statistics
|
||||
"""
|
||||
query = db.query(Customer)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
total = query.count()
|
||||
active = query.filter(Customer.is_active == True).count() # noqa: E712
|
||||
inactive = query.filter(Customer.is_active == False).count() # noqa: E712
|
||||
with_orders = query.filter(Customer.total_orders > 0).count()
|
||||
|
||||
# Total spent across all customers
|
||||
total_spent_result = query.with_entities(func.sum(Customer.total_spent)).scalar()
|
||||
total_spent = float(total_spent_result) if total_spent_result else 0
|
||||
|
||||
# Average order value
|
||||
total_orders_result = query.with_entities(func.sum(Customer.total_orders)).scalar()
|
||||
total_orders = int(total_orders_result) if total_orders_result else 0
|
||||
avg_order_value = total_spent / total_orders if total_orders > 0 else 0
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"active": active,
|
||||
"inactive": inactive,
|
||||
"with_orders": with_orders,
|
||||
"total_spent": total_spent,
|
||||
"total_orders": total_orders,
|
||||
"avg_order_value": round(avg_order_value, 2),
|
||||
}
|
||||
|
||||
def get_customer(
|
||||
self,
|
||||
db: Session,
|
||||
customer_id: int,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get customer details by ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer dict with vendor info
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
result = (
|
||||
db.query(Customer)
|
||||
.join(Vendor, Customer.vendor_id == Vendor.id)
|
||||
.add_columns(Vendor.name.label("vendor_name"), Vendor.vendor_code)
|
||||
.filter(Customer.id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise CustomerNotFoundException(str(customer_id))
|
||||
|
||||
customer = result[0]
|
||||
return {
|
||||
"id": customer.id,
|
||||
"vendor_id": customer.vendor_id,
|
||||
"email": customer.email,
|
||||
"first_name": customer.first_name,
|
||||
"last_name": customer.last_name,
|
||||
"phone": customer.phone,
|
||||
"customer_number": customer.customer_number,
|
||||
"marketing_consent": customer.marketing_consent,
|
||||
"preferred_language": customer.preferred_language,
|
||||
"last_order_date": customer.last_order_date,
|
||||
"total_orders": customer.total_orders,
|
||||
"total_spent": float(customer.total_spent) if customer.total_spent else 0,
|
||||
"is_active": customer.is_active,
|
||||
"created_at": customer.created_at,
|
||||
"updated_at": customer.updated_at,
|
||||
"vendor_name": result[1],
|
||||
"vendor_code": result[2],
|
||||
}
|
||||
|
||||
def toggle_customer_status(
|
||||
self,
|
||||
db: Session,
|
||||
customer_id: int,
|
||||
admin_email: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Toggle customer active status.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
customer_id: Customer ID
|
||||
admin_email: Admin user email for logging
|
||||
|
||||
Returns:
|
||||
Dict with customer ID, new status, and message
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
customer = db.query(Customer).filter(Customer.id == customer_id).first()
|
||||
|
||||
if not customer:
|
||||
raise CustomerNotFoundException(str(customer_id))
|
||||
|
||||
customer.is_active = not customer.is_active
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
status = "activated" if customer.is_active else "deactivated"
|
||||
logger.info(f"Customer {customer.email} {status} by admin {admin_email}")
|
||||
|
||||
return {
|
||||
"id": customer.id,
|
||||
"is_active": customer.is_active,
|
||||
"message": f"Customer {status} successfully",
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
admin_customer_service = AdminCustomerService()
|
||||
__all__ = [
|
||||
"admin_customer_service",
|
||||
"AdminCustomerService",
|
||||
]
|
||||
|
||||
@@ -1,701 +1,37 @@
|
||||
# app/services/admin_notification_service.py
|
||||
"""
|
||||
Admin notification service.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Provides functionality for:
|
||||
- Creating and managing admin notifications
|
||||
- Managing platform alerts
|
||||
- Notification statistics and queries
|
||||
The canonical implementation is now in:
|
||||
app/modules/messaging/services/admin_notification_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.messaging.services import admin_notification_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, case, func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.admin import AdminNotification, PlatformAlert
|
||||
from models.schema.admin import AdminNotificationCreate, PlatformAlertCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NOTIFICATION TYPES
|
||||
# ============================================================================
|
||||
|
||||
class NotificationType:
|
||||
"""Notification type constants."""
|
||||
|
||||
SYSTEM_ALERT = "system_alert"
|
||||
IMPORT_FAILURE = "import_failure"
|
||||
EXPORT_FAILURE = "export_failure"
|
||||
ORDER_SYNC_FAILURE = "order_sync_failure"
|
||||
VENDOR_ISSUE = "vendor_issue"
|
||||
CUSTOMER_MESSAGE = "customer_message"
|
||||
VENDOR_MESSAGE = "vendor_message"
|
||||
SECURITY_ALERT = "security_alert"
|
||||
PERFORMANCE_ALERT = "performance_alert"
|
||||
ORDER_EXCEPTION = "order_exception"
|
||||
CRITICAL_ERROR = "critical_error"
|
||||
|
||||
|
||||
class Priority:
|
||||
"""Priority level constants."""
|
||||
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class AlertType:
|
||||
"""Platform alert type constants."""
|
||||
|
||||
SECURITY = "security"
|
||||
PERFORMANCE = "performance"
|
||||
CAPACITY = "capacity"
|
||||
INTEGRATION = "integration"
|
||||
DATABASE = "database"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class Severity:
|
||||
"""Alert severity constants."""
|
||||
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ADMIN NOTIFICATION SERVICE
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class AdminNotificationService:
|
||||
"""Service for managing admin notifications."""
|
||||
|
||||
def create_notification(
|
||||
self,
|
||||
db: Session,
|
||||
notification_type: str,
|
||||
title: str,
|
||||
message: str,
|
||||
priority: str = Priority.NORMAL,
|
||||
action_required: bool = False,
|
||||
action_url: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> AdminNotification:
|
||||
"""
|
||||
Create a new admin notification.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
notification_type: Type of notification
|
||||
title: Notification title
|
||||
message: Notification message
|
||||
priority: Priority level (low, normal, high, critical)
|
||||
action_required: Whether action is required
|
||||
action_url: URL to relevant admin page
|
||||
metadata: Additional contextual data
|
||||
|
||||
Returns:
|
||||
Created AdminNotification
|
||||
"""
|
||||
notification = AdminNotification(
|
||||
type=notification_type,
|
||||
title=title,
|
||||
message=message,
|
||||
priority=priority,
|
||||
action_required=action_required,
|
||||
action_url=action_url,
|
||||
notification_metadata=metadata,
|
||||
)
|
||||
db.add(notification)
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Created notification: {notification_type} - {title} (priority: {priority})"
|
||||
)
|
||||
|
||||
return notification
|
||||
|
||||
def create_from_schema(
|
||||
self,
|
||||
db: Session,
|
||||
data: AdminNotificationCreate,
|
||||
) -> AdminNotification:
|
||||
"""Create notification from Pydantic schema."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=data.type,
|
||||
title=data.title,
|
||||
message=data.message,
|
||||
priority=data.priority,
|
||||
action_required=data.action_required,
|
||||
action_url=data.action_url,
|
||||
metadata=data.metadata,
|
||||
)
|
||||
|
||||
def get_notifications(
|
||||
self,
|
||||
db: Session,
|
||||
priority: str | None = None,
|
||||
is_read: bool | None = None,
|
||||
notification_type: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[AdminNotification], int, int]:
|
||||
"""
|
||||
Get paginated admin notifications.
|
||||
|
||||
Returns:
|
||||
Tuple of (notifications, total_count, unread_count)
|
||||
"""
|
||||
query = db.query(AdminNotification)
|
||||
|
||||
# Apply filters
|
||||
if priority:
|
||||
query = query.filter(AdminNotification.priority == priority)
|
||||
|
||||
if is_read is not None:
|
||||
query = query.filter(AdminNotification.is_read == is_read)
|
||||
|
||||
if notification_type:
|
||||
query = query.filter(AdminNotification.type == notification_type)
|
||||
|
||||
# Get counts
|
||||
total = query.count()
|
||||
unread_count = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get paginated results ordered by priority and date
|
||||
priority_order = case(
|
||||
(AdminNotification.priority == "critical", 1),
|
||||
(AdminNotification.priority == "high", 2),
|
||||
(AdminNotification.priority == "normal", 3),
|
||||
(AdminNotification.priority == "low", 4),
|
||||
else_=5,
|
||||
)
|
||||
|
||||
notifications = (
|
||||
query.order_by(
|
||||
AdminNotification.is_read, # Unread first
|
||||
priority_order,
|
||||
AdminNotification.created_at.desc(),
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return notifications, total, unread_count
|
||||
|
||||
def get_unread_count(self, db: Session) -> int:
|
||||
"""Get count of unread notifications."""
|
||||
return (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
|
||||
def get_recent_notifications(
|
||||
self,
|
||||
db: Session,
|
||||
limit: int = 5,
|
||||
) -> list[AdminNotification]:
|
||||
"""Get recent unread notifications for header dropdown."""
|
||||
priority_order = case(
|
||||
(AdminNotification.priority == "critical", 1),
|
||||
(AdminNotification.priority == "high", 2),
|
||||
(AdminNotification.priority == "normal", 3),
|
||||
(AdminNotification.priority == "low", 4),
|
||||
else_=5,
|
||||
)
|
||||
|
||||
return (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.order_by(priority_order, AdminNotification.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def mark_as_read(
|
||||
self,
|
||||
db: Session,
|
||||
notification_id: int,
|
||||
user_id: int,
|
||||
) -> AdminNotification | None:
|
||||
"""Mark a notification as read."""
|
||||
notification = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.id == notification_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if notification and not notification.is_read:
|
||||
notification.is_read = True
|
||||
notification.read_at = datetime.utcnow()
|
||||
notification.read_by_user_id = user_id
|
||||
db.flush()
|
||||
|
||||
return notification
|
||||
|
||||
def mark_all_as_read(
|
||||
self,
|
||||
db: Session,
|
||||
user_id: int,
|
||||
) -> int:
|
||||
"""Mark all unread notifications as read. Returns count of updated."""
|
||||
now = datetime.utcnow()
|
||||
count = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.is_read == False) # noqa: E712
|
||||
.update(
|
||||
{
|
||||
AdminNotification.is_read: True,
|
||||
AdminNotification.read_at: now,
|
||||
AdminNotification.read_by_user_id: user_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
db.flush()
|
||||
return count
|
||||
|
||||
def delete_old_notifications(
|
||||
self,
|
||||
db: Session,
|
||||
days: int = 30,
|
||||
) -> int:
|
||||
"""Delete notifications older than specified days."""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
count = (
|
||||
db.query(AdminNotification)
|
||||
.filter(
|
||||
and_(
|
||||
AdminNotification.is_read == True, # noqa: E712
|
||||
AdminNotification.created_at < cutoff,
|
||||
)
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
db.flush()
|
||||
return count
|
||||
|
||||
def delete_notification(
|
||||
self,
|
||||
db: Session,
|
||||
notification_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a notification by ID.
|
||||
|
||||
Returns:
|
||||
True if notification was deleted, False if not found.
|
||||
"""
|
||||
notification = (
|
||||
db.query(AdminNotification)
|
||||
.filter(AdminNotification.id == notification_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if notification:
|
||||
db.delete(notification)
|
||||
db.flush()
|
||||
logger.info(f"Deleted notification {notification_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# CONVENIENCE METHODS FOR CREATING SPECIFIC NOTIFICATION TYPES
|
||||
# =========================================================================
|
||||
|
||||
def notify_import_failure(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
job_id: int,
|
||||
error_message: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for import job failure."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.IMPORT_FAILURE,
|
||||
title=f"Import Failed: {vendor_name}",
|
||||
message=error_message,
|
||||
priority=Priority.HIGH,
|
||||
action_required=True,
|
||||
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
|
||||
if vendor_id
|
||||
else "/admin/marketplace",
|
||||
metadata={"vendor_name": vendor_name, "job_id": job_id, "vendor_id": vendor_id},
|
||||
)
|
||||
|
||||
def notify_order_sync_failure(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
error_message: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for order sync failure."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.ORDER_SYNC_FAILURE,
|
||||
title=f"Order Sync Failed: {vendor_name}",
|
||||
message=error_message,
|
||||
priority=Priority.HIGH,
|
||||
action_required=True,
|
||||
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=jobs"
|
||||
if vendor_id
|
||||
else "/admin/marketplace/letzshop",
|
||||
metadata={"vendor_name": vendor_name, "vendor_id": vendor_id},
|
||||
)
|
||||
|
||||
def notify_order_exception(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
order_number: str,
|
||||
exception_count: int,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for order item exceptions."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.ORDER_EXCEPTION,
|
||||
title=f"Order Exception: {order_number}",
|
||||
message=f"{exception_count} item(s) need attention for order {order_number} ({vendor_name})",
|
||||
priority=Priority.NORMAL,
|
||||
action_required=True,
|
||||
action_url=f"/admin/marketplace/letzshop?vendor_id={vendor_id}&tab=exceptions"
|
||||
if vendor_id
|
||||
else "/admin/marketplace/letzshop",
|
||||
metadata={
|
||||
"vendor_name": vendor_name,
|
||||
"order_number": order_number,
|
||||
"exception_count": exception_count,
|
||||
"vendor_id": vendor_id,
|
||||
},
|
||||
)
|
||||
|
||||
def notify_critical_error(
|
||||
self,
|
||||
db: Session,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for critical application errors."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.CRITICAL_ERROR,
|
||||
title=f"Critical Error: {error_type}",
|
||||
message=error_message,
|
||||
priority=Priority.CRITICAL,
|
||||
action_required=True,
|
||||
action_url="/admin/logs",
|
||||
metadata=details,
|
||||
)
|
||||
|
||||
def notify_vendor_issue(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_name: str,
|
||||
issue_type: str,
|
||||
message: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for vendor-related issues."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.VENDOR_ISSUE,
|
||||
title=f"Vendor Issue: {vendor_name}",
|
||||
message=message,
|
||||
priority=Priority.HIGH,
|
||||
action_required=True,
|
||||
action_url=f"/admin/vendors/{vendor_id}" if vendor_id else "/admin/vendors",
|
||||
metadata={
|
||||
"vendor_name": vendor_name,
|
||||
"issue_type": issue_type,
|
||||
"vendor_id": vendor_id,
|
||||
},
|
||||
)
|
||||
|
||||
def notify_security_alert(
|
||||
self,
|
||||
db: Session,
|
||||
title: str,
|
||||
message: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> AdminNotification:
|
||||
"""Create notification for security-related alerts."""
|
||||
return self.create_notification(
|
||||
db=db,
|
||||
notification_type=NotificationType.SECURITY_ALERT,
|
||||
title=title,
|
||||
message=message,
|
||||
priority=Priority.CRITICAL,
|
||||
action_required=True,
|
||||
action_url="/admin/audit",
|
||||
metadata=details,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PLATFORM ALERT SERVICE
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class PlatformAlertService:
|
||||
"""Service for managing platform-wide alerts."""
|
||||
|
||||
def create_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_type: str,
|
||||
severity: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
affected_vendors: list[int] | None = None,
|
||||
affected_systems: list[str] | None = None,
|
||||
auto_generated: bool = True,
|
||||
) -> PlatformAlert:
|
||||
"""Create a new platform alert."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
alert = PlatformAlert(
|
||||
alert_type=alert_type,
|
||||
severity=severity,
|
||||
title=title,
|
||||
description=description,
|
||||
affected_vendors=affected_vendors,
|
||||
affected_systems=affected_systems,
|
||||
auto_generated=auto_generated,
|
||||
first_occurred_at=now,
|
||||
last_occurred_at=now,
|
||||
)
|
||||
db.add(alert)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Created platform alert: {alert_type} - {title} ({severity})")
|
||||
|
||||
return alert
|
||||
|
||||
def create_from_schema(
|
||||
self,
|
||||
db: Session,
|
||||
data: PlatformAlertCreate,
|
||||
) -> PlatformAlert:
|
||||
"""Create alert from Pydantic schema."""
|
||||
return self.create_alert(
|
||||
db=db,
|
||||
alert_type=data.alert_type,
|
||||
severity=data.severity,
|
||||
title=data.title,
|
||||
description=data.description,
|
||||
affected_vendors=data.affected_vendors,
|
||||
affected_systems=data.affected_systems,
|
||||
auto_generated=data.auto_generated,
|
||||
)
|
||||
|
||||
def get_alerts(
|
||||
self,
|
||||
db: Session,
|
||||
severity: str | None = None,
|
||||
alert_type: str | None = None,
|
||||
is_resolved: bool | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[PlatformAlert], int, int, int]:
|
||||
"""
|
||||
Get paginated platform alerts.
|
||||
|
||||
Returns:
|
||||
Tuple of (alerts, total_count, active_count, critical_count)
|
||||
"""
|
||||
query = db.query(PlatformAlert)
|
||||
|
||||
# Apply filters
|
||||
if severity:
|
||||
query = query.filter(PlatformAlert.severity == severity)
|
||||
|
||||
if alert_type:
|
||||
query = query.filter(PlatformAlert.alert_type == alert_type)
|
||||
|
||||
if is_resolved is not None:
|
||||
query = query.filter(PlatformAlert.is_resolved == is_resolved)
|
||||
|
||||
# Get counts
|
||||
total = query.count()
|
||||
active_count = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(PlatformAlert.is_resolved == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
critical_count = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.is_resolved == False, # noqa: E712
|
||||
PlatformAlert.severity == Severity.CRITICAL,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get paginated results
|
||||
severity_order = case(
|
||||
(PlatformAlert.severity == "critical", 1),
|
||||
(PlatformAlert.severity == "error", 2),
|
||||
(PlatformAlert.severity == "warning", 3),
|
||||
(PlatformAlert.severity == "info", 4),
|
||||
else_=5,
|
||||
)
|
||||
|
||||
alerts = (
|
||||
query.order_by(
|
||||
PlatformAlert.is_resolved, # Unresolved first
|
||||
severity_order,
|
||||
PlatformAlert.last_occurred_at.desc(),
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return alerts, total, active_count, critical_count
|
||||
|
||||
def resolve_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_id: int,
|
||||
user_id: int,
|
||||
resolution_notes: str | None = None,
|
||||
) -> PlatformAlert | None:
|
||||
"""Resolve a platform alert."""
|
||||
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
|
||||
|
||||
if alert and not alert.is_resolved:
|
||||
alert.is_resolved = True
|
||||
alert.resolved_at = datetime.utcnow()
|
||||
alert.resolved_by_user_id = user_id
|
||||
alert.resolution_notes = resolution_notes
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Resolved platform alert {alert_id}")
|
||||
|
||||
return alert
|
||||
|
||||
def get_statistics(self, db: Session) -> dict[str, int]:
|
||||
"""Get alert statistics."""
|
||||
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
total = db.query(PlatformAlert).count()
|
||||
active = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(PlatformAlert.is_resolved == False) # noqa: E712
|
||||
.count()
|
||||
)
|
||||
critical = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.is_resolved == False, # noqa: E712
|
||||
PlatformAlert.severity == Severity.CRITICAL,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
resolved_today = (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.is_resolved == True, # noqa: E712
|
||||
PlatformAlert.resolved_at >= today_start,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_alerts": total,
|
||||
"active_alerts": active,
|
||||
"critical_alerts": critical,
|
||||
"resolved_today": resolved_today,
|
||||
}
|
||||
|
||||
def increment_occurrence(
|
||||
self,
|
||||
db: Session,
|
||||
alert_id: int,
|
||||
) -> PlatformAlert | None:
|
||||
"""Increment occurrence count for repeated alert."""
|
||||
alert = db.query(PlatformAlert).filter(PlatformAlert.id == alert_id).first()
|
||||
|
||||
if alert:
|
||||
alert.occurrence_count += 1
|
||||
alert.last_occurred_at = datetime.utcnow()
|
||||
db.flush()
|
||||
|
||||
return alert
|
||||
|
||||
def find_similar_active_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_type: str,
|
||||
title: str,
|
||||
) -> PlatformAlert | None:
|
||||
"""Find an active alert with same type and title."""
|
||||
return (
|
||||
db.query(PlatformAlert)
|
||||
.filter(
|
||||
and_(
|
||||
PlatformAlert.alert_type == alert_type,
|
||||
PlatformAlert.title == title,
|
||||
PlatformAlert.is_resolved == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def create_or_increment_alert(
|
||||
self,
|
||||
db: Session,
|
||||
alert_type: str,
|
||||
severity: str,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
affected_vendors: list[int] | None = None,
|
||||
affected_systems: list[str] | None = None,
|
||||
) -> PlatformAlert:
|
||||
"""Create alert or increment occurrence if similar exists."""
|
||||
existing = self.find_similar_active_alert(db, alert_type, title)
|
||||
|
||||
if existing:
|
||||
self.increment_occurrence(db, existing.id)
|
||||
return existing
|
||||
|
||||
return self.create_alert(
|
||||
db=db,
|
||||
alert_type=alert_type,
|
||||
severity=severity,
|
||||
title=title,
|
||||
description=description,
|
||||
affected_vendors=affected_vendors,
|
||||
affected_systems=affected_systems,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instances
|
||||
admin_notification_service = AdminNotificationService()
|
||||
platform_alert_service = PlatformAlertService()
|
||||
from app.modules.messaging.services.admin_notification_service import (
|
||||
admin_notification_service,
|
||||
AdminNotificationService,
|
||||
platform_alert_service,
|
||||
PlatformAlertService,
|
||||
# Constants
|
||||
NotificationType,
|
||||
Priority,
|
||||
AlertType,
|
||||
Severity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"admin_notification_service",
|
||||
"AdminNotificationService",
|
||||
"platform_alert_service",
|
||||
"PlatformAlertService",
|
||||
# Constants
|
||||
"NotificationType",
|
||||
"Priority",
|
||||
"AlertType",
|
||||
"Severity",
|
||||
]
|
||||
|
||||
@@ -1,194 +1,23 @@
|
||||
# app/services/background_tasks_service.py
|
||||
"""
|
||||
Background Tasks Service
|
||||
Service for monitoring background tasks across the system
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
The canonical implementation is now in:
|
||||
app/modules/monitoring/services/background_tasks_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.monitoring.services import background_tasks_service
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from app.modules.monitoring.services.background_tasks_service import (
|
||||
background_tasks_service,
|
||||
BackgroundTasksService,
|
||||
)
|
||||
|
||||
from sqlalchemy import case, desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.architecture_scan import ArchitectureScan
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.test_run import TestRun
|
||||
|
||||
|
||||
class BackgroundTasksService:
|
||||
"""Service for monitoring background tasks"""
|
||||
|
||||
def get_import_jobs(
|
||||
self, db: Session, status: str | None = None, limit: int = 50
|
||||
) -> list[MarketplaceImportJob]:
|
||||
"""Get import jobs with optional status filter"""
|
||||
query = db.query(MarketplaceImportJob)
|
||||
if status:
|
||||
query = query.filter(MarketplaceImportJob.status == status)
|
||||
return query.order_by(desc(MarketplaceImportJob.created_at)).limit(limit).all()
|
||||
|
||||
def get_test_runs(
|
||||
self, db: Session, status: str | None = None, limit: int = 50
|
||||
) -> list[TestRun]:
|
||||
"""Get test runs with optional status filter"""
|
||||
query = db.query(TestRun)
|
||||
if status:
|
||||
query = query.filter(TestRun.status == status)
|
||||
return query.order_by(desc(TestRun.timestamp)).limit(limit).all()
|
||||
|
||||
def get_running_imports(self, db: Session) -> list[MarketplaceImportJob]:
|
||||
"""Get currently running import jobs"""
|
||||
return (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.status == "processing")
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_running_test_runs(self, db: Session) -> list[TestRun]:
|
||||
"""Get currently running test runs"""
|
||||
# noqa: SVC-005 - Platform-level, TestRuns not vendor-scoped
|
||||
return db.query(TestRun).filter(TestRun.status == "running").all()
|
||||
|
||||
def get_import_stats(self, db: Session) -> dict:
|
||||
"""Get import job statistics"""
|
||||
today_start = datetime.now(UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
stats = db.query(
|
||||
func.count(MarketplaceImportJob.id).label("total"),
|
||||
func.sum(
|
||||
case((MarketplaceImportJob.status == "processing", 1), else_=0)
|
||||
).label("running"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
MarketplaceImportJob.status.in_(
|
||||
["completed", "completed_with_errors"]
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("completed"),
|
||||
func.sum(
|
||||
case((MarketplaceImportJob.status == "failed", 1), else_=0)
|
||||
).label("failed"),
|
||||
).first()
|
||||
|
||||
today_count = (
|
||||
db.query(func.count(MarketplaceImportJob.id))
|
||||
.filter(MarketplaceImportJob.created_at >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total": stats.total or 0,
|
||||
"running": stats.running or 0,
|
||||
"completed": stats.completed or 0,
|
||||
"failed": stats.failed or 0,
|
||||
"today": today_count,
|
||||
}
|
||||
|
||||
def get_test_run_stats(self, db: Session) -> dict:
|
||||
"""Get test run statistics"""
|
||||
today_start = datetime.now(UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
stats = db.query(
|
||||
func.count(TestRun.id).label("total"),
|
||||
func.sum(case((TestRun.status == "running", 1), else_=0)).label(
|
||||
"running"
|
||||
),
|
||||
func.sum(case((TestRun.status == "passed", 1), else_=0)).label(
|
||||
"completed"
|
||||
),
|
||||
func.sum(
|
||||
case((TestRun.status.in_(["failed", "error"]), 1), else_=0)
|
||||
).label("failed"),
|
||||
func.avg(TestRun.duration_seconds).label("avg_duration"),
|
||||
).first()
|
||||
|
||||
today_count = (
|
||||
db.query(func.count(TestRun.id))
|
||||
.filter(TestRun.timestamp >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total": stats.total or 0,
|
||||
"running": stats.running or 0,
|
||||
"completed": stats.completed or 0,
|
||||
"failed": stats.failed or 0,
|
||||
"today": today_count,
|
||||
"avg_duration": round(stats.avg_duration or 0, 1),
|
||||
}
|
||||
|
||||
def get_code_quality_scans(
|
||||
self, db: Session, status: str | None = None, limit: int = 50
|
||||
) -> list[ArchitectureScan]:
|
||||
"""Get code quality scans with optional status filter"""
|
||||
query = db.query(ArchitectureScan)
|
||||
if status:
|
||||
query = query.filter(ArchitectureScan.status == status)
|
||||
return query.order_by(desc(ArchitectureScan.timestamp)).limit(limit).all()
|
||||
|
||||
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
|
||||
"""Get currently running code quality scans"""
|
||||
return (
|
||||
db.query(ArchitectureScan)
|
||||
.filter(ArchitectureScan.status.in_(["pending", "running"]))
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_scan_stats(self, db: Session) -> dict:
|
||||
"""Get code quality scan statistics"""
|
||||
today_start = datetime.now(UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
stats = db.query(
|
||||
func.count(ArchitectureScan.id).label("total"),
|
||||
func.sum(
|
||||
case(
|
||||
(ArchitectureScan.status.in_(["pending", "running"]), 1), else_=0
|
||||
)
|
||||
).label("running"),
|
||||
func.sum(
|
||||
case(
|
||||
(
|
||||
ArchitectureScan.status.in_(
|
||||
["completed", "completed_with_warnings"]
|
||||
),
|
||||
1,
|
||||
),
|
||||
else_=0,
|
||||
)
|
||||
).label("completed"),
|
||||
func.sum(
|
||||
case((ArchitectureScan.status == "failed", 1), else_=0)
|
||||
).label("failed"),
|
||||
func.avg(ArchitectureScan.duration_seconds).label("avg_duration"),
|
||||
).first()
|
||||
|
||||
today_count = (
|
||||
db.query(func.count(ArchitectureScan.id))
|
||||
.filter(ArchitectureScan.timestamp >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total": stats.total or 0,
|
||||
"running": stats.running or 0,
|
||||
"completed": stats.completed or 0,
|
||||
"failed": stats.failed or 0,
|
||||
"today": today_count,
|
||||
"avg_duration": round(stats.avg_duration or 0, 1),
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
background_tasks_service = BackgroundTasksService()
|
||||
__all__ = [
|
||||
"background_tasks_service",
|
||||
"BackgroundTasksService",
|
||||
]
|
||||
|
||||
@@ -1,588 +1,35 @@
|
||||
# app/services/billing_service.py
|
||||
"""
|
||||
Billing service for subscription and payment operations.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Provides:
|
||||
- Subscription status and usage queries
|
||||
- Tier management
|
||||
- Invoice history
|
||||
- Add-on management
|
||||
The canonical implementation is now in:
|
||||
app/modules/billing/services/billing_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.billing.services import billing_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.stripe_service import stripe_service
|
||||
from app.services.subscription_service import subscription_service
|
||||
from models.database.subscription import (
|
||||
AddOnProduct,
|
||||
BillingHistory,
|
||||
SubscriptionTier,
|
||||
VendorAddOn,
|
||||
VendorSubscription,
|
||||
from app.modules.billing.services.billing_service import (
|
||||
BillingService,
|
||||
billing_service,
|
||||
BillingServiceError,
|
||||
PaymentSystemNotConfiguredError,
|
||||
TierNotFoundError,
|
||||
StripePriceNotConfiguredError,
|
||||
NoActiveSubscriptionError,
|
||||
SubscriptionNotCancelledError,
|
||||
)
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Base exception for billing service errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PaymentSystemNotConfiguredError(BillingServiceError):
|
||||
"""Raised when Stripe is not configured."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Payment system not configured")
|
||||
|
||||
|
||||
class TierNotFoundError(BillingServiceError):
|
||||
"""Raised when a tier is not found."""
|
||||
|
||||
def __init__(self, tier_code: str):
|
||||
self.tier_code = tier_code
|
||||
super().__init__(f"Tier '{tier_code}' not found")
|
||||
|
||||
|
||||
class StripePriceNotConfiguredError(BillingServiceError):
|
||||
"""Raised when Stripe price is not configured for a tier."""
|
||||
|
||||
def __init__(self, tier_code: str):
|
||||
self.tier_code = tier_code
|
||||
super().__init__(f"Stripe price not configured for tier '{tier_code}'")
|
||||
|
||||
|
||||
class NoActiveSubscriptionError(BillingServiceError):
|
||||
"""Raised when no active subscription exists."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("No active subscription found")
|
||||
|
||||
|
||||
class SubscriptionNotCancelledError(BillingServiceError):
|
||||
"""Raised when trying to reactivate a non-cancelled subscription."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("Subscription is not cancelled")
|
||||
|
||||
|
||||
class BillingService:
|
||||
"""Service for billing operations."""
|
||||
|
||||
def get_subscription_with_tier(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> tuple[VendorSubscription, SubscriptionTier | None]:
|
||||
"""
|
||||
Get subscription and its tier info.
|
||||
|
||||
Returns:
|
||||
Tuple of (subscription, tier) where tier may be None
|
||||
"""
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(SubscriptionTier.code == subscription.tier)
|
||||
.first()
|
||||
)
|
||||
|
||||
return subscription, tier
|
||||
|
||||
def get_available_tiers(
|
||||
self, db: Session, current_tier: str
|
||||
) -> tuple[list[dict], dict[str, int]]:
|
||||
"""
|
||||
Get all available tiers with upgrade/downgrade flags.
|
||||
|
||||
Returns:
|
||||
Tuple of (tier_list, tier_order_map)
|
||||
"""
|
||||
tiers = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
SubscriptionTier.is_public == True, # noqa: E712
|
||||
)
|
||||
.order_by(SubscriptionTier.display_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
tier_order = {t.code: t.display_order for t in tiers}
|
||||
current_order = tier_order.get(current_tier, 0)
|
||||
|
||||
tier_list = []
|
||||
for tier in tiers:
|
||||
tier_list.append({
|
||||
"code": tier.code,
|
||||
"name": tier.name,
|
||||
"description": tier.description,
|
||||
"price_monthly_cents": tier.price_monthly_cents,
|
||||
"price_annual_cents": tier.price_annual_cents,
|
||||
"orders_per_month": tier.orders_per_month,
|
||||
"products_limit": tier.products_limit,
|
||||
"team_members": tier.team_members,
|
||||
"features": tier.features or [],
|
||||
"is_current": tier.code == current_tier,
|
||||
"can_upgrade": tier.display_order > current_order,
|
||||
"can_downgrade": tier.display_order < current_order,
|
||||
})
|
||||
|
||||
return tier_list, tier_order
|
||||
|
||||
def get_tier_by_code(self, db: Session, tier_code: str) -> SubscriptionTier:
|
||||
"""
|
||||
Get a tier by its code.
|
||||
|
||||
Raises:
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
"""
|
||||
tier = (
|
||||
db.query(SubscriptionTier)
|
||||
.filter(
|
||||
SubscriptionTier.code == tier_code,
|
||||
SubscriptionTier.is_active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tier:
|
||||
raise TierNotFoundError(tier_code)
|
||||
|
||||
return tier
|
||||
|
||||
def get_vendor(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""
|
||||
Get vendor by ID.
|
||||
|
||||
Raises:
|
||||
VendorNotFoundException from app.exceptions
|
||||
"""
|
||||
from app.exceptions import VendorNotFoundException
|
||||
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
|
||||
|
||||
return vendor
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
tier_code: str,
|
||||
is_annual: bool,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe checkout session.
|
||||
|
||||
Returns:
|
||||
Dict with checkout_url and session_id
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
StripePriceNotConfiguredError: If price not configured
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
vendor = self.get_vendor(db, vendor_id)
|
||||
tier = self.get_tier_by_code(db, tier_code)
|
||||
|
||||
price_id = (
|
||||
tier.stripe_price_annual_id
|
||||
if is_annual and tier.stripe_price_annual_id
|
||||
else tier.stripe_price_monthly_id
|
||||
)
|
||||
|
||||
if not price_id:
|
||||
raise StripePriceNotConfiguredError(tier_code)
|
||||
|
||||
# Check if this is a new subscription (for trial)
|
||||
existing_sub = subscription_service.get_subscription(db, vendor_id)
|
||||
trial_days = None
|
||||
if not existing_sub or not existing_sub.stripe_subscription_id:
|
||||
from app.core.config import settings
|
||||
trial_days = settings.stripe_trial_days
|
||||
|
||||
session = stripe_service.create_checkout_session(
|
||||
db=db,
|
||||
vendor=vendor,
|
||||
price_id=price_id,
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
trial_days=trial_days,
|
||||
)
|
||||
|
||||
# Update subscription with tier info
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
subscription.tier = tier_code
|
||||
subscription.is_annual = is_annual
|
||||
|
||||
return {
|
||||
"checkout_url": session.url,
|
||||
"session_id": session.id,
|
||||
}
|
||||
|
||||
def create_portal_session(self, db: Session, vendor_id: int, return_url: str) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session.
|
||||
|
||||
Returns:
|
||||
Dict with portal_url
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
NoActiveSubscriptionError: If no subscription with customer ID
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_customer_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
session = stripe_service.create_portal_session(
|
||||
customer_id=subscription.stripe_customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
|
||||
return {"portal_url": session.url}
|
||||
|
||||
def get_invoices(
|
||||
self, db: Session, vendor_id: int, skip: int = 0, limit: int = 20
|
||||
) -> tuple[list[BillingHistory], int]:
|
||||
"""
|
||||
Get invoice history for a vendor.
|
||||
|
||||
Returns:
|
||||
Tuple of (invoices, total_count)
|
||||
"""
|
||||
query = db.query(BillingHistory).filter(BillingHistory.vendor_id == vendor_id)
|
||||
|
||||
total = query.count()
|
||||
|
||||
invoices = (
|
||||
query.order_by(BillingHistory.invoice_date.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return invoices, total
|
||||
|
||||
def get_available_addons(
|
||||
self, db: Session, category: str | None = None
|
||||
) -> list[AddOnProduct]:
|
||||
"""Get available add-on products."""
|
||||
query = db.query(AddOnProduct).filter(AddOnProduct.is_active == True) # noqa: E712
|
||||
|
||||
if category:
|
||||
query = query.filter(AddOnProduct.category == category)
|
||||
|
||||
return query.order_by(AddOnProduct.display_order).all()
|
||||
|
||||
def get_vendor_addons(self, db: Session, vendor_id: int) -> list[VendorAddOn]:
|
||||
"""Get vendor's purchased add-ons."""
|
||||
return (
|
||||
db.query(VendorAddOn)
|
||||
.filter(VendorAddOn.vendor_id == vendor_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
def cancel_subscription(
|
||||
self, db: Session, vendor_id: int, reason: str | None, immediately: bool
|
||||
) -> dict:
|
||||
"""
|
||||
Cancel a subscription.
|
||||
|
||||
Returns:
|
||||
Dict with message and effective_date
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription to cancel
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.cancel_subscription(
|
||||
subscription_id=subscription.stripe_subscription_id,
|
||||
immediately=immediately,
|
||||
cancellation_reason=reason,
|
||||
)
|
||||
|
||||
subscription.cancelled_at = datetime.utcnow()
|
||||
subscription.cancellation_reason = reason
|
||||
|
||||
effective_date = (
|
||||
datetime.utcnow().isoformat()
|
||||
if immediately
|
||||
else subscription.period_end.isoformat()
|
||||
if subscription.period_end
|
||||
else datetime.utcnow().isoformat()
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Subscription cancelled successfully",
|
||||
"effective_date": effective_date,
|
||||
}
|
||||
|
||||
def reactivate_subscription(self, db: Session, vendor_id: int) -> dict:
|
||||
"""
|
||||
Reactivate a cancelled subscription.
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription
|
||||
SubscriptionNotCancelledError: If not cancelled
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if not subscription.cancelled_at:
|
||||
raise SubscriptionNotCancelledError()
|
||||
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.reactivate_subscription(subscription.stripe_subscription_id)
|
||||
|
||||
subscription.cancelled_at = None
|
||||
subscription.cancellation_reason = None
|
||||
|
||||
return {"message": "Subscription reactivated successfully"}
|
||||
|
||||
def get_upcoming_invoice(self, db: Session, vendor_id: int) -> dict:
|
||||
"""
|
||||
Get upcoming invoice preview.
|
||||
|
||||
Returns:
|
||||
Dict with amount_due_cents, currency, next_payment_date, line_items
|
||||
|
||||
Raises:
|
||||
NoActiveSubscriptionError: If no subscription with customer ID
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_customer_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
if not stripe_service.is_configured:
|
||||
# Return empty preview if Stripe not configured
|
||||
return {
|
||||
"amount_due_cents": 0,
|
||||
"currency": "EUR",
|
||||
"next_payment_date": None,
|
||||
"line_items": [],
|
||||
}
|
||||
|
||||
invoice = stripe_service.get_upcoming_invoice(subscription.stripe_customer_id)
|
||||
|
||||
if not invoice:
|
||||
return {
|
||||
"amount_due_cents": 0,
|
||||
"currency": "EUR",
|
||||
"next_payment_date": None,
|
||||
"line_items": [],
|
||||
}
|
||||
|
||||
line_items = []
|
||||
if invoice.lines and invoice.lines.data:
|
||||
for line in invoice.lines.data:
|
||||
line_items.append({
|
||||
"description": line.description or "",
|
||||
"amount_cents": line.amount,
|
||||
"quantity": line.quantity or 1,
|
||||
})
|
||||
|
||||
return {
|
||||
"amount_due_cents": invoice.amount_due,
|
||||
"currency": invoice.currency.upper(),
|
||||
"next_payment_date": datetime.fromtimestamp(invoice.next_payment_attempt).isoformat()
|
||||
if invoice.next_payment_attempt
|
||||
else None,
|
||||
"line_items": line_items,
|
||||
}
|
||||
|
||||
def change_tier(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
new_tier_code: str,
|
||||
is_annual: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Change subscription tier (upgrade/downgrade).
|
||||
|
||||
Returns:
|
||||
Dict with message, new_tier, effective_immediately
|
||||
|
||||
Raises:
|
||||
TierNotFoundError: If tier doesn't exist
|
||||
NoActiveSubscriptionError: If no subscription
|
||||
StripePriceNotConfiguredError: If price not configured
|
||||
"""
|
||||
subscription = subscription_service.get_subscription(db, vendor_id)
|
||||
|
||||
if not subscription or not subscription.stripe_subscription_id:
|
||||
raise NoActiveSubscriptionError()
|
||||
|
||||
tier = self.get_tier_by_code(db, new_tier_code)
|
||||
|
||||
price_id = (
|
||||
tier.stripe_price_annual_id
|
||||
if is_annual and tier.stripe_price_annual_id
|
||||
else tier.stripe_price_monthly_id
|
||||
)
|
||||
|
||||
if not price_id:
|
||||
raise StripePriceNotConfiguredError(new_tier_code)
|
||||
|
||||
# Update in Stripe
|
||||
if stripe_service.is_configured:
|
||||
stripe_service.update_subscription(
|
||||
subscription_id=subscription.stripe_subscription_id,
|
||||
new_price_id=price_id,
|
||||
)
|
||||
|
||||
# Update local subscription
|
||||
old_tier = subscription.tier
|
||||
subscription.tier = new_tier_code
|
||||
subscription.tier_id = tier.id
|
||||
subscription.is_annual = is_annual
|
||||
subscription.updated_at = datetime.utcnow()
|
||||
|
||||
is_upgrade = self._is_upgrade(db, old_tier, new_tier_code)
|
||||
|
||||
return {
|
||||
"message": f"Subscription {'upgraded' if is_upgrade else 'changed'} to {tier.name}",
|
||||
"new_tier": new_tier_code,
|
||||
"effective_immediately": True,
|
||||
}
|
||||
|
||||
def _is_upgrade(self, db: Session, old_tier: str, new_tier: str) -> bool:
|
||||
"""Check if tier change is an upgrade."""
|
||||
old = db.query(SubscriptionTier).filter(SubscriptionTier.code == old_tier).first()
|
||||
new = db.query(SubscriptionTier).filter(SubscriptionTier.code == new_tier).first()
|
||||
|
||||
if not old or not new:
|
||||
return False
|
||||
|
||||
return new.display_order > old.display_order
|
||||
|
||||
def purchase_addon(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
addon_code: str,
|
||||
domain_name: str | None,
|
||||
quantity: int,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Create checkout session for add-on purchase.
|
||||
|
||||
Returns:
|
||||
Dict with checkout_url and session_id
|
||||
|
||||
Raises:
|
||||
PaymentSystemNotConfiguredError: If Stripe not configured
|
||||
AddonNotFoundError: If addon doesn't exist
|
||||
"""
|
||||
if not stripe_service.is_configured:
|
||||
raise PaymentSystemNotConfiguredError()
|
||||
|
||||
addon = (
|
||||
db.query(AddOnProduct)
|
||||
.filter(
|
||||
AddOnProduct.code == addon_code,
|
||||
AddOnProduct.is_active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not addon:
|
||||
raise BillingServiceError(f"Add-on '{addon_code}' not found")
|
||||
|
||||
if not addon.stripe_price_id:
|
||||
raise BillingServiceError(f"Stripe price not configured for add-on '{addon_code}'")
|
||||
|
||||
vendor = self.get_vendor(db, vendor_id)
|
||||
subscription = subscription_service.get_or_create_subscription(db, vendor_id)
|
||||
|
||||
# Create checkout session for add-on
|
||||
session = stripe_service.create_checkout_session(
|
||||
db=db,
|
||||
vendor=vendor,
|
||||
price_id=addon.stripe_price_id,
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
quantity=quantity,
|
||||
metadata={
|
||||
"addon_code": addon_code,
|
||||
"domain_name": domain_name or "",
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"checkout_url": session.url,
|
||||
"session_id": session.id,
|
||||
}
|
||||
|
||||
def cancel_addon(self, db: Session, vendor_id: int, addon_id: int) -> dict:
|
||||
"""
|
||||
Cancel a purchased add-on.
|
||||
|
||||
Returns:
|
||||
Dict with message and addon_code
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If addon not found or not owned by vendor
|
||||
"""
|
||||
vendor_addon = (
|
||||
db.query(VendorAddOn)
|
||||
.filter(
|
||||
VendorAddOn.id == addon_id,
|
||||
VendorAddOn.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not vendor_addon:
|
||||
raise BillingServiceError("Add-on not found")
|
||||
|
||||
addon_code = vendor_addon.addon_product.code
|
||||
|
||||
# Cancel in Stripe if applicable
|
||||
if stripe_service.is_configured and vendor_addon.stripe_subscription_item_id:
|
||||
try:
|
||||
stripe_service.cancel_subscription_item(vendor_addon.stripe_subscription_item_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cancel addon in Stripe: {e}")
|
||||
|
||||
# Mark as cancelled
|
||||
vendor_addon.status = "cancelled"
|
||||
vendor_addon.cancelled_at = datetime.utcnow()
|
||||
|
||||
return {
|
||||
"message": "Add-on cancelled successfully",
|
||||
"addon_code": addon_code,
|
||||
}
|
||||
|
||||
|
||||
# Create service instance
|
||||
billing_service = BillingService()
|
||||
__all__ = [
|
||||
"BillingService",
|
||||
"billing_service",
|
||||
"BillingServiceError",
|
||||
"PaymentSystemNotConfiguredError",
|
||||
"TierNotFoundError",
|
||||
"StripePriceNotConfiguredError",
|
||||
"NoActiveSubscriptionError",
|
||||
"SubscriptionNotCancelledError",
|
||||
]
|
||||
|
||||
@@ -1,820 +1,35 @@
|
||||
# app/services/code_quality_service.py
|
||||
"""
|
||||
Code Quality Service
|
||||
Business logic for managing code quality scans and violations
|
||||
Supports multiple validator types: architecture, security, performance
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
The canonical implementation is now in:
|
||||
app/modules/dev_tools/services/code_quality_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.dev_tools.services import code_quality_service
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
ScanParseException,
|
||||
ScanTimeoutException,
|
||||
ViolationNotFoundException,
|
||||
)
|
||||
from models.database.architecture_scan import (
|
||||
ArchitectureScan,
|
||||
ArchitectureViolation,
|
||||
ViolationAssignment,
|
||||
ViolationComment,
|
||||
from app.modules.dev_tools.services.code_quality_service import (
|
||||
code_quality_service,
|
||||
CodeQualityService,
|
||||
VALIDATOR_ARCHITECTURE,
|
||||
VALIDATOR_SECURITY,
|
||||
VALIDATOR_PERFORMANCE,
|
||||
VALID_VALIDATOR_TYPES,
|
||||
VALIDATOR_SCRIPTS,
|
||||
VALIDATOR_NAMES,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Validator type constants
|
||||
VALIDATOR_ARCHITECTURE = "architecture"
|
||||
VALIDATOR_SECURITY = "security"
|
||||
VALIDATOR_PERFORMANCE = "performance"
|
||||
|
||||
VALID_VALIDATOR_TYPES = [VALIDATOR_ARCHITECTURE, VALIDATOR_SECURITY, VALIDATOR_PERFORMANCE]
|
||||
|
||||
# Map validator types to their scripts
|
||||
VALIDATOR_SCRIPTS = {
|
||||
VALIDATOR_ARCHITECTURE: "scripts/validate_architecture.py",
|
||||
VALIDATOR_SECURITY: "scripts/validate_security.py",
|
||||
VALIDATOR_PERFORMANCE: "scripts/validate_performance.py",
|
||||
}
|
||||
|
||||
# Human-readable names
|
||||
VALIDATOR_NAMES = {
|
||||
VALIDATOR_ARCHITECTURE: "Architecture",
|
||||
VALIDATOR_SECURITY: "Security",
|
||||
VALIDATOR_PERFORMANCE: "Performance",
|
||||
}
|
||||
|
||||
|
||||
class CodeQualityService:
|
||||
"""Service for managing code quality scans and violations"""
|
||||
|
||||
def run_scan(
|
||||
self,
|
||||
db: Session,
|
||||
triggered_by: str = "manual",
|
||||
validator_type: str = VALIDATOR_ARCHITECTURE,
|
||||
) -> ArchitectureScan:
|
||||
"""
|
||||
Run a code quality validator and store results in database
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
triggered_by: Who/what triggered the scan ('manual', 'scheduled', 'ci/cd')
|
||||
validator_type: Type of validator ('architecture', 'security', 'performance')
|
||||
|
||||
Returns:
|
||||
ArchitectureScan object with results
|
||||
|
||||
Raises:
|
||||
ValueError: If validator_type is invalid
|
||||
ScanTimeoutException: If validator times out
|
||||
ScanParseException: If validator output cannot be parsed
|
||||
"""
|
||||
if validator_type not in VALID_VALIDATOR_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid validator type: {validator_type}. "
|
||||
f"Must be one of: {VALID_VALIDATOR_TYPES}"
|
||||
)
|
||||
|
||||
script_path = VALIDATOR_SCRIPTS[validator_type]
|
||||
validator_name = VALIDATOR_NAMES[validator_type]
|
||||
|
||||
logger.info(
|
||||
f"Starting {validator_name} scan (triggered by: {triggered_by})"
|
||||
)
|
||||
|
||||
# Get git commit hash
|
||||
git_commit = self._get_git_commit_hash()
|
||||
|
||||
# Run validator with JSON output
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python", script_path, "--json"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minute timeout
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"{validator_name} scan timed out after 5 minutes")
|
||||
raise ScanTimeoutException(timeout_seconds=300)
|
||||
|
||||
duration = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
# Parse JSON output (get only the JSON part, skip progress messages)
|
||||
try:
|
||||
# Find the JSON part in stdout
|
||||
lines = result.stdout.strip().split("\n")
|
||||
json_start = -1
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith("{"):
|
||||
json_start = i
|
||||
break
|
||||
|
||||
if json_start == -1:
|
||||
raise ValueError("No JSON output found")
|
||||
|
||||
json_output = "\n".join(lines[json_start:])
|
||||
data = json.loads(json_output)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Failed to parse {validator_name} validator output: {e}")
|
||||
logger.error(f"Stdout: {result.stdout}")
|
||||
logger.error(f"Stderr: {result.stderr}")
|
||||
raise ScanParseException(reason=str(e))
|
||||
|
||||
# Create scan record
|
||||
scan = ArchitectureScan(
|
||||
timestamp=datetime.now(),
|
||||
validator_type=validator_type,
|
||||
total_files=data.get("files_checked", 0),
|
||||
total_violations=data.get("total_violations", 0),
|
||||
errors=data.get("errors", 0),
|
||||
warnings=data.get("warnings", 0),
|
||||
duration_seconds=duration,
|
||||
triggered_by=triggered_by,
|
||||
git_commit_hash=git_commit,
|
||||
)
|
||||
db.add(scan)
|
||||
db.flush() # Get scan.id
|
||||
|
||||
# Create violation records
|
||||
violations_data = data.get("violations", [])
|
||||
logger.info(f"Creating {len(violations_data)} {validator_name} violation records")
|
||||
|
||||
for v in violations_data:
|
||||
violation = ArchitectureViolation(
|
||||
scan_id=scan.id,
|
||||
validator_type=validator_type,
|
||||
rule_id=v["rule_id"],
|
||||
rule_name=v["rule_name"],
|
||||
severity=v["severity"],
|
||||
file_path=v["file_path"],
|
||||
line_number=v["line_number"],
|
||||
message=v["message"],
|
||||
context=v.get("context", ""),
|
||||
suggestion=v.get("suggestion", ""),
|
||||
status="open",
|
||||
)
|
||||
db.add(violation)
|
||||
|
||||
db.flush()
|
||||
db.refresh(scan)
|
||||
|
||||
logger.info(
|
||||
f"{validator_name} scan completed: {scan.total_violations} violations found"
|
||||
)
|
||||
return scan
|
||||
|
||||
def run_all_scans(
|
||||
self, db: Session, triggered_by: str = "manual"
|
||||
) -> list[ArchitectureScan]:
|
||||
"""
|
||||
Run all validators and return list of scans
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
triggered_by: Who/what triggered the scan
|
||||
|
||||
Returns:
|
||||
List of ArchitectureScan objects (one per validator)
|
||||
"""
|
||||
results = []
|
||||
for validator_type in VALID_VALIDATOR_TYPES:
|
||||
try:
|
||||
scan = self.run_scan(db, triggered_by, validator_type)
|
||||
results.append(scan)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run {validator_type} scan: {e}")
|
||||
# Continue with other validators even if one fails
|
||||
return results
|
||||
|
||||
def get_latest_scan(
|
||||
self, db: Session, validator_type: str = None
|
||||
) -> ArchitectureScan | None:
|
||||
"""
|
||||
Get the most recent scan
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
validator_type: Optional filter by validator type
|
||||
|
||||
Returns:
|
||||
Most recent ArchitectureScan or None
|
||||
"""
|
||||
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
|
||||
|
||||
if validator_type:
|
||||
query = query.filter(ArchitectureScan.validator_type == validator_type)
|
||||
|
||||
return query.first()
|
||||
|
||||
def get_latest_scans_by_type(self, db: Session) -> dict[str, ArchitectureScan]:
|
||||
"""
|
||||
Get the most recent scan for each validator type
|
||||
|
||||
Returns:
|
||||
Dictionary mapping validator_type to its latest scan
|
||||
"""
|
||||
result = {}
|
||||
for vtype in VALID_VALIDATOR_TYPES:
|
||||
scan = self.get_latest_scan(db, validator_type=vtype)
|
||||
if scan:
|
||||
result[vtype] = scan
|
||||
return result
|
||||
|
||||
def get_scan_by_id(self, db: Session, scan_id: int) -> ArchitectureScan | None:
|
||||
"""Get scan by ID"""
|
||||
return db.query(ArchitectureScan).filter(ArchitectureScan.id == scan_id).first()
|
||||
|
||||
def create_pending_scan(
|
||||
self, db: Session, validator_type: str, triggered_by: str
|
||||
) -> ArchitectureScan:
|
||||
"""
|
||||
Create a new scan record with pending status.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
validator_type: Type of validator (architecture, security, performance)
|
||||
triggered_by: Who triggered the scan (e.g., "manual:username")
|
||||
|
||||
Returns:
|
||||
The created ArchitectureScan record with ID populated
|
||||
"""
|
||||
scan = ArchitectureScan(
|
||||
timestamp=datetime.now(UTC),
|
||||
validator_type=validator_type,
|
||||
status="pending",
|
||||
triggered_by=triggered_by,
|
||||
)
|
||||
db.add(scan)
|
||||
db.flush() # Get scan.id
|
||||
return scan
|
||||
|
||||
def get_running_scans(self, db: Session) -> list[ArchitectureScan]:
|
||||
"""
|
||||
Get all currently running scans (pending or running status).
|
||||
|
||||
Returns:
|
||||
List of scans with status 'pending' or 'running', newest first
|
||||
"""
|
||||
return (
|
||||
db.query(ArchitectureScan)
|
||||
.filter(ArchitectureScan.status.in_(["pending", "running"]))
|
||||
.order_by(ArchitectureScan.timestamp.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_scan_history(
|
||||
self, db: Session, limit: int = 30, validator_type: str = None
|
||||
) -> list[ArchitectureScan]:
|
||||
"""
|
||||
Get scan history for trend graphs
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
limit: Maximum number of scans to return
|
||||
validator_type: Optional filter by validator type
|
||||
|
||||
Returns:
|
||||
List of ArchitectureScan objects, newest first
|
||||
"""
|
||||
query = db.query(ArchitectureScan).order_by(desc(ArchitectureScan.timestamp))
|
||||
|
||||
if validator_type:
|
||||
query = query.filter(ArchitectureScan.validator_type == validator_type)
|
||||
|
||||
return query.limit(limit).all()
|
||||
|
||||
def get_violations(
|
||||
self,
|
||||
db: Session,
|
||||
scan_id: int = None,
|
||||
validator_type: str = None,
|
||||
severity: str = None,
|
||||
status: str = None,
|
||||
rule_id: str = None,
|
||||
file_path: str = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ArchitectureViolation], int]:
|
||||
"""
|
||||
Get violations with filtering and pagination
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
scan_id: Filter by scan ID (if None, use latest scan(s))
|
||||
validator_type: Filter by validator type
|
||||
severity: Filter by severity ('error', 'warning')
|
||||
status: Filter by status ('open', 'assigned', 'resolved', etc.)
|
||||
rule_id: Filter by rule ID
|
||||
file_path: Filter by file path (partial match)
|
||||
limit: Page size
|
||||
offset: Page offset
|
||||
|
||||
Returns:
|
||||
Tuple of (violations list, total count)
|
||||
"""
|
||||
# Build query
|
||||
query = db.query(ArchitectureViolation)
|
||||
|
||||
# If scan_id specified, filter by it
|
||||
if scan_id is not None:
|
||||
query = query.filter(ArchitectureViolation.scan_id == scan_id)
|
||||
else:
|
||||
# If no scan_id, get violations from latest scan(s)
|
||||
if validator_type:
|
||||
# Get latest scan for specific validator type
|
||||
latest_scan = self.get_latest_scan(db, validator_type)
|
||||
if not latest_scan:
|
||||
return [], 0
|
||||
query = query.filter(ArchitectureViolation.scan_id == latest_scan.id)
|
||||
else:
|
||||
# Get violations from latest scans of all types
|
||||
latest_scans = self.get_latest_scans_by_type(db)
|
||||
if not latest_scans:
|
||||
return [], 0
|
||||
scan_ids = [s.id for s in latest_scans.values()]
|
||||
query = query.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
|
||||
# Apply validator_type filter if specified (for scan_id queries)
|
||||
if validator_type and scan_id is not None:
|
||||
query = query.filter(ArchitectureViolation.validator_type == validator_type)
|
||||
|
||||
# Apply other filters
|
||||
if severity:
|
||||
query = query.filter(ArchitectureViolation.severity == severity)
|
||||
|
||||
if status:
|
||||
query = query.filter(ArchitectureViolation.status == status)
|
||||
|
||||
if rule_id:
|
||||
query = query.filter(ArchitectureViolation.rule_id == rule_id)
|
||||
|
||||
if file_path:
|
||||
query = query.filter(ArchitectureViolation.file_path.like(f"%{file_path}%"))
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get page of results
|
||||
violations = (
|
||||
query.order_by(
|
||||
ArchitectureViolation.severity.desc(),
|
||||
ArchitectureViolation.validator_type,
|
||||
ArchitectureViolation.file_path,
|
||||
)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
return violations, total
|
||||
|
||||
def get_violation_by_id(
|
||||
self, db: Session, violation_id: int
|
||||
) -> ArchitectureViolation | None:
|
||||
"""Get single violation with details"""
|
||||
return (
|
||||
db.query(ArchitectureViolation)
|
||||
.filter(ArchitectureViolation.id == violation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def assign_violation(
|
||||
self,
|
||||
db: Session,
|
||||
violation_id: int,
|
||||
user_id: int,
|
||||
assigned_by: int,
|
||||
due_date: datetime = None,
|
||||
priority: str = "medium",
|
||||
) -> ViolationAssignment:
|
||||
"""
|
||||
Assign violation to a developer
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
user_id: User to assign to
|
||||
assigned_by: User who is assigning
|
||||
due_date: Due date (optional)
|
||||
priority: Priority level ('low', 'medium', 'high', 'critical')
|
||||
|
||||
Returns:
|
||||
ViolationAssignment object
|
||||
"""
|
||||
# Update violation status
|
||||
violation = self.get_violation_by_id(db, violation_id)
|
||||
if violation:
|
||||
violation.status = "assigned"
|
||||
violation.assigned_to = user_id
|
||||
|
||||
# Create assignment record
|
||||
assignment = ViolationAssignment(
|
||||
violation_id=violation_id,
|
||||
user_id=user_id,
|
||||
assigned_by=assigned_by,
|
||||
due_date=due_date,
|
||||
priority=priority,
|
||||
)
|
||||
db.add(assignment)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Violation {violation_id} assigned to user {user_id}")
|
||||
return assignment
|
||||
|
||||
def resolve_violation(
|
||||
self, db: Session, violation_id: int, resolved_by: int, resolution_note: str
|
||||
) -> ArchitectureViolation:
|
||||
"""
|
||||
Mark violation as resolved
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
resolved_by: User who resolved it
|
||||
resolution_note: Note about resolution
|
||||
|
||||
Returns:
|
||||
Updated ArchitectureViolation object
|
||||
"""
|
||||
violation = self.get_violation_by_id(db, violation_id)
|
||||
if not violation:
|
||||
raise ViolationNotFoundException(violation_id)
|
||||
|
||||
violation.status = "resolved"
|
||||
violation.resolved_at = datetime.now()
|
||||
violation.resolved_by = resolved_by
|
||||
violation.resolution_note = resolution_note
|
||||
|
||||
db.flush()
|
||||
logger.info(f"Violation {violation_id} resolved by user {resolved_by}")
|
||||
return violation
|
||||
|
||||
def ignore_violation(
|
||||
self, db: Session, violation_id: int, ignored_by: int, reason: str
|
||||
) -> ArchitectureViolation:
|
||||
"""
|
||||
Mark violation as ignored/won't fix
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
ignored_by: User who ignored it
|
||||
reason: Reason for ignoring
|
||||
|
||||
Returns:
|
||||
Updated ArchitectureViolation object
|
||||
"""
|
||||
violation = self.get_violation_by_id(db, violation_id)
|
||||
if not violation:
|
||||
raise ViolationNotFoundException(violation_id)
|
||||
|
||||
violation.status = "ignored"
|
||||
violation.resolved_at = datetime.now()
|
||||
violation.resolved_by = ignored_by
|
||||
violation.resolution_note = f"Ignored: {reason}"
|
||||
|
||||
db.flush()
|
||||
logger.info(f"Violation {violation_id} ignored by user {ignored_by}")
|
||||
return violation
|
||||
|
||||
def add_comment(
|
||||
self, db: Session, violation_id: int, user_id: int, comment: str
|
||||
) -> ViolationComment:
|
||||
"""
|
||||
Add comment to violation
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
violation_id: Violation ID
|
||||
user_id: User posting comment
|
||||
comment: Comment text
|
||||
|
||||
Returns:
|
||||
ViolationComment object
|
||||
"""
|
||||
comment_obj = ViolationComment(
|
||||
violation_id=violation_id, user_id=user_id, comment=comment
|
||||
)
|
||||
db.add(comment_obj)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Comment added to violation {violation_id} by user {user_id}")
|
||||
return comment_obj
|
||||
|
||||
def get_dashboard_stats(
|
||||
self, db: Session, validator_type: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
Get statistics for dashboard
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
validator_type: Optional filter by validator type. If None, returns combined stats.
|
||||
|
||||
Returns:
|
||||
Dictionary with various statistics including per-validator breakdown
|
||||
"""
|
||||
# Get latest scans by type
|
||||
latest_scans = self.get_latest_scans_by_type(db)
|
||||
|
||||
if not latest_scans:
|
||||
return self._empty_dashboard_stats()
|
||||
|
||||
# If specific validator type requested
|
||||
if validator_type and validator_type in latest_scans:
|
||||
scan = latest_scans[validator_type]
|
||||
return self._get_stats_for_scan(db, scan, validator_type)
|
||||
|
||||
# Combined stats across all validators
|
||||
return self._get_combined_stats(db, latest_scans)
|
||||
|
||||
def _empty_dashboard_stats(self) -> dict:
|
||||
"""Return empty dashboard stats structure"""
|
||||
return {
|
||||
"total_violations": 0,
|
||||
"errors": 0,
|
||||
"warnings": 0,
|
||||
"info": 0,
|
||||
"open": 0,
|
||||
"assigned": 0,
|
||||
"resolved": 0,
|
||||
"ignored": 0,
|
||||
"technical_debt_score": 100,
|
||||
"trend": [],
|
||||
"by_severity": {},
|
||||
"by_rule": {},
|
||||
"by_module": {},
|
||||
"top_files": [],
|
||||
"last_scan": None,
|
||||
"by_validator": {},
|
||||
}
|
||||
|
||||
def _get_stats_for_scan(
|
||||
self, db: Session, scan: ArchitectureScan, validator_type: str
|
||||
) -> dict:
|
||||
"""Get stats for a single scan/validator type"""
|
||||
# Get violation counts by status
|
||||
status_counts = (
|
||||
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.status)
|
||||
.all()
|
||||
)
|
||||
status_dict = {status: count for status, count in status_counts}
|
||||
|
||||
# Get violations by severity
|
||||
severity_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.severity)
|
||||
.all()
|
||||
)
|
||||
by_severity = {sev: count for sev, count in severity_counts}
|
||||
|
||||
# Get violations by rule
|
||||
rule_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.rule_id)
|
||||
.all()
|
||||
)
|
||||
by_rule = {
|
||||
rule: count
|
||||
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
|
||||
}
|
||||
|
||||
# Get top violating files
|
||||
file_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.file_path,
|
||||
func.count(ArchitectureViolation.id).label("count"),
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id == scan.id)
|
||||
.group_by(ArchitectureViolation.file_path)
|
||||
.order_by(desc("count"))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
top_files = [{"file": file, "count": count} for file, count in file_counts]
|
||||
|
||||
# Get violations by module
|
||||
by_module = self._get_violations_by_module(db, scan.id)
|
||||
|
||||
# Get trend for this validator type
|
||||
trend_scans = self.get_scan_history(db, limit=7, validator_type=validator_type)
|
||||
trend = [
|
||||
{
|
||||
"timestamp": s.timestamp.isoformat(),
|
||||
"violations": s.total_violations,
|
||||
"errors": s.errors,
|
||||
"warnings": s.warnings,
|
||||
}
|
||||
for s in reversed(trend_scans)
|
||||
]
|
||||
|
||||
return {
|
||||
"total_violations": scan.total_violations,
|
||||
"errors": scan.errors,
|
||||
"warnings": scan.warnings,
|
||||
"info": by_severity.get("info", 0),
|
||||
"open": status_dict.get("open", 0),
|
||||
"assigned": status_dict.get("assigned", 0),
|
||||
"resolved": status_dict.get("resolved", 0),
|
||||
"ignored": status_dict.get("ignored", 0),
|
||||
"technical_debt_score": self._calculate_score(scan.errors, scan.warnings),
|
||||
"trend": trend,
|
||||
"by_severity": by_severity,
|
||||
"by_rule": by_rule,
|
||||
"by_module": by_module,
|
||||
"top_files": top_files,
|
||||
"last_scan": scan.timestamp.isoformat(),
|
||||
"validator_type": validator_type,
|
||||
"by_validator": {
|
||||
validator_type: {
|
||||
"total_violations": scan.total_violations,
|
||||
"errors": scan.errors,
|
||||
"warnings": scan.warnings,
|
||||
"last_scan": scan.timestamp.isoformat(),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def _get_combined_stats(
|
||||
self, db: Session, latest_scans: dict[str, ArchitectureScan]
|
||||
) -> dict:
|
||||
"""Get combined stats across all validators"""
|
||||
# Aggregate totals
|
||||
total_violations = sum(s.total_violations for s in latest_scans.values())
|
||||
total_errors = sum(s.errors for s in latest_scans.values())
|
||||
total_warnings = sum(s.warnings for s in latest_scans.values())
|
||||
|
||||
# Get all scan IDs
|
||||
scan_ids = [s.id for s in latest_scans.values()]
|
||||
|
||||
# Get violation counts by status
|
||||
status_counts = (
|
||||
db.query(ArchitectureViolation.status, func.count(ArchitectureViolation.id))
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.status)
|
||||
.all()
|
||||
)
|
||||
status_dict = {status: count for status, count in status_counts}
|
||||
|
||||
# Get violations by severity
|
||||
severity_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.severity, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.severity)
|
||||
.all()
|
||||
)
|
||||
by_severity = {sev: count for sev, count in severity_counts}
|
||||
|
||||
# Get violations by rule (across all validators)
|
||||
rule_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.rule_id, func.count(ArchitectureViolation.id)
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.rule_id)
|
||||
.all()
|
||||
)
|
||||
by_rule = {
|
||||
rule: count
|
||||
for rule, count in sorted(rule_counts, key=lambda x: x[1], reverse=True)[:10]
|
||||
}
|
||||
|
||||
# Get top violating files
|
||||
file_counts = (
|
||||
db.query(
|
||||
ArchitectureViolation.file_path,
|
||||
func.count(ArchitectureViolation.id).label("count"),
|
||||
)
|
||||
.filter(ArchitectureViolation.scan_id.in_(scan_ids))
|
||||
.group_by(ArchitectureViolation.file_path)
|
||||
.order_by(desc("count"))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
top_files = [{"file": file, "count": count} for file, count in file_counts]
|
||||
|
||||
# Get violations by module
|
||||
by_module = {}
|
||||
for scan_id in scan_ids:
|
||||
module_counts = self._get_violations_by_module(db, scan_id)
|
||||
for module, count in module_counts.items():
|
||||
by_module[module] = by_module.get(module, 0) + count
|
||||
by_module = dict(
|
||||
sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
)
|
||||
|
||||
# Per-validator breakdown
|
||||
by_validator = {}
|
||||
for vtype, scan in latest_scans.items():
|
||||
by_validator[vtype] = {
|
||||
"total_violations": scan.total_violations,
|
||||
"errors": scan.errors,
|
||||
"warnings": scan.warnings,
|
||||
"last_scan": scan.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
# Get most recent scan timestamp
|
||||
most_recent = max(latest_scans.values(), key=lambda s: s.timestamp)
|
||||
|
||||
return {
|
||||
"total_violations": total_violations,
|
||||
"errors": total_errors,
|
||||
"warnings": total_warnings,
|
||||
"info": by_severity.get("info", 0),
|
||||
"open": status_dict.get("open", 0),
|
||||
"assigned": status_dict.get("assigned", 0),
|
||||
"resolved": status_dict.get("resolved", 0),
|
||||
"ignored": status_dict.get("ignored", 0),
|
||||
"technical_debt_score": self._calculate_score(total_errors, total_warnings),
|
||||
"trend": [], # Combined trend would need special handling
|
||||
"by_severity": by_severity,
|
||||
"by_rule": by_rule,
|
||||
"by_module": by_module,
|
||||
"top_files": top_files,
|
||||
"last_scan": most_recent.timestamp.isoformat(),
|
||||
"by_validator": by_validator,
|
||||
}
|
||||
|
||||
def _get_violations_by_module(self, db: Session, scan_id: int) -> dict[str, int]:
|
||||
"""Extract module from file paths and count violations"""
|
||||
by_module = {}
|
||||
violations = (
|
||||
db.query(ArchitectureViolation.file_path)
|
||||
.filter(ArchitectureViolation.scan_id == scan_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for v in violations:
|
||||
path_parts = v.file_path.split("/")
|
||||
if len(path_parts) >= 2:
|
||||
module = "/".join(path_parts[:2])
|
||||
else:
|
||||
module = path_parts[0]
|
||||
by_module[module] = by_module.get(module, 0) + 1
|
||||
|
||||
return dict(sorted(by_module.items(), key=lambda x: x[1], reverse=True)[:10])
|
||||
|
||||
def _calculate_score(self, errors: int, warnings: int) -> int:
|
||||
"""Calculate technical debt score (0-100)"""
|
||||
score = 100 - (errors * 0.5 + warnings * 0.05)
|
||||
return max(0, min(100, int(score)))
|
||||
|
||||
def calculate_technical_debt_score(
|
||||
self, db: Session, scan_id: int = None, validator_type: str = None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate technical debt score (0-100)
|
||||
|
||||
Formula: 100 - (errors * 0.5 + warnings * 0.05)
|
||||
Capped at 0 minimum
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
scan_id: Scan ID (if None, use latest)
|
||||
validator_type: Filter by validator type
|
||||
|
||||
Returns:
|
||||
Score from 0-100
|
||||
"""
|
||||
if scan_id is None:
|
||||
latest_scan = self.get_latest_scan(db, validator_type)
|
||||
if not latest_scan:
|
||||
return 100
|
||||
scan_id = latest_scan.id
|
||||
|
||||
scan = self.get_scan_by_id(db, scan_id)
|
||||
if not scan:
|
||||
return 100
|
||||
|
||||
return self._calculate_score(scan.errors, scan.warnings)
|
||||
|
||||
def _get_git_commit_hash(self) -> str | None:
|
||||
"""Get current git commit hash"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()[:40]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
code_quality_service = CodeQualityService()
|
||||
__all__ = [
|
||||
"code_quality_service",
|
||||
"CodeQualityService",
|
||||
"VALIDATOR_ARCHITECTURE",
|
||||
"VALIDATOR_SECURITY",
|
||||
"VALIDATOR_PERFORMANCE",
|
||||
"VALID_VALIDATOR_TYPES",
|
||||
"VALIDATOR_SCRIPTS",
|
||||
"VALIDATOR_NAMES",
|
||||
]
|
||||
|
||||
@@ -1,314 +1,23 @@
|
||||
# app/services/customer_address_service.py
|
||||
"""
|
||||
Customer Address Service
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Business logic for managing customer addresses with vendor isolation.
|
||||
The canonical implementation is now in:
|
||||
app/modules/customers/services/customer_address_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.customers.services import customer_address_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
AddressLimitExceededException,
|
||||
AddressNotFoundException,
|
||||
from app.modules.customers.services.customer_address_service import (
|
||||
customer_address_service,
|
||||
CustomerAddressService,
|
||||
)
|
||||
from models.database.customer import CustomerAddress
|
||||
from models.schema.customer import CustomerAddressCreate, CustomerAddressUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomerAddressService:
|
||||
"""Service for managing customer addresses with vendor isolation."""
|
||||
|
||||
MAX_ADDRESSES_PER_CUSTOMER = 10
|
||||
|
||||
def list_addresses(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> list[CustomerAddress]:
|
||||
"""
|
||||
Get all addresses for a customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
List of customer addresses
|
||||
"""
|
||||
return (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
)
|
||||
.order_by(CustomerAddress.is_default.desc(), CustomerAddress.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_address(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_id: int
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Get a specific address with ownership validation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
|
||||
Returns:
|
||||
Customer address
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found or doesn't belong to customer
|
||||
"""
|
||||
address = (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.id == address_id,
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not address:
|
||||
raise AddressNotFoundException(address_id)
|
||||
|
||||
return address
|
||||
|
||||
def get_default_address(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_type: str
|
||||
) -> CustomerAddress | None:
|
||||
"""
|
||||
Get the default address for a specific type.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_type: 'shipping' or 'billing'
|
||||
|
||||
Returns:
|
||||
Default address or None if not set
|
||||
"""
|
||||
return (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
CustomerAddress.address_type == address_type,
|
||||
CustomerAddress.is_default == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def create_address(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
address_data: CustomerAddressCreate,
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Create a new address for a customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_data: Address creation data
|
||||
|
||||
Returns:
|
||||
Created customer address
|
||||
|
||||
Raises:
|
||||
AddressLimitExceededException: If max addresses reached
|
||||
"""
|
||||
# Check address limit
|
||||
current_count = (
|
||||
db.query(CustomerAddress)
|
||||
.filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if current_count >= self.MAX_ADDRESSES_PER_CUSTOMER:
|
||||
raise AddressLimitExceededException(self.MAX_ADDRESSES_PER_CUSTOMER)
|
||||
|
||||
# If setting as default, clear other defaults of same type
|
||||
if address_data.is_default:
|
||||
self._clear_other_defaults(
|
||||
db, vendor_id, customer_id, address_data.address_type
|
||||
)
|
||||
|
||||
# Create the address
|
||||
address = CustomerAddress(
|
||||
vendor_id=vendor_id,
|
||||
customer_id=customer_id,
|
||||
address_type=address_data.address_type,
|
||||
first_name=address_data.first_name,
|
||||
last_name=address_data.last_name,
|
||||
company=address_data.company,
|
||||
address_line_1=address_data.address_line_1,
|
||||
address_line_2=address_data.address_line_2,
|
||||
city=address_data.city,
|
||||
postal_code=address_data.postal_code,
|
||||
country_name=address_data.country_name,
|
||||
country_iso=address_data.country_iso,
|
||||
is_default=address_data.is_default,
|
||||
)
|
||||
|
||||
db.add(address)
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Created address {address.id} for customer {customer_id} "
|
||||
f"(type={address_data.address_type}, default={address_data.is_default})"
|
||||
)
|
||||
|
||||
return address
|
||||
|
||||
def update_address(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
address_id: int,
|
||||
address_data: CustomerAddressUpdate,
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Update an existing address.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
address_data: Address update data
|
||||
|
||||
Returns:
|
||||
Updated customer address
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found
|
||||
"""
|
||||
address = self.get_address(db, vendor_id, customer_id, address_id)
|
||||
|
||||
# Update only provided fields
|
||||
update_data = address_data.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle default flag - clear others if setting to default
|
||||
if update_data.get("is_default") is True:
|
||||
# Use updated type if provided, otherwise current type
|
||||
address_type = update_data.get("address_type", address.address_type)
|
||||
self._clear_other_defaults(
|
||||
db, vendor_id, customer_id, address_type, exclude_id=address_id
|
||||
)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(address, field, value)
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Updated address {address_id} for customer {customer_id}")
|
||||
|
||||
return address
|
||||
|
||||
def delete_address(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_id: int
|
||||
) -> None:
|
||||
"""
|
||||
Delete an address.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found
|
||||
"""
|
||||
address = self.get_address(db, vendor_id, customer_id, address_id)
|
||||
|
||||
db.delete(address)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Deleted address {address_id} for customer {customer_id}")
|
||||
|
||||
def set_default(
|
||||
self, db: Session, vendor_id: int, customer_id: int, address_id: int
|
||||
) -> CustomerAddress:
|
||||
"""
|
||||
Set an address as the default for its type.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_id: Address ID
|
||||
|
||||
Returns:
|
||||
Updated customer address
|
||||
|
||||
Raises:
|
||||
AddressNotFoundException: If address not found
|
||||
"""
|
||||
address = self.get_address(db, vendor_id, customer_id, address_id)
|
||||
|
||||
# Clear other defaults of same type
|
||||
self._clear_other_defaults(
|
||||
db, vendor_id, customer_id, address.address_type, exclude_id=address_id
|
||||
)
|
||||
|
||||
# Set this one as default
|
||||
address.is_default = True
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Set address {address_id} as default {address.address_type} "
|
||||
f"for customer {customer_id}"
|
||||
)
|
||||
|
||||
return address
|
||||
|
||||
def _clear_other_defaults(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
address_type: str,
|
||||
exclude_id: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Clear the default flag on other addresses of the same type.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID for isolation
|
||||
customer_id: Customer ID
|
||||
address_type: 'shipping' or 'billing'
|
||||
exclude_id: Address ID to exclude from clearing
|
||||
"""
|
||||
query = db.query(CustomerAddress).filter(
|
||||
CustomerAddress.vendor_id == vendor_id,
|
||||
CustomerAddress.customer_id == customer_id,
|
||||
CustomerAddress.address_type == address_type,
|
||||
CustomerAddress.is_default == True, # noqa: E712
|
||||
)
|
||||
|
||||
if exclude_id:
|
||||
query = query.filter(CustomerAddress.id != exclude_id)
|
||||
|
||||
query.update({"is_default": False}, synchronize_session=False)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
customer_address_service = CustomerAddressService()
|
||||
__all__ = [
|
||||
"customer_address_service",
|
||||
"CustomerAddressService",
|
||||
]
|
||||
|
||||
@@ -1,664 +1,23 @@
|
||||
# app/services/customer_service.py
|
||||
"""
|
||||
Customer management service.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Handles customer registration, authentication, and profile management
|
||||
with complete vendor isolation.
|
||||
The canonical implementation is now in:
|
||||
app/modules/customers/services/customer_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.customers.services import customer_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions.customer import (
|
||||
CustomerNotActiveException,
|
||||
CustomerNotFoundException,
|
||||
CustomerValidationException,
|
||||
DuplicateCustomerEmailException,
|
||||
InvalidCustomerCredentialsException,
|
||||
InvalidPasswordResetTokenException,
|
||||
PasswordTooShortException,
|
||||
from app.modules.customers.services.customer_service import (
|
||||
customer_service,
|
||||
CustomerService,
|
||||
)
|
||||
from app.exceptions.vendor import VendorNotActiveException, VendorNotFoundException
|
||||
from app.services.auth_service import AuthService
|
||||
from models.database.customer import Customer
|
||||
from models.database.password_reset_token import PasswordResetToken
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.auth import UserLogin
|
||||
from models.schema.customer import CustomerRegister, CustomerUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomerService:
|
||||
"""Service for managing vendor-scoped customers."""
|
||||
|
||||
def __init__(self):
|
||||
self.auth_service = AuthService()
|
||||
|
||||
def register_customer(
|
||||
self, db: Session, vendor_id: int, customer_data: CustomerRegister
|
||||
) -> Customer:
|
||||
"""
|
||||
Register a new customer for a specific vendor.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_data: Customer registration data
|
||||
|
||||
Returns:
|
||||
Customer: Created customer object
|
||||
|
||||
Raises:
|
||||
VendorNotFoundException: If vendor doesn't exist
|
||||
VendorNotActiveException: If vendor is not active
|
||||
DuplicateCustomerEmailException: If email already exists for this vendor
|
||||
CustomerValidationException: If customer data is invalid
|
||||
"""
|
||||
# Verify vendor exists and is active
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
|
||||
|
||||
if not vendor.is_active:
|
||||
raise VendorNotActiveException(vendor.vendor_code)
|
||||
|
||||
# Check if email already exists for this vendor
|
||||
existing_customer = (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == customer_data.email.lower(),
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_customer:
|
||||
raise DuplicateCustomerEmailException(
|
||||
customer_data.email, vendor.vendor_code
|
||||
)
|
||||
|
||||
# Generate unique customer number for this vendor
|
||||
customer_number = self._generate_customer_number(
|
||||
db, vendor_id, vendor.vendor_code
|
||||
)
|
||||
|
||||
# Hash password
|
||||
hashed_password = self.auth_service.hash_password(customer_data.password)
|
||||
|
||||
# Create customer
|
||||
customer = Customer(
|
||||
vendor_id=vendor_id,
|
||||
email=customer_data.email.lower(),
|
||||
hashed_password=hashed_password,
|
||||
first_name=customer_data.first_name,
|
||||
last_name=customer_data.last_name,
|
||||
phone=customer_data.phone,
|
||||
customer_number=customer_number,
|
||||
marketing_consent=(
|
||||
customer_data.marketing_consent
|
||||
if hasattr(customer_data, "marketing_consent")
|
||||
else False
|
||||
),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
try:
|
||||
db.add(customer)
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
logger.info(
|
||||
f"Customer registered successfully: {customer.email} "
|
||||
f"(ID: {customer.id}, Number: {customer.customer_number}) "
|
||||
f"for vendor {vendor.vendor_code}"
|
||||
)
|
||||
|
||||
return customer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering customer: {str(e)}")
|
||||
raise CustomerValidationException(
|
||||
message="Failed to register customer", details={"error": str(e)}
|
||||
)
|
||||
|
||||
def login_customer(
|
||||
self, db: Session, vendor_id: int, credentials: UserLogin
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Authenticate customer and generate JWT token.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
credentials: Login credentials
|
||||
|
||||
Returns:
|
||||
Dict containing customer and token data
|
||||
|
||||
Raises:
|
||||
VendorNotFoundException: If vendor doesn't exist
|
||||
InvalidCustomerCredentialsException: If credentials are invalid
|
||||
CustomerNotActiveException: If customer account is inactive
|
||||
"""
|
||||
# Verify vendor exists
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
|
||||
|
||||
# Find customer by email (vendor-scoped)
|
||||
customer = (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == credentials.email_or_username.lower(),
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not customer:
|
||||
raise InvalidCustomerCredentialsException()
|
||||
|
||||
# Verify password using auth_manager directly
|
||||
if not self.auth_service.auth_manager.verify_password(
|
||||
credentials.password, customer.hashed_password
|
||||
):
|
||||
raise InvalidCustomerCredentialsException()
|
||||
|
||||
# Check if customer is active
|
||||
if not customer.is_active:
|
||||
raise CustomerNotActiveException(customer.email)
|
||||
|
||||
# Generate JWT token with customer context
|
||||
# Use auth_manager directly since Customer is not a User model
|
||||
from datetime import datetime
|
||||
|
||||
from jose import jwt
|
||||
|
||||
auth_manager = self.auth_service.auth_manager
|
||||
expires_delta = timedelta(minutes=auth_manager.token_expire_minutes)
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
|
||||
payload = {
|
||||
"sub": str(customer.id),
|
||||
"email": customer.email,
|
||||
"vendor_id": vendor_id,
|
||||
"type": "customer",
|
||||
"exp": expire,
|
||||
"iat": datetime.now(UTC),
|
||||
}
|
||||
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
token_data = {
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": auth_manager.token_expire_minutes * 60,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Customer login successful: {customer.email} "
|
||||
f"for vendor {vendor.vendor_code}"
|
||||
)
|
||||
|
||||
return {"customer": customer, "token_data": token_data}
|
||||
|
||||
def get_customer(self, db: Session, vendor_id: int, customer_id: int) -> Customer:
|
||||
"""
|
||||
Get customer by ID with vendor isolation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer: Customer object
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
customer = (
|
||||
db.query(Customer)
|
||||
.filter(and_(Customer.id == customer_id, Customer.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not customer:
|
||||
raise CustomerNotFoundException(str(customer_id))
|
||||
|
||||
return customer
|
||||
|
||||
def get_customer_by_email(
|
||||
self, db: Session, vendor_id: int, email: str
|
||||
) -> Customer | None:
|
||||
"""
|
||||
Get customer by email (vendor-scoped).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
email: Customer email
|
||||
|
||||
Returns:
|
||||
Optional[Customer]: Customer object or None
|
||||
"""
|
||||
return (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(Customer.vendor_id == vendor_id, Customer.email == email.lower())
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_vendor_customers(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
) -> tuple[list[Customer], int]:
|
||||
"""
|
||||
Get all customers for a vendor with filtering and pagination.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
search: Search in name/email
|
||||
is_active: Filter by active status
|
||||
|
||||
Returns:
|
||||
Tuple of (customers, total_count)
|
||||
"""
|
||||
from sqlalchemy import or_
|
||||
|
||||
query = db.query(Customer).filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
Customer.email.ilike(search_pattern),
|
||||
Customer.first_name.ilike(search_pattern),
|
||||
Customer.last_name.ilike(search_pattern),
|
||||
Customer.customer_number.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Customer.is_active == is_active)
|
||||
|
||||
# Order by most recent first
|
||||
query = query.order_by(Customer.created_at.desc())
|
||||
|
||||
total = query.count()
|
||||
customers = query.offset(skip).limit(limit).all()
|
||||
|
||||
return customers, total
|
||||
|
||||
def get_customer_orders(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list, int]:
|
||||
"""
|
||||
Get orders for a specific customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
|
||||
Returns:
|
||||
Tuple of (orders, total_count)
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
from models.database.order import Order
|
||||
|
||||
# Verify customer belongs to vendor
|
||||
self.get_customer(db, vendor_id, customer_id)
|
||||
|
||||
# Get customer orders
|
||||
query = (
|
||||
db.query(Order)
|
||||
.filter(
|
||||
Order.customer_id == customer_id,
|
||||
Order.vendor_id == vendor_id,
|
||||
)
|
||||
.order_by(Order.created_at.desc())
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
orders = query.offset(skip).limit(limit).all()
|
||||
|
||||
return orders, total
|
||||
|
||||
def get_customer_statistics(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> dict:
|
||||
"""
|
||||
Get detailed statistics for a customer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Dict with customer statistics
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
from models.database.order import Order
|
||||
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
|
||||
# Get order statistics
|
||||
order_stats = (
|
||||
db.query(
|
||||
func.count(Order.id).label("total_orders"),
|
||||
func.sum(Order.total_cents).label("total_spent_cents"),
|
||||
func.avg(Order.total_cents).label("avg_order_cents"),
|
||||
func.max(Order.created_at).label("last_order_date"),
|
||||
)
|
||||
.filter(Order.customer_id == customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
total_orders = order_stats.total_orders or 0
|
||||
total_spent_cents = order_stats.total_spent_cents or 0
|
||||
avg_order_cents = order_stats.avg_order_cents or 0
|
||||
|
||||
return {
|
||||
"customer_id": customer_id,
|
||||
"total_orders": total_orders,
|
||||
"total_spent": total_spent_cents / 100, # Convert to euros
|
||||
"average_order_value": avg_order_cents / 100 if avg_order_cents else 0.0,
|
||||
"last_order_date": order_stats.last_order_date,
|
||||
"member_since": customer.created_at,
|
||||
"is_active": customer.is_active,
|
||||
}
|
||||
|
||||
def toggle_customer_status(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> Customer:
|
||||
"""
|
||||
Toggle customer active status.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer: Updated customer
|
||||
"""
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
customer.is_active = not customer.is_active
|
||||
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
action = "activated" if customer.is_active else "deactivated"
|
||||
logger.info(f"Customer {action}: {customer.email} (ID: {customer.id})")
|
||||
|
||||
return customer
|
||||
|
||||
def update_customer(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
customer_id: int,
|
||||
customer_data: CustomerUpdate,
|
||||
) -> Customer:
|
||||
"""
|
||||
Update customer profile.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
customer_data: Updated customer data
|
||||
|
||||
Returns:
|
||||
Customer: Updated customer object
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
CustomerValidationException: If update data is invalid
|
||||
"""
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
|
||||
# Update fields
|
||||
update_data = customer_data.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == "email" and value:
|
||||
# Check if new email already exists for this vendor
|
||||
existing = (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == value.lower(),
|
||||
Customer.id != customer_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
raise DuplicateCustomerEmailException(value, "vendor")
|
||||
|
||||
setattr(customer, field, value.lower())
|
||||
elif hasattr(customer, field):
|
||||
setattr(customer, field, value)
|
||||
|
||||
try:
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
logger.info(f"Customer updated: {customer.email} (ID: {customer.id})")
|
||||
|
||||
return customer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating customer: {str(e)}")
|
||||
raise CustomerValidationException(
|
||||
message="Failed to update customer", details={"error": str(e)}
|
||||
)
|
||||
|
||||
def deactivate_customer(
|
||||
self, db: Session, vendor_id: int, customer_id: int
|
||||
) -> Customer:
|
||||
"""
|
||||
Deactivate customer account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
customer_id: Customer ID
|
||||
|
||||
Returns:
|
||||
Customer: Deactivated customer object
|
||||
|
||||
Raises:
|
||||
CustomerNotFoundException: If customer not found
|
||||
"""
|
||||
customer = self.get_customer(db, vendor_id, customer_id)
|
||||
customer.is_active = False
|
||||
|
||||
db.flush()
|
||||
db.refresh(customer)
|
||||
|
||||
logger.info(f"Customer deactivated: {customer.email} (ID: {customer.id})")
|
||||
|
||||
return customer
|
||||
|
||||
def update_customer_stats(
|
||||
self, db: Session, customer_id: int, order_total: float
|
||||
) -> None:
|
||||
"""
|
||||
Update customer statistics after order.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
customer_id: Customer ID
|
||||
order_total: Order total amount
|
||||
"""
|
||||
customer = db.query(Customer).filter(Customer.id == customer_id).first()
|
||||
|
||||
if customer:
|
||||
customer.total_orders += 1
|
||||
customer.total_spent += order_total
|
||||
customer.last_order_date = datetime.utcnow()
|
||||
|
||||
logger.debug(f"Updated stats for customer {customer.email}")
|
||||
|
||||
def _generate_customer_number(
|
||||
self, db: Session, vendor_id: int, vendor_code: str
|
||||
) -> str:
|
||||
"""
|
||||
Generate unique customer number for vendor.
|
||||
|
||||
Format: {VENDOR_CODE}-CUST-{SEQUENCE}
|
||||
Example: VENDORA-CUST-00001
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
vendor_code: Vendor code
|
||||
|
||||
Returns:
|
||||
str: Unique customer number
|
||||
"""
|
||||
# Get count of customers for this vendor
|
||||
count = db.query(Customer).filter(Customer.vendor_id == vendor_id).count()
|
||||
|
||||
# Generate number with padding
|
||||
sequence = str(count + 1).zfill(5)
|
||||
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
|
||||
|
||||
# Ensure uniqueness (in case of deletions)
|
||||
while (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
and_(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.customer_number == customer_number,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
):
|
||||
count += 1
|
||||
sequence = str(count + 1).zfill(5)
|
||||
customer_number = f"{vendor_code.upper()}-CUST-{sequence}"
|
||||
|
||||
return customer_number
|
||||
|
||||
def get_customer_for_password_reset(
|
||||
self, db: Session, vendor_id: int, email: str
|
||||
) -> Customer | None:
|
||||
"""
|
||||
Get active customer by email for password reset.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
email: Customer email
|
||||
|
||||
Returns:
|
||||
Customer if found and active, None otherwise
|
||||
"""
|
||||
return (
|
||||
db.query(Customer)
|
||||
.filter(
|
||||
Customer.vendor_id == vendor_id,
|
||||
Customer.email == email.lower(),
|
||||
Customer.is_active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def validate_and_reset_password(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
reset_token: str,
|
||||
new_password: str,
|
||||
) -> Customer:
|
||||
"""
|
||||
Validate reset token and update customer password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reset_token: Password reset token from email
|
||||
new_password: New password
|
||||
|
||||
Returns:
|
||||
Customer: Updated customer
|
||||
|
||||
Raises:
|
||||
PasswordTooShortException: If password too short
|
||||
InvalidPasswordResetTokenException: If token invalid/expired
|
||||
CustomerNotActiveException: If customer not active
|
||||
"""
|
||||
# Validate password length
|
||||
if len(new_password) < 8:
|
||||
raise PasswordTooShortException(min_length=8)
|
||||
|
||||
# Find valid token
|
||||
token_record = PasswordResetToken.find_valid_token(db, reset_token)
|
||||
|
||||
if not token_record:
|
||||
raise InvalidPasswordResetTokenException()
|
||||
|
||||
# Get the customer and verify they belong to this vendor
|
||||
customer = (
|
||||
db.query(Customer)
|
||||
.filter(Customer.id == token_record.customer_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not customer or customer.vendor_id != vendor_id:
|
||||
raise InvalidPasswordResetTokenException()
|
||||
|
||||
if not customer.is_active:
|
||||
raise CustomerNotActiveException(customer.email)
|
||||
|
||||
# Hash the new password and update customer
|
||||
hashed_password = self.auth_service.hash_password(new_password)
|
||||
customer.hashed_password = hashed_password
|
||||
|
||||
# Mark token as used
|
||||
token_record.mark_used(db)
|
||||
|
||||
logger.info(f"Password reset completed for customer {customer.id}")
|
||||
|
||||
return customer
|
||||
|
||||
|
||||
# Singleton instance
|
||||
customer_service = CustomerService()
|
||||
__all__ = [
|
||||
"customer_service",
|
||||
"CustomerService",
|
||||
]
|
||||
|
||||
@@ -1,250 +1,25 @@
|
||||
# app/services/inventory_import_service.py
|
||||
"""
|
||||
Inventory import service for bulk importing stock from TSV/CSV files.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Supports two formats:
|
||||
1. One row per unit (quantity = count of rows):
|
||||
BIN EAN PRODUCT
|
||||
SA-10-02 0810050910101 Product Name
|
||||
SA-10-02 0810050910101 Product Name (2nd unit)
|
||||
The canonical implementation is now in:
|
||||
app/modules/inventory/services/inventory_import_service.py
|
||||
|
||||
2. With explicit quantity column:
|
||||
BIN EAN PRODUCT QUANTITY
|
||||
SA-10-02 0810050910101 Product Name 12
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
Products are matched by GTIN/EAN to existing vendor products.
|
||||
from app.modules.inventory.services import inventory_import_service
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from app.modules.inventory.services.inventory_import_service import (
|
||||
inventory_import_service,
|
||||
InventoryImportService,
|
||||
ImportResult,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportResult:
|
||||
"""Result of an inventory import operation."""
|
||||
|
||||
success: bool = True
|
||||
total_rows: int = 0
|
||||
entries_created: int = 0
|
||||
entries_updated: int = 0
|
||||
quantity_imported: int = 0
|
||||
unmatched_gtins: list = field(default_factory=list)
|
||||
errors: list = field(default_factory=list)
|
||||
|
||||
|
||||
class InventoryImportService:
|
||||
"""Service for importing inventory from TSV/CSV files."""
|
||||
|
||||
def import_from_text(
|
||||
self,
|
||||
db: Session,
|
||||
content: str,
|
||||
vendor_id: int,
|
||||
warehouse: str = "strassen",
|
||||
delimiter: str = "\t",
|
||||
clear_existing: bool = False,
|
||||
) -> ImportResult:
|
||||
"""
|
||||
Import inventory from TSV/CSV text content.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
content: TSV/CSV content as string
|
||||
vendor_id: Vendor ID for inventory
|
||||
warehouse: Warehouse name (default: "strassen")
|
||||
delimiter: Column delimiter (default: tab)
|
||||
clear_existing: If True, clear existing inventory before import
|
||||
|
||||
Returns:
|
||||
ImportResult with summary and errors
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
try:
|
||||
# Parse CSV/TSV
|
||||
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
|
||||
|
||||
# Normalize headers (case-insensitive, strip whitespace)
|
||||
if reader.fieldnames:
|
||||
reader.fieldnames = [h.strip().upper() for h in reader.fieldnames]
|
||||
|
||||
# Validate required columns
|
||||
required = {"BIN", "EAN"}
|
||||
if not reader.fieldnames or not required.issubset(set(reader.fieldnames)):
|
||||
result.success = False
|
||||
result.errors.append(
|
||||
f"Missing required columns. Found: {reader.fieldnames}, Required: {required}"
|
||||
)
|
||||
return result
|
||||
|
||||
has_quantity = "QUANTITY" in reader.fieldnames
|
||||
|
||||
# Group entries by (EAN, BIN)
|
||||
# Key: (ean, bin) -> quantity
|
||||
inventory_data: dict[tuple[str, str], int] = defaultdict(int)
|
||||
product_names: dict[str, str] = {} # EAN -> product name (for logging)
|
||||
|
||||
for row in reader:
|
||||
result.total_rows += 1
|
||||
|
||||
ean = row.get("EAN", "").strip()
|
||||
bin_loc = row.get("BIN", "").strip()
|
||||
product_name = row.get("PRODUCT", "").strip()
|
||||
|
||||
if not ean or not bin_loc:
|
||||
result.errors.append(f"Row {result.total_rows}: Missing EAN or BIN")
|
||||
continue
|
||||
|
||||
# Get quantity
|
||||
if has_quantity:
|
||||
try:
|
||||
qty = int(row.get("QUANTITY", "1").strip())
|
||||
except ValueError:
|
||||
result.errors.append(
|
||||
f"Row {result.total_rows}: Invalid quantity '{row.get('QUANTITY')}'"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
qty = 1 # Each row = 1 unit
|
||||
|
||||
inventory_data[(ean, bin_loc)] += qty
|
||||
if product_name:
|
||||
product_names[ean] = product_name
|
||||
|
||||
# Clear existing inventory if requested
|
||||
if clear_existing:
|
||||
db.query(Inventory).filter(
|
||||
Inventory.vendor_id == vendor_id,
|
||||
Inventory.warehouse == warehouse,
|
||||
).delete()
|
||||
db.flush()
|
||||
|
||||
# Build EAN to Product mapping for this vendor
|
||||
products = (
|
||||
db.query(Product)
|
||||
.filter(
|
||||
Product.vendor_id == vendor_id,
|
||||
Product.gtin.isnot(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
ean_to_product: dict[str, Product] = {p.gtin: p for p in products if p.gtin}
|
||||
|
||||
# Track unmatched GTINs
|
||||
unmatched: dict[str, int] = {} # EAN -> total quantity
|
||||
|
||||
# Process inventory entries
|
||||
for (ean, bin_loc), quantity in inventory_data.items():
|
||||
product = ean_to_product.get(ean)
|
||||
|
||||
if not product:
|
||||
# Track unmatched
|
||||
if ean not in unmatched:
|
||||
unmatched[ean] = 0
|
||||
unmatched[ean] += quantity
|
||||
continue
|
||||
|
||||
# Upsert inventory entry
|
||||
existing = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == product.id,
|
||||
Inventory.warehouse == warehouse,
|
||||
Inventory.bin_location == bin_loc,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.quantity = quantity
|
||||
existing.gtin = ean
|
||||
result.entries_updated += 1
|
||||
else:
|
||||
inv = Inventory(
|
||||
product_id=product.id,
|
||||
vendor_id=vendor_id,
|
||||
warehouse=warehouse,
|
||||
bin_location=bin_loc,
|
||||
location=bin_loc, # Legacy field
|
||||
quantity=quantity,
|
||||
gtin=ean,
|
||||
)
|
||||
db.add(inv)
|
||||
result.entries_created += 1
|
||||
|
||||
result.quantity_imported += quantity
|
||||
|
||||
db.flush()
|
||||
|
||||
# Format unmatched GTINs for result
|
||||
for ean, qty in unmatched.items():
|
||||
product_name = product_names.get(ean, "Unknown")
|
||||
result.unmatched_gtins.append(
|
||||
{"gtin": ean, "quantity": qty, "product_name": product_name}
|
||||
)
|
||||
|
||||
if result.unmatched_gtins:
|
||||
logger.warning(
|
||||
f"Import had {len(result.unmatched_gtins)} unmatched GTINs"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Inventory import failed")
|
||||
result.success = False
|
||||
result.errors.append(str(e))
|
||||
|
||||
return result
|
||||
|
||||
def import_from_file(
|
||||
self,
|
||||
db: Session,
|
||||
file_path: str,
|
||||
vendor_id: int,
|
||||
warehouse: str = "strassen",
|
||||
clear_existing: bool = False,
|
||||
) -> ImportResult:
|
||||
"""
|
||||
Import inventory from a TSV/CSV file.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
file_path: Path to TSV/CSV file
|
||||
vendor_id: Vendor ID for inventory
|
||||
warehouse: Warehouse name
|
||||
clear_existing: If True, clear existing inventory before import
|
||||
|
||||
Returns:
|
||||
ImportResult with summary and errors
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
return ImportResult(success=False, errors=[f"Failed to read file: {e}"])
|
||||
|
||||
# Detect delimiter
|
||||
first_line = content.split("\n")[0] if content else ""
|
||||
delimiter = "\t" if "\t" in first_line else ","
|
||||
|
||||
return self.import_from_text(
|
||||
db=db,
|
||||
content=content,
|
||||
vendor_id=vendor_id,
|
||||
warehouse=warehouse,
|
||||
delimiter=delimiter,
|
||||
clear_existing=clear_existing,
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
inventory_import_service = InventoryImportService()
|
||||
__all__ = [
|
||||
"inventory_import_service",
|
||||
"InventoryImportService",
|
||||
"ImportResult",
|
||||
]
|
||||
|
||||
@@ -1,949 +1,23 @@
|
||||
# app/services/inventory_service.py
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
"""
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
The canonical implementation is now in:
|
||||
app/modules/inventory/services/inventory_service.py
|
||||
|
||||
from app.exceptions import (
|
||||
InsufficientInventoryException,
|
||||
InvalidQuantityException,
|
||||
InventoryNotFoundException,
|
||||
InventoryValidationException,
|
||||
ProductNotFoundException,
|
||||
ValidationException,
|
||||
VendorNotFoundException,
|
||||
)
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.inventory import (
|
||||
AdminInventoryItem,
|
||||
AdminInventoryListResponse,
|
||||
AdminInventoryLocationsResponse,
|
||||
AdminInventoryStats,
|
||||
AdminLowStockItem,
|
||||
AdminVendorsWithInventoryResponse,
|
||||
AdminVendorWithInventory,
|
||||
InventoryAdjust,
|
||||
InventoryCreate,
|
||||
InventoryLocationResponse,
|
||||
InventoryReserve,
|
||||
InventoryUpdate,
|
||||
ProductInventorySummary,
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.inventory.services import inventory_service
|
||||
"""
|
||||
|
||||
from app.modules.inventory.services.inventory_service import (
|
||||
inventory_service,
|
||||
InventoryService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InventoryService:
|
||||
"""Service for inventory operations with vendor isolation."""
|
||||
|
||||
def set_inventory(
|
||||
self, db: Session, vendor_id: int, inventory_data: InventoryCreate
|
||||
) -> Inventory:
|
||||
"""
|
||||
Set exact inventory quantity for a product at a location (replaces existing).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID (from middleware)
|
||||
inventory_data: Inventory data
|
||||
|
||||
Returns:
|
||||
Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product belongs to vendor
|
||||
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
|
||||
|
||||
# Validate location
|
||||
location = self._validate_location(inventory_data.location)
|
||||
|
||||
# Validate quantity
|
||||
self._validate_quantity(inventory_data.quantity, allow_zero=True)
|
||||
|
||||
# Check if inventory entry exists
|
||||
existing = self._get_inventory_entry(
|
||||
db, inventory_data.product_id, location
|
||||
)
|
||||
|
||||
if existing:
|
||||
old_qty = existing.quantity
|
||||
existing.quantity = inventory_data.quantity
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(existing)
|
||||
|
||||
logger.info(
|
||||
f"Set inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"{old_qty} → {inventory_data.quantity}"
|
||||
)
|
||||
return existing
|
||||
# Create new inventory entry
|
||||
new_inventory = Inventory(
|
||||
product_id=inventory_data.product_id,
|
||||
vendor_id=vendor_id,
|
||||
warehouse="strassen", # Default warehouse
|
||||
bin_location=location, # Use location as bin location
|
||||
location=location, # Keep for backward compatibility
|
||||
quantity=inventory_data.quantity,
|
||||
gtin=product.marketplace_product.gtin, # Optional reference
|
||||
)
|
||||
db.add(new_inventory)
|
||||
db.flush()
|
||||
db.refresh(new_inventory)
|
||||
|
||||
logger.info(
|
||||
f"Created inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"{inventory_data.quantity}"
|
||||
)
|
||||
return new_inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InvalidQuantityException,
|
||||
InventoryValidationException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error setting inventory: {str(e)}")
|
||||
raise ValidationException("Failed to set inventory")
|
||||
|
||||
def adjust_inventory(
|
||||
self, db: Session, vendor_id: int, inventory_data: InventoryAdjust
|
||||
) -> Inventory:
|
||||
"""
|
||||
Adjust inventory by adding or removing quantity.
|
||||
Positive quantity = add, negative = remove.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
inventory_data: Adjustment data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product belongs to vendor
|
||||
product = self._get_vendor_product(db, vendor_id, inventory_data.product_id)
|
||||
|
||||
# Validate location
|
||||
location = self._validate_location(inventory_data.location)
|
||||
|
||||
# Check if inventory exists
|
||||
existing = self._get_inventory_entry(
|
||||
db, inventory_data.product_id, location
|
||||
)
|
||||
|
||||
if not existing:
|
||||
# Create new if adding, error if removing
|
||||
if inventory_data.quantity < 0:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {inventory_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Create with positive quantity
|
||||
new_inventory = Inventory(
|
||||
product_id=inventory_data.product_id,
|
||||
vendor_id=vendor_id,
|
||||
warehouse="strassen", # Default warehouse
|
||||
bin_location=location, # Use location as bin location
|
||||
location=location, # Keep for backward compatibility
|
||||
quantity=inventory_data.quantity,
|
||||
gtin=product.marketplace_product.gtin,
|
||||
)
|
||||
db.add(new_inventory)
|
||||
db.flush()
|
||||
db.refresh(new_inventory)
|
||||
|
||||
logger.info(
|
||||
f"Created inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"+{inventory_data.quantity}"
|
||||
)
|
||||
return new_inventory
|
||||
|
||||
# Adjust existing inventory
|
||||
old_qty = existing.quantity
|
||||
new_qty = old_qty + inventory_data.quantity
|
||||
|
||||
# Validate resulting quantity
|
||||
if new_qty < 0:
|
||||
raise InsufficientInventoryException(
|
||||
f"Insufficient inventory. Available: {old_qty}, "
|
||||
f"Requested removal: {abs(inventory_data.quantity)}"
|
||||
)
|
||||
|
||||
existing.quantity = new_qty
|
||||
existing.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(existing)
|
||||
|
||||
logger.info(
|
||||
f"Adjusted inventory for product {inventory_data.product_id} at {location}: "
|
||||
f"{old_qty} {'+' if inventory_data.quantity >= 0 else ''}{inventory_data.quantity} = {new_qty}"
|
||||
)
|
||||
return existing
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InventoryValidationException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error adjusting inventory: {str(e)}")
|
||||
raise ValidationException("Failed to adjust inventory")
|
||||
|
||||
def reserve_inventory(
|
||||
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
|
||||
) -> Inventory:
|
||||
"""
|
||||
Reserve inventory for an order (increases reserved_quantity).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reserve_data: Reservation data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product
|
||||
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
|
||||
|
||||
# Validate location and quantity
|
||||
location = self._validate_location(reserve_data.location)
|
||||
self._validate_quantity(reserve_data.quantity, allow_zero=False)
|
||||
|
||||
# Get inventory entry
|
||||
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {reserve_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Check available quantity
|
||||
available = inventory.quantity - inventory.reserved_quantity
|
||||
if available < reserve_data.quantity:
|
||||
raise InsufficientInventoryException(
|
||||
f"Insufficient available inventory. Available: {available}, "
|
||||
f"Requested: {reserve_data.quantity}"
|
||||
)
|
||||
|
||||
# Reserve inventory
|
||||
inventory.reserved_quantity += reserve_data.quantity
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(
|
||||
f"Reserved {reserve_data.quantity} units for product {reserve_data.product_id} "
|
||||
f"at {location}"
|
||||
)
|
||||
return inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InvalidQuantityException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error reserving inventory: {str(e)}")
|
||||
raise ValidationException("Failed to reserve inventory")
|
||||
|
||||
def release_reservation(
|
||||
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
|
||||
) -> Inventory:
|
||||
"""
|
||||
Release reserved inventory (decreases reserved_quantity).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reserve_data: Reservation data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
# Validate product
|
||||
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
|
||||
|
||||
location = self._validate_location(reserve_data.location)
|
||||
self._validate_quantity(reserve_data.quantity, allow_zero=False)
|
||||
|
||||
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {reserve_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Validate reserved quantity
|
||||
if inventory.reserved_quantity < reserve_data.quantity:
|
||||
logger.warning(
|
||||
f"Attempting to release more than reserved. Reserved: {inventory.reserved_quantity}, "
|
||||
f"Requested: {reserve_data.quantity}"
|
||||
)
|
||||
inventory.reserved_quantity = 0
|
||||
else:
|
||||
inventory.reserved_quantity -= reserve_data.quantity
|
||||
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(
|
||||
f"Released {reserve_data.quantity} units for product {reserve_data.product_id} "
|
||||
f"at {location}"
|
||||
)
|
||||
return inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InvalidQuantityException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error releasing reservation: {str(e)}")
|
||||
raise ValidationException("Failed to release reservation")
|
||||
|
||||
def fulfill_reservation(
|
||||
self, db: Session, vendor_id: int, reserve_data: InventoryReserve
|
||||
) -> Inventory:
|
||||
"""
|
||||
Fulfill a reservation (decreases both quantity and reserved_quantity).
|
||||
Use when order is shipped/completed.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
reserve_data: Reservation data
|
||||
|
||||
Returns:
|
||||
Updated Inventory object
|
||||
"""
|
||||
try:
|
||||
product = self._get_vendor_product(db, vendor_id, reserve_data.product_id)
|
||||
location = self._validate_location(reserve_data.location)
|
||||
self._validate_quantity(reserve_data.quantity, allow_zero=False)
|
||||
|
||||
inventory = self._get_inventory_entry(db, reserve_data.product_id, location)
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {reserve_data.product_id} at {location}"
|
||||
)
|
||||
|
||||
# Validate quantities
|
||||
if inventory.quantity < reserve_data.quantity:
|
||||
raise InsufficientInventoryException(
|
||||
f"Insufficient inventory. Available: {inventory.quantity}, "
|
||||
f"Requested: {reserve_data.quantity}"
|
||||
)
|
||||
|
||||
if inventory.reserved_quantity < reserve_data.quantity:
|
||||
logger.warning(
|
||||
f"Fulfilling more than reserved. Reserved: {inventory.reserved_quantity}, "
|
||||
f"Fulfilling: {reserve_data.quantity}"
|
||||
)
|
||||
|
||||
# Fulfill (remove from both quantity and reserved)
|
||||
inventory.quantity -= reserve_data.quantity
|
||||
inventory.reserved_quantity = max(
|
||||
0, inventory.reserved_quantity - reserve_data.quantity
|
||||
)
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(
|
||||
f"Fulfilled {reserve_data.quantity} units for product {reserve_data.product_id} "
|
||||
f"at {location}"
|
||||
)
|
||||
return inventory
|
||||
|
||||
except (
|
||||
ProductNotFoundException,
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InvalidQuantityException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error fulfilling reservation: {str(e)}")
|
||||
raise ValidationException("Failed to fulfill reservation")
|
||||
|
||||
def get_product_inventory(
|
||||
self, db: Session, vendor_id: int, product_id: int
|
||||
) -> ProductInventorySummary:
|
||||
"""
|
||||
Get inventory summary for a product across all locations.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
product_id: Product ID
|
||||
|
||||
Returns:
|
||||
ProductInventorySummary
|
||||
"""
|
||||
try:
|
||||
product = self._get_vendor_product(db, vendor_id, product_id)
|
||||
|
||||
inventory_entries = (
|
||||
db.query(Inventory).filter(Inventory.product_id == product_id).all()
|
||||
)
|
||||
|
||||
if not inventory_entries:
|
||||
return ProductInventorySummary(
|
||||
product_id=product_id,
|
||||
vendor_id=vendor_id,
|
||||
product_sku=product.vendor_sku,
|
||||
product_title=product.marketplace_product.get_title() or "",
|
||||
total_quantity=0,
|
||||
total_reserved=0,
|
||||
total_available=0,
|
||||
locations=[],
|
||||
)
|
||||
|
||||
total_qty = sum(inv.quantity for inv in inventory_entries)
|
||||
total_reserved = sum(inv.reserved_quantity for inv in inventory_entries)
|
||||
total_available = sum(inv.available_quantity for inv in inventory_entries)
|
||||
|
||||
locations = [
|
||||
InventoryLocationResponse(
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
)
|
||||
for inv in inventory_entries
|
||||
]
|
||||
|
||||
return ProductInventorySummary(
|
||||
product_id=product_id,
|
||||
vendor_id=vendor_id,
|
||||
product_sku=product.vendor_sku,
|
||||
product_title=product.marketplace_product.get_title() or "",
|
||||
total_quantity=total_qty,
|
||||
total_reserved=total_reserved,
|
||||
total_available=total_available,
|
||||
locations=locations,
|
||||
)
|
||||
|
||||
except ProductNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting product inventory: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve product inventory")
|
||||
|
||||
def get_vendor_inventory(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
location: str | None = None,
|
||||
low_stock_threshold: int | None = None,
|
||||
) -> list[Inventory]:
|
||||
"""
|
||||
Get all inventory for a vendor with filtering.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
location: Filter by location
|
||||
low_stock_threshold: Filter items below threshold
|
||||
|
||||
Returns:
|
||||
List of Inventory objects
|
||||
"""
|
||||
try:
|
||||
query = db.query(Inventory).filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
if location:
|
||||
query = query.filter(Inventory.location.ilike(f"%{location}%"))
|
||||
|
||||
if low_stock_threshold is not None:
|
||||
query = query.filter(Inventory.quantity <= low_stock_threshold)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting vendor inventory: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve vendor inventory")
|
||||
|
||||
def update_inventory(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
inventory_id: int,
|
||||
inventory_update: InventoryUpdate,
|
||||
) -> Inventory:
|
||||
"""Update inventory entry."""
|
||||
try:
|
||||
inventory = self._get_inventory_by_id(db, inventory_id)
|
||||
|
||||
# Verify ownership
|
||||
if inventory.vendor_id != vendor_id:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
|
||||
# Update fields
|
||||
if inventory_update.quantity is not None:
|
||||
self._validate_quantity(inventory_update.quantity, allow_zero=True)
|
||||
inventory.quantity = inventory_update.quantity
|
||||
|
||||
if inventory_update.reserved_quantity is not None:
|
||||
self._validate_quantity(
|
||||
inventory_update.reserved_quantity, allow_zero=True
|
||||
)
|
||||
inventory.reserved_quantity = inventory_update.reserved_quantity
|
||||
|
||||
if inventory_update.location:
|
||||
inventory.location = self._validate_location(inventory_update.location)
|
||||
|
||||
inventory.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(inventory)
|
||||
|
||||
logger.info(f"Updated inventory {inventory_id}")
|
||||
return inventory
|
||||
|
||||
except (
|
||||
InventoryNotFoundException,
|
||||
InvalidQuantityException,
|
||||
InventoryValidationException,
|
||||
):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating inventory: {str(e)}")
|
||||
raise ValidationException("Failed to update inventory")
|
||||
|
||||
def delete_inventory(self, db: Session, vendor_id: int, inventory_id: int) -> bool:
|
||||
"""Delete inventory entry."""
|
||||
try:
|
||||
inventory = self._get_inventory_by_id(db, inventory_id)
|
||||
|
||||
# Verify ownership
|
||||
if inventory.vendor_id != vendor_id:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
|
||||
db.delete(inventory)
|
||||
db.flush()
|
||||
|
||||
logger.info(f"Deleted inventory {inventory_id}")
|
||||
return True
|
||||
|
||||
except InventoryNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting inventory: {str(e)}")
|
||||
raise ValidationException("Failed to delete inventory")
|
||||
|
||||
# =========================================================================
|
||||
# Admin Methods (cross-vendor operations)
|
||||
# =========================================================================
|
||||
|
||||
def get_all_inventory_admin(
|
||||
self,
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
vendor_id: int | None = None,
|
||||
location: str | None = None,
|
||||
low_stock: int | None = None,
|
||||
search: str | None = None,
|
||||
) -> AdminInventoryListResponse:
|
||||
"""
|
||||
Get inventory across all vendors with filtering (admin only).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
vendor_id: Filter by vendor
|
||||
location: Filter by location
|
||||
low_stock: Filter items below threshold
|
||||
search: Search by product title or SKU
|
||||
|
||||
Returns:
|
||||
AdminInventoryListResponse
|
||||
"""
|
||||
query = db.query(Inventory).join(Product).join(Vendor)
|
||||
|
||||
# Apply filters
|
||||
if vendor_id is not None:
|
||||
query = query.filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
if location:
|
||||
query = query.filter(Inventory.location.ilike(f"%{location}%"))
|
||||
|
||||
if low_stock is not None:
|
||||
query = query.filter(Inventory.quantity <= low_stock)
|
||||
|
||||
if search:
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.marketplace_product_translation import (
|
||||
MarketplaceProductTranslation,
|
||||
)
|
||||
|
||||
query = (
|
||||
query.join(MarketplaceProduct)
|
||||
.outerjoin(MarketplaceProductTranslation)
|
||||
.filter(
|
||||
(MarketplaceProductTranslation.title.ilike(f"%{search}%"))
|
||||
| (Product.vendor_sku.ilike(f"%{search}%"))
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count before pagination
|
||||
total = query.count()
|
||||
|
||||
# Apply pagination
|
||||
inventories = query.offset(skip).limit(limit).all()
|
||||
|
||||
# Build response with vendor/product info
|
||||
items = []
|
||||
for inv in inventories:
|
||||
product = inv.product
|
||||
vendor = inv.vendor
|
||||
title = None
|
||||
if product and product.marketplace_product:
|
||||
title = product.marketplace_product.get_title()
|
||||
|
||||
items.append(
|
||||
AdminInventoryItem(
|
||||
id=inv.id,
|
||||
product_id=inv.product_id,
|
||||
vendor_id=inv.vendor_id,
|
||||
vendor_name=vendor.name if vendor else None,
|
||||
vendor_code=vendor.vendor_code if vendor else None,
|
||||
product_title=title,
|
||||
product_sku=product.vendor_sku if product else None,
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
gtin=inv.gtin,
|
||||
created_at=inv.created_at,
|
||||
updated_at=inv.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
return AdminInventoryListResponse(
|
||||
inventories=items,
|
||||
total=total,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
vendor_filter=vendor_id,
|
||||
location_filter=location,
|
||||
)
|
||||
|
||||
def get_inventory_stats_admin(self, db: Session) -> AdminInventoryStats:
|
||||
"""Get platform-wide inventory statistics (admin only)."""
|
||||
# Total entries
|
||||
total_entries = db.query(func.count(Inventory.id)).scalar() or 0
|
||||
|
||||
# Aggregate quantities
|
||||
totals = db.query(
|
||||
func.sum(Inventory.quantity).label("total_qty"),
|
||||
func.sum(Inventory.reserved_quantity).label("total_reserved"),
|
||||
).first()
|
||||
|
||||
total_quantity = totals.total_qty or 0
|
||||
total_reserved = totals.total_reserved or 0
|
||||
total_available = total_quantity - total_reserved
|
||||
|
||||
# Low stock count (default threshold: 10)
|
||||
low_stock_count = (
|
||||
db.query(func.count(Inventory.id))
|
||||
.filter(Inventory.quantity <= 10)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# Vendors with inventory
|
||||
vendors_with_inventory = (
|
||||
db.query(func.count(func.distinct(Inventory.vendor_id))).scalar() or 0
|
||||
)
|
||||
|
||||
# Unique locations
|
||||
unique_locations = (
|
||||
db.query(func.count(func.distinct(Inventory.location))).scalar() or 0
|
||||
)
|
||||
|
||||
return AdminInventoryStats(
|
||||
total_entries=total_entries,
|
||||
total_quantity=total_quantity,
|
||||
total_reserved=total_reserved,
|
||||
total_available=total_available,
|
||||
low_stock_count=low_stock_count,
|
||||
vendors_with_inventory=vendors_with_inventory,
|
||||
unique_locations=unique_locations,
|
||||
)
|
||||
|
||||
def get_low_stock_items_admin(
|
||||
self,
|
||||
db: Session,
|
||||
threshold: int = 10,
|
||||
vendor_id: int | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[AdminLowStockItem]:
|
||||
"""Get items with low stock levels (admin only)."""
|
||||
query = (
|
||||
db.query(Inventory)
|
||||
.join(Product)
|
||||
.join(Vendor)
|
||||
.filter(Inventory.quantity <= threshold)
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
# Order by quantity ascending (most critical first)
|
||||
query = query.order_by(Inventory.quantity.asc())
|
||||
|
||||
inventories = query.limit(limit).all()
|
||||
|
||||
items = []
|
||||
for inv in inventories:
|
||||
product = inv.product
|
||||
vendor = inv.vendor
|
||||
title = None
|
||||
if product and product.marketplace_product:
|
||||
title = product.marketplace_product.get_title()
|
||||
|
||||
items.append(
|
||||
AdminLowStockItem(
|
||||
id=inv.id,
|
||||
product_id=inv.product_id,
|
||||
vendor_id=inv.vendor_id,
|
||||
vendor_name=vendor.name if vendor else None,
|
||||
product_title=title,
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
def get_vendors_with_inventory_admin(
|
||||
self, db: Session
|
||||
) -> AdminVendorsWithInventoryResponse:
|
||||
"""Get list of vendors that have inventory entries (admin only)."""
|
||||
# noqa: SVC-005 - Admin function, intentionally cross-vendor
|
||||
# Use subquery to avoid DISTINCT on JSON columns (PostgreSQL can't compare JSON)
|
||||
vendor_ids_subquery = (
|
||||
db.query(Inventory.vendor_id)
|
||||
.distinct()
|
||||
.subquery()
|
||||
)
|
||||
vendors = (
|
||||
db.query(Vendor)
|
||||
.filter(Vendor.id.in_(db.query(vendor_ids_subquery.c.vendor_id)))
|
||||
.order_by(Vendor.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
return AdminVendorsWithInventoryResponse(
|
||||
vendors=[
|
||||
AdminVendorWithInventory(
|
||||
id=v.id, name=v.name, vendor_code=v.vendor_code
|
||||
)
|
||||
for v in vendors
|
||||
]
|
||||
)
|
||||
|
||||
def get_inventory_locations_admin(
|
||||
self, db: Session, vendor_id: int | None = None
|
||||
) -> AdminInventoryLocationsResponse:
|
||||
"""Get list of unique inventory locations (admin only)."""
|
||||
query = db.query(func.distinct(Inventory.location))
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(Inventory.vendor_id == vendor_id)
|
||||
|
||||
locations = [loc[0] for loc in query.all()]
|
||||
|
||||
return AdminInventoryLocationsResponse(locations=sorted(locations))
|
||||
|
||||
def get_vendor_inventory_admin(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
location: str | None = None,
|
||||
low_stock: int | None = None,
|
||||
) -> AdminInventoryListResponse:
|
||||
"""Get inventory for a specific vendor (admin only)."""
|
||||
# Verify vendor exists
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
|
||||
|
||||
# Use the existing method
|
||||
inventories = self.get_vendor_inventory(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
location=location,
|
||||
low_stock_threshold=low_stock,
|
||||
)
|
||||
|
||||
# Build response with product info
|
||||
items = []
|
||||
for inv in inventories:
|
||||
product = inv.product
|
||||
title = None
|
||||
if product and product.marketplace_product:
|
||||
title = product.marketplace_product.get_title()
|
||||
|
||||
items.append(
|
||||
AdminInventoryItem(
|
||||
id=inv.id,
|
||||
product_id=inv.product_id,
|
||||
vendor_id=inv.vendor_id,
|
||||
vendor_name=vendor.name,
|
||||
vendor_code=vendor.vendor_code,
|
||||
product_title=title,
|
||||
product_sku=product.vendor_sku if product else None,
|
||||
location=inv.location,
|
||||
quantity=inv.quantity,
|
||||
reserved_quantity=inv.reserved_quantity,
|
||||
available_quantity=inv.available_quantity,
|
||||
gtin=inv.gtin,
|
||||
created_at=inv.created_at,
|
||||
updated_at=inv.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
total_query = db.query(func.count(Inventory.id)).filter(
|
||||
Inventory.vendor_id == vendor_id
|
||||
)
|
||||
if location:
|
||||
total_query = total_query.filter(Inventory.location.ilike(f"%{location}%"))
|
||||
if low_stock is not None:
|
||||
total_query = total_query.filter(Inventory.quantity <= low_stock)
|
||||
total = total_query.scalar() or 0
|
||||
|
||||
return AdminInventoryListResponse(
|
||||
inventories=items,
|
||||
total=total,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
vendor_filter=vendor_id,
|
||||
location_filter=location,
|
||||
)
|
||||
|
||||
def get_product_inventory_admin(
|
||||
self, db: Session, product_id: int
|
||||
) -> ProductInventorySummary:
|
||||
"""Get inventory summary for a product (admin only - no vendor check)."""
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
raise ProductNotFoundException(f"Product {product_id} not found")
|
||||
|
||||
# Use existing method with the product's vendor_id
|
||||
return self.get_product_inventory(db, product.vendor_id, product_id)
|
||||
|
||||
def verify_vendor_exists(self, db: Session, vendor_id: int) -> Vendor:
|
||||
"""Verify vendor exists and return it."""
|
||||
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
|
||||
if not vendor:
|
||||
raise VendorNotFoundException(f"Vendor {vendor_id} not found")
|
||||
return vendor
|
||||
|
||||
def get_inventory_by_id_admin(self, db: Session, inventory_id: int) -> Inventory:
|
||||
"""Get inventory by ID (admin only - returns inventory with vendor_id)."""
|
||||
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
return inventory
|
||||
|
||||
# =========================================================================
|
||||
# Private helper methods
|
||||
# =========================================================================
|
||||
|
||||
def _get_vendor_product(
|
||||
self, db: Session, vendor_id: int, product_id: int
|
||||
) -> Product:
|
||||
"""Get product and verify it belongs to vendor."""
|
||||
product = (
|
||||
db.query(Product)
|
||||
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not product:
|
||||
raise ProductNotFoundException(
|
||||
f"Product {product_id} not found in your catalog"
|
||||
)
|
||||
|
||||
return product
|
||||
|
||||
def _get_inventory_entry(
|
||||
self, db: Session, product_id: int, location: str
|
||||
) -> Inventory | None:
|
||||
"""Get inventory entry by product and location."""
|
||||
return (
|
||||
db.query(Inventory)
|
||||
.filter(Inventory.product_id == product_id, Inventory.location == location)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _get_inventory_by_id(self, db: Session, inventory_id: int) -> Inventory:
|
||||
"""Get inventory by ID or raise exception."""
|
||||
inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
if not inventory:
|
||||
raise InventoryNotFoundException(f"Inventory {inventory_id} not found")
|
||||
return inventory
|
||||
|
||||
def _validate_location(self, location: str) -> str:
|
||||
"""Validate and normalize location."""
|
||||
if not location or not location.strip():
|
||||
raise InventoryValidationException("Location is required")
|
||||
return location.strip().upper()
|
||||
|
||||
def _validate_quantity(self, quantity: int, allow_zero: bool = True) -> None:
|
||||
"""Validate quantity value."""
|
||||
if quantity is None:
|
||||
raise InvalidQuantityException("Quantity is required")
|
||||
|
||||
if not isinstance(quantity, int):
|
||||
raise InvalidQuantityException("Quantity must be an integer")
|
||||
|
||||
if quantity < 0:
|
||||
raise InvalidQuantityException("Quantity cannot be negative")
|
||||
|
||||
if not allow_zero and quantity == 0:
|
||||
raise InvalidQuantityException("Quantity must be positive")
|
||||
|
||||
|
||||
# Create service instance
|
||||
inventory_service = InventoryService()
|
||||
__all__ = [
|
||||
"inventory_service",
|
||||
"InventoryService",
|
||||
]
|
||||
|
||||
@@ -1,431 +1,23 @@
|
||||
# app/services/inventory_transaction_service.py
|
||||
"""
|
||||
Inventory Transaction Service.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Provides query operations for inventory transaction history.
|
||||
All transaction WRITES are handled by OrderInventoryService.
|
||||
This service handles transaction READS for reporting and auditing.
|
||||
The canonical implementation is now in:
|
||||
app/modules/inventory/services/inventory_transaction_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.inventory.services import inventory_transaction_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from app.modules.inventory.services.inventory_transaction_service import (
|
||||
inventory_transaction_service,
|
||||
InventoryTransactionService,
|
||||
)
|
||||
|
||||
from app.exceptions import OrderNotFoundException, ProductNotFoundException
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.inventory_transaction import InventoryTransaction
|
||||
from models.database.order import Order
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InventoryTransactionService:
|
||||
"""Service for querying inventory transaction history."""
|
||||
|
||||
def get_vendor_transactions(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
product_id: int | None = None,
|
||||
transaction_type: str | None = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get inventory transactions for a vendor with optional filters.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
product_id: Optional product filter
|
||||
transaction_type: Optional transaction type filter
|
||||
|
||||
Returns:
|
||||
Tuple of (transactions with product details, total count)
|
||||
"""
|
||||
# Build query
|
||||
query = db.query(InventoryTransaction).filter(
|
||||
InventoryTransaction.vendor_id == vendor_id
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if product_id:
|
||||
query = query.filter(InventoryTransaction.product_id == product_id)
|
||||
if transaction_type:
|
||||
query = query.filter(
|
||||
InventoryTransaction.transaction_type == transaction_type
|
||||
)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get transactions with pagination (newest first)
|
||||
transactions = (
|
||||
query.order_by(InventoryTransaction.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Build result with product details
|
||||
result = []
|
||||
for tx in transactions:
|
||||
product = db.query(Product).filter(Product.id == tx.product_id).first()
|
||||
product_title = None
|
||||
product_sku = None
|
||||
if product:
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
}
|
||||
)
|
||||
|
||||
return result, total
|
||||
|
||||
def get_product_history(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
product_id: int,
|
||||
limit: int = 50,
|
||||
) -> dict:
|
||||
"""
|
||||
Get transaction history for a specific product.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
product_id: Product ID
|
||||
limit: Max transactions to return
|
||||
|
||||
Returns:
|
||||
Dict with product info, current inventory, and transactions
|
||||
|
||||
Raises:
|
||||
ProductNotFoundException: If product not found or doesn't belong to vendor
|
||||
"""
|
||||
# Get product details
|
||||
product = (
|
||||
db.query(Product)
|
||||
.filter(Product.id == product_id, Product.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not product:
|
||||
raise ProductNotFoundException(
|
||||
f"Product {product_id} not found in vendor catalog"
|
||||
)
|
||||
|
||||
product_title = None
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
# Get current inventory
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(Inventory.product_id == product_id, Inventory.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
current_quantity = inventory.quantity if inventory else 0
|
||||
current_reserved = inventory.reserved_quantity if inventory else 0
|
||||
|
||||
# Get transactions
|
||||
transactions = (
|
||||
db.query(InventoryTransaction)
|
||||
.filter(
|
||||
InventoryTransaction.vendor_id == vendor_id,
|
||||
InventoryTransaction.product_id == product_id,
|
||||
)
|
||||
.order_by(InventoryTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
total = (
|
||||
db.query(func.count(InventoryTransaction.id))
|
||||
.filter(
|
||||
InventoryTransaction.vendor_id == vendor_id,
|
||||
InventoryTransaction.product_id == product_id,
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"product_id": product_id,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
"current_quantity": current_quantity,
|
||||
"current_reserved": current_reserved,
|
||||
"transactions": [
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
}
|
||||
for tx in transactions
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
|
||||
def get_order_history(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Get all inventory transactions for a specific order.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
Dict with order info and transactions
|
||||
|
||||
Raises:
|
||||
OrderNotFoundException: If order not found or doesn't belong to vendor
|
||||
"""
|
||||
# Verify order belongs to vendor
|
||||
order = (
|
||||
db.query(Order)
|
||||
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not order:
|
||||
raise OrderNotFoundException(f"Order {order_id} not found")
|
||||
|
||||
# Get transactions for this order
|
||||
transactions = (
|
||||
db.query(InventoryTransaction)
|
||||
.filter(InventoryTransaction.order_id == order_id)
|
||||
.order_by(InventoryTransaction.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# Build result with product details
|
||||
result = []
|
||||
for tx in transactions:
|
||||
product = db.query(Product).filter(Product.id == tx.product_id).first()
|
||||
product_title = None
|
||||
product_sku = None
|
||||
if product:
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"transactions": result,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Admin Methods (cross-vendor operations)
|
||||
# =========================================================================
|
||||
|
||||
def get_all_transactions_admin(
|
||||
self,
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
vendor_id: int | None = None,
|
||||
product_id: int | None = None,
|
||||
transaction_type: str | None = None,
|
||||
order_id: int | None = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get inventory transactions across all vendors (admin only).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
vendor_id: Optional vendor filter
|
||||
product_id: Optional product filter
|
||||
transaction_type: Optional transaction type filter
|
||||
order_id: Optional order filter
|
||||
|
||||
Returns:
|
||||
Tuple of (transactions with details, total count)
|
||||
"""
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
# Build query
|
||||
query = db.query(InventoryTransaction)
|
||||
|
||||
# Apply filters
|
||||
if vendor_id:
|
||||
query = query.filter(InventoryTransaction.vendor_id == vendor_id)
|
||||
if product_id:
|
||||
query = query.filter(InventoryTransaction.product_id == product_id)
|
||||
if transaction_type:
|
||||
query = query.filter(
|
||||
InventoryTransaction.transaction_type == transaction_type
|
||||
)
|
||||
if order_id:
|
||||
query = query.filter(InventoryTransaction.order_id == order_id)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get transactions with pagination (newest first)
|
||||
transactions = (
|
||||
query.order_by(InventoryTransaction.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Build result with vendor and product details
|
||||
result = []
|
||||
for tx in transactions:
|
||||
vendor = db.query(Vendor).filter(Vendor.id == tx.vendor_id).first()
|
||||
product = db.query(Product).filter(Product.id == tx.product_id).first()
|
||||
|
||||
product_title = None
|
||||
product_sku = None
|
||||
if product:
|
||||
product_sku = product.vendor_sku
|
||||
if product.marketplace_product:
|
||||
product_title = product.marketplace_product.get_title()
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": tx.id,
|
||||
"vendor_id": tx.vendor_id,
|
||||
"vendor_name": vendor.name if vendor else None,
|
||||
"vendor_code": vendor.vendor_code if vendor else None,
|
||||
"product_id": tx.product_id,
|
||||
"inventory_id": tx.inventory_id,
|
||||
"transaction_type": (
|
||||
tx.transaction_type.value if tx.transaction_type else None
|
||||
),
|
||||
"quantity_change": tx.quantity_change,
|
||||
"quantity_after": tx.quantity_after,
|
||||
"reserved_after": tx.reserved_after,
|
||||
"location": tx.location,
|
||||
"warehouse": tx.warehouse,
|
||||
"order_id": tx.order_id,
|
||||
"order_number": tx.order_number,
|
||||
"reason": tx.reason,
|
||||
"created_by": tx.created_by,
|
||||
"created_at": tx.created_at,
|
||||
"product_title": product_title,
|
||||
"product_sku": product_sku,
|
||||
}
|
||||
)
|
||||
|
||||
return result, total
|
||||
|
||||
def get_transaction_stats_admin(self, db: Session) -> dict:
|
||||
"""
|
||||
Get transaction statistics across the platform (admin only).
|
||||
|
||||
Returns:
|
||||
Dict with transaction counts by type
|
||||
"""
|
||||
from sqlalchemy import func as sql_func
|
||||
|
||||
# Count by transaction type
|
||||
type_counts = (
|
||||
db.query(
|
||||
InventoryTransaction.transaction_type,
|
||||
sql_func.count(InventoryTransaction.id).label("count"),
|
||||
)
|
||||
.group_by(InventoryTransaction.transaction_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Total transactions
|
||||
total = db.query(sql_func.count(InventoryTransaction.id)).scalar() or 0
|
||||
|
||||
# Transactions today
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_count = (
|
||||
db.query(sql_func.count(InventoryTransaction.id))
|
||||
.filter(InventoryTransaction.created_at >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_transactions": total,
|
||||
"transactions_today": today_count,
|
||||
"by_type": {tc.transaction_type.value: tc.count for tc in type_counts},
|
||||
}
|
||||
|
||||
|
||||
# Create service instance
|
||||
inventory_transaction_service = InventoryTransactionService()
|
||||
__all__ = [
|
||||
"inventory_transaction_service",
|
||||
"InventoryTransactionService",
|
||||
]
|
||||
|
||||
@@ -1,164 +1,23 @@
|
||||
# app/services/invoice_pdf_service.py
|
||||
"""
|
||||
Invoice PDF generation service using WeasyPrint.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Renders HTML invoice templates to PDF using Jinja2 + WeasyPrint.
|
||||
Stores generated PDFs in the configured storage location.
|
||||
The canonical implementation is now in:
|
||||
app/modules/orders/services/invoice_pdf_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.orders.services import invoice_pdf_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from app.modules.orders.services.invoice_pdf_service import (
|
||||
invoice_pdf_service,
|
||||
InvoicePDFService,
|
||||
)
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from models.database.invoice import Invoice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Template directory
|
||||
TEMPLATE_DIR = Path(__file__).parent.parent / "templates" / "invoices"
|
||||
|
||||
# PDF storage directory (relative to project root)
|
||||
PDF_STORAGE_DIR = Path("storage") / "invoices"
|
||||
|
||||
|
||||
class InvoicePDFService:
|
||||
"""Service for generating invoice PDFs."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the PDF service with Jinja2 environment."""
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
def _ensure_storage_dir(self, vendor_id: int) -> Path:
|
||||
"""Ensure the storage directory exists for a vendor."""
|
||||
storage_path = PDF_STORAGE_DIR / str(vendor_id)
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
return storage_path
|
||||
|
||||
def _get_pdf_filename(self, invoice: Invoice) -> str:
|
||||
"""Generate PDF filename for an invoice."""
|
||||
# Sanitize invoice number for filename
|
||||
safe_number = invoice.invoice_number.replace("/", "-").replace("\\", "-")
|
||||
return f"{safe_number}.pdf"
|
||||
|
||||
def generate_pdf(
|
||||
self,
|
||||
db: Session,
|
||||
invoice: Invoice,
|
||||
force_regenerate: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate PDF for an invoice.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
invoice: Invoice to generate PDF for
|
||||
force_regenerate: If True, regenerate even if PDF already exists
|
||||
|
||||
Returns:
|
||||
Path to the generated PDF file
|
||||
"""
|
||||
# Check if PDF already exists
|
||||
if invoice.pdf_path and not force_regenerate:
|
||||
if Path(invoice.pdf_path).exists():
|
||||
logger.debug(f"PDF already exists for invoice {invoice.invoice_number}")
|
||||
return invoice.pdf_path
|
||||
|
||||
# Ensure storage directory exists
|
||||
storage_dir = self._ensure_storage_dir(invoice.vendor_id)
|
||||
pdf_filename = self._get_pdf_filename(invoice)
|
||||
pdf_path = storage_dir / pdf_filename
|
||||
|
||||
# Render HTML template
|
||||
html_content = self._render_html(invoice)
|
||||
|
||||
# Generate PDF using WeasyPrint
|
||||
try:
|
||||
from weasyprint import HTML
|
||||
|
||||
html_doc = HTML(string=html_content, base_url=str(TEMPLATE_DIR))
|
||||
html_doc.write_pdf(str(pdf_path))
|
||||
|
||||
logger.info(f"Generated PDF for invoice {invoice.invoice_number} at {pdf_path}")
|
||||
except ImportError:
|
||||
logger.error("WeasyPrint not installed. Install with: pip install weasyprint")
|
||||
raise RuntimeError("WeasyPrint not installed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate PDF for invoice {invoice.invoice_number}: {e}")
|
||||
raise
|
||||
|
||||
# Update invoice record with PDF path and timestamp
|
||||
invoice.pdf_path = str(pdf_path)
|
||||
invoice.pdf_generated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
def _render_html(self, invoice: Invoice) -> str:
|
||||
"""Render the invoice HTML template."""
|
||||
template = self.env.get_template("invoice.html")
|
||||
|
||||
# Prepare template context
|
||||
context = {
|
||||
"invoice": invoice,
|
||||
"seller": invoice.seller_details,
|
||||
"buyer": invoice.buyer_details,
|
||||
"line_items": invoice.line_items,
|
||||
"bank_details": invoice.bank_details,
|
||||
"payment_terms": invoice.payment_terms,
|
||||
"footer_text": invoice.footer_text,
|
||||
"now": datetime.now(UTC),
|
||||
}
|
||||
|
||||
return template.render(**context)
|
||||
|
||||
def get_pdf_path(self, invoice: Invoice) -> str | None:
|
||||
"""Get the PDF path for an invoice if it exists."""
|
||||
if invoice.pdf_path and Path(invoice.pdf_path).exists():
|
||||
return invoice.pdf_path
|
||||
return None
|
||||
|
||||
def delete_pdf(self, invoice: Invoice, db: Session) -> bool:
|
||||
"""
|
||||
Delete the PDF file for an invoice.
|
||||
|
||||
Args:
|
||||
invoice: Invoice whose PDF to delete
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
if not invoice.pdf_path:
|
||||
return False
|
||||
|
||||
pdf_path = Path(invoice.pdf_path)
|
||||
if pdf_path.exists():
|
||||
try:
|
||||
pdf_path.unlink()
|
||||
logger.info(f"Deleted PDF for invoice {invoice.invoice_number}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete PDF {pdf_path}: {e}")
|
||||
return False
|
||||
|
||||
# Clear PDF fields
|
||||
invoice.pdf_path = None
|
||||
invoice.pdf_generated_at = None
|
||||
db.flush()
|
||||
|
||||
return True
|
||||
|
||||
def regenerate_pdf(self, db: Session, invoice: Invoice) -> str:
|
||||
"""Force regenerate PDF for an invoice."""
|
||||
return self.generate_pdf(db, invoice, force_regenerate=True)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
invoice_pdf_service = InvoicePDFService()
|
||||
__all__ = [
|
||||
"invoice_pdf_service",
|
||||
"InvoicePDFService",
|
||||
]
|
||||
|
||||
@@ -1,675 +1,23 @@
|
||||
# app/services/invoice_service.py
|
||||
"""
|
||||
Invoice service for generating and managing invoices.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Handles:
|
||||
- Vendor invoice settings management
|
||||
- Invoice generation from orders
|
||||
- VAT calculation (Luxembourg, EU, B2B reverse charge)
|
||||
- Invoice number sequencing
|
||||
- PDF generation (via separate module)
|
||||
The canonical implementation is now in:
|
||||
app/modules/orders/services/invoice_service.py
|
||||
|
||||
VAT Logic:
|
||||
- Luxembourg domestic: 17% (standard), 8% (reduced), 3% (super-reduced), 14% (intermediate)
|
||||
- EU cross-border B2C with OSS: Use destination country VAT rate
|
||||
- EU cross-border B2C without OSS: Use Luxembourg VAT rate (origin principle)
|
||||
- EU B2B with valid VAT number: Reverse charge (0% VAT)
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.orders.services import invoice_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
ValidationException,
|
||||
)
|
||||
from app.exceptions.invoice import (
|
||||
InvoiceNotFoundException,
|
||||
InvoicePDFGenerationException,
|
||||
InvoicePDFNotFoundException,
|
||||
InvoiceSettingsAlreadyExistException,
|
||||
InvoiceSettingsNotFoundException,
|
||||
InvoiceValidationException,
|
||||
OrderNotFoundException,
|
||||
)
|
||||
from models.database.invoice import (
|
||||
Invoice,
|
||||
InvoiceStatus,
|
||||
VATRegime,
|
||||
VendorInvoiceSettings,
|
||||
)
|
||||
from models.database.order import Order
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.invoice import (
|
||||
InvoiceBuyerDetails,
|
||||
InvoiceCreate,
|
||||
InvoiceLineItem,
|
||||
InvoiceManualCreate,
|
||||
InvoiceSellerDetails,
|
||||
VendorInvoiceSettingsCreate,
|
||||
VendorInvoiceSettingsUpdate,
|
||||
from app.modules.orders.services.invoice_service import (
|
||||
invoice_service,
|
||||
InvoiceService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# EU VAT rates by country code (2024 standard rates)
|
||||
EU_VAT_RATES: dict[str, Decimal] = {
|
||||
"AT": Decimal("20.00"), # Austria
|
||||
"BE": Decimal("21.00"), # Belgium
|
||||
"BG": Decimal("20.00"), # Bulgaria
|
||||
"HR": Decimal("25.00"), # Croatia
|
||||
"CY": Decimal("19.00"), # Cyprus
|
||||
"CZ": Decimal("21.00"), # Czech Republic
|
||||
"DK": Decimal("25.00"), # Denmark
|
||||
"EE": Decimal("22.00"), # Estonia
|
||||
"FI": Decimal("24.00"), # Finland
|
||||
"FR": Decimal("20.00"), # France
|
||||
"DE": Decimal("19.00"), # Germany
|
||||
"GR": Decimal("24.00"), # Greece
|
||||
"HU": Decimal("27.00"), # Hungary
|
||||
"IE": Decimal("23.00"), # Ireland
|
||||
"IT": Decimal("22.00"), # Italy
|
||||
"LV": Decimal("21.00"), # Latvia
|
||||
"LT": Decimal("21.00"), # Lithuania
|
||||
"LU": Decimal("17.00"), # Luxembourg (standard)
|
||||
"MT": Decimal("18.00"), # Malta
|
||||
"NL": Decimal("21.00"), # Netherlands
|
||||
"PL": Decimal("23.00"), # Poland
|
||||
"PT": Decimal("23.00"), # Portugal
|
||||
"RO": Decimal("19.00"), # Romania
|
||||
"SK": Decimal("20.00"), # Slovakia
|
||||
"SI": Decimal("22.00"), # Slovenia
|
||||
"ES": Decimal("21.00"), # Spain
|
||||
"SE": Decimal("25.00"), # Sweden
|
||||
}
|
||||
|
||||
# Luxembourg specific VAT rates
|
||||
LU_VAT_RATES = {
|
||||
"standard": Decimal("17.00"),
|
||||
"intermediate": Decimal("14.00"),
|
||||
"reduced": Decimal("8.00"),
|
||||
"super_reduced": Decimal("3.00"),
|
||||
}
|
||||
|
||||
|
||||
class InvoiceService:
|
||||
"""Service for invoice operations."""
|
||||
|
||||
# =========================================================================
|
||||
# VAT Calculation
|
||||
# =========================================================================
|
||||
|
||||
def get_vat_rate_for_country(self, country_iso: str) -> Decimal:
|
||||
"""Get standard VAT rate for EU country."""
|
||||
return EU_VAT_RATES.get(country_iso.upper(), Decimal("0.00"))
|
||||
|
||||
def get_vat_rate_label(self, country_iso: str, vat_rate: Decimal) -> str:
|
||||
"""Get human-readable VAT rate label."""
|
||||
country_names = {
|
||||
"AT": "Austria",
|
||||
"BE": "Belgium",
|
||||
"BG": "Bulgaria",
|
||||
"HR": "Croatia",
|
||||
"CY": "Cyprus",
|
||||
"CZ": "Czech Republic",
|
||||
"DK": "Denmark",
|
||||
"EE": "Estonia",
|
||||
"FI": "Finland",
|
||||
"FR": "France",
|
||||
"DE": "Germany",
|
||||
"GR": "Greece",
|
||||
"HU": "Hungary",
|
||||
"IE": "Ireland",
|
||||
"IT": "Italy",
|
||||
"LV": "Latvia",
|
||||
"LT": "Lithuania",
|
||||
"LU": "Luxembourg",
|
||||
"MT": "Malta",
|
||||
"NL": "Netherlands",
|
||||
"PL": "Poland",
|
||||
"PT": "Portugal",
|
||||
"RO": "Romania",
|
||||
"SK": "Slovakia",
|
||||
"SI": "Slovenia",
|
||||
"ES": "Spain",
|
||||
"SE": "Sweden",
|
||||
}
|
||||
country_name = country_names.get(country_iso.upper(), country_iso)
|
||||
return f"{country_name} VAT {vat_rate}%"
|
||||
|
||||
def determine_vat_regime(
|
||||
self,
|
||||
seller_country: str,
|
||||
buyer_country: str,
|
||||
buyer_vat_number: str | None,
|
||||
seller_oss_registered: bool,
|
||||
) -> tuple[VATRegime, Decimal, str | None]:
|
||||
"""
|
||||
Determine VAT regime and rate for invoice.
|
||||
|
||||
Returns: (regime, vat_rate, destination_country)
|
||||
"""
|
||||
seller_country = seller_country.upper()
|
||||
buyer_country = buyer_country.upper()
|
||||
|
||||
# Same country = domestic VAT
|
||||
if seller_country == buyer_country:
|
||||
vat_rate = self.get_vat_rate_for_country(seller_country)
|
||||
return VATRegime.DOMESTIC, vat_rate, None
|
||||
|
||||
# Different EU countries
|
||||
if buyer_country in EU_VAT_RATES:
|
||||
# B2B with valid VAT number = reverse charge
|
||||
if buyer_vat_number:
|
||||
return VATRegime.REVERSE_CHARGE, Decimal("0.00"), buyer_country
|
||||
|
||||
# B2C cross-border
|
||||
if seller_oss_registered:
|
||||
# OSS: use destination country VAT
|
||||
vat_rate = self.get_vat_rate_for_country(buyer_country)
|
||||
return VATRegime.OSS, vat_rate, buyer_country
|
||||
else:
|
||||
# No OSS: use origin country VAT
|
||||
vat_rate = self.get_vat_rate_for_country(seller_country)
|
||||
return VATRegime.ORIGIN, vat_rate, buyer_country
|
||||
|
||||
# Non-EU = VAT exempt (export)
|
||||
return VATRegime.EXEMPT, Decimal("0.00"), buyer_country
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Settings Management
|
||||
# =========================================================================
|
||||
|
||||
def get_settings(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorInvoiceSettings | None:
|
||||
"""Get vendor invoice settings."""
|
||||
return (
|
||||
db.query(VendorInvoiceSettings)
|
||||
.filter(VendorInvoiceSettings.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_settings_or_raise(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Get vendor invoice settings or raise exception."""
|
||||
settings = self.get_settings(db, vendor_id)
|
||||
if not settings:
|
||||
raise InvoiceSettingsNotFoundException(vendor_id)
|
||||
return settings
|
||||
|
||||
def create_settings(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: VendorInvoiceSettingsCreate,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Create vendor invoice settings."""
|
||||
# Check if settings already exist
|
||||
existing = self.get_settings(db, vendor_id)
|
||||
if existing:
|
||||
raise ValidationException(
|
||||
"Invoice settings already exist for this vendor"
|
||||
)
|
||||
|
||||
settings = VendorInvoiceSettings(
|
||||
vendor_id=vendor_id,
|
||||
**data.model_dump(),
|
||||
)
|
||||
db.add(settings)
|
||||
db.flush()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Created invoice settings for vendor {vendor_id}")
|
||||
return settings
|
||||
|
||||
def update_settings(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
data: VendorInvoiceSettingsUpdate,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""Update vendor invoice settings."""
|
||||
settings = self.get_settings_or_raise(db, vendor_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(settings, key, value)
|
||||
|
||||
settings.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Updated invoice settings for vendor {vendor_id}")
|
||||
return settings
|
||||
|
||||
def create_settings_from_vendor(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor,
|
||||
) -> VendorInvoiceSettings:
|
||||
"""
|
||||
Create invoice settings from vendor/company info.
|
||||
|
||||
Used for initial setup based on existing vendor data.
|
||||
"""
|
||||
company = vendor.company
|
||||
|
||||
settings = VendorInvoiceSettings(
|
||||
vendor_id=vendor.id,
|
||||
company_name=company.legal_name if company else vendor.name,
|
||||
company_address=vendor.effective_business_address,
|
||||
company_city=None, # Would need to parse from address
|
||||
company_postal_code=None,
|
||||
company_country="LU",
|
||||
vat_number=vendor.effective_tax_number,
|
||||
is_vat_registered=bool(vendor.effective_tax_number),
|
||||
)
|
||||
db.add(settings)
|
||||
db.flush()
|
||||
db.refresh(settings)
|
||||
|
||||
logger.info(f"Created invoice settings from vendor data for vendor {vendor.id}")
|
||||
return settings
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Number Generation
|
||||
# =========================================================================
|
||||
|
||||
def _get_next_invoice_number(
|
||||
self, db: Session, settings: VendorInvoiceSettings
|
||||
) -> str:
|
||||
"""Generate next invoice number and increment counter."""
|
||||
number = str(settings.invoice_next_number).zfill(settings.invoice_number_padding)
|
||||
invoice_number = f"{settings.invoice_prefix}{number}"
|
||||
|
||||
# Increment counter
|
||||
settings.invoice_next_number += 1
|
||||
db.flush()
|
||||
|
||||
return invoice_number
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Creation
|
||||
# =========================================================================
|
||||
|
||||
def create_invoice_from_order(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
notes: str | None = None,
|
||||
) -> Invoice:
|
||||
"""
|
||||
Create an invoice from an order.
|
||||
|
||||
Captures snapshots of seller/buyer details and calculates VAT.
|
||||
"""
|
||||
# Get invoice settings
|
||||
settings = self.get_settings_or_raise(db, vendor_id)
|
||||
|
||||
# Get order
|
||||
order = (
|
||||
db.query(Order)
|
||||
.filter(and_(Order.id == order_id, Order.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
if not order:
|
||||
raise OrderNotFoundException(f"Order {order_id} not found")
|
||||
|
||||
# Check for existing invoice
|
||||
existing = (
|
||||
db.query(Invoice)
|
||||
.filter(and_(Invoice.order_id == order_id, Invoice.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValidationException(f"Invoice already exists for order {order_id}")
|
||||
|
||||
# Determine VAT regime
|
||||
buyer_country = order.bill_country_iso
|
||||
vat_regime, vat_rate, destination_country = self.determine_vat_regime(
|
||||
seller_country=settings.company_country,
|
||||
buyer_country=buyer_country,
|
||||
buyer_vat_number=None, # TODO: Add B2B VAT number support
|
||||
seller_oss_registered=settings.is_oss_registered,
|
||||
)
|
||||
|
||||
# Build seller details snapshot
|
||||
seller_details = {
|
||||
"company_name": settings.company_name,
|
||||
"address": settings.company_address,
|
||||
"city": settings.company_city,
|
||||
"postal_code": settings.company_postal_code,
|
||||
"country": settings.company_country,
|
||||
"vat_number": settings.vat_number,
|
||||
}
|
||||
|
||||
# Build buyer details snapshot
|
||||
buyer_details = {
|
||||
"name": f"{order.bill_first_name} {order.bill_last_name}".strip(),
|
||||
"email": order.customer_email,
|
||||
"address": order.bill_address_line_1,
|
||||
"city": order.bill_city,
|
||||
"postal_code": order.bill_postal_code,
|
||||
"country": order.bill_country_iso,
|
||||
"vat_number": None, # TODO: B2B support
|
||||
}
|
||||
if order.bill_company:
|
||||
buyer_details["company"] = order.bill_company
|
||||
|
||||
# Build line items from order items
|
||||
line_items = []
|
||||
for item in order.items:
|
||||
line_items.append({
|
||||
"description": item.product_name,
|
||||
"quantity": item.quantity,
|
||||
"unit_price_cents": item.unit_price_cents,
|
||||
"total_cents": item.total_price_cents,
|
||||
"sku": item.product_sku,
|
||||
"ean": item.gtin,
|
||||
})
|
||||
|
||||
# Calculate amounts
|
||||
subtotal_cents = sum(item["total_cents"] for item in line_items)
|
||||
|
||||
# Calculate VAT
|
||||
if vat_rate > 0:
|
||||
vat_amount_cents = int(
|
||||
subtotal_cents * float(vat_rate) / 100
|
||||
)
|
||||
else:
|
||||
vat_amount_cents = 0
|
||||
|
||||
total_cents = subtotal_cents + vat_amount_cents
|
||||
|
||||
# Get VAT label
|
||||
vat_rate_label = None
|
||||
if vat_rate > 0:
|
||||
if destination_country:
|
||||
vat_rate_label = self.get_vat_rate_label(destination_country, vat_rate)
|
||||
else:
|
||||
vat_rate_label = self.get_vat_rate_label(settings.company_country, vat_rate)
|
||||
|
||||
# Generate invoice number
|
||||
invoice_number = self._get_next_invoice_number(db, settings)
|
||||
|
||||
# Create invoice
|
||||
invoice = Invoice(
|
||||
vendor_id=vendor_id,
|
||||
order_id=order_id,
|
||||
invoice_number=invoice_number,
|
||||
invoice_date=datetime.now(UTC),
|
||||
status=InvoiceStatus.DRAFT.value,
|
||||
seller_details=seller_details,
|
||||
buyer_details=buyer_details,
|
||||
line_items=line_items,
|
||||
vat_regime=vat_regime.value,
|
||||
destination_country=destination_country,
|
||||
vat_rate=vat_rate,
|
||||
vat_rate_label=vat_rate_label,
|
||||
currency=order.currency,
|
||||
subtotal_cents=subtotal_cents,
|
||||
vat_amount_cents=vat_amount_cents,
|
||||
total_cents=total_cents,
|
||||
payment_terms=settings.payment_terms,
|
||||
bank_details={
|
||||
"bank_name": settings.bank_name,
|
||||
"iban": settings.bank_iban,
|
||||
"bic": settings.bank_bic,
|
||||
} if settings.bank_iban else None,
|
||||
footer_text=settings.footer_text,
|
||||
notes=notes,
|
||||
)
|
||||
|
||||
db.add(invoice)
|
||||
db.flush()
|
||||
db.refresh(invoice)
|
||||
|
||||
logger.info(
|
||||
f"Created invoice {invoice_number} for order {order_id} "
|
||||
f"(vendor={vendor_id}, total={total_cents/100:.2f} EUR, VAT={vat_regime.value})"
|
||||
)
|
||||
|
||||
return invoice
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Retrieval
|
||||
# =========================================================================
|
||||
|
||||
def get_invoice(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by ID."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(and_(Invoice.id == invoice_id, Invoice.vendor_id == vendor_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_invoice_or_raise(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Get invoice by ID or raise exception."""
|
||||
invoice = self.get_invoice(db, vendor_id, invoice_id)
|
||||
if not invoice:
|
||||
raise InvoiceNotFoundException(invoice_id)
|
||||
return invoice
|
||||
|
||||
def get_invoice_by_number(
|
||||
self, db: Session, vendor_id: int, invoice_number: str
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by invoice number."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.invoice_number == invoice_number,
|
||||
Invoice.vendor_id == vendor_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_invoice_by_order_id(
|
||||
self, db: Session, vendor_id: int, order_id: int
|
||||
) -> Invoice | None:
|
||||
"""Get invoice by order ID."""
|
||||
return (
|
||||
db.query(Invoice)
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.order_id == order_id,
|
||||
Invoice.vendor_id == vendor_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def list_invoices(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> tuple[list[Invoice], int]:
|
||||
"""
|
||||
List invoices for vendor with pagination.
|
||||
|
||||
Returns: (invoices, total_count)
|
||||
"""
|
||||
query = db.query(Invoice).filter(Invoice.vendor_id == vendor_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Invoice.status == status)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Apply pagination and order
|
||||
invoices = (
|
||||
query.order_by(Invoice.invoice_date.desc())
|
||||
.offset((page - 1) * per_page)
|
||||
.limit(per_page)
|
||||
.all()
|
||||
)
|
||||
|
||||
return invoices, total
|
||||
|
||||
# =========================================================================
|
||||
# Invoice Status Management
|
||||
# =========================================================================
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
new_status: str,
|
||||
) -> Invoice:
|
||||
"""Update invoice status."""
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
|
||||
# Validate status transition
|
||||
valid_statuses = [s.value for s in InvoiceStatus]
|
||||
if new_status not in valid_statuses:
|
||||
raise ValidationException(f"Invalid status: {new_status}")
|
||||
|
||||
# Cannot change cancelled invoices
|
||||
if invoice.status == InvoiceStatus.CANCELLED.value:
|
||||
raise ValidationException("Cannot change status of cancelled invoice")
|
||||
|
||||
invoice.status = new_status
|
||||
invoice.updated_at = datetime.now(UTC)
|
||||
db.flush()
|
||||
db.refresh(invoice)
|
||||
|
||||
logger.info(f"Updated invoice {invoice.invoice_number} status to {new_status}")
|
||||
return invoice
|
||||
|
||||
def mark_as_issued(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Mark invoice as issued."""
|
||||
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.ISSUED.value)
|
||||
|
||||
def mark_as_paid(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Mark invoice as paid."""
|
||||
return self.update_status(db, vendor_id, invoice_id, InvoiceStatus.PAID.value)
|
||||
|
||||
def cancel_invoice(
|
||||
self, db: Session, vendor_id: int, invoice_id: int
|
||||
) -> Invoice:
|
||||
"""Cancel invoice."""
|
||||
return self.update_status(
|
||||
db, vendor_id, invoice_id, InvoiceStatus.CANCELLED.value
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_invoice_stats(
|
||||
self, db: Session, vendor_id: int
|
||||
) -> dict[str, Any]:
|
||||
"""Get invoice statistics for vendor."""
|
||||
total_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(Invoice.vendor_id == vendor_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
total_revenue = (
|
||||
db.query(func.sum(Invoice.total_cents))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status.in_([
|
||||
InvoiceStatus.ISSUED.value,
|
||||
InvoiceStatus.PAID.value,
|
||||
]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
draft_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status == InvoiceStatus.DRAFT.value,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
paid_count = (
|
||||
db.query(func.count(Invoice.id))
|
||||
.filter(
|
||||
and_(
|
||||
Invoice.vendor_id == vendor_id,
|
||||
Invoice.status == InvoiceStatus.PAID.value,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_invoices": total_count,
|
||||
"total_revenue_cents": total_revenue,
|
||||
"total_revenue": total_revenue / 100 if total_revenue else 0,
|
||||
"draft_count": draft_count,
|
||||
"paid_count": paid_count,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PDF Generation
|
||||
# =========================================================================
|
||||
|
||||
def generate_pdf(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
force_regenerate: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Generate PDF for an invoice.
|
||||
|
||||
Returns path to the generated PDF.
|
||||
"""
|
||||
from app.services.invoice_pdf_service import invoice_pdf_service
|
||||
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
return invoice_pdf_service.generate_pdf(db, invoice, force_regenerate)
|
||||
|
||||
def get_pdf_path(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
invoice_id: int,
|
||||
) -> str | None:
|
||||
"""Get PDF path for an invoice if it exists."""
|
||||
from app.services.invoice_pdf_service import invoice_pdf_service
|
||||
|
||||
invoice = self.get_invoice_or_raise(db, vendor_id, invoice_id)
|
||||
return invoice_pdf_service.get_pdf_path(invoice)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
invoice_service = InvoiceService()
|
||||
__all__ = [
|
||||
"invoice_service",
|
||||
"InvoiceService",
|
||||
]
|
||||
|
||||
@@ -1,33 +1,33 @@
|
||||
# app/services/letzshop/__init__.py
|
||||
"""
|
||||
Letzshop marketplace integration services.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Provides:
|
||||
- GraphQL client for API communication
|
||||
- Credential management service
|
||||
- Order import service
|
||||
- Fulfillment sync service
|
||||
- Vendor directory sync service
|
||||
The canonical implementation is now in:
|
||||
app/modules/marketplace/services/letzshop/
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.marketplace.services.letzshop import LetzshopClient
|
||||
"""
|
||||
|
||||
from .client_service import (
|
||||
LetzshopAPIError,
|
||||
LetzshopAuthError,
|
||||
from app.modules.marketplace.services.letzshop import (
|
||||
# Client
|
||||
LetzshopClient,
|
||||
LetzshopClientError,
|
||||
LetzshopAuthError,
|
||||
LetzshopAPIError,
|
||||
LetzshopConnectionError,
|
||||
)
|
||||
from .credentials_service import (
|
||||
# Credentials
|
||||
LetzshopCredentialsService,
|
||||
CredentialsError,
|
||||
CredentialsNotFoundError,
|
||||
LetzshopCredentialsService,
|
||||
)
|
||||
from .order_service import (
|
||||
# Order Service
|
||||
LetzshopOrderService,
|
||||
OrderNotFoundError,
|
||||
VendorNotFoundError,
|
||||
)
|
||||
from .vendor_sync_service import (
|
||||
# Vendor Sync Service
|
||||
LetzshopVendorSyncService,
|
||||
get_vendor_sync_service,
|
||||
)
|
||||
|
||||
@@ -1,338 +1,25 @@
|
||||
# app/services/letzshop_export_service.py
|
||||
"""
|
||||
Service for exporting products to Letzshop CSV format.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Generates Google Shopping compatible CSV files for Letzshop marketplace.
|
||||
The canonical implementation is now in:
|
||||
app/modules/marketplace/services/letzshop_export_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.marketplace.services import letzshop_export_service
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from app.modules.marketplace.services.letzshop_export_service import (
|
||||
LetzshopExportService,
|
||||
letzshop_export_service,
|
||||
LETZSHOP_CSV_COLUMNS,
|
||||
)
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from models.database.letzshop import LetzshopSyncLog
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Letzshop CSV columns in order
|
||||
LETZSHOP_CSV_COLUMNS = [
|
||||
"id",
|
||||
"title",
|
||||
"description",
|
||||
"link",
|
||||
"image_link",
|
||||
"additional_image_link",
|
||||
"availability",
|
||||
"price",
|
||||
"sale_price",
|
||||
"brand",
|
||||
"gtin",
|
||||
"mpn",
|
||||
"google_product_category",
|
||||
"product_type",
|
||||
"condition",
|
||||
"adult",
|
||||
"multipack",
|
||||
"is_bundle",
|
||||
"age_group",
|
||||
"color",
|
||||
"gender",
|
||||
"material",
|
||||
"pattern",
|
||||
"size",
|
||||
"size_type",
|
||||
"size_system",
|
||||
"item_group_id",
|
||||
"custom_label_0",
|
||||
"custom_label_1",
|
||||
"custom_label_2",
|
||||
"custom_label_3",
|
||||
"custom_label_4",
|
||||
"identifier_exists",
|
||||
"unit_pricing_measure",
|
||||
"unit_pricing_base_measure",
|
||||
"shipping",
|
||||
"atalanda:tax_rate",
|
||||
"atalanda:quantity",
|
||||
"atalanda:boost_sort",
|
||||
"atalanda:delivery_method",
|
||||
__all__ = [
|
||||
"LetzshopExportService",
|
||||
"letzshop_export_service",
|
||||
"LETZSHOP_CSV_COLUMNS",
|
||||
]
|
||||
|
||||
|
||||
class LetzshopExportService:
|
||||
"""Service for exporting products to Letzshop CSV format."""
|
||||
|
||||
def __init__(self, default_tax_rate: float = 17.0):
|
||||
"""
|
||||
Initialize the export service.
|
||||
|
||||
Args:
|
||||
default_tax_rate: Default VAT rate for Luxembourg (17%)
|
||||
"""
|
||||
self.default_tax_rate = default_tax_rate
|
||||
|
||||
def export_vendor_products(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
language: str = "en",
|
||||
include_inactive: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Export all products for a vendor in Letzshop CSV format.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID to export products for
|
||||
language: Language for title/description (en, fr, de)
|
||||
include_inactive: Whether to include inactive products
|
||||
|
||||
Returns:
|
||||
CSV string content
|
||||
"""
|
||||
# Query products for this vendor with their marketplace product data
|
||||
query = (
|
||||
db.query(Product)
|
||||
.filter(Product.vendor_id == vendor_id)
|
||||
.options(
|
||||
joinedload(Product.marketplace_product).joinedload(
|
||||
MarketplaceProduct.translations
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not include_inactive:
|
||||
query = query.filter(Product.is_active == True)
|
||||
|
||||
products = query.all()
|
||||
|
||||
logger.info(
|
||||
f"Exporting {len(products)} products for vendor {vendor_id} in {language}"
|
||||
)
|
||||
|
||||
return self._generate_csv(products, language)
|
||||
|
||||
def export_marketplace_products(
|
||||
self,
|
||||
db: Session,
|
||||
marketplace: str = "Letzshop",
|
||||
language: str = "en",
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Export marketplace products directly (admin use).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
marketplace: Filter by marketplace source
|
||||
language: Language for title/description
|
||||
limit: Optional limit on number of products
|
||||
|
||||
Returns:
|
||||
CSV string content
|
||||
"""
|
||||
query = (
|
||||
db.query(MarketplaceProduct)
|
||||
.filter(MarketplaceProduct.is_active == True)
|
||||
.options(joinedload(MarketplaceProduct.translations))
|
||||
)
|
||||
|
||||
if marketplace:
|
||||
query = query.filter(
|
||||
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
|
||||
)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
products = query.all()
|
||||
|
||||
logger.info(
|
||||
f"Exporting {len(products)} marketplace products for {marketplace} in {language}"
|
||||
)
|
||||
|
||||
return self._generate_csv_from_marketplace_products(products, language)
|
||||
|
||||
def _generate_csv(self, products: list[Product], language: str) -> str:
|
||||
"""Generate CSV from vendor Product objects."""
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
output,
|
||||
fieldnames=LETZSHOP_CSV_COLUMNS,
|
||||
delimiter="\t",
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
for product in products:
|
||||
if product.marketplace_product:
|
||||
row = self._product_to_row(product, language)
|
||||
writer.writerow(row)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def _generate_csv_from_marketplace_products(
|
||||
self, products: list[MarketplaceProduct], language: str
|
||||
) -> str:
|
||||
"""Generate CSV from MarketplaceProduct objects directly."""
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
output,
|
||||
fieldnames=LETZSHOP_CSV_COLUMNS,
|
||||
delimiter="\t",
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
for mp in products:
|
||||
row = self._marketplace_product_to_row(mp, language)
|
||||
writer.writerow(row)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def _product_to_row(self, product: Product, language: str) -> dict:
|
||||
"""Convert a Product (with MarketplaceProduct) to a CSV row."""
|
||||
mp = product.marketplace_product
|
||||
return self._marketplace_product_to_row(
|
||||
mp, language, vendor_sku=product.vendor_sku
|
||||
)
|
||||
|
||||
def _marketplace_product_to_row(
|
||||
self,
|
||||
mp: MarketplaceProduct,
|
||||
language: str,
|
||||
vendor_sku: str | None = None,
|
||||
) -> dict:
|
||||
"""Convert a MarketplaceProduct to a CSV row dict."""
|
||||
# Get localized title and description
|
||||
title = mp.get_title(language) or ""
|
||||
description = mp.get_description(language) or ""
|
||||
|
||||
# Format price with currency
|
||||
price = ""
|
||||
if mp.price_numeric:
|
||||
price = f"{mp.price_numeric:.2f} {mp.currency or 'EUR'}"
|
||||
elif mp.price:
|
||||
price = mp.price
|
||||
|
||||
# Format sale price
|
||||
sale_price = ""
|
||||
if mp.sale_price_numeric:
|
||||
sale_price = f"{mp.sale_price_numeric:.2f} {mp.currency or 'EUR'}"
|
||||
elif mp.sale_price:
|
||||
sale_price = mp.sale_price
|
||||
|
||||
# Additional images - join with comma if multiple
|
||||
additional_images = ""
|
||||
if mp.additional_images:
|
||||
additional_images = ",".join(mp.additional_images)
|
||||
elif mp.additional_image_link:
|
||||
additional_images = mp.additional_image_link
|
||||
|
||||
# Determine identifier_exists
|
||||
identifier_exists = mp.identifier_exists
|
||||
if not identifier_exists:
|
||||
identifier_exists = "yes" if (mp.gtin or mp.mpn) else "no"
|
||||
|
||||
return {
|
||||
"id": vendor_sku or mp.marketplace_product_id,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"link": mp.link or mp.source_url or "",
|
||||
"image_link": mp.image_link or "",
|
||||
"additional_image_link": additional_images,
|
||||
"availability": mp.availability or "in stock",
|
||||
"price": price,
|
||||
"sale_price": sale_price,
|
||||
"brand": mp.brand or "",
|
||||
"gtin": mp.gtin or "",
|
||||
"mpn": mp.mpn or "",
|
||||
"google_product_category": mp.google_product_category or "",
|
||||
"product_type": mp.product_type_raw or "",
|
||||
"condition": mp.condition or "new",
|
||||
"adult": mp.adult or "no",
|
||||
"multipack": str(mp.multipack) if mp.multipack else "",
|
||||
"is_bundle": mp.is_bundle or "no",
|
||||
"age_group": mp.age_group or "",
|
||||
"color": mp.color or "",
|
||||
"gender": mp.gender or "",
|
||||
"material": mp.material or "",
|
||||
"pattern": mp.pattern or "",
|
||||
"size": mp.size or "",
|
||||
"size_type": mp.size_type or "",
|
||||
"size_system": mp.size_system or "",
|
||||
"item_group_id": mp.item_group_id or "",
|
||||
"custom_label_0": mp.custom_label_0 or "",
|
||||
"custom_label_1": mp.custom_label_1 or "",
|
||||
"custom_label_2": mp.custom_label_2 or "",
|
||||
"custom_label_3": mp.custom_label_3 or "",
|
||||
"custom_label_4": mp.custom_label_4 or "",
|
||||
"identifier_exists": identifier_exists,
|
||||
"unit_pricing_measure": mp.unit_pricing_measure or "",
|
||||
"unit_pricing_base_measure": mp.unit_pricing_base_measure or "",
|
||||
"shipping": mp.shipping or "",
|
||||
"atalanda:tax_rate": str(self.default_tax_rate),
|
||||
"atalanda:quantity": "", # Would need inventory data
|
||||
"atalanda:boost_sort": "",
|
||||
"atalanda:delivery_method": "",
|
||||
}
|
||||
|
||||
def log_export(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
started_at: datetime,
|
||||
completed_at: datetime,
|
||||
files_processed: int,
|
||||
files_succeeded: int,
|
||||
files_failed: int,
|
||||
products_exported: int,
|
||||
triggered_by: str,
|
||||
error_details: dict | None = None,
|
||||
) -> LetzshopSyncLog:
|
||||
"""
|
||||
Log an export operation to the sync log.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
started_at: When the export started
|
||||
completed_at: When the export completed
|
||||
files_processed: Number of language files to export (e.g., 3)
|
||||
files_succeeded: Number of files successfully exported
|
||||
files_failed: Number of files that failed
|
||||
products_exported: Total products in the export
|
||||
triggered_by: Who triggered the export (e.g., "admin:123")
|
||||
error_details: Optional error details if any failures
|
||||
|
||||
Returns:
|
||||
Created LetzshopSyncLog entry
|
||||
"""
|
||||
sync_log = LetzshopSyncLog(
|
||||
vendor_id=vendor_id,
|
||||
operation_type="product_export",
|
||||
direction="outbound",
|
||||
status="completed" if files_failed == 0 else "partial",
|
||||
records_processed=files_processed,
|
||||
records_succeeded=files_succeeded,
|
||||
records_failed=files_failed,
|
||||
started_at=started_at,
|
||||
completed_at=completed_at,
|
||||
duration_seconds=int((completed_at - started_at).total_seconds()),
|
||||
triggered_by=triggered_by,
|
||||
error_details={
|
||||
"products_exported": products_exported,
|
||||
**(error_details or {}),
|
||||
} if products_exported or error_details else None,
|
||||
)
|
||||
db.add(sync_log)
|
||||
db.flush()
|
||||
return sync_log
|
||||
|
||||
|
||||
# Singleton instance
|
||||
letzshop_export_service = LetzshopExportService()
|
||||
|
||||
@@ -1,334 +1,23 @@
|
||||
# app/services/marketplace_import_job_service.py
|
||||
import logging
|
||||
"""
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
The canonical implementation is now in:
|
||||
app/modules/marketplace/services/marketplace_import_job_service.py
|
||||
|
||||
from app.exceptions import (
|
||||
ImportJobNotFoundException,
|
||||
ImportJobNotOwnedException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.database.marketplace_import_job import (
|
||||
MarketplaceImportError,
|
||||
MarketplaceImportJob,
|
||||
)
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.marketplace_import_job import (
|
||||
AdminMarketplaceImportJobResponse,
|
||||
MarketplaceImportJobRequest,
|
||||
MarketplaceImportJobResponse,
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.marketplace.services import marketplace_import_job_service
|
||||
"""
|
||||
|
||||
from app.modules.marketplace.services.marketplace_import_job_service import (
|
||||
MarketplaceImportJobService,
|
||||
marketplace_import_job_service,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketplaceImportJobService:
|
||||
"""Service class for Marketplace operations."""
|
||||
|
||||
def create_import_job(
|
||||
self,
|
||||
db: Session,
|
||||
request: MarketplaceImportJobRequest,
|
||||
vendor: Vendor, # CHANGED: Vendor object from middleware
|
||||
user: User,
|
||||
) -> MarketplaceImportJob:
|
||||
"""
|
||||
Create a new marketplace import job.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
request: Import request data
|
||||
vendor: Vendor object (from middleware)
|
||||
user: User creating the job
|
||||
|
||||
Returns:
|
||||
Created MarketplaceImportJob object
|
||||
"""
|
||||
try:
|
||||
# Create marketplace import job record
|
||||
import_job = MarketplaceImportJob(
|
||||
status="pending",
|
||||
source_url=request.source_url,
|
||||
marketplace=request.marketplace,
|
||||
language=request.language,
|
||||
vendor_id=vendor.id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
db.add(import_job)
|
||||
db.flush()
|
||||
db.refresh(import_job)
|
||||
|
||||
logger.info(
|
||||
f"Created marketplace import job {import_job.id}: "
|
||||
f"{request.marketplace} -> {vendor.name} (code: {vendor.vendor_code}) "
|
||||
f"by user {user.username}"
|
||||
)
|
||||
|
||||
return import_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating import job: {str(e)}")
|
||||
raise ValidationException("Failed to create import job")
|
||||
|
||||
def get_import_job_by_id(
|
||||
self, db: Session, job_id: int, user: User
|
||||
) -> MarketplaceImportJob:
|
||||
"""Get a marketplace import job by ID with access control."""
|
||||
try:
|
||||
job = (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.id == job_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not job:
|
||||
raise ImportJobNotFoundException(job_id)
|
||||
|
||||
# Users can only see their own jobs, admins can see all
|
||||
if user.role != "admin" and job.user_id != user.id:
|
||||
raise ImportJobNotOwnedException(job_id, user.id)
|
||||
|
||||
return job
|
||||
|
||||
except (ImportJobNotFoundException, ImportJobNotOwnedException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting import job {job_id}: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import job")
|
||||
|
||||
def get_import_job_for_vendor(
|
||||
self, db: Session, job_id: int, vendor_id: int
|
||||
) -> MarketplaceImportJob:
|
||||
"""
|
||||
Get a marketplace import job by ID with vendor access control.
|
||||
|
||||
Validates that the job belongs to the specified vendor.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
job_id: Import job ID
|
||||
vendor_id: Vendor ID from token (to verify ownership)
|
||||
|
||||
Raises:
|
||||
ImportJobNotFoundException: If job not found
|
||||
UnauthorizedVendorAccessException: If job doesn't belong to vendor
|
||||
"""
|
||||
from app.exceptions import UnauthorizedVendorAccessException
|
||||
|
||||
try:
|
||||
job = (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.id == job_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not job:
|
||||
raise ImportJobNotFoundException(job_id)
|
||||
|
||||
# Verify job belongs to vendor (service layer validation)
|
||||
if job.vendor_id != vendor_id:
|
||||
raise UnauthorizedVendorAccessException(
|
||||
vendor_code=str(vendor_id),
|
||||
user_id=0, # Not user-specific, but vendor mismatch
|
||||
)
|
||||
|
||||
return job
|
||||
|
||||
except (ImportJobNotFoundException, UnauthorizedVendorAccessException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}"
|
||||
)
|
||||
raise ValidationException("Failed to retrieve import job")
|
||||
|
||||
def get_import_jobs(
|
||||
self,
|
||||
db: Session,
|
||||
vendor: Vendor, # ADDED: Vendor filter
|
||||
user: User,
|
||||
marketplace: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[MarketplaceImportJob]:
|
||||
"""Get marketplace import jobs for a specific vendor."""
|
||||
try:
|
||||
query = db.query(MarketplaceImportJob).filter(
|
||||
MarketplaceImportJob.vendor_id == vendor.id
|
||||
)
|
||||
|
||||
# Users can only see their own jobs, admins can see all vendor jobs
|
||||
if user.role != "admin":
|
||||
query = query.filter(MarketplaceImportJob.user_id == user.id)
|
||||
|
||||
# Apply marketplace filter
|
||||
if marketplace:
|
||||
query = query.filter(
|
||||
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
|
||||
)
|
||||
|
||||
# Order by creation date (newest first) and apply pagination
|
||||
jobs = (
|
||||
query.order_by(
|
||||
MarketplaceImportJob.created_at.desc(),
|
||||
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting import jobs: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import jobs")
|
||||
|
||||
def convert_to_response_model(
|
||||
self, job: MarketplaceImportJob
|
||||
) -> MarketplaceImportJobResponse:
|
||||
"""Convert database model to API response model."""
|
||||
return MarketplaceImportJobResponse(
|
||||
job_id=job.id,
|
||||
status=job.status,
|
||||
marketplace=job.marketplace,
|
||||
language=job.language,
|
||||
vendor_id=job.vendor_id,
|
||||
vendor_code=job.vendor.vendor_code if job.vendor else None,
|
||||
vendor_name=job.vendor.name if job.vendor else None,
|
||||
source_url=job.source_url,
|
||||
imported=job.imported_count or 0,
|
||||
updated=job.updated_count or 0,
|
||||
total_processed=job.total_processed or 0,
|
||||
error_count=job.error_count or 0,
|
||||
error_message=job.error_message,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
)
|
||||
|
||||
def convert_to_admin_response_model(
|
||||
self, job: MarketplaceImportJob
|
||||
) -> AdminMarketplaceImportJobResponse:
|
||||
"""Convert database model to admin API response model with extra fields."""
|
||||
return AdminMarketplaceImportJobResponse(
|
||||
id=job.id,
|
||||
job_id=job.id,
|
||||
status=job.status,
|
||||
marketplace=job.marketplace,
|
||||
language=job.language,
|
||||
vendor_id=job.vendor_id,
|
||||
vendor_code=job.vendor.vendor_code if job.vendor else None,
|
||||
vendor_name=job.vendor.name if job.vendor else None,
|
||||
source_url=job.source_url,
|
||||
imported=job.imported_count or 0,
|
||||
updated=job.updated_count or 0,
|
||||
total_processed=job.total_processed or 0,
|
||||
error_count=job.error_count or 0,
|
||||
error_message=job.error_message,
|
||||
error_details=[],
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
created_by_name=job.user.username if job.user else None,
|
||||
)
|
||||
|
||||
def get_all_import_jobs_paginated(
|
||||
self,
|
||||
db: Session,
|
||||
marketplace: str | None = None,
|
||||
status: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[MarketplaceImportJob], int]:
|
||||
"""Get all marketplace import jobs with pagination (for admin)."""
|
||||
try:
|
||||
query = db.query(MarketplaceImportJob)
|
||||
|
||||
if marketplace:
|
||||
query = query.filter(
|
||||
MarketplaceImportJob.marketplace.ilike(f"%{marketplace}%")
|
||||
)
|
||||
if status:
|
||||
query = query.filter(MarketplaceImportJob.status == status)
|
||||
|
||||
total = query.count()
|
||||
skip = (page - 1) * limit
|
||||
jobs = (
|
||||
query.order_by(
|
||||
MarketplaceImportJob.created_at.desc(),
|
||||
MarketplaceImportJob.id.desc(), # Tiebreaker for same timestamp
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return jobs, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all import jobs: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import jobs")
|
||||
|
||||
def get_import_job_by_id_admin(
|
||||
self, db: Session, job_id: int
|
||||
) -> MarketplaceImportJob:
|
||||
"""Get a marketplace import job by ID (admin - no access control)."""
|
||||
job = (
|
||||
db.query(MarketplaceImportJob)
|
||||
.filter(MarketplaceImportJob.id == job_id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise ImportJobNotFoundException(job_id)
|
||||
return job
|
||||
|
||||
def get_import_job_errors(
|
||||
self,
|
||||
db: Session,
|
||||
job_id: int,
|
||||
error_type: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[MarketplaceImportError], int]:
|
||||
"""
|
||||
Get import errors for a specific job with pagination.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
job_id: Import job ID
|
||||
error_type: Optional filter by error type
|
||||
page: Page number (1-indexed)
|
||||
limit: Number of items per page
|
||||
|
||||
Returns:
|
||||
Tuple of (list of errors, total count)
|
||||
"""
|
||||
try:
|
||||
query = db.query(MarketplaceImportError).filter(
|
||||
MarketplaceImportError.import_job_id == job_id
|
||||
)
|
||||
|
||||
if error_type:
|
||||
query = query.filter(MarketplaceImportError.error_type == error_type)
|
||||
|
||||
total = query.count()
|
||||
|
||||
offset = (page - 1) * limit
|
||||
errors = (
|
||||
query.order_by(MarketplaceImportError.row_number)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return errors, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting import job errors for job {job_id}: {str(e)}")
|
||||
raise ValidationException("Failed to retrieve import errors")
|
||||
|
||||
|
||||
marketplace_import_job_service = MarketplaceImportJobService()
|
||||
__all__ = [
|
||||
"MarketplaceImportJobService",
|
||||
"marketplace_import_job_service",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,225 +1,23 @@
|
||||
# app/services/message_attachment_service.py
|
||||
"""
|
||||
Attachment handling service for messaging system.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Handles file upload, validation, storage, and retrieval.
|
||||
The canonical implementation is now in:
|
||||
app/modules/messaging/services/message_attachment_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.messaging.services import message_attachment_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from app.modules.messaging.services.message_attachment_service import (
|
||||
message_attachment_service,
|
||||
MessageAttachmentService,
|
||||
)
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.admin_settings_service import admin_settings_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allowed MIME types for attachments
|
||||
ALLOWED_MIME_TYPES = {
|
||||
# Images
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
# Documents
|
||||
"application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
# Archives
|
||||
"application/zip",
|
||||
# Text
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
}
|
||||
|
||||
IMAGE_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"}
|
||||
|
||||
# Default max file size in MB
|
||||
DEFAULT_MAX_FILE_SIZE_MB = 10
|
||||
|
||||
|
||||
class MessageAttachmentService:
|
||||
"""Service for handling message attachments."""
|
||||
|
||||
def __init__(self, storage_base: str = "uploads/messages"):
|
||||
self.storage_base = storage_base
|
||||
|
||||
def get_max_file_size_bytes(self, db: Session) -> int:
|
||||
"""Get maximum file size from platform settings."""
|
||||
max_mb = admin_settings_service.get_setting_value(
|
||||
db,
|
||||
"message_attachment_max_size_mb",
|
||||
default=DEFAULT_MAX_FILE_SIZE_MB,
|
||||
)
|
||||
try:
|
||||
max_mb = int(max_mb)
|
||||
except (TypeError, ValueError):
|
||||
max_mb = DEFAULT_MAX_FILE_SIZE_MB
|
||||
return max_mb * 1024 * 1024 # Convert to bytes
|
||||
|
||||
def validate_file_type(self, mime_type: str) -> bool:
|
||||
"""Check if file type is allowed."""
|
||||
return mime_type in ALLOWED_MIME_TYPES
|
||||
|
||||
def is_image(self, mime_type: str) -> bool:
|
||||
"""Check if file is an image."""
|
||||
return mime_type in IMAGE_MIME_TYPES
|
||||
|
||||
async def validate_and_store(
|
||||
self,
|
||||
db: Session,
|
||||
file: UploadFile,
|
||||
conversation_id: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and store an uploaded file.
|
||||
|
||||
Returns dict with file metadata for MessageAttachment creation.
|
||||
|
||||
Raises:
|
||||
ValueError: If file type or size is invalid
|
||||
"""
|
||||
# Validate MIME type
|
||||
content_type = file.content_type or "application/octet-stream"
|
||||
if not self.validate_file_type(content_type):
|
||||
raise ValueError(
|
||||
f"File type '{content_type}' not allowed. "
|
||||
"Allowed types: images (JPEG, PNG, GIF, WebP), "
|
||||
"PDF, Office documents, ZIP, text files."
|
||||
)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
# Validate file size
|
||||
max_size = self.get_max_file_size_bytes(db)
|
||||
if file_size > max_size:
|
||||
raise ValueError(
|
||||
f"File size {file_size / 1024 / 1024:.1f}MB exceeds "
|
||||
f"maximum allowed size of {max_size / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
|
||||
# Generate unique filename
|
||||
original_filename = file.filename or "attachment"
|
||||
ext = Path(original_filename).suffix.lower()
|
||||
unique_filename = f"{uuid.uuid4().hex}{ext}"
|
||||
|
||||
# Create storage path: uploads/messages/YYYY/MM/conversation_id/filename
|
||||
now = datetime.utcnow()
|
||||
relative_path = os.path.join(
|
||||
self.storage_base,
|
||||
str(now.year),
|
||||
f"{now.month:02d}",
|
||||
str(conversation_id),
|
||||
)
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(relative_path, exist_ok=True)
|
||||
|
||||
# Full file path
|
||||
file_path = os.path.join(relative_path, unique_filename)
|
||||
|
||||
# Write file
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Prepare metadata
|
||||
is_image = self.is_image(content_type)
|
||||
metadata = {
|
||||
"filename": unique_filename,
|
||||
"original_filename": original_filename,
|
||||
"file_path": file_path,
|
||||
"file_size": file_size,
|
||||
"mime_type": content_type,
|
||||
"is_image": is_image,
|
||||
}
|
||||
|
||||
# Generate thumbnail for images
|
||||
if is_image:
|
||||
thumbnail_data = self._create_thumbnail(content, file_path)
|
||||
metadata.update(thumbnail_data)
|
||||
|
||||
logger.info(
|
||||
f"Stored attachment {unique_filename} for conversation {conversation_id} "
|
||||
f"({file_size} bytes, type: {content_type})"
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def _create_thumbnail(self, content: bytes, original_path: str) -> dict:
|
||||
"""Create thumbnail for image attachments."""
|
||||
try:
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
img = Image.open(io.BytesIO(content))
|
||||
width, height = img.size
|
||||
|
||||
# Create thumbnail
|
||||
img.thumbnail((200, 200))
|
||||
|
||||
thumb_path = original_path.replace(".", "_thumb.")
|
||||
img.save(thumb_path)
|
||||
|
||||
return {
|
||||
"image_width": width,
|
||||
"image_height": height,
|
||||
"thumbnail_path": thumb_path,
|
||||
}
|
||||
except ImportError:
|
||||
logger.warning("PIL not installed, skipping thumbnail generation")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create thumbnail: {e}")
|
||||
return {}
|
||||
|
||||
def delete_attachment(
|
||||
self, file_path: str, thumbnail_path: str | None = None
|
||||
) -> bool:
|
||||
"""Delete attachment files from storage."""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logger.info(f"Deleted attachment file: {file_path}")
|
||||
|
||||
if thumbnail_path and os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
logger.info(f"Deleted thumbnail: {thumbnail_path}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete attachment {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def get_download_url(self, file_path: str) -> str:
|
||||
"""
|
||||
Get download URL for an attachment.
|
||||
|
||||
For local storage, returns a relative path that can be served
|
||||
by the static file handler or a dedicated download endpoint.
|
||||
"""
|
||||
# Convert local path to URL path
|
||||
# Assumes files are served from /static/uploads or similar
|
||||
return f"/static/{file_path}"
|
||||
|
||||
def get_file_content(self, file_path: str) -> bytes | None:
|
||||
"""Read file content from storage."""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Singleton instance
|
||||
message_attachment_service = MessageAttachmentService()
|
||||
__all__ = [
|
||||
"message_attachment_service",
|
||||
"MessageAttachmentService",
|
||||
]
|
||||
|
||||
@@ -1,684 +1,23 @@
|
||||
# app/services/messaging_service.py
|
||||
"""
|
||||
Messaging service for conversation and message management.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
Provides functionality for:
|
||||
- Creating conversations between different participant types
|
||||
- Sending messages with attachments
|
||||
- Managing read status and unread counts
|
||||
- Conversation listing with filters
|
||||
- Multi-tenant data isolation
|
||||
The canonical implementation is now in:
|
||||
app/modules/messaging/services/messaging_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.messaging.services import messaging_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from models.database.customer import Customer
|
||||
from models.database.message import (
|
||||
Conversation,
|
||||
ConversationParticipant,
|
||||
ConversationType,
|
||||
Message,
|
||||
MessageAttachment,
|
||||
ParticipantType,
|
||||
from app.modules.messaging.services.messaging_service import (
|
||||
messaging_service,
|
||||
MessagingService,
|
||||
)
|
||||
from models.database.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagingService:
|
||||
"""Service for managing conversations and messages."""
|
||||
|
||||
# =========================================================================
|
||||
# CONVERSATION MANAGEMENT
|
||||
# =========================================================================
|
||||
|
||||
def create_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_type: ConversationType,
|
||||
subject: str,
|
||||
initiator_type: ParticipantType,
|
||||
initiator_id: int,
|
||||
recipient_type: ParticipantType,
|
||||
recipient_id: int,
|
||||
vendor_id: int | None = None,
|
||||
initial_message: str | None = None,
|
||||
) -> Conversation:
|
||||
"""
|
||||
Create a new conversation between two participants.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
conversation_type: Type of conversation channel
|
||||
subject: Conversation subject line
|
||||
initiator_type: Type of initiating participant
|
||||
initiator_id: ID of initiating participant
|
||||
recipient_type: Type of receiving participant
|
||||
recipient_id: ID of receiving participant
|
||||
vendor_id: Required for vendor_customer/admin_customer types
|
||||
initial_message: Optional first message content
|
||||
|
||||
Returns:
|
||||
Created Conversation object
|
||||
"""
|
||||
# Validate vendor_id requirement
|
||||
if conversation_type in [
|
||||
ConversationType.VENDOR_CUSTOMER,
|
||||
ConversationType.ADMIN_CUSTOMER,
|
||||
]:
|
||||
if not vendor_id:
|
||||
raise ValueError(
|
||||
f"vendor_id required for {conversation_type.value} conversations"
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation = Conversation(
|
||||
conversation_type=conversation_type,
|
||||
subject=subject,
|
||||
vendor_id=vendor_id,
|
||||
)
|
||||
db.add(conversation)
|
||||
db.flush()
|
||||
|
||||
# Add participants
|
||||
initiator_vendor_id = (
|
||||
vendor_id if initiator_type == ParticipantType.VENDOR else None
|
||||
)
|
||||
recipient_vendor_id = (
|
||||
vendor_id if recipient_type == ParticipantType.VENDOR else None
|
||||
)
|
||||
|
||||
initiator = ConversationParticipant(
|
||||
conversation_id=conversation.id,
|
||||
participant_type=initiator_type,
|
||||
participant_id=initiator_id,
|
||||
vendor_id=initiator_vendor_id,
|
||||
unread_count=0, # Initiator has read their own message
|
||||
)
|
||||
recipient = ConversationParticipant(
|
||||
conversation_id=conversation.id,
|
||||
participant_type=recipient_type,
|
||||
participant_id=recipient_id,
|
||||
vendor_id=recipient_vendor_id,
|
||||
unread_count=1 if initial_message else 0,
|
||||
)
|
||||
|
||||
db.add(initiator)
|
||||
db.add(recipient)
|
||||
db.flush()
|
||||
|
||||
# Add initial message if provided
|
||||
if initial_message:
|
||||
self.send_message(
|
||||
db=db,
|
||||
conversation_id=conversation.id,
|
||||
sender_type=initiator_type,
|
||||
sender_id=initiator_id,
|
||||
content=initial_message,
|
||||
_skip_unread_update=True, # Already set above
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {conversation_type.value} conversation {conversation.id}: "
|
||||
f"{initiator_type.value}:{initiator_id} -> {recipient_type.value}:{recipient_id}"
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
def get_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
) -> Conversation | None:
|
||||
"""
|
||||
Get conversation if participant has access.
|
||||
|
||||
Validates that the requester is a participant.
|
||||
"""
|
||||
conversation = (
|
||||
db.query(Conversation)
|
||||
.options(
|
||||
joinedload(Conversation.participants),
|
||||
joinedload(Conversation.messages).joinedload(Message.attachments),
|
||||
)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
# Verify participant access
|
||||
has_access = any(
|
||||
p.participant_type == participant_type
|
||||
and p.participant_id == participant_id
|
||||
for p in conversation.participants
|
||||
)
|
||||
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
f"Access denied to conversation {conversation_id} for "
|
||||
f"{participant_type.value}:{participant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
return conversation
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
db: Session,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
vendor_id: int | None = None,
|
||||
conversation_type: ConversationType | None = None,
|
||||
is_closed: bool | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
) -> tuple[list[Conversation], int, int]:
|
||||
"""
|
||||
List conversations for a participant with filters.
|
||||
|
||||
Returns:
|
||||
Tuple of (conversations, total_count, total_unread)
|
||||
"""
|
||||
# Base query: conversations where user is a participant
|
||||
query = (
|
||||
db.query(Conversation)
|
||||
.join(ConversationParticipant)
|
||||
.filter(
|
||||
and_(
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Multi-tenant filter for vendor users
|
||||
if participant_type == ParticipantType.VENDOR and vendor_id:
|
||||
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
|
||||
|
||||
# Customer vendor isolation
|
||||
if participant_type == ParticipantType.CUSTOMER and vendor_id:
|
||||
query = query.filter(Conversation.vendor_id == vendor_id)
|
||||
|
||||
# Type filter
|
||||
if conversation_type:
|
||||
query = query.filter(Conversation.conversation_type == conversation_type)
|
||||
|
||||
# Status filter
|
||||
if is_closed is not None:
|
||||
query = query.filter(Conversation.is_closed == is_closed)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Get total unread across all conversations
|
||||
unread_query = db.query(
|
||||
func.sum(ConversationParticipant.unread_count)
|
||||
).filter(
|
||||
and_(
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if participant_type == ParticipantType.VENDOR and vendor_id:
|
||||
unread_query = unread_query.filter(
|
||||
ConversationParticipant.vendor_id == vendor_id
|
||||
)
|
||||
|
||||
total_unread = unread_query.scalar() or 0
|
||||
|
||||
# Get paginated results, ordered by last activity
|
||||
conversations = (
|
||||
query.options(joinedload(Conversation.participants))
|
||||
.order_by(Conversation.last_message_at.desc().nullslast())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return conversations, total, total_unread
|
||||
|
||||
def close_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
closer_type: ParticipantType,
|
||||
closer_id: int,
|
||||
) -> Conversation | None:
|
||||
"""Close a conversation thread."""
|
||||
conversation = self.get_conversation(
|
||||
db, conversation_id, closer_type, closer_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
conversation.is_closed = True
|
||||
conversation.closed_at = datetime.now(UTC)
|
||||
conversation.closed_by_type = closer_type
|
||||
conversation.closed_by_id = closer_id
|
||||
|
||||
# Add system message
|
||||
self.send_message(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
sender_type=closer_type,
|
||||
sender_id=closer_id,
|
||||
content="Conversation closed",
|
||||
is_system_message=True,
|
||||
)
|
||||
|
||||
db.flush()
|
||||
return conversation
|
||||
|
||||
def reopen_conversation(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
opener_type: ParticipantType,
|
||||
opener_id: int,
|
||||
) -> Conversation | None:
|
||||
"""Reopen a closed conversation."""
|
||||
conversation = self.get_conversation(
|
||||
db, conversation_id, opener_type, opener_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
conversation.is_closed = False
|
||||
conversation.closed_at = None
|
||||
conversation.closed_by_type = None
|
||||
conversation.closed_by_id = None
|
||||
|
||||
# Add system message
|
||||
self.send_message(
|
||||
db=db,
|
||||
conversation_id=conversation_id,
|
||||
sender_type=opener_type,
|
||||
sender_id=opener_id,
|
||||
content="Conversation reopened",
|
||||
is_system_message=True,
|
||||
)
|
||||
|
||||
db.flush()
|
||||
return conversation
|
||||
|
||||
# =========================================================================
|
||||
# MESSAGE MANAGEMENT
|
||||
# =========================================================================
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
sender_type: ParticipantType,
|
||||
sender_id: int,
|
||||
content: str,
|
||||
attachments: list[dict[str, Any]] | None = None,
|
||||
is_system_message: bool = False,
|
||||
_skip_unread_update: bool = False,
|
||||
) -> Message:
|
||||
"""
|
||||
Send a message in a conversation.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
conversation_id: Target conversation ID
|
||||
sender_type: Type of sender
|
||||
sender_id: ID of sender
|
||||
content: Message text content
|
||||
attachments: List of attachment dicts with file metadata
|
||||
is_system_message: Whether this is a system-generated message
|
||||
_skip_unread_update: Internal flag to skip unread increment
|
||||
|
||||
Returns:
|
||||
Created Message object
|
||||
"""
|
||||
# Create message
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
sender_type=sender_type,
|
||||
sender_id=sender_id,
|
||||
content=content,
|
||||
is_system_message=is_system_message,
|
||||
)
|
||||
db.add(message)
|
||||
db.flush()
|
||||
|
||||
# Add attachments if any
|
||||
if attachments:
|
||||
for att_data in attachments:
|
||||
attachment = MessageAttachment(
|
||||
message_id=message.id,
|
||||
filename=att_data["filename"],
|
||||
original_filename=att_data["original_filename"],
|
||||
file_path=att_data["file_path"],
|
||||
file_size=att_data["file_size"],
|
||||
mime_type=att_data["mime_type"],
|
||||
is_image=att_data.get("is_image", False),
|
||||
image_width=att_data.get("image_width"),
|
||||
image_height=att_data.get("image_height"),
|
||||
thumbnail_path=att_data.get("thumbnail_path"),
|
||||
)
|
||||
db.add(attachment)
|
||||
|
||||
# Update conversation metadata
|
||||
conversation = (
|
||||
db.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
)
|
||||
|
||||
if conversation:
|
||||
conversation.last_message_at = datetime.now(UTC)
|
||||
conversation.message_count += 1
|
||||
|
||||
# Update unread counts for other participants
|
||||
if not _skip_unread_update:
|
||||
db.query(ConversationParticipant).filter(
|
||||
and_(
|
||||
ConversationParticipant.conversation_id == conversation_id,
|
||||
or_(
|
||||
ConversationParticipant.participant_type != sender_type,
|
||||
ConversationParticipant.participant_id != sender_id,
|
||||
),
|
||||
)
|
||||
).update(
|
||||
{
|
||||
ConversationParticipant.unread_count: ConversationParticipant.unread_count
|
||||
+ 1
|
||||
}
|
||||
)
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Message {message.id} sent in conversation {conversation_id} "
|
||||
f"by {sender_type.value}:{sender_id}"
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
def delete_message(
|
||||
self,
|
||||
db: Session,
|
||||
message_id: int,
|
||||
deleter_type: ParticipantType,
|
||||
deleter_id: int,
|
||||
) -> Message | None:
|
||||
"""Soft delete a message (for moderation)."""
|
||||
message = db.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
if not message:
|
||||
return None
|
||||
|
||||
# Verify deleter has access to conversation
|
||||
conversation = self.get_conversation(
|
||||
db, message.conversation_id, deleter_type, deleter_id
|
||||
)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
message.is_deleted = True
|
||||
message.deleted_at = datetime.now(UTC)
|
||||
message.deleted_by_type = deleter_type
|
||||
message.deleted_by_id = deleter_id
|
||||
|
||||
db.flush()
|
||||
return message
|
||||
|
||||
def mark_conversation_read(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
reader_type: ParticipantType,
|
||||
reader_id: int,
|
||||
) -> bool:
|
||||
"""Mark all messages in conversation as read for participant."""
|
||||
result = (
|
||||
db.query(ConversationParticipant)
|
||||
.filter(
|
||||
and_(
|
||||
ConversationParticipant.conversation_id == conversation_id,
|
||||
ConversationParticipant.participant_type == reader_type,
|
||||
ConversationParticipant.participant_id == reader_id,
|
||||
)
|
||||
)
|
||||
.update(
|
||||
{
|
||||
ConversationParticipant.unread_count: 0,
|
||||
ConversationParticipant.last_read_at: datetime.now(UTC),
|
||||
}
|
||||
)
|
||||
)
|
||||
db.flush()
|
||||
return result > 0
|
||||
|
||||
def get_unread_count(
|
||||
self,
|
||||
db: Session,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
vendor_id: int | None = None,
|
||||
) -> int:
|
||||
"""Get total unread message count for a participant."""
|
||||
query = db.query(func.sum(ConversationParticipant.unread_count)).filter(
|
||||
and_(
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(ConversationParticipant.vendor_id == vendor_id)
|
||||
|
||||
return query.scalar() or 0
|
||||
|
||||
# =========================================================================
|
||||
# PARTICIPANT HELPERS
|
||||
# =========================================================================
|
||||
|
||||
def get_participant_info(
|
||||
self,
|
||||
db: Session,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get display info for a participant (name, email, avatar)."""
|
||||
if participant_type in [ParticipantType.ADMIN, ParticipantType.VENDOR]:
|
||||
user = db.query(User).filter(User.id == participant_id).first()
|
||||
if user:
|
||||
return {
|
||||
"id": user.id,
|
||||
"type": participant_type.value,
|
||||
"name": f"{user.first_name or ''} {user.last_name or ''}".strip()
|
||||
or user.username,
|
||||
"email": user.email,
|
||||
"avatar_url": None, # Could add avatar support later
|
||||
}
|
||||
elif participant_type == ParticipantType.CUSTOMER:
|
||||
customer = db.query(Customer).filter(Customer.id == participant_id).first()
|
||||
if customer:
|
||||
return {
|
||||
"id": customer.id,
|
||||
"type": participant_type.value,
|
||||
"name": f"{customer.first_name or ''} {customer.last_name or ''}".strip()
|
||||
or customer.email,
|
||||
"email": customer.email,
|
||||
"avatar_url": None,
|
||||
}
|
||||
return None
|
||||
|
||||
def get_other_participant(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
my_type: ParticipantType,
|
||||
my_id: int,
|
||||
) -> ConversationParticipant | None:
|
||||
"""Get the other participant in a conversation."""
|
||||
for p in conversation.participants:
|
||||
if p.participant_type != my_type or p.participant_id != my_id:
|
||||
return p
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# NOTIFICATION PREFERENCES
|
||||
# =========================================================================
|
||||
|
||||
def update_notification_preferences(
|
||||
self,
|
||||
db: Session,
|
||||
conversation_id: int,
|
||||
participant_type: ParticipantType,
|
||||
participant_id: int,
|
||||
email_notifications: bool | None = None,
|
||||
muted: bool | None = None,
|
||||
) -> bool:
|
||||
"""Update notification preferences for a participant in a conversation."""
|
||||
updates = {}
|
||||
if email_notifications is not None:
|
||||
updates[ConversationParticipant.email_notifications] = email_notifications
|
||||
if muted is not None:
|
||||
updates[ConversationParticipant.muted] = muted
|
||||
|
||||
if not updates:
|
||||
return False
|
||||
|
||||
result = (
|
||||
db.query(ConversationParticipant)
|
||||
.filter(
|
||||
and_(
|
||||
ConversationParticipant.conversation_id == conversation_id,
|
||||
ConversationParticipant.participant_type == participant_type,
|
||||
ConversationParticipant.participant_id == participant_id,
|
||||
)
|
||||
)
|
||||
.update(updates)
|
||||
)
|
||||
db.flush()
|
||||
return result > 0
|
||||
|
||||
# =========================================================================
|
||||
# RECIPIENT QUERIES
|
||||
# =========================================================================
|
||||
|
||||
def get_vendor_recipients(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get list of vendor users as potential recipients.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter
|
||||
search: Search term for name/email
|
||||
skip: Pagination offset
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Tuple of (recipients list, total count)
|
||||
"""
|
||||
from models.database.vendor import VendorUser
|
||||
|
||||
query = (
|
||||
db.query(User, VendorUser)
|
||||
.join(VendorUser, User.id == VendorUser.user_id)
|
||||
.filter(User.is_active == True) # noqa: E712
|
||||
)
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(VendorUser.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
(User.username.ilike(search_pattern))
|
||||
| (User.email.ilike(search_pattern))
|
||||
| (User.first_name.ilike(search_pattern))
|
||||
| (User.last_name.ilike(search_pattern))
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
results = query.offset(skip).limit(limit).all()
|
||||
|
||||
recipients = []
|
||||
for user, vendor_user in results:
|
||||
name = f"{user.first_name or ''} {user.last_name or ''}".strip() or user.username
|
||||
recipients.append({
|
||||
"id": user.id,
|
||||
"type": ParticipantType.VENDOR,
|
||||
"name": name,
|
||||
"email": user.email,
|
||||
"vendor_id": vendor_user.vendor_id,
|
||||
"vendor_name": vendor_user.vendor.name if vendor_user.vendor else None,
|
||||
})
|
||||
|
||||
return recipients, total
|
||||
|
||||
def get_customer_recipients(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""
|
||||
Get list of customers as potential recipients.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor ID filter (required for vendor users)
|
||||
search: Search term for name/email
|
||||
skip: Pagination offset
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Tuple of (recipients list, total count)
|
||||
"""
|
||||
query = db.query(Customer).filter(Customer.is_active == True) # noqa: E712
|
||||
|
||||
if vendor_id:
|
||||
query = query.filter(Customer.vendor_id == vendor_id)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
(Customer.email.ilike(search_pattern))
|
||||
| (Customer.first_name.ilike(search_pattern))
|
||||
| (Customer.last_name.ilike(search_pattern))
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
results = query.offset(skip).limit(limit).all()
|
||||
|
||||
recipients = []
|
||||
for customer in results:
|
||||
name = f"{customer.first_name or ''} {customer.last_name or ''}".strip()
|
||||
recipients.append({
|
||||
"id": customer.id,
|
||||
"type": ParticipantType.CUSTOMER,
|
||||
"name": name or customer.email,
|
||||
"email": customer.email,
|
||||
"vendor_id": customer.vendor_id,
|
||||
})
|
||||
|
||||
return recipients, total
|
||||
|
||||
|
||||
# Singleton instance
|
||||
messaging_service = MessagingService()
|
||||
__all__ = [
|
||||
"messaging_service",
|
||||
"MessagingService",
|
||||
]
|
||||
|
||||
@@ -1,735 +1,23 @@
|
||||
# app/services/order_inventory_service.py
|
||||
"""
|
||||
Order-Inventory Integration Service.
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
This service orchestrates inventory operations for orders:
|
||||
- Reserve inventory when orders are confirmed
|
||||
- Fulfill (deduct) inventory when orders are shipped
|
||||
- Release reservations when orders are cancelled
|
||||
The canonical implementation is now in:
|
||||
app/modules/orders/services/order_inventory_service.py
|
||||
|
||||
This is the critical link between the order and inventory systems
|
||||
that ensures stock accuracy.
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
All operations are logged to the inventory_transactions table for audit trail.
|
||||
from app.modules.orders.services import order_inventory_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
InsufficientInventoryException,
|
||||
InventoryNotFoundException,
|
||||
OrderNotFoundException,
|
||||
ValidationException,
|
||||
from app.modules.orders.services.order_inventory_service import (
|
||||
order_inventory_service,
|
||||
OrderInventoryService,
|
||||
)
|
||||
from app.services.inventory_service import inventory_service
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.inventory_transaction import InventoryTransaction, TransactionType
|
||||
from models.database.order import Order, OrderItem
|
||||
from models.schema.inventory import InventoryReserve
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default location for inventory operations
|
||||
DEFAULT_LOCATION = "DEFAULT"
|
||||
|
||||
|
||||
class OrderInventoryService:
|
||||
"""
|
||||
Orchestrate order and inventory operations together.
|
||||
|
||||
This service ensures that:
|
||||
1. When orders are confirmed, inventory is reserved
|
||||
2. When orders are shipped, inventory is fulfilled (deducted)
|
||||
3. When orders are cancelled, reservations are released
|
||||
|
||||
Note: Letzshop orders with unmatched products (placeholder) skip
|
||||
inventory operations for those items.
|
||||
"""
|
||||
|
||||
def get_order_with_items(
|
||||
self, db: Session, vendor_id: int, order_id: int
|
||||
) -> Order:
|
||||
"""Get order with items or raise OrderNotFoundException."""
|
||||
order = (
|
||||
db.query(Order)
|
||||
.filter(Order.id == order_id, Order.vendor_id == vendor_id)
|
||||
.first()
|
||||
)
|
||||
if not order:
|
||||
raise OrderNotFoundException(f"Order {order_id} not found")
|
||||
return order
|
||||
|
||||
def _find_inventory_location(
|
||||
self, db: Session, product_id: int, vendor_id: int
|
||||
) -> str | None:
|
||||
"""
|
||||
Find the location with available inventory for a product.
|
||||
|
||||
Returns the first location with available quantity, or None if no
|
||||
inventory exists.
|
||||
"""
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
Inventory.quantity > Inventory.reserved_quantity,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return inventory.location if inventory else None
|
||||
|
||||
def _is_placeholder_product(self, order_item: OrderItem) -> bool:
|
||||
"""Check if the order item uses a placeholder product."""
|
||||
if not order_item.product:
|
||||
return True
|
||||
# Check if it's the placeholder product (GTIN 0000000000000)
|
||||
return order_item.product.gtin == "0000000000000"
|
||||
|
||||
def _log_transaction(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
product_id: int,
|
||||
inventory: Inventory,
|
||||
transaction_type: TransactionType,
|
||||
quantity_change: int,
|
||||
order: Order,
|
||||
reason: str | None = None,
|
||||
) -> InventoryTransaction:
|
||||
"""
|
||||
Create an inventory transaction record for audit trail.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
product_id: Product ID
|
||||
inventory: Inventory record after the operation
|
||||
transaction_type: Type of transaction
|
||||
quantity_change: Change in quantity (positive = add, negative = remove)
|
||||
order: Order associated with this transaction
|
||||
reason: Optional reason for the transaction
|
||||
|
||||
Returns:
|
||||
Created InventoryTransaction
|
||||
"""
|
||||
transaction = InventoryTransaction.create_transaction(
|
||||
vendor_id=vendor_id,
|
||||
product_id=product_id,
|
||||
inventory_id=inventory.id if inventory else None,
|
||||
transaction_type=transaction_type,
|
||||
quantity_change=quantity_change,
|
||||
quantity_after=inventory.quantity if inventory else 0,
|
||||
reserved_after=inventory.reserved_quantity if inventory else 0,
|
||||
location=inventory.location if inventory else None,
|
||||
warehouse=inventory.warehouse if inventory else None,
|
||||
order_id=order.id,
|
||||
order_number=order.order_number,
|
||||
reason=reason,
|
||||
created_by="system",
|
||||
)
|
||||
db.add(transaction)
|
||||
return transaction
|
||||
|
||||
def reserve_for_order(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Reserve inventory for all items in an order.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
skip_missing: If True, skip items without inventory instead of failing
|
||||
|
||||
Returns:
|
||||
Dict with reserved count and any skipped items
|
||||
|
||||
Raises:
|
||||
InsufficientInventoryException: If skip_missing=False and inventory unavailable
|
||||
"""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
reserved_count = 0
|
||||
skipped_items = []
|
||||
|
||||
for item in order.items:
|
||||
# Skip placeholder products
|
||||
if self._is_placeholder_product(item):
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"reason": "placeholder_product",
|
||||
})
|
||||
continue
|
||||
|
||||
# Find inventory location
|
||||
location = self._find_inventory_location(db, item.product_id, vendor_id)
|
||||
|
||||
if not location:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "no_inventory",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=location,
|
||||
quantity=item.quantity,
|
||||
)
|
||||
updated_inventory = inventory_service.reserve_inventory(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
reserved_count += 1
|
||||
|
||||
# Log transaction for audit trail
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.RESERVE,
|
||||
quantity_change=0, # Reserve doesn't change quantity, only reserved_quantity
|
||||
order=order,
|
||||
reason=f"Reserved for order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Reserved {item.quantity} units of product {item.product_id} "
|
||||
f"for order {order.order_number}"
|
||||
)
|
||||
except InsufficientInventoryException:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "insufficient_inventory",
|
||||
})
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Order {order.order_number}: reserved {reserved_count} items, "
|
||||
f"skipped {len(skipped_items)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"reserved_count": reserved_count,
|
||||
"skipped_items": skipped_items,
|
||||
}
|
||||
|
||||
def fulfill_order(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Fulfill (deduct) inventory when an order is shipped.
|
||||
|
||||
This decreases both the total quantity and reserved quantity,
|
||||
effectively consuming the reserved stock.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
skip_missing: If True, skip items without inventory
|
||||
|
||||
Returns:
|
||||
Dict with fulfilled count and any skipped items
|
||||
"""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
fulfilled_count = 0
|
||||
skipped_items = []
|
||||
|
||||
for item in order.items:
|
||||
# Skip already fully shipped items
|
||||
if item.is_fully_shipped:
|
||||
continue
|
||||
|
||||
# Skip placeholder products
|
||||
if self._is_placeholder_product(item):
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"reason": "placeholder_product",
|
||||
})
|
||||
continue
|
||||
|
||||
# Only fulfill remaining quantity
|
||||
quantity_to_fulfill = item.remaining_quantity
|
||||
|
||||
# Find inventory location
|
||||
location = self._find_inventory_location(db, item.product_id, vendor_id)
|
||||
|
||||
# Also check for inventory with reserved quantity
|
||||
if not location:
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == item.product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if inventory:
|
||||
location = inventory.location
|
||||
|
||||
if not location:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "no_inventory",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=location,
|
||||
quantity=quantity_to_fulfill,
|
||||
)
|
||||
updated_inventory = inventory_service.fulfill_reservation(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
fulfilled_count += 1
|
||||
|
||||
# Update item shipped quantity
|
||||
item.shipped_quantity = item.quantity
|
||||
item.inventory_fulfilled = True
|
||||
|
||||
# Log transaction for audit trail
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.FULFILL,
|
||||
quantity_change=-quantity_to_fulfill, # Negative because stock is consumed
|
||||
order=order,
|
||||
reason=f"Fulfilled for order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Fulfilled {quantity_to_fulfill} units of product {item.product_id} "
|
||||
f"for order {order.order_number}"
|
||||
)
|
||||
except (InsufficientInventoryException, InventoryNotFoundException) as e:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": str(e),
|
||||
})
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Order {order.order_number}: fulfilled {fulfilled_count} items, "
|
||||
f"skipped {len(skipped_items)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"fulfilled_count": fulfilled_count,
|
||||
"skipped_items": skipped_items,
|
||||
}
|
||||
|
||||
def fulfill_item(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
item_id: int,
|
||||
quantity: int | None = None,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Fulfill (deduct) inventory for a specific order item.
|
||||
|
||||
Supports partial fulfillment - ship some units now, rest later.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
item_id: Order item ID
|
||||
quantity: Quantity to ship (defaults to remaining quantity)
|
||||
skip_missing: If True, skip if inventory not found
|
||||
|
||||
Returns:
|
||||
Dict with fulfillment result
|
||||
|
||||
Raises:
|
||||
ValidationException: If quantity exceeds remaining
|
||||
"""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
# Find the item
|
||||
item = None
|
||||
for order_item in order.items:
|
||||
if order_item.id == item_id:
|
||||
item = order_item
|
||||
break
|
||||
|
||||
if not item:
|
||||
raise ValidationException(f"Item {item_id} not found in order {order_id}")
|
||||
|
||||
# Check if already fully shipped
|
||||
if item.is_fully_shipped:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "Item already fully shipped",
|
||||
}
|
||||
|
||||
# Default to remaining quantity
|
||||
quantity_to_fulfill = quantity or item.remaining_quantity
|
||||
|
||||
# Validate quantity
|
||||
if quantity_to_fulfill > item.remaining_quantity:
|
||||
raise ValidationException(
|
||||
f"Cannot ship {quantity_to_fulfill} units - only {item.remaining_quantity} remaining"
|
||||
)
|
||||
|
||||
if quantity_to_fulfill <= 0:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "Nothing to fulfill",
|
||||
}
|
||||
|
||||
# Skip placeholder products
|
||||
if self._is_placeholder_product(item):
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "Placeholder product - skipped",
|
||||
}
|
||||
|
||||
# Find inventory location
|
||||
location = self._find_inventory_location(db, item.product_id, vendor_id)
|
||||
|
||||
if not location:
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == item.product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if inventory:
|
||||
location = inventory.location
|
||||
|
||||
if not location:
|
||||
if skip_missing:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": "No inventory found",
|
||||
}
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=location,
|
||||
quantity=quantity_to_fulfill,
|
||||
)
|
||||
updated_inventory = inventory_service.fulfill_reservation(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
|
||||
# Update item shipped quantity
|
||||
item.shipped_quantity += quantity_to_fulfill
|
||||
|
||||
# Mark as fulfilled only if fully shipped
|
||||
if item.is_fully_shipped:
|
||||
item.inventory_fulfilled = True
|
||||
|
||||
# Log transaction
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.FULFILL,
|
||||
quantity_change=-quantity_to_fulfill,
|
||||
order=order,
|
||||
reason=f"Partial shipment for order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Fulfilled {quantity_to_fulfill} of {item.quantity} units "
|
||||
f"for item {item_id} in order {order.order_number}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": quantity_to_fulfill,
|
||||
"shipped_quantity": item.shipped_quantity,
|
||||
"remaining_quantity": item.remaining_quantity,
|
||||
"is_fully_shipped": item.is_fully_shipped,
|
||||
}
|
||||
|
||||
except (InsufficientInventoryException, InventoryNotFoundException) as e:
|
||||
if skip_missing:
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"item_id": item_id,
|
||||
"fulfilled_quantity": 0,
|
||||
"message": str(e),
|
||||
}
|
||||
else:
|
||||
raise
|
||||
|
||||
def release_order_reservation(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
skip_missing: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Release reserved inventory when an order is cancelled.
|
||||
|
||||
This decreases the reserved quantity, making the stock available again.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
skip_missing: If True, skip items without inventory
|
||||
|
||||
Returns:
|
||||
Dict with released count and any skipped items
|
||||
"""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
released_count = 0
|
||||
skipped_items = []
|
||||
|
||||
for item in order.items:
|
||||
# Skip placeholder products
|
||||
if self._is_placeholder_product(item):
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"reason": "placeholder_product",
|
||||
})
|
||||
continue
|
||||
|
||||
# Find inventory - look for any inventory for this product
|
||||
inventory = (
|
||||
db.query(Inventory)
|
||||
.filter(
|
||||
Inventory.product_id == item.product_id,
|
||||
Inventory.vendor_id == vendor_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not inventory:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": "no_inventory",
|
||||
})
|
||||
continue
|
||||
else:
|
||||
raise InventoryNotFoundException(
|
||||
f"No inventory found for product {item.product_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
reserve_data = InventoryReserve(
|
||||
product_id=item.product_id,
|
||||
location=inventory.location,
|
||||
quantity=item.quantity,
|
||||
)
|
||||
updated_inventory = inventory_service.release_reservation(
|
||||
db, vendor_id, reserve_data
|
||||
)
|
||||
released_count += 1
|
||||
|
||||
# Log transaction for audit trail
|
||||
self._log_transaction(
|
||||
db=db,
|
||||
vendor_id=vendor_id,
|
||||
product_id=item.product_id,
|
||||
inventory=updated_inventory,
|
||||
transaction_type=TransactionType.RELEASE,
|
||||
quantity_change=0, # Release doesn't change quantity, only reserved_quantity
|
||||
order=order,
|
||||
reason=f"Released for cancelled order {order.order_number}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Released {item.quantity} units of product {item.product_id} "
|
||||
f"for cancelled order {order.order_number}"
|
||||
)
|
||||
except Exception as e:
|
||||
if skip_missing:
|
||||
skipped_items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"reason": str(e),
|
||||
})
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Order {order.order_number}: released {released_count} items, "
|
||||
f"skipped {len(skipped_items)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"released_count": released_count,
|
||||
"skipped_items": skipped_items,
|
||||
}
|
||||
|
||||
def handle_status_change(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
old_status: str | None,
|
||||
new_status: str,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Handle inventory operations based on order status changes.
|
||||
|
||||
Status transitions that trigger inventory operations:
|
||||
- Any → processing: Reserve inventory (if not already reserved)
|
||||
- processing → shipped: Fulfill inventory (deduct from stock)
|
||||
- processing → partially_shipped: Partial fulfillment already done via fulfill_item
|
||||
- Any → cancelled: Release reservations
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
old_status: Previous status (can be None for new orders)
|
||||
new_status: New status
|
||||
|
||||
Returns:
|
||||
Result of inventory operation, or None if no operation needed
|
||||
"""
|
||||
# Skip if status didn't change
|
||||
if old_status == new_status:
|
||||
return None
|
||||
|
||||
result = None
|
||||
|
||||
# Transitioning to processing - reserve inventory
|
||||
if new_status == "processing":
|
||||
result = self.reserve_for_order(db, vendor_id, order_id, skip_missing=True)
|
||||
logger.info(f"Order {order_id} confirmed: inventory reserved")
|
||||
|
||||
# Transitioning to shipped - fulfill remaining inventory
|
||||
elif new_status == "shipped":
|
||||
result = self.fulfill_order(db, vendor_id, order_id, skip_missing=True)
|
||||
logger.info(f"Order {order_id} shipped: inventory fulfilled")
|
||||
|
||||
# partially_shipped - no automatic fulfillment (handled via fulfill_item)
|
||||
elif new_status == "partially_shipped":
|
||||
logger.info(
|
||||
f"Order {order_id} partially shipped: use fulfill_item for item-level fulfillment"
|
||||
)
|
||||
result = {"order_id": order_id, "status": "partially_shipped"}
|
||||
|
||||
# Transitioning to cancelled - release reservations
|
||||
elif new_status == "cancelled":
|
||||
# Only release if there was a previous status (order was in progress)
|
||||
if old_status and old_status not in ("cancelled", "refunded"):
|
||||
result = self.release_order_reservation(
|
||||
db, vendor_id, order_id, skip_missing=True
|
||||
)
|
||||
logger.info(f"Order {order_id} cancelled: reservations released")
|
||||
|
||||
return result
|
||||
|
||||
def get_shipment_status(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
order_id: int,
|
||||
) -> dict:
|
||||
"""
|
||||
Get detailed shipment status for an order.
|
||||
|
||||
Returns item-level shipment status for partial shipment tracking.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
Dict with shipment status details
|
||||
"""
|
||||
order = self.get_order_with_items(db, vendor_id, order_id)
|
||||
|
||||
items = []
|
||||
for item in order.items:
|
||||
items.append({
|
||||
"item_id": item.id,
|
||||
"product_id": item.product_id,
|
||||
"product_name": item.product_name,
|
||||
"quantity": item.quantity,
|
||||
"shipped_quantity": item.shipped_quantity,
|
||||
"remaining_quantity": item.remaining_quantity,
|
||||
"is_fully_shipped": item.is_fully_shipped,
|
||||
"is_partially_shipped": item.is_partially_shipped,
|
||||
})
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"order_number": order.order_number,
|
||||
"order_status": order.status,
|
||||
"is_fully_shipped": order.is_fully_shipped,
|
||||
"is_partially_shipped": order.is_partially_shipped,
|
||||
"shipped_item_count": order.shipped_item_count,
|
||||
"total_item_count": len(order.items),
|
||||
"total_shipped_units": order.total_shipped_units,
|
||||
"total_ordered_units": order.total_ordered_units,
|
||||
"items": items,
|
||||
}
|
||||
|
||||
|
||||
# Create service instance
|
||||
order_inventory_service = OrderInventoryService()
|
||||
__all__ = [
|
||||
"order_inventory_service",
|
||||
"OrderInventoryService",
|
||||
]
|
||||
|
||||
@@ -1,632 +1,23 @@
|
||||
# app/services/order_item_exception_service.py
|
||||
"""
|
||||
Service for managing order item exceptions (unmatched products).
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
This service handles:
|
||||
- Creating exceptions when products are not found during order import
|
||||
- Resolving exceptions by assigning products
|
||||
- Auto-matching when new products are imported
|
||||
- Querying and statistics for exceptions
|
||||
The canonical implementation is now in:
|
||||
app/modules/orders/services/order_item_exception_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.orders.services import order_item_exception_service
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.exceptions import (
|
||||
ExceptionAlreadyResolvedException,
|
||||
InvalidProductForExceptionException,
|
||||
OrderItemExceptionNotFoundException,
|
||||
ProductNotFoundException,
|
||||
from app.modules.orders.services.order_item_exception_service import (
|
||||
order_item_exception_service,
|
||||
OrderItemExceptionService,
|
||||
)
|
||||
from models.database.order import Order, OrderItem
|
||||
from models.database.order_item_exception import OrderItemException
|
||||
from models.database.product import Product
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrderItemExceptionService:
|
||||
"""Service for order item exception CRUD and resolution workflow."""
|
||||
|
||||
# =========================================================================
|
||||
# Exception Creation
|
||||
# =========================================================================
|
||||
|
||||
def create_exception(
|
||||
self,
|
||||
db: Session,
|
||||
order_item: OrderItem,
|
||||
vendor_id: int,
|
||||
original_gtin: str | None,
|
||||
original_product_name: str | None,
|
||||
original_sku: str | None,
|
||||
exception_type: str = "product_not_found",
|
||||
) -> OrderItemException:
|
||||
"""
|
||||
Create an exception record for an unmatched order item.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
order_item: The order item that couldn't be matched
|
||||
vendor_id: Vendor ID (denormalized for efficient queries)
|
||||
original_gtin: Original GTIN from marketplace
|
||||
original_product_name: Original product name from marketplace
|
||||
original_sku: Original SKU from marketplace
|
||||
exception_type: Type of exception (product_not_found, gtin_mismatch, etc.)
|
||||
|
||||
Returns:
|
||||
Created OrderItemException
|
||||
"""
|
||||
exception = OrderItemException(
|
||||
order_item_id=order_item.id,
|
||||
vendor_id=vendor_id,
|
||||
original_gtin=original_gtin,
|
||||
original_product_name=original_product_name,
|
||||
original_sku=original_sku,
|
||||
exception_type=exception_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(exception)
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Created order item exception {exception.id} for order item "
|
||||
f"{order_item.id}, GTIN: {original_gtin}"
|
||||
)
|
||||
|
||||
return exception
|
||||
|
||||
# =========================================================================
|
||||
# Exception Retrieval
|
||||
# =========================================================================
|
||||
|
||||
def get_exception_by_id(
|
||||
self,
|
||||
db: Session,
|
||||
exception_id: int,
|
||||
vendor_id: int | None = None,
|
||||
) -> OrderItemException:
|
||||
"""
|
||||
Get an exception by ID, optionally filtered by vendor.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
exception_id: Exception ID
|
||||
vendor_id: Optional vendor ID filter (for vendor-scoped access)
|
||||
|
||||
Returns:
|
||||
OrderItemException
|
||||
|
||||
Raises:
|
||||
OrderItemExceptionNotFoundException: If not found
|
||||
"""
|
||||
query = db.query(OrderItemException).filter(
|
||||
OrderItemException.id == exception_id
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(OrderItemException.vendor_id == vendor_id)
|
||||
|
||||
exception = query.first()
|
||||
|
||||
if not exception:
|
||||
raise OrderItemExceptionNotFoundException(exception_id)
|
||||
|
||||
return exception
|
||||
|
||||
def get_pending_exceptions(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[OrderItemException], int]:
|
||||
"""
|
||||
Get exceptions with pagination and filtering.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor filter
|
||||
status: Optional status filter (pending, resolved, ignored)
|
||||
search: Optional search in GTIN, product name, or order number
|
||||
skip: Pagination offset
|
||||
limit: Pagination limit
|
||||
|
||||
Returns:
|
||||
Tuple of (list of exceptions, total count)
|
||||
"""
|
||||
query = (
|
||||
db.query(OrderItemException)
|
||||
.join(OrderItem)
|
||||
.join(Order)
|
||||
.options(
|
||||
joinedload(OrderItemException.order_item).joinedload(OrderItem.order)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if vendor_id is not None:
|
||||
query = query.filter(OrderItemException.vendor_id == vendor_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(OrderItemException.status == status)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
OrderItemException.original_gtin.ilike(search_pattern),
|
||||
OrderItemException.original_product_name.ilike(search_pattern),
|
||||
OrderItemException.original_sku.ilike(search_pattern),
|
||||
Order.order_number.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Apply pagination and ordering
|
||||
exceptions = (
|
||||
query.order_by(OrderItemException.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return exceptions, total
|
||||
|
||||
def get_exceptions_for_order(
|
||||
self,
|
||||
db: Session,
|
||||
order_id: int,
|
||||
) -> list[OrderItemException]:
|
||||
"""
|
||||
Get all exceptions for items in an order.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
List of exceptions for the order
|
||||
"""
|
||||
return (
|
||||
db.query(OrderItemException)
|
||||
.join(OrderItem)
|
||||
.filter(OrderItem.order_id == order_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Exception Statistics
|
||||
# =========================================================================
|
||||
|
||||
def get_exception_stats(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Get exception counts by status.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Optional vendor filter
|
||||
|
||||
Returns:
|
||||
Dict with pending, resolved, ignored, total counts
|
||||
"""
|
||||
query = db.query(
|
||||
OrderItemException.status,
|
||||
func.count(OrderItemException.id).label("count"),
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
query = query.filter(OrderItemException.vendor_id == vendor_id)
|
||||
|
||||
results = query.group_by(OrderItemException.status).all()
|
||||
|
||||
stats = {
|
||||
"pending": 0,
|
||||
"resolved": 0,
|
||||
"ignored": 0,
|
||||
"total": 0,
|
||||
}
|
||||
|
||||
for status, count in results:
|
||||
if status in stats:
|
||||
stats[status] = count
|
||||
stats["total"] += count
|
||||
|
||||
# Count orders with pending exceptions
|
||||
orders_query = (
|
||||
db.query(func.count(func.distinct(OrderItem.order_id)))
|
||||
.join(OrderItemException)
|
||||
.filter(OrderItemException.status == "pending")
|
||||
)
|
||||
|
||||
if vendor_id is not None:
|
||||
orders_query = orders_query.filter(
|
||||
OrderItemException.vendor_id == vendor_id
|
||||
)
|
||||
|
||||
stats["orders_with_exceptions"] = orders_query.scalar() or 0
|
||||
|
||||
return stats
|
||||
|
||||
# =========================================================================
|
||||
# Exception Resolution
|
||||
# =========================================================================
|
||||
|
||||
def resolve_exception(
|
||||
self,
|
||||
db: Session,
|
||||
exception_id: int,
|
||||
product_id: int,
|
||||
resolved_by: int,
|
||||
notes: str | None = None,
|
||||
vendor_id: int | None = None,
|
||||
) -> OrderItemException:
|
||||
"""
|
||||
Resolve an exception by assigning a product.
|
||||
|
||||
This updates:
|
||||
- The exception record (status, resolved_product_id, etc.)
|
||||
- The order item (product_id, needs_product_match)
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
exception_id: Exception ID to resolve
|
||||
product_id: Product ID to assign
|
||||
resolved_by: User ID who resolved
|
||||
notes: Optional resolution notes
|
||||
vendor_id: Optional vendor filter (for scoped access)
|
||||
|
||||
Returns:
|
||||
Updated OrderItemException
|
||||
|
||||
Raises:
|
||||
OrderItemExceptionNotFoundException: If exception not found
|
||||
ExceptionAlreadyResolvedException: If already resolved
|
||||
InvalidProductForExceptionException: If product is invalid
|
||||
"""
|
||||
exception = self.get_exception_by_id(db, exception_id, vendor_id)
|
||||
|
||||
if exception.status == "resolved":
|
||||
raise ExceptionAlreadyResolvedException(exception_id)
|
||||
|
||||
# Validate product exists and belongs to vendor
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
raise ProductNotFoundException(product_id)
|
||||
|
||||
if product.vendor_id != exception.vendor_id:
|
||||
raise InvalidProductForExceptionException(
|
||||
product_id, "Product belongs to a different vendor"
|
||||
)
|
||||
|
||||
if not product.is_active:
|
||||
raise InvalidProductForExceptionException(
|
||||
product_id, "Product is not active"
|
||||
)
|
||||
|
||||
# Update exception
|
||||
exception.status = "resolved"
|
||||
exception.resolved_product_id = product_id
|
||||
exception.resolved_at = datetime.now(UTC)
|
||||
exception.resolved_by = resolved_by
|
||||
exception.resolution_notes = notes
|
||||
|
||||
# Update order item
|
||||
order_item = exception.order_item
|
||||
order_item.product_id = product_id
|
||||
order_item.needs_product_match = False
|
||||
|
||||
# Update product snapshot on order item
|
||||
if product.marketplace_product:
|
||||
order_item.product_name = product.marketplace_product.get_title("en")
|
||||
order_item.product_sku = product.vendor_sku or order_item.product_sku
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Resolved exception {exception_id} with product {product_id} "
|
||||
f"by user {resolved_by}"
|
||||
)
|
||||
|
||||
return exception
|
||||
|
||||
def ignore_exception(
|
||||
self,
|
||||
db: Session,
|
||||
exception_id: int,
|
||||
resolved_by: int,
|
||||
notes: str,
|
||||
vendor_id: int | None = None,
|
||||
) -> OrderItemException:
|
||||
"""
|
||||
Mark an exception as ignored.
|
||||
|
||||
Note: Ignored exceptions still block order confirmation.
|
||||
This is for tracking purposes (e.g., product will never be matched).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
exception_id: Exception ID
|
||||
resolved_by: User ID who ignored
|
||||
notes: Reason for ignoring (required)
|
||||
vendor_id: Optional vendor filter
|
||||
|
||||
Returns:
|
||||
Updated OrderItemException
|
||||
"""
|
||||
exception = self.get_exception_by_id(db, exception_id, vendor_id)
|
||||
|
||||
if exception.status == "resolved":
|
||||
raise ExceptionAlreadyResolvedException(exception_id)
|
||||
|
||||
exception.status = "ignored"
|
||||
exception.resolved_at = datetime.now(UTC)
|
||||
exception.resolved_by = resolved_by
|
||||
exception.resolution_notes = notes
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Ignored exception {exception_id} by user {resolved_by}: {notes}"
|
||||
)
|
||||
|
||||
return exception
|
||||
|
||||
# =========================================================================
|
||||
# Auto-Matching
|
||||
# =========================================================================
|
||||
|
||||
def auto_match_by_gtin(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
gtin: str,
|
||||
product_id: int,
|
||||
) -> list[OrderItemException]:
|
||||
"""
|
||||
Auto-resolve pending exceptions matching a GTIN.
|
||||
|
||||
Called after a product is imported with a GTIN.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
gtin: GTIN to match
|
||||
product_id: Product ID to assign
|
||||
|
||||
Returns:
|
||||
List of resolved exceptions
|
||||
"""
|
||||
if not gtin:
|
||||
return []
|
||||
|
||||
# Find pending exceptions for this GTIN
|
||||
pending = (
|
||||
db.query(OrderItemException)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItemException.vendor_id == vendor_id,
|
||||
OrderItemException.original_gtin == gtin,
|
||||
OrderItemException.status == "pending",
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not pending:
|
||||
return []
|
||||
|
||||
# Get product for snapshot update
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
logger.warning(f"Product {product_id} not found for auto-match")
|
||||
return []
|
||||
|
||||
resolved = []
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for exception in pending:
|
||||
exception.status = "resolved"
|
||||
exception.resolved_product_id = product_id
|
||||
exception.resolved_at = now
|
||||
exception.resolution_notes = "Auto-matched during product import"
|
||||
|
||||
# Update order item
|
||||
order_item = exception.order_item
|
||||
order_item.product_id = product_id
|
||||
order_item.needs_product_match = False
|
||||
if product.marketplace_product:
|
||||
order_item.product_name = product.marketplace_product.get_title("en")
|
||||
|
||||
resolved.append(exception)
|
||||
|
||||
if resolved:
|
||||
db.flush()
|
||||
logger.info(
|
||||
f"Auto-matched {len(resolved)} exceptions for GTIN {gtin} "
|
||||
f"with product {product_id}"
|
||||
)
|
||||
|
||||
return resolved
|
||||
|
||||
def auto_match_batch(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
gtin_to_product: dict[str, int],
|
||||
) -> int:
|
||||
"""
|
||||
Batch auto-match multiple GTINs after bulk import.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
gtin_to_product: Dict mapping GTIN to product ID
|
||||
|
||||
Returns:
|
||||
Total number of resolved exceptions
|
||||
"""
|
||||
if not gtin_to_product:
|
||||
return 0
|
||||
|
||||
total_resolved = 0
|
||||
|
||||
for gtin, product_id in gtin_to_product.items():
|
||||
resolved = self.auto_match_by_gtin(db, vendor_id, gtin, product_id)
|
||||
total_resolved += len(resolved)
|
||||
|
||||
return total_resolved
|
||||
|
||||
# =========================================================================
|
||||
# Confirmation Checks
|
||||
# =========================================================================
|
||||
|
||||
def order_has_unresolved_exceptions(
|
||||
self,
|
||||
db: Session,
|
||||
order_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if order has any unresolved exceptions.
|
||||
|
||||
An order cannot be confirmed if it has pending or ignored exceptions.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
True if order has unresolved exceptions
|
||||
"""
|
||||
count = (
|
||||
db.query(func.count(OrderItemException.id))
|
||||
.join(OrderItem)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItem.order_id == order_id,
|
||||
OrderItemException.status.in_(["pending", "ignored"]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
return count > 0
|
||||
|
||||
def get_unresolved_exception_count(
|
||||
self,
|
||||
db: Session,
|
||||
order_id: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get count of unresolved exceptions for an order.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
Count of unresolved exceptions
|
||||
"""
|
||||
return (
|
||||
db.query(func.count(OrderItemException.id))
|
||||
.join(OrderItem)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItem.order_id == order_id,
|
||||
OrderItemException.status.in_(["pending", "ignored"]),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
|
||||
# =========================================================================
|
||||
# Bulk Operations
|
||||
# =========================================================================
|
||||
|
||||
def bulk_resolve_by_gtin(
|
||||
self,
|
||||
db: Session,
|
||||
vendor_id: int,
|
||||
gtin: str,
|
||||
product_id: int,
|
||||
resolved_by: int,
|
||||
notes: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Bulk resolve all pending exceptions for a GTIN.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
vendor_id: Vendor ID
|
||||
gtin: GTIN to match
|
||||
product_id: Product ID to assign
|
||||
resolved_by: User ID who resolved
|
||||
notes: Optional notes
|
||||
|
||||
Returns:
|
||||
Number of resolved exceptions
|
||||
"""
|
||||
# Validate product
|
||||
product = db.query(Product).filter(Product.id == product_id).first()
|
||||
if not product:
|
||||
raise ProductNotFoundException(product_id)
|
||||
|
||||
if product.vendor_id != vendor_id:
|
||||
raise InvalidProductForExceptionException(
|
||||
product_id, "Product belongs to a different vendor"
|
||||
)
|
||||
|
||||
# Find and resolve all pending exceptions for this GTIN
|
||||
pending = (
|
||||
db.query(OrderItemException)
|
||||
.filter(
|
||||
and_(
|
||||
OrderItemException.vendor_id == vendor_id,
|
||||
OrderItemException.original_gtin == gtin,
|
||||
OrderItemException.status == "pending",
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
resolution_notes = notes or f"Bulk resolved for GTIN {gtin}"
|
||||
|
||||
for exception in pending:
|
||||
exception.status = "resolved"
|
||||
exception.resolved_product_id = product_id
|
||||
exception.resolved_at = now
|
||||
exception.resolved_by = resolved_by
|
||||
exception.resolution_notes = resolution_notes
|
||||
|
||||
# Update order item
|
||||
order_item = exception.order_item
|
||||
order_item.product_id = product_id
|
||||
order_item.needs_product_match = False
|
||||
if product.marketplace_product:
|
||||
order_item.product_name = product.marketplace_product.get_title("en")
|
||||
|
||||
db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Bulk resolved {len(pending)} exceptions for GTIN {gtin} "
|
||||
f"with product {product_id} by user {resolved_by}"
|
||||
)
|
||||
|
||||
return len(pending)
|
||||
|
||||
|
||||
# Global service instance
|
||||
order_item_exception_service = OrderItemExceptionService()
|
||||
__all__ = [
|
||||
"order_item_exception_service",
|
||||
"OrderItemExceptionService",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,507 +1,23 @@
|
||||
# app/services/test_runner_service.py
|
||||
"""
|
||||
Test Runner Service
|
||||
Service for running pytest and storing results
|
||||
LEGACY LOCATION - Re-exports from module for backwards compatibility.
|
||||
|
||||
The canonical implementation is now in:
|
||||
app/modules/dev_tools/services/test_runner_service.py
|
||||
|
||||
This file exists to maintain backwards compatibility with code that
|
||||
imports from the old location. All new code should import directly
|
||||
from the module:
|
||||
|
||||
from app.modules.dev_tools.services import test_runner_service
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.database.test_run import TestCollection, TestResult, TestRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestRunnerService:
|
||||
"""Service for managing pytest test runs"""
|
||||
|
||||
def __init__(self):
|
||||
self.project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
def create_test_run(
|
||||
self,
|
||||
db: Session,
|
||||
test_path: str = "tests",
|
||||
triggered_by: str = "manual",
|
||||
) -> TestRun:
|
||||
"""Create a test run record without executing tests"""
|
||||
test_run = TestRun(
|
||||
timestamp=datetime.now(UTC),
|
||||
triggered_by=triggered_by,
|
||||
test_path=test_path,
|
||||
status="running",
|
||||
git_commit_hash=self._get_git_commit(),
|
||||
git_branch=self._get_git_branch(),
|
||||
)
|
||||
db.add(test_run)
|
||||
db.flush()
|
||||
return test_run
|
||||
|
||||
def run_tests(
|
||||
self,
|
||||
db: Session,
|
||||
test_path: str = "tests",
|
||||
triggered_by: str = "manual",
|
||||
extra_args: list[str] | None = None,
|
||||
) -> TestRun:
|
||||
"""
|
||||
Run pytest synchronously and store results in database
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
test_path: Path to tests (relative to project root)
|
||||
triggered_by: Who triggered the run
|
||||
extra_args: Additional pytest arguments
|
||||
|
||||
Returns:
|
||||
TestRun object with results
|
||||
"""
|
||||
test_run = self.create_test_run(db, test_path, triggered_by)
|
||||
self._execute_tests(db, test_run, test_path, extra_args)
|
||||
return test_run
|
||||
|
||||
def _execute_tests(
|
||||
self,
|
||||
db: Session,
|
||||
test_run: TestRun,
|
||||
test_path: str,
|
||||
extra_args: list[str] | None,
|
||||
) -> None:
|
||||
"""Execute pytest and update the test run record"""
|
||||
try:
|
||||
# Build pytest command with JSON output
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as f:
|
||||
json_report_path = f.name
|
||||
|
||||
pytest_args = [
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
test_path,
|
||||
"--json-report",
|
||||
f"--json-report-file={json_report_path}",
|
||||
"-v",
|
||||
"--tb=short",
|
||||
]
|
||||
|
||||
if extra_args:
|
||||
pytest_args.extend(extra_args)
|
||||
|
||||
test_run.pytest_args = " ".join(pytest_args)
|
||||
|
||||
# Run pytest
|
||||
start_time = datetime.now(UTC)
|
||||
result = subprocess.run(
|
||||
pytest_args,
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600, # 10 minute timeout
|
||||
)
|
||||
end_time = datetime.now(UTC)
|
||||
|
||||
test_run.duration_seconds = (end_time - start_time).total_seconds()
|
||||
|
||||
# Parse JSON report
|
||||
try:
|
||||
with open(json_report_path) as f:
|
||||
report = json.load(f)
|
||||
|
||||
self._process_json_report(db, test_run, report)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
# Fallback to parsing stdout if JSON report failed
|
||||
logger.warning(f"JSON report unavailable ({e}), parsing stdout")
|
||||
self._parse_pytest_output(test_run, result.stdout, result.stderr)
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
Path(json_report_path).unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Set final status
|
||||
if test_run.failed > 0 or test_run.errors > 0:
|
||||
test_run.status = "failed"
|
||||
else:
|
||||
test_run.status = "passed"
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
test_run.status = "error"
|
||||
logger.error("Pytest run timed out")
|
||||
except Exception as e:
|
||||
test_run.status = "error"
|
||||
logger.error(f"Error running tests: {e}")
|
||||
|
||||
def _process_json_report(self, db: Session, test_run: TestRun, report: dict):
|
||||
"""Process pytest-json-report output"""
|
||||
summary = report.get("summary", {})
|
||||
|
||||
test_run.total_tests = summary.get("total", 0)
|
||||
test_run.passed = summary.get("passed", 0)
|
||||
test_run.failed = summary.get("failed", 0)
|
||||
test_run.errors = summary.get("error", 0)
|
||||
test_run.skipped = summary.get("skipped", 0)
|
||||
test_run.xfailed = summary.get("xfailed", 0)
|
||||
test_run.xpassed = summary.get("xpassed", 0)
|
||||
|
||||
# Process individual test results
|
||||
tests = report.get("tests", [])
|
||||
for test in tests:
|
||||
node_id = test.get("nodeid", "")
|
||||
outcome = test.get("outcome", "unknown")
|
||||
|
||||
# Parse node_id to get file, class, function
|
||||
test_file, test_class, test_name = self._parse_node_id(node_id)
|
||||
|
||||
# Get failure details
|
||||
error_message = None
|
||||
traceback = None
|
||||
if outcome in ("failed", "error"):
|
||||
call_info = test.get("call", {})
|
||||
if "longrepr" in call_info:
|
||||
traceback = call_info["longrepr"]
|
||||
# Extract error message from traceback
|
||||
if isinstance(traceback, str):
|
||||
lines = traceback.strip().split("\n")
|
||||
if lines:
|
||||
error_message = lines[-1][:500] # Last line, limited length
|
||||
|
||||
test_result = TestResult(
|
||||
run_id=test_run.id,
|
||||
node_id=node_id,
|
||||
test_name=test_name,
|
||||
test_file=test_file,
|
||||
test_class=test_class,
|
||||
outcome=outcome,
|
||||
duration_seconds=test.get("duration", 0.0),
|
||||
error_message=error_message,
|
||||
traceback=traceback,
|
||||
markers=test.get("keywords", []),
|
||||
)
|
||||
db.add(test_result)
|
||||
|
||||
def _parse_node_id(self, node_id: str) -> tuple[str, str | None, str]:
|
||||
"""Parse pytest node_id into file, class, function"""
|
||||
# Format: tests/unit/test_foo.py::TestClass::test_method
|
||||
# or: tests/unit/test_foo.py::test_function
|
||||
parts = node_id.split("::")
|
||||
|
||||
test_file = parts[0] if parts else ""
|
||||
test_class = None
|
||||
test_name = parts[-1] if parts else ""
|
||||
|
||||
if len(parts) == 3:
|
||||
test_class = parts[1]
|
||||
elif len(parts) == 2:
|
||||
# Could be Class::method or file::function
|
||||
if parts[1].startswith("Test"):
|
||||
test_class = parts[1]
|
||||
test_name = parts[1]
|
||||
|
||||
# Handle parametrized tests
|
||||
if "[" in test_name:
|
||||
test_name = test_name.split("[")[0]
|
||||
|
||||
return test_file, test_class, test_name
|
||||
|
||||
def _parse_pytest_output(self, test_run: TestRun, stdout: str, stderr: str):
|
||||
"""Fallback parser for pytest text output"""
|
||||
# Parse summary line like: "10 passed, 2 failed, 1 skipped"
|
||||
summary_pattern = r"(\d+)\s+(passed|failed|error|skipped|xfailed|xpassed)"
|
||||
|
||||
for match in re.finditer(summary_pattern, stdout):
|
||||
count = int(match.group(1))
|
||||
status = match.group(2)
|
||||
|
||||
if status == "passed":
|
||||
test_run.passed = count
|
||||
elif status == "failed":
|
||||
test_run.failed = count
|
||||
elif status == "error":
|
||||
test_run.errors = count
|
||||
elif status == "skipped":
|
||||
test_run.skipped = count
|
||||
elif status == "xfailed":
|
||||
test_run.xfailed = count
|
||||
elif status == "xpassed":
|
||||
test_run.xpassed = count
|
||||
|
||||
test_run.total_tests = (
|
||||
test_run.passed
|
||||
+ test_run.failed
|
||||
+ test_run.errors
|
||||
+ test_run.skipped
|
||||
+ test_run.xfailed
|
||||
+ test_run.xpassed
|
||||
)
|
||||
|
||||
def _get_git_commit(self) -> str | None:
|
||||
"""Get current git commit hash"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"],
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.stdout.strip()[:40] if result.returncode == 0 else None
|
||||
except:
|
||||
return None
|
||||
|
||||
def _get_git_branch(self) -> str | None:
|
||||
"""Get current git branch"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.stdout.strip() if result.returncode == 0 else None
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]:
|
||||
"""Get recent test run history"""
|
||||
return db.query(TestRun).order_by(desc(TestRun.timestamp)).limit(limit).all()
|
||||
|
||||
def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None:
|
||||
"""Get a specific test run with results"""
|
||||
return db.query(TestRun).filter(TestRun.id == run_id).first()
|
||||
|
||||
def get_failed_tests(self, db: Session, run_id: int) -> list[TestResult]:
|
||||
"""Get failed tests from a run"""
|
||||
return (
|
||||
db.query(TestResult)
|
||||
.filter(
|
||||
TestResult.run_id == run_id, TestResult.outcome.in_(["failed", "error"])
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_run_results(
|
||||
self, db: Session, run_id: int, outcome: str | None = None
|
||||
) -> list[TestResult]:
|
||||
"""Get test results for a specific run, optionally filtered by outcome"""
|
||||
query = db.query(TestResult).filter(TestResult.run_id == run_id)
|
||||
|
||||
if outcome:
|
||||
query = query.filter(TestResult.outcome == outcome)
|
||||
|
||||
return query.all()
|
||||
|
||||
def get_dashboard_stats(self, db: Session) -> dict:
|
||||
"""Get statistics for the testing dashboard"""
|
||||
# Get latest run
|
||||
latest_run = (
|
||||
db.query(TestRun)
|
||||
.filter(TestRun.status != "running")
|
||||
.order_by(desc(TestRun.timestamp))
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get test collection info (or calculate from latest run)
|
||||
collection = (
|
||||
db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first()
|
||||
)
|
||||
|
||||
# Get trend data (last 10 runs)
|
||||
trend_runs = (
|
||||
db.query(TestRun)
|
||||
.filter(TestRun.status != "running")
|
||||
.order_by(desc(TestRun.timestamp))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Calculate stats by category from latest run
|
||||
by_category = {}
|
||||
if latest_run:
|
||||
results = (
|
||||
db.query(TestResult).filter(TestResult.run_id == latest_run.id).all()
|
||||
)
|
||||
for result in results:
|
||||
# Categorize by test path
|
||||
if "unit" in result.test_file:
|
||||
category = "Unit Tests"
|
||||
elif "integration" in result.test_file:
|
||||
category = "Integration Tests"
|
||||
elif "performance" in result.test_file:
|
||||
category = "Performance Tests"
|
||||
elif "system" in result.test_file:
|
||||
category = "System Tests"
|
||||
else:
|
||||
category = "Other"
|
||||
|
||||
if category not in by_category:
|
||||
by_category[category] = {"total": 0, "passed": 0, "failed": 0}
|
||||
by_category[category]["total"] += 1
|
||||
if result.outcome == "passed":
|
||||
by_category[category]["passed"] += 1
|
||||
elif result.outcome in ("failed", "error"):
|
||||
by_category[category]["failed"] += 1
|
||||
|
||||
# Get top failing tests (across recent runs)
|
||||
top_failing = (
|
||||
db.query(
|
||||
TestResult.test_name,
|
||||
TestResult.test_file,
|
||||
func.count(TestResult.id).label("failure_count"),
|
||||
)
|
||||
.filter(TestResult.outcome.in_(["failed", "error"]))
|
||||
.group_by(TestResult.test_name, TestResult.test_file)
|
||||
.order_by(desc("failure_count"))
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
# Current run stats
|
||||
"total_tests": latest_run.total_tests if latest_run else 0,
|
||||
"passed": latest_run.passed if latest_run else 0,
|
||||
"failed": latest_run.failed if latest_run else 0,
|
||||
"errors": latest_run.errors if latest_run else 0,
|
||||
"skipped": latest_run.skipped if latest_run else 0,
|
||||
"pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0,
|
||||
"duration_seconds": round(latest_run.duration_seconds, 2)
|
||||
if latest_run
|
||||
else 0,
|
||||
"coverage_percent": latest_run.coverage_percent if latest_run else None,
|
||||
"last_run": latest_run.timestamp.isoformat() if latest_run else None,
|
||||
"last_run_status": latest_run.status if latest_run else None,
|
||||
# Collection stats
|
||||
"total_test_files": collection.total_files if collection else 0,
|
||||
"collected_tests": collection.total_tests if collection else 0,
|
||||
"unit_tests": collection.unit_tests if collection else 0,
|
||||
"integration_tests": collection.integration_tests if collection else 0,
|
||||
"performance_tests": collection.performance_tests if collection else 0,
|
||||
"system_tests": collection.system_tests if collection else 0,
|
||||
"last_collected": collection.collected_at.isoformat()
|
||||
if collection
|
||||
else None,
|
||||
# Trend data
|
||||
"trend": [
|
||||
{
|
||||
"timestamp": run.timestamp.isoformat(),
|
||||
"total": run.total_tests,
|
||||
"passed": run.passed,
|
||||
"failed": run.failed,
|
||||
"pass_rate": round(run.pass_rate, 1),
|
||||
"duration": round(run.duration_seconds, 1),
|
||||
}
|
||||
for run in reversed(trend_runs)
|
||||
],
|
||||
# By category
|
||||
"by_category": by_category,
|
||||
# Top failing tests
|
||||
"top_failing": [
|
||||
{
|
||||
"test_name": t.test_name,
|
||||
"test_file": t.test_file,
|
||||
"failure_count": t.failure_count,
|
||||
}
|
||||
for t in top_failing
|
||||
],
|
||||
}
|
||||
|
||||
def collect_tests(self, db: Session) -> TestCollection:
|
||||
"""Collect test information without running tests"""
|
||||
collection = TestCollection(
|
||||
collected_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
try:
|
||||
# Run pytest --collect-only with JSON report
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as f:
|
||||
json_report_path = f.name
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"pytest",
|
||||
"--collect-only",
|
||||
"--json-report",
|
||||
f"--json-report-file={json_report_path}",
|
||||
"tests",
|
||||
],
|
||||
cwd=str(self.project_root),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
# Parse JSON report
|
||||
json_path = Path(json_report_path)
|
||||
if json_path.exists():
|
||||
with open(json_path) as f:
|
||||
report = json.load(f)
|
||||
|
||||
# Get total from summary
|
||||
collection.total_tests = report.get("summary", {}).get("collected", 0)
|
||||
|
||||
# Parse collectors to get test files and counts
|
||||
test_files = {}
|
||||
for collector in report.get("collectors", []):
|
||||
for item in collector.get("result", []):
|
||||
if item.get("type") == "Function":
|
||||
node_id = item.get("nodeid", "")
|
||||
if "::" in node_id:
|
||||
file_path = node_id.split("::")[0]
|
||||
if file_path not in test_files:
|
||||
test_files[file_path] = 0
|
||||
test_files[file_path] += 1
|
||||
|
||||
# Count files and categorize
|
||||
for file_path, count in test_files.items():
|
||||
collection.total_files += 1
|
||||
|
||||
if "/unit/" in file_path or file_path.startswith("tests/unit"):
|
||||
collection.unit_tests += count
|
||||
elif "/integration/" in file_path or file_path.startswith(
|
||||
"tests/integration"
|
||||
):
|
||||
collection.integration_tests += count
|
||||
elif "/performance/" in file_path or file_path.startswith(
|
||||
"tests/performance"
|
||||
):
|
||||
collection.performance_tests += count
|
||||
elif "/system/" in file_path or file_path.startswith(
|
||||
"tests/system"
|
||||
):
|
||||
collection.system_tests += count
|
||||
|
||||
collection.test_files = [
|
||||
{"file": f, "count": c}
|
||||
for f, c in sorted(test_files.items(), key=lambda x: -x[1])
|
||||
]
|
||||
|
||||
# Cleanup
|
||||
json_path.unlink(missing_ok=True)
|
||||
|
||||
logger.info(
|
||||
f"Collected {collection.total_tests} tests from {collection.total_files} files"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting tests: {e}", exc_info=True)
|
||||
|
||||
db.add(collection)
|
||||
return collection
|
||||
|
||||
|
||||
# Singleton instance
|
||||
test_runner_service = TestRunnerService()
|
||||
from app.modules.dev_tools.services.test_runner_service import (
|
||||
test_runner_service,
|
||||
TestRunnerService,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"test_runner_service",
|
||||
"TestRunnerService",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user