# app/modules/billing/tasks/subscription.py """ Celery tasks for subscription management. Scheduled tasks for: - Resetting period counters - Checking trial expirations - Syncing with Stripe - Cleaning up stale subscriptions """ import logging from datetime import UTC, datetime, timedelta from app.core.celery_config import celery_app from app.modules.billing.models import MerchantSubscription, SubscriptionStatus from app.modules.billing.services import stripe_service from app.modules.task_base import ModuleTask logger = logging.getLogger(__name__) @celery_app.task( bind=True, base=ModuleTask, name="app.modules.billing.tasks.subscription.reset_period_counters", ) def reset_period_counters(self): """ Reset billing period dates for subscriptions whose billing period has ended. Runs daily at 00:05. Updates period_start and period_end for the new cycle. """ now = datetime.now(UTC) reset_count = 0 with self.get_db() as db: # Find subscriptions where period has ended expired_periods = ( db.query(MerchantSubscription) .filter( MerchantSubscription.period_end <= now, MerchantSubscription.status.in_(["active", "trial"]), ) .all() ) for subscription in expired_periods: old_period_end = subscription.period_end # Set new period dates if subscription.is_annual: subscription.period_start = now subscription.period_end = now + timedelta(days=365) else: subscription.period_start = now subscription.period_end = now + timedelta(days=30) subscription.updated_at = now reset_count += 1 logger.info( f"Reset period for merchant {subscription.merchant_id}: " f"old_period_end={old_period_end}, new_period_end={subscription.period_end}" ) logger.info(f"Reset period counters for {reset_count} subscriptions") return {"reset_count": reset_count} @celery_app.task( bind=True, base=ModuleTask, name="app.modules.billing.tasks.subscription.check_trial_expirations", ) def check_trial_expirations(self): """ Check for expired trials and update their status. Runs daily at 01:00. - Trials without payment method -> expired - Trials with payment method -> active """ now = datetime.now(UTC) expired_count = 0 activated_count = 0 with self.get_db() as db: # Find expired trials expired_trials = ( db.query(MerchantSubscription) .filter( MerchantSubscription.status == SubscriptionStatus.TRIAL.value, MerchantSubscription.trial_ends_at <= now, ) .all() ) for subscription in expired_trials: if subscription.stripe_payment_method_id: # Has payment method - activate subscription.status = SubscriptionStatus.ACTIVE.value activated_count += 1 logger.info( f"Activated subscription for merchant {subscription.merchant_id} " f"(trial ended with payment method)" ) else: # No payment method - expire subscription.status = SubscriptionStatus.EXPIRED.value expired_count += 1 logger.info( f"Expired trial for merchant {subscription.merchant_id} " f"(no payment method)" ) subscription.updated_at = now logger.info(f"Trial expiration check: {expired_count} expired, {activated_count} activated") return {"expired_count": expired_count, "activated_count": activated_count} @celery_app.task( bind=True, base=ModuleTask, name="app.modules.billing.tasks.subscription.sync_stripe_status", max_retries=3, default_retry_delay=300, ) def sync_stripe_status(self): """ Sync subscription status with Stripe. Runs hourly at :30. Fetches current status from Stripe and updates local records. """ if not stripe_service.is_configured: logger.warning("Stripe not configured, skipping sync") return {"synced": 0, "skipped": True} synced_count = 0 error_count = 0 with self.get_db() as db: # Find subscriptions with Stripe IDs subscriptions = ( db.query(MerchantSubscription) .filter(MerchantSubscription.stripe_subscription_id.isnot(None)) .all() ) for subscription in subscriptions: try: # Fetch from Stripe stripe_sub = stripe_service.get_subscription(subscription.stripe_subscription_id) if not stripe_sub: logger.warning( f"Stripe subscription {subscription.stripe_subscription_id} " f"not found for merchant {subscription.merchant_id}" ) continue # Map Stripe status to local status status_map = { "active": SubscriptionStatus.ACTIVE.value, "trialing": SubscriptionStatus.TRIAL.value, "past_due": SubscriptionStatus.PAST_DUE.value, "canceled": SubscriptionStatus.CANCELLED.value, "unpaid": SubscriptionStatus.PAST_DUE.value, "incomplete": SubscriptionStatus.TRIAL.value, "incomplete_expired": SubscriptionStatus.EXPIRED.value, } new_status = status_map.get(stripe_sub.status) if new_status and new_status != subscription.status: old_status = subscription.status subscription.status = new_status subscription.updated_at = datetime.now(UTC) logger.info( f"Updated merchant {subscription.merchant_id} status: " f"{old_status} -> {new_status} (from Stripe)" ) # Update period dates from Stripe if stripe_sub.current_period_start: subscription.period_start = datetime.fromtimestamp( stripe_sub.current_period_start, tz=UTC ) if stripe_sub.current_period_end: subscription.period_end = datetime.fromtimestamp( stripe_sub.current_period_end, tz=UTC ) # Update payment method if stripe_sub.default_payment_method: subscription.stripe_payment_method_id = ( stripe_sub.default_payment_method if isinstance(stripe_sub.default_payment_method, str) else stripe_sub.default_payment_method.id ) synced_count += 1 except Exception as e: logger.error(f"Error syncing subscription {subscription.stripe_subscription_id}: {e}") error_count += 1 logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors") return {"synced_count": synced_count, "error_count": error_count} @celery_app.task( bind=True, base=ModuleTask, name="app.modules.billing.tasks.subscription.cleanup_stale_subscriptions", ) def cleanup_stale_subscriptions(self): """ Clean up subscriptions in inconsistent states. Runs weekly on Sunday at 03:00. """ now = datetime.now(UTC) cleaned_count = 0 with self.get_db() as db: # Find cancelled subscriptions past their period end stale_cancelled = ( db.query(MerchantSubscription) .filter( MerchantSubscription.status == SubscriptionStatus.CANCELLED.value, MerchantSubscription.period_end < now - timedelta(days=30), ) .all() ) for subscription in stale_cancelled: # Mark as expired (fully terminated) subscription.status = SubscriptionStatus.EXPIRED.value subscription.updated_at = now cleaned_count += 1 logger.info( f"Marked stale cancelled subscription as expired: merchant {subscription.merchant_id}" ) logger.info(f"Cleaned up {cleaned_count} stale subscriptions") return {"cleaned_count": cleaned_count}