feat: add Celery/Redis task queue with feature flag support

Migrate background tasks from FastAPI BackgroundTasks to Celery with Redis
for persistent task queuing, retries, and scheduled jobs.

Key changes:
- Add Celery configuration with Redis broker/backend
- Create task dispatcher with USE_CELERY feature flag for gradual rollout
- Add Celery task wrappers for all background operations:
  - Marketplace imports
  - Letzshop historical imports
  - Product exports
  - Code quality scans
  - Test runs
  - Subscription scheduled tasks (via Celery Beat)
- Add celery_task_id column to job tables for Flower integration
- Add Flower dashboard link to admin background tasks page
- Update docker-compose.yml with worker, beat, and flower services
- Add Makefile targets: celery-worker, celery-beat, celery-dev, flower

When USE_CELERY=false (default), system falls back to FastAPI BackgroundTasks
for development without Redis dependency.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-11 17:35:16 +01:00
parent 879ac0caea
commit 2792414395
30 changed files with 2218 additions and 79 deletions

View File

@@ -0,0 +1,16 @@
# app/tasks/celery_tasks/__init__.py
"""
Celery task modules for Wizamart.
This package contains Celery task wrappers for background processing:
- marketplace: Product import tasks
- letzshop: Historical import tasks
- subscription: Scheduled subscription management
- export: Product export tasks
- code_quality: Code quality scan tasks
- test_runner: Test execution tasks
"""
from app.tasks.celery_tasks.base import DatabaseTask
__all__ = ["DatabaseTask"]

View File

@@ -0,0 +1,91 @@
# app/tasks/celery_tasks/base.py
"""
Base Celery task class with database session management.
Provides a DatabaseTask base class that handles:
- Database session lifecycle (create/close)
- Context manager pattern for session usage
- Proper cleanup on task completion or failure
"""
import logging
from contextlib import contextmanager
from celery import Task
from app.core.database import SessionLocal
logger = logging.getLogger(__name__)
class DatabaseTask(Task):
"""
Base task with database session management.
Usage:
@celery_app.task(bind=True, base=DatabaseTask)
def my_task(self, arg1, arg2):
with self.get_db() as db:
# Use db session
result = db.query(Model).all()
return result
"""
abstract = True
@contextmanager
def get_db(self):
"""
Context manager for database session.
Yields a database session and ensures proper cleanup
on both success and failure.
Yields:
Session: SQLAlchemy database session
Example:
with self.get_db() as db:
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
"""
db = SessionLocal()
try:
yield db
except Exception as e:
logger.error(f"Database error in task {self.name}: {e}")
db.rollback()
raise
finally:
db.close()
def on_failure(self, exc, task_id, args, kwargs, einfo):
"""
Called when task fails.
Logs the failure with task details for debugging.
"""
logger.error(
f"Task {self.name}[{task_id}] failed: {exc}\n"
f"Args: {args}\n"
f"Kwargs: {kwargs}\n"
f"Traceback: {einfo}"
)
def on_success(self, retval, task_id, args, kwargs):
"""
Called when task succeeds.
Logs successful completion with task ID.
"""
logger.info(f"Task {self.name}[{task_id}] completed successfully")
def on_retry(self, exc, task_id, args, kwargs, einfo):
"""
Called when task is being retried.
Logs retry attempt with reason.
"""
logger.warning(
f"Task {self.name}[{task_id}] retrying due to: {exc}\n"
f"Retry count: {self.request.retries}"
)

View File

