fix: correct tojson|safe usage in templates and update validator

- Remove |safe from |tojson in HTML attributes (x-data) - quotes must
  become " for browsers to parse correctly
- Update LANG-002 and LANG-003 architecture rules to document correct
  |tojson usage patterns:
  - HTML attributes: |tojson (no |safe)
  - Script blocks: |tojson|safe
- Fix validator to warn when |tojson|safe is used in x-data (breaks
  HTML attribute parsing)
- Improve code quality across services, APIs, and tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-13 22:59:51 +01:00
parent 94d268f330
commit 9920430b9e
123 changed files with 1408 additions and 840 deletions

View File

@@ -35,7 +35,7 @@ from middleware.auth import AuthManager
from models.database.company import Company
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
from models.database.vendor import Role, Vendor
from models.schema.marketplace_import_job import MarketplaceImportJobResponse
from models.schema.vendor import VendorCreate
@@ -143,7 +143,9 @@ class AdminService:
# Apply pagination
skip = (page - 1) * per_page
users = query.order_by(User.created_at.desc()).offset(skip).limit(per_page).all()
users = (
query.order_by(User.created_at.desc()).offset(skip).limit(per_page).all()
)
return users, total, pages
@@ -199,7 +201,9 @@ class AdminService:
"""
user = (
db.query(User)
.options(joinedload(User.owned_companies), joinedload(User.vendor_memberships))
.options(
joinedload(User.owned_companies), joinedload(User.vendor_memberships)
)
.filter(User.id == user_id)
.first()
)
@@ -243,12 +247,16 @@ class AdminService:
# Check email uniqueness if changing
if email and email != user.email:
if db.query(User).filter(User.email == email).first():
raise UserAlreadyExistsException("Email already registered", field="email")
raise UserAlreadyExistsException(
"Email already registered", field="email"
)
# Check username uniqueness if changing
if username and username != user.username:
if db.query(User).filter(User.username == username).first():
raise UserAlreadyExistsException("Username already taken", field="username")
raise UserAlreadyExistsException(
"Username already taken", field="username"
)
# Update fields
if email is not None:
@@ -322,7 +330,9 @@ class AdminService:
search_term = f"%{query.lower()}%"
users = (
db.query(User)
.filter(or_(User.username.ilike(search_term), User.email.ilike(search_term)))
.filter(
or_(User.username.ilike(search_term), User.email.ilike(search_term))
)
.limit(limit)
.all()
)
@@ -360,14 +370,20 @@ class AdminService:
"""
try:
# Validate company exists
company = db.query(Company).filter(Company.id == vendor_data.company_id).first()
company = (
db.query(Company).filter(Company.id == vendor_data.company_id).first()
)
if not company:
raise ValidationException(f"Company with ID {vendor_data.company_id} not found")
raise ValidationException(
f"Company with ID {vendor_data.company_id} not found"
)
# Check if vendor code already exists
existing_vendor = (
db.query(Vendor)
.filter(func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper())
.filter(
func.upper(Vendor.vendor_code) == vendor_data.vendor_code.upper()
)
.first()
)
if existing_vendor:
@@ -613,7 +629,13 @@ class AdminService:
update_data["tax_number"] = None
# Convert empty strings to None for contact fields (empty = inherit)
contact_fields = ["contact_email", "contact_phone", "website", "business_address", "tax_number"]
contact_fields = [
"contact_email",
"contact_phone",
"website",
"business_address",
"tax_number",
]
for field in contact_fields:
if field in update_data and update_data[field] == "":
update_data[field] = None

View File

