# tests/integration/tasks/test_subscription_tasks.py """Integration tests for subscription background tasks.""" from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, PropertyMock, patch import pytest from app.tasks.subscription_tasks import ( check_trial_expirations, cleanup_stale_subscriptions, reset_period_counters, sync_stripe_status, ) from models.database.subscription import SubscriptionStatus, VendorSubscription @pytest.fixture def active_subscription(db, test_vendor): """Create an active subscription.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.ACTIVE.value, is_annual=False, period_start=datetime.now(UTC) - timedelta(days=25), period_end=datetime.now(UTC) - timedelta(days=1), # Ended yesterday orders_this_period=50, ) db.add(subscription) db.commit() db.refresh(subscription) return subscription @pytest.fixture def trial_subscription_expired_no_payment(db, test_vendor): """Create an expired trial subscription without payment method.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.TRIAL.value, trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday period_start=datetime.now(UTC) - timedelta(days=14), period_end=datetime.now(UTC) + timedelta(days=16), stripe_payment_method_id=None, # No payment method ) db.add(subscription) db.commit() db.refresh(subscription) return subscription @pytest.fixture def trial_subscription_expired_with_payment(db, test_vendor): """Create an expired trial subscription with payment method.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.TRIAL.value, trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday period_start=datetime.now(UTC) - timedelta(days=14), period_end=datetime.now(UTC) + timedelta(days=16), stripe_payment_method_id="pm_test_123", # Has payment method ) db.add(subscription) db.commit() db.refresh(subscription) return subscription @pytest.fixture def cancelled_subscription(db, test_vendor): """Create a cancelled subscription past period end.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.CANCELLED.value, period_start=datetime.now(UTC) - timedelta(days=60), period_end=datetime.now(UTC) - timedelta(days=35), # Ended 35 days ago ) db.add(subscription) db.commit() db.refresh(subscription) return subscription @pytest.mark.integration @pytest.mark.database @pytest.mark.asyncio class TestResetPeriodCounters: """Test reset_period_counters task.""" async def test_resets_expired_period_counters(self, db, active_subscription): """Test that period counters are reset for expired periods.""" subscription_id = active_subscription.id old_orders = active_subscription.orders_this_period with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await reset_period_counters() assert result["reset_count"] == 1 # Verify subscription was updated db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() assert updated.orders_this_period == 0 assert updated.orders_limit_reached_at is None # Handle timezone-naive datetime from database period_end = updated.period_end if period_end.tzinfo is None: period_end = period_end.replace(tzinfo=UTC) assert period_end > datetime.now(UTC) async def test_does_not_reset_active_period(self, db, test_vendor): """Test that active periods are not reset.""" # Create subscription with future period end subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.ACTIVE.value, period_start=datetime.now(UTC) - timedelta(days=15), period_end=datetime.now(UTC) + timedelta(days=15), # Still active orders_this_period=25, ) db.add(subscription) db.commit() with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await reset_period_counters() assert result["reset_count"] == 0 # Orders should not be reset db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription.id).first() assert updated.orders_this_period == 25 async def test_handles_annual_subscription(self, db, test_vendor): """Test that annual subscriptions get 365-day periods.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="professional", status=SubscriptionStatus.ACTIVE.value, is_annual=True, period_start=datetime.now(UTC) - timedelta(days=370), period_end=datetime.now(UTC) - timedelta(days=5), orders_this_period=500, ) db.add(subscription) db.commit() subscription_id = subscription.id with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await reset_period_counters() assert result["reset_count"] == 1 db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() # Should be ~365 days from now expected_end = datetime.now(UTC) + timedelta(days=365) # Handle timezone-naive datetime from database period_end = updated.period_end if period_end.tzinfo is None: period_end = period_end.replace(tzinfo=UTC) assert abs((period_end - expected_end).total_seconds()) < 60 @pytest.mark.integration @pytest.mark.database @pytest.mark.asyncio class TestCheckTrialExpirations: """Test check_trial_expirations task.""" async def test_expires_trial_without_payment( self, db, trial_subscription_expired_no_payment ): """Test that trials without payment method are expired.""" subscription_id = trial_subscription_expired_no_payment.id with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await check_trial_expirations() assert result["expired_count"] == 1 assert result["activated_count"] == 0 db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() assert updated.status == SubscriptionStatus.EXPIRED.value async def test_activates_trial_with_payment( self, db, trial_subscription_expired_with_payment ): """Test that trials with payment method are activated.""" subscription_id = trial_subscription_expired_with_payment.id with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await check_trial_expirations() assert result["expired_count"] == 0 assert result["activated_count"] == 1 db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() assert updated.status == SubscriptionStatus.ACTIVE.value async def test_does_not_affect_active_trial(self, db, test_vendor): """Test that active trials are not affected.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.TRIAL.value, trial_ends_at=datetime.now(UTC) + timedelta(days=7), # Still active period_start=datetime.now(UTC), period_end=datetime.now(UTC) + timedelta(days=30), ) db.add(subscription) db.commit() subscription_id = subscription.id with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await check_trial_expirations() assert result["expired_count"] == 0 assert result["activated_count"] == 0 db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() assert updated.status == SubscriptionStatus.TRIAL.value @pytest.mark.integration @pytest.mark.database @pytest.mark.asyncio class TestSyncStripeStatus: """Test sync_stripe_status task.""" async def test_skips_when_stripe_not_configured(self, db): """Test that sync is skipped when Stripe is not configured.""" mock_stripe = MagicMock() mock_stripe.is_configured = False with ( patch("app.tasks.subscription_tasks.SessionLocal", return_value=db), patch("app.tasks.subscription_tasks.stripe_service", mock_stripe), ): result = await sync_stripe_status() assert result["skipped"] is True assert result["synced"] == 0 async def test_syncs_subscription_status(self, db, test_vendor): """Test that subscription status is synced from Stripe.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.TRIAL.value, stripe_subscription_id="sub_test_123", period_start=datetime.now(UTC), period_end=datetime.now(UTC) + timedelta(days=30), ) db.add(subscription) db.commit() subscription_id = subscription.id # Mock Stripe response mock_stripe_sub = MagicMock() mock_stripe_sub.status = "active" mock_stripe_sub.current_period_start = int( (datetime.now(UTC) - timedelta(days=5)).timestamp() ) mock_stripe_sub.current_period_end = int( (datetime.now(UTC) + timedelta(days=25)).timestamp() ) mock_stripe_sub.default_payment_method = "pm_test_456" mock_stripe = MagicMock() mock_stripe.is_configured = True mock_stripe.get_subscription.return_value = mock_stripe_sub with ( patch("app.tasks.subscription_tasks.SessionLocal", return_value=db), patch("app.tasks.subscription_tasks.stripe_service", mock_stripe), ): result = await sync_stripe_status() assert result["synced_count"] == 1 assert result["error_count"] == 0 db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() assert updated.status == SubscriptionStatus.ACTIVE.value assert updated.stripe_payment_method_id == "pm_test_456" async def test_handles_missing_stripe_subscription(self, db, test_vendor): """Test handling when Stripe subscription is not found.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.ACTIVE.value, stripe_subscription_id="sub_deleted_123", period_start=datetime.now(UTC), period_end=datetime.now(UTC) + timedelta(days=30), ) db.add(subscription) db.commit() mock_stripe = MagicMock() mock_stripe.is_configured = True mock_stripe.get_subscription.return_value = None with ( patch("app.tasks.subscription_tasks.SessionLocal", return_value=db), patch("app.tasks.subscription_tasks.stripe_service", mock_stripe), ): result = await sync_stripe_status() # Should not count as synced (subscription not found in Stripe) assert result["synced_count"] == 0 @pytest.mark.integration @pytest.mark.database @pytest.mark.asyncio class TestCleanupStaleSubscriptions: """Test cleanup_stale_subscriptions task.""" async def test_cleans_old_cancelled_subscriptions(self, db, cancelled_subscription): """Test that old cancelled subscriptions are marked as expired.""" subscription_id = cancelled_subscription.id with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await cleanup_stale_subscriptions() assert result["cleaned_count"] == 1 db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() assert updated.status == SubscriptionStatus.EXPIRED.value async def test_does_not_clean_recent_cancelled(self, db, test_vendor): """Test that recently cancelled subscriptions are not cleaned.""" subscription = VendorSubscription( vendor_id=test_vendor.id, tier="essential", status=SubscriptionStatus.CANCELLED.value, period_start=datetime.now(UTC) - timedelta(days=25), period_end=datetime.now(UTC) - timedelta(days=5), # Only 5 days ago ) db.add(subscription) db.commit() subscription_id = subscription.id with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): result = await cleanup_stale_subscriptions() assert result["cleaned_count"] == 0 db.expire_all() updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() assert updated.status == SubscriptionStatus.CANCELLED.value