@@ -0,0 +1,236 @@
# app/tasks/celery_tasks/code_quality.py
"""
Celery tasks for code quality scans.
Wraps the existing execute_code_quality_scan function for Celery execution.
"""
import json
import logging
import subprocess
from datetime import UTC, datetime
from app.core.celery_config import celery_app
from app.services.admin_notification_service import admin_notification_service
from app.tasks.celery_tasks.base import DatabaseTask
from models.database.architecture_scan import ArchitectureScan, ArchitectureViolation
logger = logging.getLogger(__name__)
# Validator type constants
VALIDATOR_ARCHITECTURE = "architecture"
VALIDATOR_SECURITY = "security"
VALIDATOR_PERFORMANCE = "performance"
VALID_VALIDATOR_TYPES = [VALIDATOR_ARCHITECTURE, VALIDATOR_SECURITY, VALIDATOR_PERFORMANCE]
# Map validator types to their scripts
VALIDATOR_SCRIPTS = {
VALIDATOR_ARCHITECTURE: "scripts/validate_architecture.py",
VALIDATOR_SECURITY: "scripts/validate_security.py",
VALIDATOR_PERFORMANCE: "scripts/validate_performance.py",
}
# Human-readable names
VALIDATOR_NAMES = {
VALIDATOR_ARCHITECTURE: "Architecture",
VALIDATOR_SECURITY: "Security",
VALIDATOR_PERFORMANCE: "Performance",
}
def _get_git_commit_hash() -> str | None:
"""Get current git commit hash."""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
return result.stdout.strip()[:40]
except Exception:
pass
return None
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.code_quality.execute_code_quality_scan",
max_retries=1,
time_limit=700, # 11+ minutes hard limit
soft_time_limit=600, # 10 minutes soft limit
)
def execute_code_quality_scan(self, scan_id: int):
"""
Celery task to execute a code quality scan.
This task:
1. Gets the scan record from DB
2. Updates status to 'running'
3. Runs the validator script
4. Parses JSON output and creates violation records
5. Updates scan with results and status 'completed' or 'failed'
Args:
scan_id: ID of the ArchitectureScan record
Returns:
dict: Scan results summary
"""
with self.get_db() as db:
# Get the scan record
scan = db.query(ArchitectureScan).filter(ArchitectureScan.id == scan_id).first()
if not scan:
logger.error(f"Code quality scan {scan_id} not found")
return {"error": f"Scan {scan_id} not found"}
# Store Celery task ID
scan.celery_task_id = self.request.id
validator_type = scan.validator_type
if validator_type not in VALID_VALIDATOR_TYPES:
scan.status = "failed"
scan.error_message = f"Invalid validator type: {validator_type}"
db.commit()
return {"error": f"Invalid validator type: {validator_type}"}
script_path = VALIDATOR_SCRIPTS[validator_type]
validator_name = VALIDATOR_NAMES[validator_type]
try:
# Update status to running
scan.status = "running"
scan.started_at = datetime.now(UTC)
scan.progress_message = f"Running {validator_name} validator..."
scan.git_commit_hash = _get_git_commit_hash()
db.commit()
logger.info(f"Starting {validator_name} scan (scan_id={scan_id})")
# Run validator with JSON output
start_time = datetime.now(UTC)
try:
result = subprocess.run(
["python", script_path, "--json"],
capture_output=True,
text=True,
timeout=600, # 10 minute timeout
)
except subprocess.TimeoutExpired:
logger.error(f"{validator_name} scan {scan_id} timed out after 10 minutes")
scan.status = "failed"
scan.error_message = "Scan timed out after 10 minutes"
scan.completed_at = datetime.now(UTC)
db.commit()
return {"error": "Scan timed out"}
duration = (datetime.now(UTC) - start_time).total_seconds()
# Update progress
scan.progress_message = "Parsing results..."
db.commit()
# Parse JSON output
try:
lines = result.stdout.strip().split("\n")
json_start = -1
for i, line in enumerate(lines):
if line.strip().startswith("{"):
json_start = i
break
if json_start == -1:
raise ValueError("No JSON output found in validator output")
json_output = "\n".join(lines[json_start:])
data = json.loads(json_output)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse {validator_name} validator output: {e}")
scan.status = "failed"
scan.error_message = f"Failed to parse validator output: {e}"
scan.completed_at = datetime.now(UTC)
scan.duration_seconds = duration
db.commit()
return {"error": str(e)}
# Update progress
scan.progress_message = "Storing violations..."
db.commit()
# Create violation records
violations_data = data.get("violations", [])
logger.info(f"Creating {len(violations_data)} {validator_name} violation records")
for v in violations_data:
violation = ArchitectureViolation(
scan_id=scan.id,
validator_type=validator_type,
rule_id=v.get("rule_id", "UNKNOWN"),
rule_name=v.get("rule_name", "Unknown Rule"),
severity=v.get("severity", "warning"),
file_path=v.get("file_path", ""),
line_number=v.get("line_number", 0),
message=v.get("message", ""),
context=v.get("context", ""),
suggestion=v.get("suggestion", ""),
status="open",
)
db.add(violation)
# Update scan with results
scan.total_files = data.get("files_checked", 0)
scan.total_violations = data.get("total_violations", len(violations_data))
scan.errors = data.get("errors", 0)
scan.warnings = data.get("warnings", 0)
scan.duration_seconds = duration
scan.completed_at = datetime.now(UTC)
scan.progress_message = None
# Set final status based on results
if scan.errors > 0:
scan.status = "completed_with_warnings"
else:
scan.status = "completed"
db.commit()
logger.info(
f"{validator_name} scan {scan_id} completed: "
f"files={scan.total_files}, violations={scan.total_violations}, "
f"errors={scan.errors}, warnings={scan.warnings}, "
f"duration={duration:.1f}s"
)
return {
"scan_id": scan_id,
"validator_type": validator_type,
"status": scan.status,
"total_files": scan.total_files,
"total_violations": scan.total_violations,
"errors": scan.errors,
"warnings": scan.warnings,
"duration_seconds": duration,
}
except Exception as e:
logger.error(f"Code quality scan {scan_id} failed: {e}", exc_info=True)
scan.status = "failed"
scan.error_message = str(e)[:500]
scan.completed_at = datetime.now(UTC)
scan.progress_message = None
# Create admin notification for scan failure
admin_notification_service.create_notification(
db=db,
title="Code Quality Scan Failed",
message=f"{VALIDATOR_NAMES.get(scan.validator_type, 'Unknown')} scan failed: {str(e)[:200]}",
notification_type="error",
category="code_quality",
action_url="/admin/code-quality",
)
db.commit()
raise # Re-raise for Celery

View File

