Add comprehensive integration tests for: - Vendor orders API (list, detail, status updates, pagination, filtering) - Letzshop historical import background task (success, errors, progress) - Subscription background tasks (period reset, trial expiration, Stripe sync) These 39 new tests improve coverage for background task functionality and vendor order management. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
363 lines
13 KiB
Python
363 lines
13 KiB
Python
# tests/integration/tasks/test_subscription_tasks.py
|
|
"""Integration tests for subscription background tasks."""
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
from unittest.mock import MagicMock, PropertyMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.tasks.subscription_tasks import (
|
|
check_trial_expirations,
|
|
cleanup_stale_subscriptions,
|
|
reset_period_counters,
|
|
sync_stripe_status,
|
|
)
|
|
from models.database.subscription import SubscriptionStatus, VendorSubscription
|
|
|
|
|
|
@pytest.fixture
|
|
def active_subscription(db, test_vendor):
|
|
"""Create an active subscription."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.ACTIVE.value,
|
|
is_annual=False,
|
|
period_start=datetime.now(UTC) - timedelta(days=25),
|
|
period_end=datetime.now(UTC) - timedelta(days=1), # Ended yesterday
|
|
orders_this_period=50,
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
db.refresh(subscription)
|
|
return subscription
|
|
|
|
|
|
@pytest.fixture
|
|
def trial_subscription_expired_no_payment(db, test_vendor):
|
|
"""Create an expired trial subscription without payment method."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.TRIAL.value,
|
|
trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday
|
|
period_start=datetime.now(UTC) - timedelta(days=14),
|
|
period_end=datetime.now(UTC) + timedelta(days=16),
|
|
stripe_payment_method_id=None, # No payment method
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
db.refresh(subscription)
|
|
return subscription
|
|
|
|
|
|
@pytest.fixture
|
|
def trial_subscription_expired_with_payment(db, test_vendor):
|
|
"""Create an expired trial subscription with payment method."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.TRIAL.value,
|
|
trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday
|
|
period_start=datetime.now(UTC) - timedelta(days=14),
|
|
period_end=datetime.now(UTC) + timedelta(days=16),
|
|
stripe_payment_method_id="pm_test_123", # Has payment method
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
db.refresh(subscription)
|
|
return subscription
|
|
|
|
|
|
@pytest.fixture
|
|
def cancelled_subscription(db, test_vendor):
|
|
"""Create a cancelled subscription past period end."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.CANCELLED.value,
|
|
period_start=datetime.now(UTC) - timedelta(days=60),
|
|
period_end=datetime.now(UTC) - timedelta(days=35), # Ended 35 days ago
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
db.refresh(subscription)
|
|
return subscription
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.database
|
|
@pytest.mark.asyncio
|
|
class TestResetPeriodCounters:
|
|
"""Test reset_period_counters task."""
|
|
|
|
async def test_resets_expired_period_counters(self, db, active_subscription):
|
|
"""Test that period counters are reset for expired periods."""
|
|
subscription_id = active_subscription.id
|
|
old_orders = active_subscription.orders_this_period
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await reset_period_counters()
|
|
|
|
assert result["reset_count"] == 1
|
|
|
|
# Verify subscription was updated
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
assert updated.orders_this_period == 0
|
|
assert updated.orders_limit_reached_at is None
|
|
# Handle timezone-naive datetime from database
|
|
period_end = updated.period_end
|
|
if period_end.tzinfo is None:
|
|
period_end = period_end.replace(tzinfo=UTC)
|
|
assert period_end > datetime.now(UTC)
|
|
|
|
async def test_does_not_reset_active_period(self, db, test_vendor):
|
|
"""Test that active periods are not reset."""
|
|
# Create subscription with future period end
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.ACTIVE.value,
|
|
period_start=datetime.now(UTC) - timedelta(days=15),
|
|
period_end=datetime.now(UTC) + timedelta(days=15), # Still active
|
|
orders_this_period=25,
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await reset_period_counters()
|
|
|
|
assert result["reset_count"] == 0
|
|
|
|
# Orders should not be reset
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription.id).first()
|
|
assert updated.orders_this_period == 25
|
|
|
|
async def test_handles_annual_subscription(self, db, test_vendor):
|
|
"""Test that annual subscriptions get 365-day periods."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="professional",
|
|
status=SubscriptionStatus.ACTIVE.value,
|
|
is_annual=True,
|
|
period_start=datetime.now(UTC) - timedelta(days=370),
|
|
period_end=datetime.now(UTC) - timedelta(days=5),
|
|
orders_this_period=500,
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
subscription_id = subscription.id
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await reset_period_counters()
|
|
|
|
assert result["reset_count"] == 1
|
|
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
# Should be ~365 days from now
|
|
expected_end = datetime.now(UTC) + timedelta(days=365)
|
|
# Handle timezone-naive datetime from database
|
|
period_end = updated.period_end
|
|
if period_end.tzinfo is None:
|
|
period_end = period_end.replace(tzinfo=UTC)
|
|
assert abs((period_end - expected_end).total_seconds()) < 60
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.database
|
|
@pytest.mark.asyncio
|
|
class TestCheckTrialExpirations:
|
|
"""Test check_trial_expirations task."""
|
|
|
|
async def test_expires_trial_without_payment(
|
|
self, db, trial_subscription_expired_no_payment
|
|
):
|
|
"""Test that trials without payment method are expired."""
|
|
subscription_id = trial_subscription_expired_no_payment.id
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await check_trial_expirations()
|
|
|
|
assert result["expired_count"] == 1
|
|
assert result["activated_count"] == 0
|
|
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
assert updated.status == SubscriptionStatus.EXPIRED.value
|
|
|
|
async def test_activates_trial_with_payment(
|
|
self, db, trial_subscription_expired_with_payment
|
|
):
|
|
"""Test that trials with payment method are activated."""
|
|
subscription_id = trial_subscription_expired_with_payment.id
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await check_trial_expirations()
|
|
|
|
assert result["expired_count"] == 0
|
|
assert result["activated_count"] == 1
|
|
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
assert updated.status == SubscriptionStatus.ACTIVE.value
|
|
|
|
async def test_does_not_affect_active_trial(self, db, test_vendor):
|
|
"""Test that active trials are not affected."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.TRIAL.value,
|
|
trial_ends_at=datetime.now(UTC) + timedelta(days=7), # Still active
|
|
period_start=datetime.now(UTC),
|
|
period_end=datetime.now(UTC) + timedelta(days=30),
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
subscription_id = subscription.id
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await check_trial_expirations()
|
|
|
|
assert result["expired_count"] == 0
|
|
assert result["activated_count"] == 0
|
|
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
assert updated.status == SubscriptionStatus.TRIAL.value
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.database
|
|
@pytest.mark.asyncio
|
|
class TestSyncStripeStatus:
|
|
"""Test sync_stripe_status task."""
|
|
|
|
async def test_skips_when_stripe_not_configured(self, db):
|
|
"""Test that sync is skipped when Stripe is not configured."""
|
|
mock_stripe = MagicMock()
|
|
mock_stripe.is_configured = False
|
|
|
|
with (
|
|
patch("app.tasks.subscription_tasks.SessionLocal", return_value=db),
|
|
patch("app.tasks.subscription_tasks.stripe_service", mock_stripe),
|
|
):
|
|
result = await sync_stripe_status()
|
|
|
|
assert result["skipped"] is True
|
|
assert result["synced"] == 0
|
|
|
|
async def test_syncs_subscription_status(self, db, test_vendor):
|
|
"""Test that subscription status is synced from Stripe."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.TRIAL.value,
|
|
stripe_subscription_id="sub_test_123",
|
|
period_start=datetime.now(UTC),
|
|
period_end=datetime.now(UTC) + timedelta(days=30),
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
subscription_id = subscription.id
|
|
|
|
# Mock Stripe response
|
|
mock_stripe_sub = MagicMock()
|
|
mock_stripe_sub.status = "active"
|
|
mock_stripe_sub.current_period_start = int(
|
|
(datetime.now(UTC) - timedelta(days=5)).timestamp()
|
|
)
|
|
mock_stripe_sub.current_period_end = int(
|
|
(datetime.now(UTC) + timedelta(days=25)).timestamp()
|
|
)
|
|
mock_stripe_sub.default_payment_method = "pm_test_456"
|
|
|
|
mock_stripe = MagicMock()
|
|
mock_stripe.is_configured = True
|
|
mock_stripe.get_subscription.return_value = mock_stripe_sub
|
|
|
|
with (
|
|
patch("app.tasks.subscription_tasks.SessionLocal", return_value=db),
|
|
patch("app.tasks.subscription_tasks.stripe_service", mock_stripe),
|
|
):
|
|
result = await sync_stripe_status()
|
|
|
|
assert result["synced_count"] == 1
|
|
assert result["error_count"] == 0
|
|
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
assert updated.status == SubscriptionStatus.ACTIVE.value
|
|
assert updated.stripe_payment_method_id == "pm_test_456"
|
|
|
|
async def test_handles_missing_stripe_subscription(self, db, test_vendor):
|
|
"""Test handling when Stripe subscription is not found."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.ACTIVE.value,
|
|
stripe_subscription_id="sub_deleted_123",
|
|
period_start=datetime.now(UTC),
|
|
period_end=datetime.now(UTC) + timedelta(days=30),
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
|
|
mock_stripe = MagicMock()
|
|
mock_stripe.is_configured = True
|
|
mock_stripe.get_subscription.return_value = None
|
|
|
|
with (
|
|
patch("app.tasks.subscription_tasks.SessionLocal", return_value=db),
|
|
patch("app.tasks.subscription_tasks.stripe_service", mock_stripe),
|
|
):
|
|
result = await sync_stripe_status()
|
|
|
|
# Should not count as synced (subscription not found in Stripe)
|
|
assert result["synced_count"] == 0
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.database
|
|
@pytest.mark.asyncio
|
|
class TestCleanupStaleSubscriptions:
|
|
"""Test cleanup_stale_subscriptions task."""
|
|
|
|
async def test_cleans_old_cancelled_subscriptions(self, db, cancelled_subscription):
|
|
"""Test that old cancelled subscriptions are marked as expired."""
|
|
subscription_id = cancelled_subscription.id
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await cleanup_stale_subscriptions()
|
|
|
|
assert result["cleaned_count"] == 1
|
|
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
assert updated.status == SubscriptionStatus.EXPIRED.value
|
|
|
|
async def test_does_not_clean_recent_cancelled(self, db, test_vendor):
|
|
"""Test that recently cancelled subscriptions are not cleaned."""
|
|
subscription = VendorSubscription(
|
|
vendor_id=test_vendor.id,
|
|
tier="essential",
|
|
status=SubscriptionStatus.CANCELLED.value,
|
|
period_start=datetime.now(UTC) - timedelta(days=25),
|
|
period_end=datetime.now(UTC) - timedelta(days=5), # Only 5 days ago
|
|
)
|
|
db.add(subscription)
|
|
db.commit()
|
|
subscription_id = subscription.id
|
|
|
|
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
|
|
result = await cleanup_stale_subscriptions()
|
|
|
|
assert result["cleaned_count"] == 0
|
|
|
|
db.expire_all()
|
|
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
|
|
assert updated.status == SubscriptionStatus.CANCELLED.value
|