# app/modules/task_base.py """ Base Celery task class for module tasks. Provides a ModuleTask base class that handles: - Database session lifecycle (create/close) - Context manager pattern for session usage - Proper cleanup on task completion or failure - Structured logging with task context This is the standard base class for all Celery tasks defined within modules. It replaces the legacy app/tasks/celery_tasks/base.py for new module tasks. Usage: from app.core.celery_config import celery_app from app.modules.task_base import ModuleTask @celery_app.task(bind=True, base=ModuleTask) def my_task(self, arg1, arg2): with self.get_db() as db: # Use db session result = db.query(Model).filter(...).all() return result """ import logging from contextlib import contextmanager from typing import Any from celery import Task from sqlalchemy.orm import Session from app.core.database import SessionLocal logger = logging.getLogger(__name__) class ModuleTask(Task): """ Base Celery task with database session management. Provides: - Database session context manager - Automatic session cleanup on success/failure - Structured logging for task lifecycle events - Retry tracking All module tasks should use this as their base class for consistent behavior and proper resource management. Example: @celery_app.task(bind=True, base=ModuleTask) def process_subscription(self, subscription_id: int): with self.get_db() as db: subscription = db.query(Subscription).get(subscription_id) # Process subscription... return {"status": "processed"} """ # Mark as abstract so Celery doesn't register this as a task abstract = True @contextmanager def get_db(self) -> Session: """ Context manager for database session. Yields a database session and ensures proper cleanup on both success and failure. Commits on success, rolls back on error. Yields: Session: SQLAlchemy database session Example: with self.get_db() as db: store = db.query(Store).filter(Store.id == store_id).first() store.status = "active" # Session commits automatically on exit """ db = SessionLocal() try: yield db db.commit() except Exception as e: logger.error( f"Database error in task {self.name}: {e}", extra={ "task_name": self.name, "task_id": getattr(self.request, "id", None), "error": str(e), }, ) db.rollback() raise finally: db.close() def on_failure( self, exc: Exception, task_id: str, args: tuple, kwargs: dict, einfo: Any, ) -> None: """ Called when task fails. Logs the failure with structured context for debugging and monitoring. """ logger.error( f"Task {self.name}[{task_id}] failed: {exc}", extra={ "task_name": self.name, "task_id": task_id, "args": args, "kwargs": kwargs, "exception": str(exc), "traceback": str(einfo), }, exc_info=True, ) def on_success( self, retval: Any, task_id: str, args: tuple, kwargs: dict, ) -> None: """ Called when task succeeds. Logs successful completion with task context. """ logger.info( f"Task {self.name}[{task_id}] completed successfully", extra={ "task_name": self.name, "task_id": task_id, }, ) def on_retry( self, exc: Exception, task_id: str, args: tuple, kwargs: dict, einfo: Any, ) -> None: """ Called when task is being retried. Logs retry attempt with reason and retry count. """ retry_count = getattr(self.request, "retries", 0) logger.warning( f"Task {self.name}[{task_id}] retrying (attempt {retry_count + 1}): {exc}", extra={ "task_name": self.name, "task_id": task_id, "retry_count": retry_count, "exception": str(exc), }, ) # Alias for backward compatibility and clarity DatabaseTask = ModuleTask __all__ = ["ModuleTask", "DatabaseTask"]