@@ -0,0 +1,199 @@
# app/tasks/celery_tasks/export.py
"""
Celery tasks for product export operations.
Handles exporting vendor products to various formats (e.g., Letzshop CSV).
"""
import logging
import os
from datetime import UTC, datetime
from pathlib import Path
from app.core.celery_config import celery_app
from app.tasks.celery_tasks.base import DatabaseTask
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.export.export_vendor_products_to_folder",
max_retries=3,
default_retry_delay=60,
)
def export_vendor_products_to_folder(
self,
vendor_id: int,
triggered_by: str,
include_inactive: bool = False,
):
"""
Export all 3 languages (en, fr, de) to disk asynchronously.
Args:
vendor_id: ID of the vendor to export
triggered_by: User identifier who triggered the export
include_inactive: Whether to include inactive products
Returns:
dict: Export results per language with file paths
"""
from app.services.letzshop_export_service import letzshop_export_service
languages = ["en", "fr", "de"]
results = {}
export_dir = None
with self.get_db() as db:
# Get vendor info
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
logger.error(f"Vendor {vendor_id} not found for export")
return {"error": f"Vendor {vendor_id} not found"}
vendor_code = vendor.vendor_code
# Create export directory
export_dir = Path(f"exports/letzshop/{vendor_code}")
export_dir.mkdir(parents=True, exist_ok=True)
started_at = datetime.now(UTC)
logger.info(f"Starting product export for vendor {vendor_code} (ID: {vendor_id})")
for lang in languages:
try:
# Generate CSV content
csv_content = letzshop_export_service.export_vendor_products(
db=db,
vendor_id=vendor_id,
language=lang,
include_inactive=include_inactive,
)
# Write to file
file_name = f"{vendor_code}_products_{lang}.csv"
file_path = export_dir / file_name
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
results[lang] = {
"success": True,
"path": str(file_path),
"file_name": file_name,
}
logger.info(f"Exported {lang} products to {file_path}")
except Exception as e:
logger.error(f"Error exporting {lang} products for vendor {vendor_id}: {e}")
results[lang] = {
"success": False,
"error": str(e),
}
# Log the export
completed_at = datetime.now(UTC)
duration = (completed_at - started_at).total_seconds()
success_count = sum(1 for r in results.values() if r.get("success"))
try:
letzshop_export_service.log_export(
db=db,
vendor_id=vendor_id,
triggered_by=triggered_by,
records_processed=len(languages),
records_succeeded=success_count,
records_failed=len(languages) - success_count,
duration_seconds=duration,
)
db.commit()
except Exception as e:
logger.error(f"Failed to log export: {e}")
logger.info(
f"Product export complete for vendor {vendor_code}: "
f"{success_count}/{len(languages)} languages exported in {duration:.2f}s"
)
return {
"vendor_id": vendor_id,
"vendor_code": vendor_code,
"export_dir": str(export_dir),
"results": results,
"duration_seconds": duration,
"triggered_by": triggered_by,
}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.export.export_marketplace_products",
max_retries=3,
default_retry_delay=60,
)
def export_marketplace_products(
self,
language: str = "en",
triggered_by: str = "system",
):
"""
Export all marketplace products (admin use).
Args:
language: Language code for export (en, fr, de)
triggered_by: User identifier who triggered the export
Returns:
dict: Export result with file path
"""
from app.services.letzshop_export_service import letzshop_export_service
with self.get_db() as db:
started_at = datetime.now(UTC)
logger.info(f"Starting marketplace product export ({language})")
try:
# Create export directory
export_dir = Path("exports/marketplace")
export_dir.mkdir(parents=True, exist_ok=True)
# Generate CSV content
csv_content = letzshop_export_service.export_marketplace_products(
db=db,
language=language,
)
# Write to file
timestamp = started_at.strftime("%Y%m%d_%H%M%S")
file_name = f"marketplace_products_{language}_{timestamp}.csv"
file_path = export_dir / file_name
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
completed_at = datetime.now(UTC)
duration = (completed_at - started_at).total_seconds()
logger.info(f"Marketplace export complete: {file_path} ({duration:.2f}s)")
return {
"success": True,
"path": str(file_path),
"file_name": file_name,
"language": language,
"duration_seconds": duration,
"triggered_by": triggered_by,
}
except Exception as e:
logger.error(f"Error exporting marketplace products: {e}")
return {
"success": False,
"error": str(e),
"language": language,
}

View File

