# app/tasks/subscription_tasks.py """ Background tasks for subscription management. Provides scheduled tasks for: - Resetting period counters at billing period end - Expiring trials without payment methods - Syncing subscription status with Stripe - Capturing daily capacity snapshots """ import logging from datetime import UTC, datetime, timedelta from app.core.database import SessionLocal from app.services.stripe_service import stripe_service from models.database.subscription import SubscriptionStatus, VendorSubscription logger = logging.getLogger(__name__) async def reset_period_counters(): """ Reset order counters for subscriptions whose billing period has ended. Should run daily. Resets orders_this_period to 0 and updates period dates. """ db = SessionLocal() now = datetime.now(UTC) reset_count = 0 try: # Find subscriptions where period has ended expired_periods = ( db.query(VendorSubscription) .filter( VendorSubscription.period_end <= now, VendorSubscription.status.in_(["active", "trial"]), ) .all() ) for subscription in expired_periods: old_period_end = subscription.period_end # Reset counters subscription.orders_this_period = 0 subscription.orders_limit_reached_at = None # Set new period dates if subscription.is_annual: subscription.period_start = now subscription.period_end = now + timedelta(days=365) else: subscription.period_start = now subscription.period_end = now + timedelta(days=30) subscription.updated_at = now reset_count += 1 logger.info( f"Reset period counters for vendor {subscription.vendor_id}: " f"old_period_end={old_period_end}, new_period_end={subscription.period_end}" ) db.commit() logger.info(f"Reset period counters for {reset_count} subscriptions") except Exception as e: logger.error(f"Error resetting period counters: {e}") db.rollback() raise finally: db.close() return {"reset_count": reset_count} async def check_trial_expirations(): """ Check for expired trials and update their status. Trials without a payment method are marked as expired. Trials with a payment method transition to active. Should run daily. """ db = SessionLocal() now = datetime.now(UTC) expired_count = 0 activated_count = 0 try: # Find expired trials expired_trials = ( db.query(VendorSubscription) .filter( VendorSubscription.status == SubscriptionStatus.TRIAL.value, VendorSubscription.trial_ends_at <= now, ) .all() ) for subscription in expired_trials: if subscription.stripe_payment_method_id: # Has payment method - activate subscription.status = SubscriptionStatus.ACTIVE.value activated_count += 1 logger.info( f"Activated subscription for vendor {subscription.vendor_id} " f"(trial ended with payment method)" ) else: # No payment method - expire subscription.status = SubscriptionStatus.EXPIRED.value expired_count += 1 logger.info( f"Expired trial for vendor {subscription.vendor_id} " f"(no payment method)" ) subscription.updated_at = now db.commit() logger.info( f"Trial expiration check: {expired_count} expired, {activated_count} activated" ) except Exception as e: logger.error(f"Error checking trial expirations: {e}") db.rollback() raise finally: db.close() return {"expired_count": expired_count, "activated_count": activated_count} async def sync_stripe_status(): """ Sync subscription status with Stripe. Fetches current status from Stripe and updates local records. Handles cases where Stripe status differs from local status. Should run hourly. """ if not stripe_service.is_configured: logger.warning("Stripe not configured, skipping sync") return {"synced": 0, "skipped": True} db = SessionLocal() synced_count = 0 error_count = 0 try: # Find subscriptions with Stripe IDs subscriptions = ( db.query(VendorSubscription) .filter(VendorSubscription.stripe_subscription_id.isnot(None)) .all() ) for subscription in subscriptions: try: # Fetch from Stripe stripe_sub = stripe_service.get_subscription( subscription.stripe_subscription_id ) if not stripe_sub: logger.warning( f"Stripe subscription {subscription.stripe_subscription_id} " f"not found for vendor {subscription.vendor_id}" ) continue # Map Stripe status to local status status_map = { "active": SubscriptionStatus.ACTIVE.value, "trialing": SubscriptionStatus.TRIAL.value, "past_due": SubscriptionStatus.PAST_DUE.value, "canceled": SubscriptionStatus.CANCELLED.value, "unpaid": SubscriptionStatus.PAST_DUE.value, "incomplete": SubscriptionStatus.TRIAL.value, "incomplete_expired": SubscriptionStatus.EXPIRED.value, } new_status = status_map.get(stripe_sub.status) if new_status and new_status != subscription.status: old_status = subscription.status subscription.status = new_status subscription.updated_at = datetime.now(UTC) logger.info( f"Updated vendor {subscription.vendor_id} status: " f"{old_status} -> {new_status} (from Stripe)" ) # Update period dates from Stripe if stripe_sub.current_period_start: subscription.period_start = datetime.fromtimestamp( stripe_sub.current_period_start, tz=UTC ) if stripe_sub.current_period_end: subscription.period_end = datetime.fromtimestamp( stripe_sub.current_period_end, tz=UTC ) # Update payment method if stripe_sub.default_payment_method: subscription.stripe_payment_method_id = ( stripe_sub.default_payment_method if isinstance(stripe_sub.default_payment_method, str) else stripe_sub.default_payment_method.id ) synced_count += 1 except Exception as e: logger.error( f"Error syncing subscription {subscription.stripe_subscription_id}: {e}" ) error_count += 1 db.commit() logger.info(f"Stripe sync complete: {synced_count} synced, {error_count} errors") except Exception as e: logger.error(f"Error in Stripe sync task: {e}") db.rollback() raise finally: db.close() return {"synced_count": synced_count, "error_count": error_count} async def cleanup_stale_subscriptions(): """ Clean up subscriptions in inconsistent states. Handles edge cases like: - Subscriptions stuck in processing - Old cancelled subscriptions past their period end Should run weekly. """ db = SessionLocal() now = datetime.now(UTC) cleaned_count = 0 try: # Find cancelled subscriptions past their period end stale_cancelled = ( db.query(VendorSubscription) .filter( VendorSubscription.status == SubscriptionStatus.CANCELLED.value, VendorSubscription.period_end < now - timedelta(days=30), ) .all() ) for subscription in stale_cancelled: # Mark as expired (fully terminated) subscription.status = SubscriptionStatus.EXPIRED.value subscription.updated_at = now cleaned_count += 1 logger.info( f"Marked stale cancelled subscription as expired: " f"vendor {subscription.vendor_id}" ) db.commit() logger.info(f"Cleaned up {cleaned_count} stale subscriptions") except Exception as e: logger.error(f"Error cleaning up stale subscriptions: {e}") db.rollback() raise finally: db.close() return {"cleaned_count": cleaned_count} async def capture_capacity_snapshot(): """ Capture a daily snapshot of platform capacity metrics. Used for growth trending and capacity forecasting. Should run daily (e.g., at midnight). """ from app.services.capacity_forecast_service import capacity_forecast_service db = SessionLocal() try: snapshot = capacity_forecast_service.capture_daily_snapshot(db) db.commit() logger.info( f"Captured capacity snapshot: {snapshot.total_vendors} vendors, " f"{snapshot.total_products} products" ) return { "snapshot_id": snapshot.id, "snapshot_date": snapshot.snapshot_date.isoformat(), "total_vendors": snapshot.total_vendors, "total_products": snapshot.total_products, } except Exception as e: logger.error(f"Error capturing capacity snapshot: {e}") db.rollback() raise finally: db.close()