@@ -48,7 +48,9 @@ class BackgroundTasksService:
def get_import_stats(self, db: Session) -> dict:
"""Get import job statistics"""
today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(MarketplaceImportJob.id).label("total"),
@@ -57,7 +59,12 @@ class BackgroundTasksService:
).label("running"),
func.sum(
func.case(
(MarketplaceImportJob.status.in_(["completed", "completed_with_errors"]), 1),
(
MarketplaceImportJob.status.in_(
["completed", "completed_with_errors"]
),
1,
),
else_=0,
)
).label("completed"),
@@ -83,12 +90,18 @@ class BackgroundTasksService:
def get_test_run_stats(self, db: Session) -> dict:
"""Get test run statistics"""
today_start = datetime.now(UTC).replace(hour=0, minute=0, second=0, microsecond=0)
today_start = datetime.now(UTC).replace(
hour=0, minute=0, second=0, microsecond=0
)
stats = db.query(
func.count(TestRun.id).label("total"),
func.sum(func.case((TestRun.status == "running", 1), else_=0)).label("running"),
func.sum(func.case((TestRun.status == "passed", 1), else_=0)).label("completed"),
func.sum(func.case((TestRun.status == "running", 1), else_=0)).label(
"running"
),
func.sum(func.case((TestRun.status == "passed", 1), else_=0)).label(
"completed"
),
func.sum(
func.case((TestRun.status.in_(["failed", "error"]), 1), else_=0)
).label("failed"),

View File

@@ -8,7 +8,6 @@ This service handles CRUD operations for companies and company-vendor relationsh
import logging
import secrets
import string
from typing import List, Optional
from sqlalchemy import func, select
from sqlalchemy.orm import Session, joinedload
@@ -26,7 +25,6 @@ class CompanyService:
def __init__(self):
"""Initialize company service."""
pass
def create_company_with_owner(
self, db: Session, company_data: CompanyCreate
@@ -106,11 +104,15 @@ class CompanyService:
Raises:
CompanyNotFoundException: If company not found
"""
company = db.execute(
select(Company)
.where(Company.id == company_id)
.options(joinedload(Company.vendors))
).unique().scalar_one_or_none()
company = (
db.execute(
select(Company)
.where(Company.id == company_id)
.options(joinedload(Company.vendors))
)
.unique()
.scalar_one_or_none()
)
if not company:
raise CompanyNotFoundException(company_id)
@@ -125,7 +127,7 @@ class CompanyService:
search: str | None = None,
is_active: bool | None = None,
is_verified: bool | None = None,
) -> tuple[List[Company], int]:
) -> tuple[list[Company], int]:
"""
Get paginated list of companies with optional filters.
@@ -209,7 +211,9 @@ class CompanyService:
db.flush()
logger.info(f"Deleted company ID {company_id} and associated vendors")
def toggle_verification(self, db: Session, company_id: int, is_verified: bool) -> Company:
def toggle_verification(
self, db: Session, company_id: int, is_verified: bool
) -> Company:
"""
Toggle company verification status.
@@ -227,9 +231,7 @@ class CompanyService:
company = self.get_company_by_id(db, company_id)
company.is_verified = is_verified
db.flush()
logger.info(
f"Company ID {company_id} verification set to {is_verified}"
)
logger.info(f"Company ID {company_id} verification set to {is_verified}")
return company
@@ -251,9 +253,7 @@ class CompanyService:
company = self.get_company_by_id(db, company_id)
company.is_active = is_active
db.flush()
logger.info(
f"Company ID {company_id} active status set to {is_active}"
)
logger.info(f"Company ID {company_id} active status set to {is_active}")
return company

View File

@@ -526,7 +526,9 @@ class ContentPageService:
return (
db.query(ContentPage)
.filter(and_(*filters) if filters else True)
.order_by(ContentPage.vendor_id, ContentPage.display_order, ContentPage.title)
.order_by(
ContentPage.vendor_id, ContentPage.display_order, ContentPage.title
)
.all()
)

View File

@@ -30,20 +30,14 @@ class LetzshopClientError(Exception):
class LetzshopAuthError(LetzshopClientError):
"""Raised when authentication fails."""
pass
class LetzshopAPIError(LetzshopClientError):
"""Raised when the API returns an error response."""
pass
class LetzshopConnectionError(LetzshopClientError):
"""Raised when connection to the API fails."""
pass
# ============================================================================
# GraphQL Queries

View File

@@ -6,7 +6,7 @@ Handles secure storage and retrieval of per-vendor Letzshop API credentials.
"""
import logging
from datetime import datetime, timezone
from datetime import UTC, datetime
from sqlalchemy.orm import Session
@@ -24,14 +24,10 @@ DEFAULT_ENDPOINT = "https://letzshop.lu/graphql"
class CredentialsError(Exception):
"""Base exception for credentials errors."""
pass
class CredentialsNotFoundError(CredentialsError):
"""Raised when credentials are not found for a vendor."""
pass
class LetzshopCredentialsService:
"""
@@ -54,9 +50,7 @@ class LetzshopCredentialsService:
# CRUD Operations
# ========================================================================
def get_credentials(
self, vendor_id: int
) -> VendorLetzshopCredentials | None:
def get_credentials(self, vendor_id: int) -> VendorLetzshopCredentials | None:
"""
Get Letzshop credentials for a vendor.
@@ -72,9 +66,7 @@ class LetzshopCredentialsService:
.first()
)
def get_credentials_or_raise(
self, vendor_id: int
) -> VendorLetzshopCredentials:
def get_credentials_or_raise(self, vendor_id: int) -> VendorLetzshopCredentials:
"""
Get Letzshop credentials for a vendor or raise an exception.
@@ -293,9 +285,7 @@ class LetzshopCredentialsService:
# Connection Testing
# ========================================================================
def test_connection(
self, vendor_id: int
) -> tuple[bool, float | None, str | None]:
def test_connection(self, vendor_id: int) -> tuple[bool, float | None, str | None]:
"""
Test the connection for a vendor's credentials.
@@ -364,7 +354,7 @@ class LetzshopCredentialsService:
if credentials is None:
return None
credentials.last_sync_at = datetime.now(timezone.utc)
credentials.last_sync_at = datetime.now(UTC)
credentials.last_sync_status = status
credentials.last_sync_error = error

View File

@@ -7,7 +7,7 @@ architecture rules (API-002: endpoints should not contain business logic).
"""
import logging
from datetime import datetime, timezone
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import func
@@ -21,21 +21,16 @@ from models.database.letzshop import (
)
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
class VendorNotFoundError(Exception):
"""Raised when a vendor is not found."""
pass
class OrderNotFoundError(Exception):
"""Raised when a Letzshop order is not found."""
pass
class LetzshopOrderService:
"""Service for Letzshop order database operations."""
@@ -114,17 +109,23 @@ class LetzshopOrderService:
or 0
)
vendor_overviews.append({
"vendor_id": vendor.id,
"vendor_name": vendor.name,
"vendor_code": vendor.vendor_code,
"is_configured": credentials is not None,
"auto_sync_enabled": credentials.auto_sync_enabled if credentials else False,
"last_sync_at": credentials.last_sync_at if credentials else None,
"last_sync_status": credentials.last_sync_status if credentials else None,
"pending_orders": pending_orders,
"total_orders": total_orders,
})
vendor_overviews.append(
{
"vendor_id": vendor.id,
"vendor_name": vendor.name,
"vendor_code": vendor.vendor_code,
"is_configured": credentials is not None,
"auto_sync_enabled": credentials.auto_sync_enabled
if credentials
else False,
"last_sync_at": credentials.last_sync_at if credentials else None,
"last_sync_status": credentials.last_sync_status
if credentials
else None,
"pending_orders": pending_orders,
"total_orders": total_orders,
}
)
return vendor_overviews, total
@@ -210,9 +211,7 @@ class LetzshopOrderService:
letzshop_order_number=order_data.get("number"),
letzshop_state=shipment_data.get("state"),
customer_email=order_data.get("email"),
total_amount=str(
order_data.get("totalPrice", {}).get("amount", "")
),
total_amount=str(order_data.get("totalPrice", {}).get("amount", "")),
currency=order_data.get("totalPrice", {}).get("currency", "EUR"),
raw_order_data=shipment_data,
inventory_units=[
@@ -236,13 +235,13 @@ class LetzshopOrderService:
def mark_order_confirmed(self, order: LetzshopOrder) -> LetzshopOrder:
"""Mark an order as confirmed."""
order.confirmed_at = datetime.now(timezone.utc)
order.confirmed_at = datetime.now(UTC)
order.sync_status = "confirmed"
return order
def mark_order_rejected(self, order: LetzshopOrder) -> LetzshopOrder:
"""Mark an order as rejected."""
order.rejected_at = datetime.now(timezone.utc)
order.rejected_at = datetime.now(UTC)
order.sync_status = "rejected"
return order
@@ -255,7 +254,7 @@ class LetzshopOrderService:
"""Set tracking information for an order."""
order.tracking_number = tracking_number
order.tracking_carrier = tracking_carrier
order.tracking_set_at = datetime.now(timezone.utc)
order.tracking_set_at = datetime.now(UTC)
order.sync_status = "shipped"
return order

View File

@@ -4,10 +4,10 @@ Service for exporting products to Letzshop CSV format.
Generates Google Shopping compatible CSV files for Letzshop marketplace.
"""
import csv
import io
import logging
from typing import BinaryIO
from sqlalchemy.orm import Session, joinedload
@@ -140,7 +140,9 @@ class LetzshopExportService:
)
if marketplace:
query = query.filter(MarketplaceProduct.marketplace.ilike(f"%{marketplace}%"))
query = query.filter(
MarketplaceProduct.marketplace.ilike(f"%{marketplace}%")
)
if limit:
query = query.limit(limit)
@@ -193,7 +195,9 @@ class LetzshopExportService:
def _product_to_row(self, product: Product, language: str) -> dict:
"""Convert a Product (with MarketplaceProduct) to a CSV row."""
mp = product.marketplace_product
return self._marketplace_product_to_row(mp, language, vendor_sku=product.vendor_sku)
return self._marketplace_product_to_row(
mp, language, vendor_sku=product.vendor_sku
)
def _marketplace_product_to_row(
self,

View File

@@ -11,7 +11,6 @@ This module provides functions for:
"""
import logging
import os
from datetime import UTC, datetime, timedelta
from pathlib import Path
@@ -58,7 +57,9 @@ class LogService:
conditions.append(ApplicationLog.level == filters.level.upper())
if filters.logger_name:
conditions.append(ApplicationLog.logger_name.like(f"%{filters.logger_name}%"))
conditions.append(
ApplicationLog.logger_name.like(f"%{filters.logger_name}%")
)
if filters.module:
conditions.append(ApplicationLog.module.like(f"%{filters.module}%"))
@@ -215,7 +216,8 @@ class LogService:
except Exception as e:
logger.error(f"Failed to get log statistics: {e}")
raise AdminOperationException(
operation="get_log_statistics", reason=f"Database query failed: {str(e)}"
operation="get_log_statistics",
reason=f"Database query failed: {str(e)}",
)
def get_file_logs(
@@ -252,7 +254,7 @@ class LogService:
stat = log_file.stat()
# Read last N lines efficiently
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
with open(log_file, encoding="utf-8", errors="replace") as f:
# For large files, seek to end and read backwards
all_lines = f.readlines()
log_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
@@ -349,16 +351,21 @@ class LogService:
db.rollback()
logger.error(f"Failed to cleanup old logs: {e}")
raise AdminOperationException(
operation="cleanup_old_logs", reason=f"Delete operation failed: {str(e)}"
operation="cleanup_old_logs",
reason=f"Delete operation failed: {str(e)}",
)
def delete_log(self, db: Session, log_id: int) -> str:
"""Delete a specific log entry."""
try:
log_entry = db.query(ApplicationLog).filter(ApplicationLog.id == log_id).first()
log_entry = (
db.query(ApplicationLog).filter(ApplicationLog.id == log_id).first()
)
if not log_entry:
raise ResourceNotFoundException(resource_type="log", identifier=str(log_id))
raise ResourceNotFoundException(
resource_type="log", identifier=str(log_id)
)
db.delete(log_entry)
db.commit()

View File

@@ -8,7 +8,10 @@ from app.exceptions import (
ImportJobNotOwnedException,
ValidationException,
)
from models.database.marketplace_import_job import MarketplaceImportError, MarketplaceImportJob
from models.database.marketplace_import_job import (
MarketplaceImportError,
MarketplaceImportJob,
)
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import (
@@ -136,7 +139,9 @@ class MarketplaceImportJobService:
except (ImportJobNotFoundException, UnauthorizedVendorAccessException):
raise
except Exception as e:
logger.error(f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}")
logger.error(
f"Error getting import job {job_id} for vendor {vendor_id}: {str(e)}"
)
raise ValidationException("Failed to retrieve import job")
def get_import_jobs(

View File

@@ -32,7 +32,9 @@ from app.exceptions import (
from app.utils.data_processing import GTINProcessor, PriceProcessor
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
from models.database.marketplace_product_translation import MarketplaceProductTranslation
from models.database.marketplace_product_translation import (
MarketplaceProductTranslation,
)
from models.schema.inventory import InventoryLocationResponse, InventorySummaryResponse
from models.schema.marketplace_product import (
MarketplaceProductCreate,
@@ -602,7 +604,6 @@ class MarketplaceProductService:
return normalized
# =========================================================================
# Admin-specific methods for marketplace product management
# =========================================================================
@@ -632,22 +633,28 @@ class MarketplaceProductService:
if search:
search_term = f"%{search}%"
query = query.outerjoin(MarketplaceProductTranslation).filter(
or_(
MarketplaceProductTranslation.title.ilike(search_term),
MarketplaceProduct.gtin.ilike(search_term),
MarketplaceProduct.sku.ilike(search_term),
MarketplaceProduct.brand.ilike(search_term),
MarketplaceProduct.mpn.ilike(search_term),
MarketplaceProduct.marketplace_product_id.ilike(search_term),
query = (
query.outerjoin(MarketplaceProductTranslation)
.filter(
or_(
MarketplaceProductTranslation.title.ilike(search_term),
MarketplaceProduct.gtin.ilike(search_term),
MarketplaceProduct.sku.ilike(search_term),
MarketplaceProduct.brand.ilike(search_term),
MarketplaceProduct.mpn.ilike(search_term),
MarketplaceProduct.marketplace_product_id.ilike(search_term),
)
)
).distinct()
.distinct()
)
if marketplace:
query = query.filter(MarketplaceProduct.marketplace == marketplace)
if vendor_name:
query = query.filter(MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%"))
query = query.filter(
MarketplaceProduct.vendor_name.ilike(f"%{vendor_name}%")
)
if availability:
query = query.filter(MarketplaceProduct.availability == availability)
@@ -787,8 +794,12 @@ class MarketplaceProductService:
"weight": product.weight,
"weight_unit": product.weight_unit,
"translations": translations,
"created_at": product.created_at.isoformat() if product.created_at else None,
"updated_at": product.updated_at.isoformat() if product.updated_at else None,
"created_at": product.created_at.isoformat()
if product.created_at
else None,
"updated_at": product.updated_at.isoformat()
if product.updated_at
else None,
}
def copy_to_vendor_catalog(
@@ -810,6 +821,7 @@ class MarketplaceProductService:
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
from app.exceptions import VendorNotFoundException
raise VendorNotFoundException(str(vendor_id), identifier_type="id")
marketplace_products = (
@@ -839,11 +851,13 @@ class MarketplaceProductService:
if existing:
skipped += 1
details.append({
"id": mp.id,
"status": "skipped",
"reason": "Already exists in catalog",
})
details.append(
{
"id": mp.id,
"status": "skipped",
"reason": "Already exists in catalog",
}
)
continue
product = Product(
@@ -876,7 +890,9 @@ class MarketplaceProductService:
"details": details if len(details) <= 100 else None,
}
def _build_admin_product_item(self, product: MarketplaceProduct, title: str | None) -> dict:
def _build_admin_product_item(
self, product: MarketplaceProduct, title: str | None
) -> dict:
"""Build a product list item dict for admin view."""
return {
"id": product.id,
@@ -894,8 +910,12 @@ class MarketplaceProductService:
"is_active": product.is_active,
"is_digital": product.is_digital,
"product_type_enum": product.product_type_enum,
"created_at": product.created_at.isoformat() if product.created_at else None,
"updated_at": product.updated_at.isoformat() if product.updated_at else None,
"created_at": product.created_at.isoformat()
if product.created_at
else None,
"updated_at": product.updated_at.isoformat()
if product.updated_at
else None,
}

View File

@@ -160,7 +160,9 @@ class StatsService:
# Inventory stats
"total_inventory_quantity": int(total_inventory),
"reserved_inventory_quantity": int(reserved_inventory),
"available_inventory_quantity": int(total_inventory - reserved_inventory),
"available_inventory_quantity": int(
total_inventory - reserved_inventory
),
"inventory_locations_count": inventory_locations,
}

View File

@@ -77,11 +77,15 @@ class TestRunnerService:
"""Execute pytest and update the test run record"""
try:
# Build pytest command with JSON output
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json_report_path = f.name
pytest_args = [
"python", "-m", "pytest",
"python",
"-m",
"pytest",
test_path,
"--json-report",
f"--json-report-file={json_report_path}",
@@ -109,7 +113,7 @@ class TestRunnerService:
# Parse JSON report
try:
with open(json_report_path, 'r') as f:
with open(json_report_path) as f:
report = json.load(f)
self._process_json_report(db, test_run, report)
@@ -167,7 +171,7 @@ class TestRunnerService:
traceback = call_info["longrepr"]
# Extract error message from traceback
if isinstance(traceback, str):
lines = traceback.strip().split('\n')
lines = traceback.strip().split("\n")
if lines:
error_message = lines[-1][:500] # Last line, limited length
@@ -232,8 +236,12 @@ class TestRunnerService:
test_run.xpassed = count
test_run.total_tests = (
test_run.passed + test_run.failed + test_run.errors +
test_run.skipped + test_run.xfailed + test_run.xpassed
test_run.passed
+ test_run.failed
+ test_run.errors
+ test_run.skipped
+ test_run.xfailed
+ test_run.xpassed
)
def _get_git_commit(self) -> str | None:
@@ -266,12 +274,7 @@ class TestRunnerService:
def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]:
"""Get recent test run history"""
return (
db.query(TestRun)
.order_by(desc(TestRun.timestamp))
.limit(limit)
.all()
)
return db.query(TestRun).order_by(desc(TestRun.timestamp)).limit(limit).all()
def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None:
"""Get a specific test run with results"""
@@ -282,8 +285,7 @@ class TestRunnerService:
return (
db.query(TestResult)
.filter(
TestResult.run_id == run_id,
TestResult.outcome.in_(["failed", "error"])
TestResult.run_id == run_id, TestResult.outcome.in_(["failed", "error"])
)
.all()
)
@@ -310,7 +312,9 @@ class TestRunnerService:
)
# Get test collection info (or calculate from latest run)
collection = db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first()
collection = (
db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first()
)
# Get trend data (last 10 runs)
trend_runs = (
@@ -324,7 +328,9 @@ class TestRunnerService:
# Calculate stats by category from latest run
by_category = {}
if latest_run:
results = db.query(TestResult).filter(TestResult.run_id == latest_run.id).all()
results = (
db.query(TestResult).filter(TestResult.run_id == latest_run.id).all()
)
for result in results:
# Categorize by test path
if "unit" in result.test_file:
@@ -351,7 +357,7 @@ class TestRunnerService:
db.query(
TestResult.test_name,
TestResult.test_file,
func.count(TestResult.id).label("failure_count")
func.count(TestResult.id).label("failure_count"),
)
.filter(TestResult.outcome.in_(["failed", "error"]))
.group_by(TestResult.test_name, TestResult.test_file)
@@ -368,11 +374,12 @@ class TestRunnerService:
"errors": latest_run.errors if latest_run else 0,
"skipped": latest_run.skipped if latest_run else 0,
"pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0,
"duration_seconds": round(latest_run.duration_seconds, 2) if latest_run else 0,
"duration_seconds": round(latest_run.duration_seconds, 2)
if latest_run
else 0,
"coverage_percent": latest_run.coverage_percent if latest_run else None,
"last_run": latest_run.timestamp.isoformat() if latest_run else None,
"last_run_status": latest_run.status if latest_run else None,
# Collection stats
"total_test_files": collection.total_files if collection else 0,
"collected_tests": collection.total_tests if collection else 0,
@@ -380,8 +387,9 @@ class TestRunnerService:
"integration_tests": collection.integration_tests if collection else 0,
"performance_tests": collection.performance_tests if collection else 0,
"system_tests": collection.system_tests if collection else 0,
"last_collected": collection.collected_at.isoformat() if collection else None,
"last_collected": collection.collected_at.isoformat()
if collection
else None,
# Trend data
"trend": [
{
@@ -394,10 +402,8 @@ class TestRunnerService:
}
for run in reversed(trend_runs)
],
# By category
"by_category": by_category,
# Top failing tests
"top_failing": [
{
@@ -417,16 +423,20 @@ class TestRunnerService:
try:
# Run pytest --collect-only with JSON report
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as f:
json_report_path = f.name
result = subprocess.run(
[
"python", "-m", "pytest",
"python",
"-m",
"pytest",
"--collect-only",
"--json-report",
f"--json-report-file={json_report_path}",
"tests"
"tests",
],
cwd=str(self.project_root),
capture_output=True,
@@ -461,11 +471,17 @@ class TestRunnerService:
if "/unit/" in file_path or file_path.startswith("tests/unit"):
collection.unit_tests += count
elif "/integration/" in file_path or file_path.startswith("tests/integration"):
elif "/integration/" in file_path or file_path.startswith(
"tests/integration"
):
collection.integration_tests += count
elif "/performance/" in file_path or file_path.startswith("tests/performance"):
elif "/performance/" in file_path or file_path.startswith(
"tests/performance"
):
collection.performance_tests += count
elif "/system/" in file_path or file_path.startswith("tests/system"):
elif "/system/" in file_path or file_path.startswith(
"tests/system"
):
collection.system_tests += count
collection.test_files = [
@@ -476,7 +492,9 @@ class TestRunnerService:
# Cleanup
json_path.unlink(missing_ok=True)
logger.info(f"Collected {collection.total_tests} tests from {collection.total_files} files")
logger.info(
f"Collected {collection.total_tests} tests from {collection.total_files} files"
)
except Exception as e:
logger.error(f"Error collecting tests: {e}", exc_info=True)

View File

@@ -66,10 +66,7 @@ class VendorProductService:
total = query.count()
products = (
query.order_by(Product.updated_at.desc())
.offset(skip)
.limit(limit)
.all()
query.order_by(Product.updated_at.desc()).offset(skip).limit(limit).all()
)
result = []
@@ -138,8 +135,7 @@ class VendorProductService:
.all()
)
return [
{"id": v.id, "name": v.name, "vendor_code": v.vendor_code}
for v in vendors
{"id": v.id, "name": v.name, "vendor_code": v.vendor_code} for v in vendors
]
def get_product_detail(self, db: Session, product_id: int) -> dict:
@@ -212,8 +208,12 @@ class VendorProductService:
"marketplace_translations": mp_translations,
"vendor_translations": vendor_translations,
# Timestamps
"created_at": product.created_at.isoformat() if product.created_at else None,
"updated_at": product.updated_at.isoformat() if product.updated_at else None,
"created_at": product.created_at.isoformat()
if product.created_at
else None,
"updated_at": product.updated_at.isoformat()
if product.updated_at
else None,
}
def remove_product(self, db: Session, product_id: int) -> dict:
@@ -257,8 +257,12 @@ class VendorProductService:
"image_url": product.effective_primary_image_url,
"source_marketplace": mp.marketplace if mp else None,
"source_vendor": mp.vendor_name if mp else None,
"created_at": product.created_at.isoformat() if product.created_at else None,
"updated_at": product.updated_at.isoformat() if product.updated_at else None,
"created_at": product.created_at.isoformat()
if product.created_at
else None,
"updated_at": product.updated_at.isoformat()
if product.updated_at
else None,
}

View File

@@ -67,20 +67,26 @@ class VendorService:
try:
# Validate company_id is provided
if not hasattr(vendor_data, 'company_id') or not vendor_data.company_id:
if not hasattr(vendor_data, "company_id") or not vendor_data.company_id:
raise InvalidVendorDataException(
"company_id is required to create a vendor", field="company_id"
)
# Get company and verify ownership
company = db.query(Company).filter(Company.id == vendor_data.company_id).first()
company = (
db.query(Company).filter(Company.id == vendor_data.company_id).first()
)
if not company:
raise InvalidVendorDataException(
f"Company with ID {vendor_data.company_id} not found", field="company_id"
f"Company with ID {vendor_data.company_id} not found",
field="company_id",
)
# Check if user is company owner or admin
if current_user.role != "admin" and company.owner_user_id != current_user.id:
if (
current_user.role != "admin"
and company.owner_user_id != current_user.id
):
raise UnauthorizedVendorAccessException(
f"company-{vendor_data.company_id}", current_user.id
)
@@ -163,9 +169,7 @@ class VendorService:
)
query = query.filter(
(Vendor.is_active == True)
& (
(Vendor.is_verified == True) | (Vendor.id.in_(owned_vendor_ids))
)
& ((Vendor.is_verified == True) | (Vendor.id.in_(owned_vendor_ids)))
)
else:
# Admin can apply filters
@@ -238,6 +242,7 @@ class VendorService:
VendorNotFoundException: If vendor not found
"""
from sqlalchemy.orm import joinedload
from models.database.company import Company
vendor = (
@@ -272,6 +277,7 @@ class VendorService:
VendorNotFoundException: If vendor not found or inactive
"""
from sqlalchemy.orm import joinedload
from models.database.company import Company
vendor = (
@@ -305,6 +311,7 @@ class VendorService:
VendorNotFoundException: If vendor not found
"""
from sqlalchemy.orm import joinedload
from models.database.company import Company
# Try as integer ID first