@@ -0,0 +1,272 @@
# app/tasks/celery_tasks/letzshop.py
"""
Celery tasks for Letzshop historical order imports.
Wraps the existing process_historical_import function for Celery execution.
"""
import logging
from datetime import UTC, datetime
from typing import Callable
from app.core.celery_config import celery_app
from app.services.admin_notification_service import admin_notification_service
from app.services.letzshop import LetzshopClientError
from app.services.letzshop.credentials_service import LetzshopCredentialsService
from app.services.letzshop.order_service import LetzshopOrderService
from app.tasks.celery_tasks.base import DatabaseTask
from models.database.letzshop import LetzshopHistoricalImportJob
logger = logging.getLogger(__name__)
def _get_credentials_service(db) -> LetzshopCredentialsService:
"""Create a credentials service instance."""
return LetzshopCredentialsService(db)
def _get_order_service(db) -> LetzshopOrderService:
"""Create an order service instance."""
return LetzshopOrderService(db)
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.letzshop.process_historical_import",
max_retries=2,
default_retry_delay=120,
autoretry_for=(Exception,),
retry_backoff=True,
)
def process_historical_import(self, job_id: int, vendor_id: int):
"""
Celery task for historical order import with progress tracking.
Imports both confirmed and unconfirmed orders from Letzshop API,
updating job progress in the database for frontend polling.
Args:
job_id: ID of the LetzshopHistoricalImportJob record
vendor_id: ID of the vendor to import orders for
Returns:
dict: Import statistics
"""
with self.get_db() as db:
# Get the import job
job = (
db.query(LetzshopHistoricalImportJob)
.filter(LetzshopHistoricalImportJob.id == job_id)
.first()
)
if not job:
logger.error(f"Historical import job {job_id} not found")
return {"error": f"Job {job_id} not found"}
# Store Celery task ID
job.celery_task_id = self.request.id
try:
# Mark as started
job.status = "fetching"
job.started_at = datetime.now(UTC)
db.commit()
creds_service = _get_credentials_service(db)
order_service = _get_order_service(db)
# Create progress callback for fetching
def fetch_progress_callback(page: int, total_fetched: int):
"""Update fetch progress in database."""
job.current_page = page
job.shipments_fetched = total_fetched
db.commit()
# Create progress callback for processing
def create_processing_callback(
phase: str,
) -> Callable[[int, int, int, int], None]:
"""Create a processing progress callback for a phase."""
def callback(processed: int, imported: int, updated: int, skipped: int):
job.orders_processed = processed
job.orders_imported = imported
job.orders_updated = updated
job.orders_skipped = skipped
db.commit()
return callback
with creds_service.create_client(vendor_id) as client:
# ================================================================
# Phase 1: Import confirmed orders
# ================================================================
job.current_phase = "confirmed"
job.current_page = 0
job.shipments_fetched = 0
db.commit()
logger.info(f"Job {job_id}: Fetching confirmed shipments for vendor {vendor_id}")
confirmed_shipments = client.get_all_shipments_paginated(
state="confirmed",
page_size=50,
progress_callback=fetch_progress_callback,
)
logger.info(f"Job {job_id}: Fetched {len(confirmed_shipments)} confirmed shipments")
# Process confirmed shipments
job.status = "processing"
job.orders_processed = 0
job.orders_imported = 0
job.orders_updated = 0
job.orders_skipped = 0
db.commit()
confirmed_stats = order_service.import_historical_shipments(
vendor_id=vendor_id,
shipments=confirmed_shipments,
match_products=True,
progress_callback=create_processing_callback("confirmed"),
)
# Store confirmed stats
job.confirmed_stats = {
"total": confirmed_stats["total"],
"imported": confirmed_stats["imported"],
"updated": confirmed_stats["updated"],
"skipped": confirmed_stats["skipped"],
"products_matched": confirmed_stats["products_matched"],
"products_not_found": confirmed_stats["products_not_found"],
}
job.products_matched = confirmed_stats["products_matched"]
job.products_not_found = confirmed_stats["products_not_found"]
db.commit()
logger.info(
f"Job {job_id}: Confirmed phase complete - "
f"imported={confirmed_stats['imported']}, "
f"updated={confirmed_stats['updated']}, "
f"skipped={confirmed_stats['skipped']}"
)
# ================================================================
# Phase 2: Import unconfirmed (pending) orders
# ================================================================
job.current_phase = "unconfirmed"
job.status = "fetching"
job.current_page = 0
job.shipments_fetched = 0
db.commit()
logger.info(f"Job {job_id}: Fetching unconfirmed shipments for vendor {vendor_id}")
unconfirmed_shipments = client.get_all_shipments_paginated(
state="unconfirmed",
page_size=50,
progress_callback=fetch_progress_callback,
)
logger.info(
f"Job {job_id}: Fetched {len(unconfirmed_shipments)} unconfirmed shipments"
)
# Process unconfirmed shipments
job.status = "processing"
job.orders_processed = 0
db.commit()
unconfirmed_stats = order_service.import_historical_shipments(
vendor_id=vendor_id,
shipments=unconfirmed_shipments,
match_products=True,
progress_callback=create_processing_callback("unconfirmed"),
)
# Store unconfirmed stats
job.declined_stats = {
"total": unconfirmed_stats["total"],
"imported": unconfirmed_stats["imported"],
"updated": unconfirmed_stats["updated"],
"skipped": unconfirmed_stats["skipped"],
"products_matched": unconfirmed_stats["products_matched"],
"products_not_found": unconfirmed_stats["products_not_found"],
}
# Add to cumulative product matching stats
job.products_matched += unconfirmed_stats["products_matched"]
job.products_not_found += unconfirmed_stats["products_not_found"]
logger.info(
f"Job {job_id}: Unconfirmed phase complete - "
f"imported={unconfirmed_stats['imported']}, "
f"updated={unconfirmed_stats['updated']}, "
f"skipped={unconfirmed_stats['skipped']}"
)
# ================================================================
# Complete
# ================================================================
job.status = "completed"
job.completed_at = datetime.now(UTC)
db.commit()
# Update credentials sync status
creds_service.update_sync_status(vendor_id, "success", None)
logger.info(f"Job {job_id}: Historical import completed successfully")
return {
"job_id": job_id,
"confirmed": confirmed_stats,
"unconfirmed": unconfirmed_stats,
}
except LetzshopClientError as e:
logger.error(f"Job {job_id}: Letzshop API error: {e}")
job.status = "failed"
job.error_message = f"Letzshop API error: {e}"
job.completed_at = datetime.now(UTC)
# Get vendor name for notification
order_service = _get_order_service(db)
vendor = order_service.get_vendor(vendor_id)
vendor_name = vendor.name if vendor else f"Vendor {vendor_id}"
# Create admin notification
admin_notification_service.notify_order_sync_failure(
db=db,
vendor_name=vendor_name,
error_message=f"Historical import failed: {str(e)[:150]}",
vendor_id=vendor_id,
)
db.commit()
creds_service = _get_credentials_service(db)
creds_service.update_sync_status(vendor_id, "failed", str(e))
raise # Re-raise for Celery retry
except Exception as e:
logger.error(f"Job {job_id}: Unexpected error: {e}", exc_info=True)
job.status = "failed"
job.error_message = str(e)[:500]
job.completed_at = datetime.now(UTC)
# Get vendor name for notification
order_service = _get_order_service(db)
vendor = order_service.get_vendor(vendor_id)
vendor_name = vendor.name if vendor else f"Vendor {vendor_id}"
# Create admin notification
admin_notification_service.notify_critical_error(
db=db,
error_type="Historical Import",
error_message=f"Import job {job_id} failed for {vendor_name}: {str(e)[:150]}",
details={"job_id": job_id, "vendor_id": vendor_id, "vendor_name": vendor_name},
)
db.commit()
raise # Re-raise for Celery retry

