test: add integration tests for vendor orders API and background tasks

Add comprehensive integration tests for:
- Vendor orders API (list, detail, status updates, pagination, filtering)
- Letzshop historical import background task (success, errors, progress)
- Subscription background tasks (period reset, trial expiration, Stripe sync)

These 39 new tests improve coverage for background task functionality
and vendor order management.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-28 19:01:54 +01:00
parent beebcb5631
commit 23a5ccbcff
3 changed files with 1016 additions and 0 deletions

View File

@@ -0,0 +1,341 @@
# tests/integration/tasks/test_letzshop_tasks.py
"""Integration tests for Letzshop background tasks."""
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
import pytest
from app.services.letzshop import LetzshopClientError
from app.tasks.letzshop_tasks import process_historical_import
from models.database.letzshop import LetzshopHistoricalImportJob
@pytest.fixture
def historical_import_job(db, test_vendor, test_user):
"""Create a test historical import job."""
job = LetzshopHistoricalImportJob(
vendor_id=test_vendor.id,
user_id=test_user.id,
status="pending",
)
db.add(job)
db.commit()
db.refresh(job)
return job
@pytest.mark.integration
@pytest.mark.database
@pytest.mark.letzshop
class TestHistoricalImportTask:
"""Test historical import background task."""
def test_job_not_found(self, db, test_vendor):
"""Test handling when job doesn't exist."""
with patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db):
# Should not raise an exception
process_historical_import(job_id=99999, vendor_id=test_vendor.id)
def test_successful_import(self, db, test_vendor, historical_import_job):
"""Test successful historical import with both phases."""
job_id = historical_import_job.id
# Mock the services
mock_client = MagicMock()
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
mock_client.get_all_shipments_paginated.side_effect = [
# First call: confirmed shipments
[{"id": "1", "state": "confirmed"}, {"id": "2", "state": "confirmed"}],
# Second call: unconfirmed shipments
[{"id": "3", "state": "unconfirmed"}],
]
mock_creds_service = MagicMock()
mock_creds_service.create_client.return_value = mock_client
mock_creds_service.update_sync_status = MagicMock()
mock_order_service = MagicMock()
mock_order_service.import_historical_shipments.side_effect = [
# First call: confirmed stats
{
"total": 2,
"imported": 2,
"updated": 0,
"skipped": 0,
"products_matched": 5,
"products_not_found": 1,
},
# Second call: unconfirmed stats
{
"total": 1,
"imported": 1,
"updated": 0,
"skipped": 0,
"products_matched": 2,
"products_not_found": 0,
},
]
with (
patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db),
patch(
"app.tasks.letzshop_tasks._get_credentials_service",
return_value=mock_creds_service,
),
patch(
"app.tasks.letzshop_tasks._get_order_service",
return_value=mock_order_service,
),
):
process_historical_import(job_id=job_id, vendor_id=test_vendor.id)
# Verify job was updated
updated_job = (
db.query(LetzshopHistoricalImportJob)
.filter(LetzshopHistoricalImportJob.id == job_id)
.first()
)
assert updated_job.status == "completed"
assert updated_job.started_at is not None
assert updated_job.completed_at is not None
assert updated_job.confirmed_stats["imported"] == 2
assert updated_job.declined_stats["imported"] == 1
assert updated_job.products_matched == 7
assert updated_job.products_not_found == 1
# Verify sync status was updated
mock_creds_service.update_sync_status.assert_called_with(
test_vendor.id, "success", None
)
def test_letzshop_client_error(self, db, test_vendor, historical_import_job):
"""Test handling Letzshop API errors."""
job_id = historical_import_job.id
mock_creds_service = MagicMock()
mock_creds_service.create_client.side_effect = LetzshopClientError(
"API connection failed"
)
mock_creds_service.update_sync_status = MagicMock()
mock_order_service = MagicMock()
mock_order_service.get_vendor.return_value = test_vendor
with (
patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db),
patch(
"app.tasks.letzshop_tasks._get_credentials_service",
return_value=mock_creds_service,
),
patch(
"app.tasks.letzshop_tasks._get_order_service",
return_value=mock_order_service,
),
patch(
"app.tasks.letzshop_tasks.admin_notification_service"
) as mock_notify,
):
process_historical_import(job_id=job_id, vendor_id=test_vendor.id)
# Verify job failed
updated_job = (
db.query(LetzshopHistoricalImportJob)
.filter(LetzshopHistoricalImportJob.id == job_id)
.first()
)
assert updated_job.status == "failed"
assert "API connection failed" in updated_job.error_message
assert updated_job.completed_at is not None
# Verify sync status was updated to failed
mock_creds_service.update_sync_status.assert_called_with(
test_vendor.id, "failed", "API connection failed"
)
def test_unexpected_error(self, db, test_vendor, historical_import_job):
"""Test handling unexpected errors."""
job_id = historical_import_job.id
mock_creds_service = MagicMock()
mock_creds_service.create_client.side_effect = RuntimeError("Unexpected error")
mock_order_service = MagicMock()
mock_order_service.get_vendor.return_value = test_vendor
with (
patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db),
patch(
"app.tasks.letzshop_tasks._get_credentials_service",
return_value=mock_creds_service,
),
patch(
"app.tasks.letzshop_tasks._get_order_service",
return_value=mock_order_service,
),
patch(
"app.tasks.letzshop_tasks.admin_notification_service"
) as mock_notify,
):
process_historical_import(job_id=job_id, vendor_id=test_vendor.id)
# Verify job failed
updated_job = (
db.query(LetzshopHistoricalImportJob)
.filter(LetzshopHistoricalImportJob.id == job_id)
.first()
)
assert updated_job.status == "failed"
assert "Unexpected error" in updated_job.error_message
# Verify critical error notification was sent
mock_notify.notify_critical_error.assert_called_once()
def test_progress_tracking(self, db, test_vendor, historical_import_job):
"""Test that progress is tracked correctly during import."""
job_id = historical_import_job.id
progress_updates = []
# Create a mock client that tracks progress calls
mock_client = MagicMock()
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
def track_fetch_progress(*args, **kwargs):
# Simulate fetching shipments and call progress callback
progress_callback = kwargs.get("progress_callback")
if progress_callback:
progress_callback(1, 10)
progress_callback(2, 20)
return [{"id": str(i)} for i in range(20)]
mock_client.get_all_shipments_paginated.side_effect = [
track_fetch_progress(state="confirmed", page_size=50, progress_callback=None),
[], # Empty unconfirmed
]
mock_creds_service = MagicMock()
mock_creds_service.create_client.return_value = mock_client
mock_creds_service.update_sync_status = MagicMock()
mock_order_service = MagicMock()
mock_order_service.import_historical_shipments.return_value = {
"total": 20,
"imported": 18,
"updated": 2,
"skipped": 0,
"products_matched": 50,
"products_not_found": 5,
}
with (
patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db),
patch(
"app.tasks.letzshop_tasks._get_credentials_service",
return_value=mock_creds_service,
),
patch(
"app.tasks.letzshop_tasks._get_order_service",
return_value=mock_order_service,
),
):
process_historical_import(job_id=job_id, vendor_id=test_vendor.id)
# Verify final job state
updated_job = (
db.query(LetzshopHistoricalImportJob)
.filter(LetzshopHistoricalImportJob.id == job_id)
.first()
)
assert updated_job.status == "completed"
def test_commit_error_in_exception_handler(
self, db, test_vendor, historical_import_job
):
"""Test handling when commit fails during exception handling."""
job_id = historical_import_job.id
# Create a mock session that fails on the second commit
mock_session = MagicMock()
mock_session.query.return_value.filter.return_value.first.return_value = (
historical_import_job
)
mock_session.commit.side_effect = [
None, # First commit (status update) succeeds
Exception("Commit failed"), # Second commit fails
]
mock_session.rollback = MagicMock()
mock_session.close = MagicMock()
mock_creds_service = MagicMock()
mock_creds_service.create_client.side_effect = RuntimeError("Test error")
mock_order_service = MagicMock()
mock_order_service.get_vendor.return_value = test_vendor
with (
patch("app.tasks.letzshop_tasks.SessionLocal", return_value=mock_session),
patch(
"app.tasks.letzshop_tasks._get_credentials_service",
return_value=mock_creds_service,
),
patch(
"app.tasks.letzshop_tasks._get_order_service",
return_value=mock_order_service,
),
patch("app.tasks.letzshop_tasks.admin_notification_service"),
):
# Should not raise
process_historical_import(job_id=job_id, vendor_id=test_vendor.id)
# Verify rollback was called
mock_session.rollback.assert_called()
def test_close_error_handling(self, db, test_vendor, historical_import_job):
"""Test handling when session close fails."""
job_id = historical_import_job.id
# Create a mock session that fails on close
mock_session = MagicMock()
mock_session.query.return_value.filter.return_value.first.return_value = (
historical_import_job
)
mock_session.commit = MagicMock()
mock_session.close.side_effect = Exception("Close failed")
mock_client = MagicMock()
mock_client.__enter__ = MagicMock(return_value=mock_client)
mock_client.__exit__ = MagicMock(return_value=False)
mock_client.get_all_shipments_paginated.return_value = []
mock_creds_service = MagicMock()
mock_creds_service.create_client.return_value = mock_client
mock_creds_service.update_sync_status = MagicMock()
mock_order_service = MagicMock()
mock_order_service.import_historical_shipments.return_value = {
"total": 0,
"imported": 0,
"updated": 0,
"skipped": 0,
"products_matched": 0,
"products_not_found": 0,
}
with (
patch("app.tasks.letzshop_tasks.SessionLocal", return_value=mock_session),
patch(
"app.tasks.letzshop_tasks._get_credentials_service",
return_value=mock_creds_service,
),
patch(
"app.tasks.letzshop_tasks._get_order_service",
return_value=mock_order_service,
),
):
# Should not raise
process_historical_import(job_id=job_id, vendor_id=test_vendor.id)
# Verify close was attempted
mock_session.close.assert_called()