View File

@@ -0,0 +1,160 @@
# app/tasks/celery_tasks/marketplace.py
"""
Celery tasks for marketplace product imports.
Wraps the existing process_marketplace_import function for Celery execution.
"""
import asyncio
import logging
from datetime import UTC, datetime
from app.core.celery_config import celery_app
from app.services.admin_notification_service import admin_notification_service
from app.tasks.celery_tasks.base import DatabaseTask
from app.utils.csv_processor import CSVProcessor
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
logger = logging.getLogger(__name__)
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.marketplace.process_marketplace_import",
max_retries=3,
default_retry_delay=60,
autoretry_for=(Exception,),
retry_backoff=True,
retry_backoff_max=600,
retry_jitter=True,
)
def process_marketplace_import(
self,
job_id: int,
url: str,
marketplace: str,
vendor_id: int,
batch_size: int = 1000,
language: str = "en",
):
"""
Celery task to process marketplace CSV import.
Args:
job_id: ID of the MarketplaceImportJob record
url: URL to the CSV file
marketplace: Name of the marketplace (e.g., 'Letzshop')
vendor_id: ID of the vendor
batch_size: Number of rows to process per batch
language: Language code for translations (default: 'en')
Returns:
dict: Import results with counts
"""
csv_processor = CSVProcessor()
with self.get_db() as db:
# Get the import job
job = db.query(MarketplaceImportJob).filter(MarketplaceImportJob.id == job_id).first()
if not job:
logger.error(f"Import job {job_id} not found")
return {"error": f"Import job {job_id} not found"}
# Store Celery task ID on job
job.celery_task_id = self.request.id
# Get vendor information
vendor = db.query(Vendor).filter(Vendor.id == vendor_id).first()
if not vendor:
logger.error(f"Vendor {vendor_id} not found for import job {job_id}")
job.status = "failed"
job.error_message = f"Vendor {vendor_id} not found"
job.completed_at = datetime.now(UTC)
db.commit()
return {"error": f"Vendor {vendor_id} not found"}
# Update job status
job.status = "processing"
job.started_at = datetime.now(UTC)
db.commit()
logger.info(
f"Processing import: Job {job_id}, Marketplace: {marketplace}, "
f"Vendor: {vendor.name} ({vendor.vendor_code}), Language: {language}"
)
try:
# Run the async CSV processor in a sync context
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
csv_processor.process_marketplace_csv_from_url(
url=url,
marketplace=marketplace,
vendor_name=vendor.name,
batch_size=batch_size,
db=db,
language=language,
import_job_id=job_id,
)
)
finally:
loop.close()
# Update job with results
job.status = "completed"
job.completed_at = datetime.now(UTC)
job.imported_count = result["imported"]
job.updated_count = result["updated"]
job.error_count = result.get("errors", 0)
job.total_processed = result["total_processed"]
if result.get("errors", 0) > 0:
job.status = "completed_with_errors"
job.error_message = f"{result['errors']} rows had errors"
# Notify admin if error count is significant
if result.get("errors", 0) >= 5:
admin_notification_service.notify_import_failure(
db=db,
vendor_name=vendor.name,
job_id=job_id,
error_message=f"Import completed with {result['errors']} errors out of {result['total_processed']} rows",
vendor_id=vendor_id,
)
db.commit()
logger.info(
f"Import job {job_id} completed: "
f"imported={result['imported']}, updated={result['updated']}, "
f"errors={result.get('errors', 0)}"
)
return {
"job_id": job_id,
"imported": result["imported"],
"updated": result["updated"],
"errors": result.get("errors", 0),
"total_processed": result["total_processed"],
}
except Exception as e:
logger.error(f"Import job {job_id} failed: {e}", exc_info=True)
job.status = "failed"
job.error_message = str(e)[:500] # Truncate long errors
job.completed_at = datetime.now(UTC)
# Create admin notification for import failure
admin_notification_service.notify_import_failure(
db=db,
vendor_name=vendor.name,
job_id=job_id,
error_message=str(e)[:200],
vendor_id=vendor_id,
)
db.commit()
raise # Re-raise for Celery retry

View File

@@ -0,0 +1,290 @@
# app/tasks/celery_tasks/subscription.py
"""
Celery tasks for subscription management.
Scheduled tasks for:
- Resetting period counters
- Checking trial expirations
- Syncing with Stripe
- Cleaning up stale subscriptions
- Capturing capacity snapshots
"""
import logging
from datetime import UTC, datetime, timedelta
from app.core.celery_config import celery_app
from app.services.stripe_service import stripe_service
from app.tasks.celery_tasks.base import DatabaseTask
from models.database.subscription import SubscriptionStatus, VendorSubscription
logger = logging.getLogger(__name__)
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.reset_period_counters",
)
def reset_period_counters(self):
"""
Reset order counters for subscriptions whose billing period has ended.
Runs daily at 00:05. Resets orders_this_period to 0 and updates period dates.
"""
now = datetime.now(UTC)
reset_count = 0
with self.get_db() as db:
# Find subscriptions where period has ended
expired_periods = (
db.query(VendorSubscription)
.filter(
VendorSubscription.period_end <= now,
VendorSubscription.status.in_(["active", "trial"]),
)
.all()
)
for subscription in expired_periods:
old_period_end = subscription.period_end
# Reset counters
subscription.orders_this_period = 0
subscription.orders_limit_reached_at = None
# Set new period dates
if subscription.is_annual:
subscription.period_start = now
subscription.period_end = now + timedelta(days=365)
else:
subscription.period_start = now
subscription.period_end = now + timedelta(days=30)
subscription.updated_at = now
reset_count += 1
logger.info(
f"Reset period counters for vendor {subscription.vendor_id}: "
f"old_period_end={old_period_end}, new_period_end={subscription.period_end}"
)
db.commit()
logger.info(f"Reset period counters for {reset_count} subscriptions")
return {"reset_count": reset_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.check_trial_expirations",
)
def check_trial_expirations(self):
"""
Check for expired trials and update their status.
Runs daily at 01:00.
- Trials without payment method -> expired
- Trials with payment method -> active
"""
now = datetime.now(UTC)
expired_count = 0
activated_count = 0
with self.get_db() as db:
# Find expired trials
expired_trials = (
db.query(VendorSubscription)
.filter(
VendorSubscription.status == SubscriptionStatus.TRIAL.value,
VendorSubscription.trial_ends_at <= now,
)
.all()
)
for subscription in expired_trials:
if subscription.stripe_payment_method_id:
# Has payment method - activate
subscription.status = SubscriptionStatus.ACTIVE.value
activated_count += 1
logger.info(
f"Activated subscription for vendor {subscription.vendor_id} "
f"(trial ended with payment method)"
)
else:
# No payment method - expire
subscription.status = SubscriptionStatus.EXPIRED.value
expired_count += 1
logger.info(
f"Expired trial for vendor {subscription.vendor_id} "
f"(no payment method)"
)
subscription.updated_at = now
db.commit()
logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated")
return {"expired_count": expired_count, "activated_count": activated_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.sync_stripe_status",
max_retries=3,
default_retry_delay=300,
)
def sync_stripe_status(self):
"""
Sync subscription status with Stripe.
Runs hourly at :30. Fetches current status from Stripe and updates local records.
"""
if not stripe_service.is_configured:
logger.warning("Stripe not configured, skipping sync")
return {"synced": 0, "skipped": True}
synced_count = 0
error_count = 0
with self.get_db() as db:
# Find subscriptions with Stripe IDs
subscriptions = (
db.query(VendorSubscription)
.filter(VendorSubscription.stripe_subscription_id.isnot(None))
.all()
)
for subscription in subscriptions:
try:
# Fetch from Stripe
stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id)
if not stripe_sub:
logger.warning(
f"Stripe subscription {subscription.stripe_subscription_id} "
f"not found for vendor {subscription.vendor_id}"
)
continue
# Map Stripe status to local status
status_map = {
"active": SubscriptionStatus.ACTIVE.value,
"trialing": SubscriptionStatus.TRIAL.value,
"past_due": SubscriptionStatus.PAST_DUE.value,
"canceled": SubscriptionStatus.CANCELLED.value,
"unpaid": SubscriptionStatus.PAST_DUE.value,
"incomplete": SubscriptionStatus.TRIAL.value,
"incomplete_expired": SubscriptionStatus.EXPIRED.value,
}
new_status = status_map.get(stripe_sub.status)
if new_status and new_status != subscription.status:
old_status = subscription.status
subscription.status = new_status
subscription.updated_at = datetime.now(UTC)
logger.info(
f"Updated vendor {subscription.vendor_id} status: "
f"{old_status} -> {new_status} (from Stripe)"
)
# Update period dates from Stripe
if stripe_sub.current_period_start:
subscription.period_start = datetime.fromtimestamp(
stripe_sub.current_period_start, tz=UTC
)
if stripe_sub.current_period_end:
subscription.period_end = datetime.fromtimestamp(
stripe_sub.current_period_end, tz=UTC
)
# Update payment method
if stripe_sub.default_payment_method:
subscription.stripe_payment_method_id = (
stripe_sub.default_payment_method
if isinstance(stripe_sub.default_payment_method, str)
else stripe_sub.default_payment_method.id
)
synced_count += 1
except Exception as e:
logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}")
error_count += 1
db.commit()
logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors")
return {"synced_count": synced_count, "error_count": error_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.cleanup_stale_subscriptions",
)
def cleanup_stale_subscriptions(self):
"""
Clean up subscriptions in inconsistent states.
Runs weekly on Sunday at 03:00.
"""
now = datetime.now(UTC)
cleaned_count = 0
with self.get_db() as db:
# Find cancelled subscriptions past their period end
stale_cancelled = (
db.query(VendorSubscription)
.filter(
VendorSubscription.status == SubscriptionStatus.CANCELLED.value,
VendorSubscription.period_end < now - timedelta(days=30),
)
.all()
)
for subscription in stale_cancelled:
# Mark as expired (fully terminated)
subscription.status = SubscriptionStatus.EXPIRED.value
subscription.updated_at = now
cleaned_count += 1
logger.info(
f"Marked stale cancelled subscription as expired: vendor {subscription.vendor_id}"
)
db.commit()
logger.info(f"Cleaned up {cleaned_count} stale subscriptions")
return {"cleaned_count": cleaned_count}
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.subscription.capture_capacity_snapshot",
)
def capture_capacity_snapshot(self):
"""
Capture a daily snapshot of platform capacity metrics.
Runs daily at midnight.
"""
from app.services.capacity_forecast_service import capacity_forecast_service
with self.get_db() as db:
snapshot = capacity_forecast_service.capture_daily_snapshot(db)
db.commit()
logger.info(
f"Captured capacity snapshot: {snapshot.total_vendors} vendors, "
f"{snapshot.total_products} products"
)
return {
"snapshot_id": snapshot.id,
"snapshot_date": snapshot.snapshot_date.isoformat(),
"total_vendors": snapshot.total_vendors,
"total_products": snapshot.total_products,
}