View File

@@ -0,0 +1,362 @@
# tests/integration/tasks/test_subscription_tasks.py
"""Integration tests for subscription background tasks."""
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from app.tasks.subscription_tasks import (
check_trial_expirations,
cleanup_stale_subscriptions,
reset_period_counters,
sync_stripe_status,
)
from models.database.subscription import SubscriptionStatus, VendorSubscription
@pytest.fixture
def active_subscription(db, test_vendor):
"""Create an active subscription."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.ACTIVE.value,
is_annual=False,
period_start=datetime.now(UTC) - timedelta(days=25),
period_end=datetime.now(UTC) - timedelta(days=1), # Ended yesterday
orders_this_period=50,
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return subscription
@pytest.fixture
def trial_subscription_expired_no_payment(db, test_vendor):
"""Create an expired trial subscription without payment method."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.TRIAL.value,
trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday
period_start=datetime.now(UTC) - timedelta(days=14),
period_end=datetime.now(UTC) + timedelta(days=16),
stripe_payment_method_id=None, # No payment method
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return subscription
@pytest.fixture
def trial_subscription_expired_with_payment(db, test_vendor):
"""Create an expired trial subscription with payment method."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.TRIAL.value,
trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday
period_start=datetime.now(UTC) - timedelta(days=14),
period_end=datetime.now(UTC) + timedelta(days=16),
stripe_payment_method_id="pm_test_123", # Has payment method
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return subscription
@pytest.fixture
def cancelled_subscription(db, test_vendor):
"""Create a cancelled subscription past period end."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.CANCELLED.value,
period_start=datetime.now(UTC) - timedelta(days=60),
period_end=datetime.now(UTC) - timedelta(days=35), # Ended 35 days ago
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return subscription
@pytest.mark.integration
@pytest.mark.database
@pytest.mark.asyncio
class TestResetPeriodCounters:
"""Test reset_period_counters task."""
async def test_resets_expired_period_counters(self, db, active_subscription):
"""Test that period counters are reset for expired periods."""
subscription_id = active_subscription.id
old_orders = active_subscription.orders_this_period
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await reset_period_counters()
assert result["reset_count"] == 1
# Verify subscription was updated
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
assert updated.orders_this_period == 0
assert updated.orders_limit_reached_at is None
# Handle timezone-naive datetime from database
period_end = updated.period_end
if period_end.tzinfo is None:
period_end = period_end.replace(tzinfo=UTC)
assert period_end > datetime.now(UTC)
async def test_does_not_reset_active_period(self, db, test_vendor):
"""Test that active periods are not reset."""
# Create subscription with future period end
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.ACTIVE.value,
period_start=datetime.now(UTC) - timedelta(days=15),
period_end=datetime.now(UTC) + timedelta(days=15), # Still active
orders_this_period=25,
)
db.add(subscription)
db.commit()
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await reset_period_counters()
assert result["reset_count"] == 0
# Orders should not be reset
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription.id).first()
assert updated.orders_this_period == 25
async def test_handles_annual_subscription(self, db, test_vendor):
"""Test that annual subscriptions get 365-day periods."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="professional",
status=SubscriptionStatus.ACTIVE.value,
is_annual=True,
period_start=datetime.now(UTC) - timedelta(days=370),
period_end=datetime.now(UTC) - timedelta(days=5),
orders_this_period=500,
)
db.add(subscription)
db.commit()
subscription_id = subscription.id
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await reset_period_counters()
assert result["reset_count"] == 1
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
# Should be ~365 days from now
expected_end = datetime.now(UTC) + timedelta(days=365)
# Handle timezone-naive datetime from database
period_end = updated.period_end
if period_end.tzinfo is None:
period_end = period_end.replace(tzinfo=UTC)
assert abs((period_end - expected_end).total_seconds()) < 60
@pytest.mark.integration
@pytest.mark.database
@pytest.mark.asyncio
class TestCheckTrialExpirations:
"""Test check_trial_expirations task."""
async def test_expires_trial_without_payment(
self, db, trial_subscription_expired_no_payment
):
"""Test that trials without payment method are expired."""
subscription_id = trial_subscription_expired_no_payment.id
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await check_trial_expirations()
assert result["expired_count"] == 1
assert result["activated_count"] == 0
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
assert updated.status == SubscriptionStatus.EXPIRED.value
async def test_activates_trial_with_payment(
self, db, trial_subscription_expired_with_payment
):
"""Test that trials with payment method are activated."""
subscription_id = trial_subscription_expired_with_payment.id
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await check_trial_expirations()
assert result["expired_count"] == 0
assert result["activated_count"] == 1
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
assert updated.status == SubscriptionStatus.ACTIVE.value
async def test_does_not_affect_active_trial(self, db, test_vendor):
"""Test that active trials are not affected."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.TRIAL.value,
trial_ends_at=datetime.now(UTC) + timedelta(days=7), # Still active
period_start=datetime.now(UTC),
period_end=datetime.now(UTC) + timedelta(days=30),
)
db.add(subscription)
db.commit()
subscription_id = subscription.id
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await check_trial_expirations()
assert result["expired_count"] == 0
assert result["activated_count"] == 0
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
assert updated.status == SubscriptionStatus.TRIAL.value
@pytest.mark.integration
@pytest.mark.database
@pytest.mark.asyncio
class TestSyncStripeStatus:
"""Test sync_stripe_status task."""
async def test_skips_when_stripe_not_configured(self, db):
"""Test that sync is skipped when Stripe is not configured."""
mock_stripe = MagicMock()
mock_stripe.is_configured = False
with (
patch("app.tasks.subscription_tasks.SessionLocal", return_value=db),
patch("app.tasks.subscription_tasks.stripe_service", mock_stripe),
):
result = await sync_stripe_status()
assert result["skipped"] is True
assert result["synced"] == 0
async def test_syncs_subscription_status(self, db, test_vendor):
"""Test that subscription status is synced from Stripe."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.TRIAL.value,
stripe_subscription_id="sub_test_123",
period_start=datetime.now(UTC),
period_end=datetime.now(UTC) + timedelta(days=30),
)
db.add(subscription)
db.commit()
subscription_id = subscription.id
# Mock Stripe response
mock_stripe_sub = MagicMock()
mock_stripe_sub.status = "active"
mock_stripe_sub.current_period_start = int(
(datetime.now(UTC) - timedelta(days=5)).timestamp()
)
mock_stripe_sub.current_period_end = int(
(datetime.now(UTC) + timedelta(days=25)).timestamp()
)
mock_stripe_sub.default_payment_method = "pm_test_456"
mock_stripe = MagicMock()
mock_stripe.is_configured = True
mock_stripe.get_subscription.return_value = mock_stripe_sub
with (
patch("app.tasks.subscription_tasks.SessionLocal", return_value=db),
patch("app.tasks.subscription_tasks.stripe_service", mock_stripe),
):
result = await sync_stripe_status()
assert result["synced_count"] == 1
assert result["error_count"] == 0
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
assert updated.status == SubscriptionStatus.ACTIVE.value
assert updated.stripe_payment_method_id == "pm_test_456"
async def test_handles_missing_stripe_subscription(self, db, test_vendor):
"""Test handling when Stripe subscription is not found."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.ACTIVE.value,
stripe_subscription_id="sub_deleted_123",
period_start=datetime.now(UTC),
period_end=datetime.now(UTC) + timedelta(days=30),
)
db.add(subscription)
db.commit()
mock_stripe = MagicMock()
mock_stripe.is_configured = True
mock_stripe.get_subscription.return_value = None
with (
patch("app.tasks.subscription_tasks.SessionLocal", return_value=db),
patch("app.tasks.subscription_tasks.stripe_service", mock_stripe),
):
result = await sync_stripe_status()
# Should not count as synced (subscription not found in Stripe)
assert result["synced_count"] == 0
@pytest.mark.integration
@pytest.mark.database
@pytest.mark.asyncio
class TestCleanupStaleSubscriptions:
"""Test cleanup_stale_subscriptions task."""
async def test_cleans_old_cancelled_subscriptions(self, db, cancelled_subscription):
"""Test that old cancelled subscriptions are marked as expired."""
subscription_id = cancelled_subscription.id
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await cleanup_stale_subscriptions()
assert result["cleaned_count"] == 1
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
assert updated.status == SubscriptionStatus.EXPIRED.value
async def test_does_not_clean_recent_cancelled(self, db, test_vendor):
"""Test that recently cancelled subscriptions are not cleaned."""
subscription = VendorSubscription(
vendor_id=test_vendor.id,
tier="essential",
status=SubscriptionStatus.CANCELLED.value,
period_start=datetime.now(UTC) - timedelta(days=25),
period_end=datetime.now(UTC) - timedelta(days=5), # Only 5 days ago
)
db.add(subscription)
db.commit()
subscription_id = subscription.id
with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db):
result = await cleanup_stale_subscriptions()
assert result["cleaned_count"] == 0
db.expire_all()
updated = db.query(VendorSubscription).filter_by(id=subscription_id).first()
assert updated.status == SubscriptionStatus.CANCELLED.value