View File

@@ -0,0 +1,83 @@
# app/tasks/celery_tasks/test_runner.py
"""
Celery tasks for test execution.
Wraps the existing execute_test_run function for Celery execution.
"""
import logging
from app.core.celery_config import celery_app
from app.services.test_runner_service import test_runner_service
from app.tasks.celery_tasks.base import DatabaseTask
from models.database.test_run import TestRun
logger = logging.getLogger(__name__)
@celery_app.task(
bind=True,
base=DatabaseTask,
name="app.tasks.celery_tasks.test_runner.execute_test_run",
max_retries=1,
time_limit=3600, # 1 hour hard limit
soft_time_limit=3300, # 55 minutes soft limit
)
def execute_test_run(
self,
run_id: int,
test_path: str = "tests",
extra_args: list[str] | None = None,
):
"""
Celery task to execute pytest tests.
Args:
run_id: ID of the TestRun record
test_path: Path to tests (relative to project root)
extra_args: Additional pytest arguments
Returns:
dict: Test run results summary
"""
with self.get_db() as db:
# Get the test run record
test_run = db.query(TestRun).filter(TestRun.id == run_id).first()
if not test_run:
logger.error(f"Test run {run_id} not found")
return {"error": f"Test run {run_id} not found"}
# Store Celery task ID
test_run.celery_task_id = self.request.id
db.commit()
try:
logger.info(f"Starting test execution: Run {run_id}, Path: {test_path}")
# Execute the tests
test_runner_service._execute_tests(db, test_run, test_path, extra_args)
db.commit()
logger.info(
f"Test run {run_id} completed: "
f"status={test_run.status}, passed={test_run.passed}, "
f"failed={test_run.failed}, duration={test_run.duration_seconds:.1f}s"
)
return {
"run_id": run_id,
"status": test_run.status,
"total_tests": test_run.total_tests,
"passed": test_run.passed,
"failed": test_run.failed,
"errors": test_run.errors,
"skipped": test_run.skipped,
"coverage_percent": test_run.coverage_percent,
"duration_seconds": test_run.duration_seconds,
}
except Exception as e:
logger.error(f"Test run {run_id} failed: {e}", exc_info=True)
test_run.status = "error"
db.commit()
raise # Re-raise for Celery

286
app/tasks/dispatcher.py Normal file
View File

@@ -0,0 +1,286 @@
# app/tasks/dispatcher.py
"""
Task dispatcher with feature flag for gradual Celery migration.
This module provides a unified interface for dispatching background tasks.
Based on the USE_CELERY setting, tasks are either:
- Sent to Celery for persistent, reliable execution
- Run via FastAPI BackgroundTasks (fire-and-forget)
This allows for gradual rollout and instant rollback.
"""
import logging
from typing import Any
from fastapi import BackgroundTasks
from app.core.config import settings
logger = logging.getLogger(__name__)
class TaskDispatcher:
"""
Dispatches tasks to either Celery or FastAPI BackgroundTasks.
Usage:
from app.tasks.dispatcher import task_dispatcher
# In an API endpoint:
task_id = task_dispatcher.dispatch_marketplace_import(
background_tasks=background_tasks,
job_id=job.id,
url=url,
marketplace=marketplace,
vendor_id=vendor_id,
)
"""
@property
def use_celery(self) -> bool:
"""Check if Celery is enabled."""
return settings.use_celery
def dispatch_marketplace_import(
self,
background_tasks: BackgroundTasks,
job_id: int,
url: str,
marketplace: str,
vendor_id: int,
batch_size: int = 1000,
language: str = "en",
) -> str | None:
"""
Dispatch marketplace import task.
Args:
background_tasks: FastAPI BackgroundTasks instance
job_id: ID of the MarketplaceImportJob record
url: URL to the CSV file
marketplace: Name of the marketplace
vendor_id: ID of the vendor
batch_size: Number of rows per batch
language: Language code for translations
Returns:
str | None: Celery task ID if using Celery, None otherwise
"""
if self.use_celery:
from app.tasks.celery_tasks.marketplace import process_marketplace_import
task = process_marketplace_import.delay(
job_id=job_id,
url=url,
marketplace=marketplace,
vendor_id=vendor_id,
batch_size=batch_size,
language=language,
)
logger.info(f"Dispatched marketplace import to Celery: task_id={task.id}")
return task.id
else:
from app.tasks.background_tasks import process_marketplace_import
background_tasks.add_task(
process_marketplace_import,
job_id=job_id,
url=url,
marketplace=marketplace,
vendor_id=vendor_id,
batch_size=batch_size,
language=language,
)
logger.info("Dispatched marketplace import to BackgroundTasks")
return None
def dispatch_historical_import(
self,
background_tasks: BackgroundTasks,
job_id: int,
vendor_id: int,
) -> str | None:
"""
Dispatch Letzshop historical import task.
Args:
background_tasks: FastAPI BackgroundTasks instance
job_id: ID of the LetzshopHistoricalImportJob record
vendor_id: ID of the vendor
Returns:
str | None: Celery task ID if using Celery, None otherwise
"""
if self.use_celery:
from app.tasks.celery_tasks.letzshop import process_historical_import
task = process_historical_import.delay(job_id=job_id, vendor_id=vendor_id)
logger.info(f"Dispatched historical import to Celery: task_id={task.id}")
return task.id
else:
from app.tasks.letzshop_tasks import process_historical_import
background_tasks.add_task(
process_historical_import,
job_id=job_id,
vendor_id=vendor_id,
)
logger.info("Dispatched historical import to BackgroundTasks")
return None
def dispatch_code_quality_scan(
self,
background_tasks: BackgroundTasks,
scan_id: int,
) -> str | None:
"""
Dispatch code quality scan task.
Args:
background_tasks: FastAPI BackgroundTasks instance
scan_id: ID of the ArchitectureScan record
Returns:
str | None: Celery task ID if using Celery, None otherwise
"""
if self.use_celery:
from app.tasks.celery_tasks.code_quality import execute_code_quality_scan
task = execute_code_quality_scan.delay(scan_id=scan_id)
logger.info(f"Dispatched code quality scan to Celery: task_id={task.id}")
return task.id
else:
from app.tasks.code_quality_tasks import execute_code_quality_scan
background_tasks.add_task(execute_code_quality_scan, scan_id=scan_id)
logger.info("Dispatched code quality scan to BackgroundTasks")
return None
def dispatch_test_run(
self,
background_tasks: BackgroundTasks,
run_id: int,
test_path: str = "tests",
extra_args: list[str] | None = None,
) -> str | None:
"""
Dispatch test run task.
Args:
background_tasks: FastAPI BackgroundTasks instance
run_id: ID of the TestRun record
test_path: Path to tests
extra_args: Additional pytest arguments
Returns:
str | None: Celery task ID if using Celery, None otherwise
"""
if self.use_celery:
from app.tasks.celery_tasks.test_runner import execute_test_run
task = execute_test_run.delay(
run_id=run_id,
test_path=test_path,
extra_args=extra_args,
)
logger.info(f"Dispatched test run to Celery: task_id={task.id}")
return task.id
else:
from app.tasks.test_runner_tasks import execute_test_run
background_tasks.add_task(
execute_test_run,
run_id=run_id,
test_path=test_path,
extra_args=extra_args,
)
logger.info("Dispatched test run to BackgroundTasks")
return None
def dispatch_product_export(
self,
vendor_id: int,
triggered_by: str,
include_inactive: bool = False,
) -> str | None:
"""
Dispatch product export task (Celery only).
This task is only available via Celery as it's designed for
asynchronous batch exports. For synchronous exports, use
the export service directly.
Args:
vendor_id: ID of the vendor to export
triggered_by: User identifier
include_inactive: Whether to include inactive products
Returns:
str | None: Celery task ID if using Celery, None otherwise
"""
if self.use_celery:
from app.tasks.celery_tasks.export import export_vendor_products_to_folder
task = export_vendor_products_to_folder.delay(
vendor_id=vendor_id,
triggered_by=triggered_by,
include_inactive=include_inactive,
)
logger.info(f"Dispatched product export to Celery: task_id={task.id}")
return task.id
else:
logger.warning(
"Product export task requires Celery. "
"Use letzshop_export_service directly for synchronous export."
)
return None
def get_task_status(self, task_id: str) -> dict[str, Any]:
"""
Get the status of a Celery task.
Args:
task_id: Celery task ID
Returns:
dict: Task status info including state and result
"""
if not self.use_celery:
return {"error": "Celery not enabled"}
from app.core.celery_config import celery_app
result = celery_app.AsyncResult(task_id)
return {
"task_id": task_id,
"state": result.state,
"ready": result.ready(),
"successful": result.successful() if result.ready() else None,
"result": result.result if result.ready() else None,
}
def revoke_task(self, task_id: str, terminate: bool = False) -> bool:
"""
Revoke (cancel) a Celery task.
Args:
task_id: Celery task ID to revoke
terminate: If True, terminate running task; if False, just prevent execution
Returns:
bool: True if revocation was sent
"""
if not self.use_celery:
logger.warning("Cannot revoke task: Celery not enabled")
return False
from app.core.celery_config import celery_app
celery_app.control.revoke(task_id, terminate=terminate)
logger.info(f"Revoked Celery task: task_id={task_id}, terminate={terminate}")
return True
# Singleton instance
task_dispatcher = TaskDispatcher()