diff --git a/tests/integration/api/v1/vendor/test_orders.py b/tests/integration/api/v1/vendor/test_orders.py new file mode 100644 index 00000000..79255c02 --- /dev/null +++ b/tests/integration/api/v1/vendor/test_orders.py @@ -0,0 +1,313 @@ +# tests/integration/api/v1/vendor/test_orders.py +"""Integration tests for vendor orders API endpoints.""" + +from datetime import UTC, datetime + +import pytest + +from models.database.order import Order, OrderItem + + +@pytest.fixture +def test_order(db, test_vendor_with_vendor_user, test_customer): + """Create a test order for the vendor.""" + order = Order( + vendor_id=test_vendor_with_vendor_user.id, + customer_id=test_customer.id, + order_number="ORD-TEST-001", + status="pending", + channel="direct", + order_date=datetime.now(UTC), + subtotal_cents=10000, + tax_amount_cents=1700, + shipping_amount_cents=500, + total_amount_cents=12200, + currency="EUR", + customer_email="customer@test.com", + customer_first_name="Test", + customer_last_name="Customer", + ship_first_name="Test", + ship_last_name="Customer", + ship_address_line_1="123 Test St", + ship_city="Test City", + ship_postal_code="12345", + ship_country_iso="LU", + bill_first_name="Test", + bill_last_name="Customer", + bill_address_line_1="123 Test St", + bill_city="Test City", + bill_postal_code="12345", + bill_country_iso="LU", + ) + db.add(order) + db.commit() + db.refresh(order) + return order + + +@pytest.fixture +def test_orders(db, test_vendor_with_vendor_user, test_customer): + """Create multiple test orders.""" + orders = [] + for i, status in enumerate(["pending", "processing", "shipped", "delivered", "cancelled"]): + order = Order( + vendor_id=test_vendor_with_vendor_user.id, + customer_id=test_customer.id, + order_number=f"ORD-TEST-{i+1:03d}", + status=status, + channel="direct" if i % 2 == 0 else "letzshop", + order_date=datetime.now(UTC), + subtotal_cents=10000 * (i + 1), + tax_amount_cents=1700 * (i + 1), + shipping_amount_cents=500, + total_amount_cents=12200 * (i + 1), + currency="EUR", + customer_email=f"customer{i}@test.com", + customer_first_name="Test", + customer_last_name=f"Customer{i}", + ship_first_name="Test", + ship_last_name=f"Customer{i}", + ship_address_line_1="123 Test St", + ship_city="Test City", + ship_postal_code="12345", + ship_country_iso="LU", + bill_first_name="Test", + bill_last_name=f"Customer{i}", + bill_address_line_1="123 Test St", + bill_city="Test City", + bill_postal_code="12345", + bill_country_iso="LU", + ) + db.add(order) + orders.append(order) + db.commit() + for order in orders: + db.refresh(order) + return orders + + +@pytest.fixture +def test_order_with_items(db, test_vendor_with_vendor_user, test_customer): + """Create a test order with order items.""" + order = Order( + vendor_id=test_vendor_with_vendor_user.id, + customer_id=test_customer.id, + order_number="ORD-ITEMS-001", + status="pending", + channel="direct", + order_date=datetime.now(UTC), + subtotal_cents=20000, + tax_amount_cents=3400, + shipping_amount_cents=500, + total_amount_cents=23900, + currency="EUR", + customer_email="customer@test.com", + customer_first_name="Test", + customer_last_name="Customer", + ship_first_name="Test", + ship_last_name="Customer", + ship_address_line_1="123 Test St", + ship_city="Test City", + ship_postal_code="12345", + ship_country_iso="LU", + bill_first_name="Test", + bill_last_name="Customer", + bill_address_line_1="123 Test St", + bill_city="Test City", + bill_postal_code="12345", + bill_country_iso="LU", + ) + db.add(order) + db.flush() + + # Add order items + item = OrderItem( + order_id=order.id, + product_id=1, # Placeholder, no product FK constraint in test + product_sku="TEST-SKU-001", + product_name="Test Product", + quantity=2, + unit_price_cents=10000, + total_price_cents=20000, + ) + db.add(item) + db.commit() + db.refresh(order) + return order + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.vendor +class TestVendorOrdersListAPI: + """Test vendor orders list endpoint.""" + + def test_list_orders_success(self, client, vendor_user_headers, test_orders): + """Test listing vendor orders.""" + response = client.get("/api/v1/vendor/orders", headers=vendor_user_headers) + + assert response.status_code == 200 + data = response.json() + assert "orders" in data + assert "total" in data + assert data["total"] == 5 + assert len(data["orders"]) == 5 + + def test_list_orders_with_pagination(self, client, vendor_user_headers, test_orders): + """Test orders list with pagination.""" + response = client.get( + "/api/v1/vendor/orders?skip=2&limit=2", headers=vendor_user_headers + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["orders"]) == 2 + assert data["skip"] == 2 + assert data["limit"] == 2 + + def test_list_orders_filter_by_status(self, client, vendor_user_headers, test_orders): + """Test filtering orders by status.""" + response = client.get( + "/api/v1/vendor/orders?status=pending", headers=vendor_user_headers + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["orders"]) == 1 + assert data["orders"][0]["status"] == "pending" + + def test_list_orders_filter_by_customer( + self, client, vendor_user_headers, test_orders, test_customer + ): + """Test filtering orders by customer ID.""" + response = client.get( + f"/api/v1/vendor/orders?customer_id={test_customer.id}", + headers=vendor_user_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] == 5 # All orders belong to test_customer + + def test_list_orders_empty(self, client, vendor_user_headers): + """Test empty orders list.""" + response = client.get("/api/v1/vendor/orders", headers=vendor_user_headers) + + assert response.status_code == 200 + data = response.json() + assert data["orders"] == [] + assert data["total"] == 0 + + def test_list_orders_unauthorized(self, client): + """Test orders list without authentication.""" + response = client.get("/api/v1/vendor/orders") + + assert response.status_code == 401 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.vendor +class TestVendorOrderDetailAPI: + """Test vendor order detail endpoint.""" + + def test_get_order_detail(self, client, vendor_user_headers, test_order_with_items): + """Test getting order details.""" + response = client.get( + f"/api/v1/vendor/orders/{test_order_with_items.id}", + headers=vendor_user_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["order_number"] == "ORD-ITEMS-001" + assert data["status"] == "pending" + assert "items" in data + assert len(data["items"]) == 1 + assert data["items"][0]["quantity"] == 2 + + def test_get_order_not_found(self, client, vendor_user_headers): + """Test getting non-existent order.""" + response = client.get( + "/api/v1/vendor/orders/99999", headers=vendor_user_headers + ) + + assert response.status_code == 404 + + def test_get_order_unauthorized(self, client, test_order): + """Test getting order without authentication.""" + response = client.get(f"/api/v1/vendor/orders/{test_order.id}") + + assert response.status_code == 401 + + +@pytest.mark.integration +@pytest.mark.api +@pytest.mark.vendor +class TestVendorOrderStatusUpdateAPI: + """Test vendor order status update endpoint.""" + + def test_update_order_status_to_processing( + self, client, vendor_user_headers, test_order + ): + """Test updating order status to processing.""" + response = client.put( + f"/api/v1/vendor/orders/{test_order.id}/status", + json={"status": "processing"}, + headers=vendor_user_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "processing" + + def test_update_order_status_to_shipped( + self, client, vendor_user_headers, test_order + ): + """Test updating order status to shipped with tracking.""" + response = client.put( + f"/api/v1/vendor/orders/{test_order.id}/status", + json={ + "status": "shipped", + "tracking_number": "TRACK123456", + "tracking_url": "https://tracking.example.com/TRACK123456", + }, + headers=vendor_user_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "shipped" + + def test_update_order_status_to_cancelled( + self, client, vendor_user_headers, test_order + ): + """Test updating order status to cancelled.""" + response = client.put( + f"/api/v1/vendor/orders/{test_order.id}/status", + json={"status": "cancelled"}, + headers=vendor_user_headers, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "cancelled" + + def test_update_order_not_found(self, client, vendor_user_headers): + """Test updating non-existent order.""" + response = client.put( + "/api/v1/vendor/orders/99999/status", + json={"status": "processing"}, + headers=vendor_user_headers, + ) + + assert response.status_code == 404 + + def test_update_order_unauthorized(self, client, test_order): + """Test updating order without authentication.""" + response = client.put( + f"/api/v1/vendor/orders/{test_order.id}/status", + json={"status": "processing"}, + ) + + assert response.status_code == 401 diff --git a/tests/integration/tasks/test_letzshop_tasks.py b/tests/integration/tasks/test_letzshop_tasks.py new file mode 100644 index 00000000..5aba23c2 --- /dev/null +++ b/tests/integration/tasks/test_letzshop_tasks.py @@ -0,0 +1,341 @@ +# tests/integration/tasks/test_letzshop_tasks.py +"""Integration tests for Letzshop background tasks.""" + +from datetime import UTC, datetime +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.letzshop import LetzshopClientError +from app.tasks.letzshop_tasks import process_historical_import +from models.database.letzshop import LetzshopHistoricalImportJob + + +@pytest.fixture +def historical_import_job(db, test_vendor, test_user): + """Create a test historical import job.""" + job = LetzshopHistoricalImportJob( + vendor_id=test_vendor.id, + user_id=test_user.id, + status="pending", + ) + db.add(job) + db.commit() + db.refresh(job) + return job + + +@pytest.mark.integration +@pytest.mark.database +@pytest.mark.letzshop +class TestHistoricalImportTask: + """Test historical import background task.""" + + def test_job_not_found(self, db, test_vendor): + """Test handling when job doesn't exist.""" + with patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db): + # Should not raise an exception + process_historical_import(job_id=99999, vendor_id=test_vendor.id) + + def test_successful_import(self, db, test_vendor, historical_import_job): + """Test successful historical import with both phases.""" + job_id = historical_import_job.id + + # Mock the services + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.get_all_shipments_paginated.side_effect = [ + # First call: confirmed shipments + [{"id": "1", "state": "confirmed"}, {"id": "2", "state": "confirmed"}], + # Second call: unconfirmed shipments + [{"id": "3", "state": "unconfirmed"}], + ] + + mock_creds_service = MagicMock() + mock_creds_service.create_client.return_value = mock_client + mock_creds_service.update_sync_status = MagicMock() + + mock_order_service = MagicMock() + mock_order_service.import_historical_shipments.side_effect = [ + # First call: confirmed stats + { + "total": 2, + "imported": 2, + "updated": 0, + "skipped": 0, + "products_matched": 5, + "products_not_found": 1, + }, + # Second call: unconfirmed stats + { + "total": 1, + "imported": 1, + "updated": 0, + "skipped": 0, + "products_matched": 2, + "products_not_found": 0, + }, + ] + + with ( + patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db), + patch( + "app.tasks.letzshop_tasks._get_credentials_service", + return_value=mock_creds_service, + ), + patch( + "app.tasks.letzshop_tasks._get_order_service", + return_value=mock_order_service, + ), + ): + process_historical_import(job_id=job_id, vendor_id=test_vendor.id) + + # Verify job was updated + updated_job = ( + db.query(LetzshopHistoricalImportJob) + .filter(LetzshopHistoricalImportJob.id == job_id) + .first() + ) + assert updated_job.status == "completed" + assert updated_job.started_at is not None + assert updated_job.completed_at is not None + assert updated_job.confirmed_stats["imported"] == 2 + assert updated_job.declined_stats["imported"] == 1 + assert updated_job.products_matched == 7 + assert updated_job.products_not_found == 1 + + # Verify sync status was updated + mock_creds_service.update_sync_status.assert_called_with( + test_vendor.id, "success", None + ) + + def test_letzshop_client_error(self, db, test_vendor, historical_import_job): + """Test handling Letzshop API errors.""" + job_id = historical_import_job.id + + mock_creds_service = MagicMock() + mock_creds_service.create_client.side_effect = LetzshopClientError( + "API connection failed" + ) + mock_creds_service.update_sync_status = MagicMock() + + mock_order_service = MagicMock() + mock_order_service.get_vendor.return_value = test_vendor + + with ( + patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db), + patch( + "app.tasks.letzshop_tasks._get_credentials_service", + return_value=mock_creds_service, + ), + patch( + "app.tasks.letzshop_tasks._get_order_service", + return_value=mock_order_service, + ), + patch( + "app.tasks.letzshop_tasks.admin_notification_service" + ) as mock_notify, + ): + process_historical_import(job_id=job_id, vendor_id=test_vendor.id) + + # Verify job failed + updated_job = ( + db.query(LetzshopHistoricalImportJob) + .filter(LetzshopHistoricalImportJob.id == job_id) + .first() + ) + assert updated_job.status == "failed" + assert "API connection failed" in updated_job.error_message + assert updated_job.completed_at is not None + + # Verify sync status was updated to failed + mock_creds_service.update_sync_status.assert_called_with( + test_vendor.id, "failed", "API connection failed" + ) + + def test_unexpected_error(self, db, test_vendor, historical_import_job): + """Test handling unexpected errors.""" + job_id = historical_import_job.id + + mock_creds_service = MagicMock() + mock_creds_service.create_client.side_effect = RuntimeError("Unexpected error") + + mock_order_service = MagicMock() + mock_order_service.get_vendor.return_value = test_vendor + + with ( + patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db), + patch( + "app.tasks.letzshop_tasks._get_credentials_service", + return_value=mock_creds_service, + ), + patch( + "app.tasks.letzshop_tasks._get_order_service", + return_value=mock_order_service, + ), + patch( + "app.tasks.letzshop_tasks.admin_notification_service" + ) as mock_notify, + ): + process_historical_import(job_id=job_id, vendor_id=test_vendor.id) + + # Verify job failed + updated_job = ( + db.query(LetzshopHistoricalImportJob) + .filter(LetzshopHistoricalImportJob.id == job_id) + .first() + ) + assert updated_job.status == "failed" + assert "Unexpected error" in updated_job.error_message + + # Verify critical error notification was sent + mock_notify.notify_critical_error.assert_called_once() + + def test_progress_tracking(self, db, test_vendor, historical_import_job): + """Test that progress is tracked correctly during import.""" + job_id = historical_import_job.id + progress_updates = [] + + # Create a mock client that tracks progress calls + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + + def track_fetch_progress(*args, **kwargs): + # Simulate fetching shipments and call progress callback + progress_callback = kwargs.get("progress_callback") + if progress_callback: + progress_callback(1, 10) + progress_callback(2, 20) + return [{"id": str(i)} for i in range(20)] + + mock_client.get_all_shipments_paginated.side_effect = [ + track_fetch_progress(state="confirmed", page_size=50, progress_callback=None), + [], # Empty unconfirmed + ] + + mock_creds_service = MagicMock() + mock_creds_service.create_client.return_value = mock_client + mock_creds_service.update_sync_status = MagicMock() + + mock_order_service = MagicMock() + mock_order_service.import_historical_shipments.return_value = { + "total": 20, + "imported": 18, + "updated": 2, + "skipped": 0, + "products_matched": 50, + "products_not_found": 5, + } + + with ( + patch("app.tasks.letzshop_tasks.SessionLocal", return_value=db), + patch( + "app.tasks.letzshop_tasks._get_credentials_service", + return_value=mock_creds_service, + ), + patch( + "app.tasks.letzshop_tasks._get_order_service", + return_value=mock_order_service, + ), + ): + process_historical_import(job_id=job_id, vendor_id=test_vendor.id) + + # Verify final job state + updated_job = ( + db.query(LetzshopHistoricalImportJob) + .filter(LetzshopHistoricalImportJob.id == job_id) + .first() + ) + assert updated_job.status == "completed" + + def test_commit_error_in_exception_handler( + self, db, test_vendor, historical_import_job + ): + """Test handling when commit fails during exception handling.""" + job_id = historical_import_job.id + + # Create a mock session that fails on the second commit + mock_session = MagicMock() + mock_session.query.return_value.filter.return_value.first.return_value = ( + historical_import_job + ) + mock_session.commit.side_effect = [ + None, # First commit (status update) succeeds + Exception("Commit failed"), # Second commit fails + ] + mock_session.rollback = MagicMock() + mock_session.close = MagicMock() + + mock_creds_service = MagicMock() + mock_creds_service.create_client.side_effect = RuntimeError("Test error") + + mock_order_service = MagicMock() + mock_order_service.get_vendor.return_value = test_vendor + + with ( + patch("app.tasks.letzshop_tasks.SessionLocal", return_value=mock_session), + patch( + "app.tasks.letzshop_tasks._get_credentials_service", + return_value=mock_creds_service, + ), + patch( + "app.tasks.letzshop_tasks._get_order_service", + return_value=mock_order_service, + ), + patch("app.tasks.letzshop_tasks.admin_notification_service"), + ): + # Should not raise + process_historical_import(job_id=job_id, vendor_id=test_vendor.id) + + # Verify rollback was called + mock_session.rollback.assert_called() + + def test_close_error_handling(self, db, test_vendor, historical_import_job): + """Test handling when session close fails.""" + job_id = historical_import_job.id + + # Create a mock session that fails on close + mock_session = MagicMock() + mock_session.query.return_value.filter.return_value.first.return_value = ( + historical_import_job + ) + mock_session.commit = MagicMock() + mock_session.close.side_effect = Exception("Close failed") + + mock_client = MagicMock() + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=False) + mock_client.get_all_shipments_paginated.return_value = [] + + mock_creds_service = MagicMock() + mock_creds_service.create_client.return_value = mock_client + mock_creds_service.update_sync_status = MagicMock() + + mock_order_service = MagicMock() + mock_order_service.import_historical_shipments.return_value = { + "total": 0, + "imported": 0, + "updated": 0, + "skipped": 0, + "products_matched": 0, + "products_not_found": 0, + } + + with ( + patch("app.tasks.letzshop_tasks.SessionLocal", return_value=mock_session), + patch( + "app.tasks.letzshop_tasks._get_credentials_service", + return_value=mock_creds_service, + ), + patch( + "app.tasks.letzshop_tasks._get_order_service", + return_value=mock_order_service, + ), + ): + # Should not raise + process_historical_import(job_id=job_id, vendor_id=test_vendor.id) + + # Verify close was attempted + mock_session.close.assert_called() diff --git a/tests/integration/tasks/test_subscription_tasks.py b/tests/integration/tasks/test_subscription_tasks.py new file mode 100644 index 00000000..bfb2d3c9 --- /dev/null +++ b/tests/integration/tasks/test_subscription_tasks.py @@ -0,0 +1,362 @@ +# tests/integration/tasks/test_subscription_tasks.py +"""Integration tests for subscription background tasks.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from app.tasks.subscription_tasks import ( + check_trial_expirations, + cleanup_stale_subscriptions, + reset_period_counters, + sync_stripe_status, +) +from models.database.subscription import SubscriptionStatus, VendorSubscription + + +@pytest.fixture +def active_subscription(db, test_vendor): + """Create an active subscription.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.ACTIVE.value, + is_annual=False, + period_start=datetime.now(UTC) - timedelta(days=25), + period_end=datetime.now(UTC) - timedelta(days=1), # Ended yesterday + orders_this_period=50, + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def trial_subscription_expired_no_payment(db, test_vendor): + """Create an expired trial subscription without payment method.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.TRIAL.value, + trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday + period_start=datetime.now(UTC) - timedelta(days=14), + period_end=datetime.now(UTC) + timedelta(days=16), + stripe_payment_method_id=None, # No payment method + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def trial_subscription_expired_with_payment(db, test_vendor): + """Create an expired trial subscription with payment method.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.TRIAL.value, + trial_ends_at=datetime.now(UTC) - timedelta(days=1), # Ended yesterday + period_start=datetime.now(UTC) - timedelta(days=14), + period_end=datetime.now(UTC) + timedelta(days=16), + stripe_payment_method_id="pm_test_123", # Has payment method + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.fixture +def cancelled_subscription(db, test_vendor): + """Create a cancelled subscription past period end.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.CANCELLED.value, + period_start=datetime.now(UTC) - timedelta(days=60), + period_end=datetime.now(UTC) - timedelta(days=35), # Ended 35 days ago + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return subscription + + +@pytest.mark.integration +@pytest.mark.database +@pytest.mark.asyncio +class TestResetPeriodCounters: + """Test reset_period_counters task.""" + + async def test_resets_expired_period_counters(self, db, active_subscription): + """Test that period counters are reset for expired periods.""" + subscription_id = active_subscription.id + old_orders = active_subscription.orders_this_period + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await reset_period_counters() + + assert result["reset_count"] == 1 + + # Verify subscription was updated + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + assert updated.orders_this_period == 0 + assert updated.orders_limit_reached_at is None + # Handle timezone-naive datetime from database + period_end = updated.period_end + if period_end.tzinfo is None: + period_end = period_end.replace(tzinfo=UTC) + assert period_end > datetime.now(UTC) + + async def test_does_not_reset_active_period(self, db, test_vendor): + """Test that active periods are not reset.""" + # Create subscription with future period end + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.ACTIVE.value, + period_start=datetime.now(UTC) - timedelta(days=15), + period_end=datetime.now(UTC) + timedelta(days=15), # Still active + orders_this_period=25, + ) + db.add(subscription) + db.commit() + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await reset_period_counters() + + assert result["reset_count"] == 0 + + # Orders should not be reset + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription.id).first() + assert updated.orders_this_period == 25 + + async def test_handles_annual_subscription(self, db, test_vendor): + """Test that annual subscriptions get 365-day periods.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="professional", + status=SubscriptionStatus.ACTIVE.value, + is_annual=True, + period_start=datetime.now(UTC) - timedelta(days=370), + period_end=datetime.now(UTC) - timedelta(days=5), + orders_this_period=500, + ) + db.add(subscription) + db.commit() + subscription_id = subscription.id + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await reset_period_counters() + + assert result["reset_count"] == 1 + + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + # Should be ~365 days from now + expected_end = datetime.now(UTC) + timedelta(days=365) + # Handle timezone-naive datetime from database + period_end = updated.period_end + if period_end.tzinfo is None: + period_end = period_end.replace(tzinfo=UTC) + assert abs((period_end - expected_end).total_seconds()) < 60 + + +@pytest.mark.integration +@pytest.mark.database +@pytest.mark.asyncio +class TestCheckTrialExpirations: + """Test check_trial_expirations task.""" + + async def test_expires_trial_without_payment( + self, db, trial_subscription_expired_no_payment + ): + """Test that trials without payment method are expired.""" + subscription_id = trial_subscription_expired_no_payment.id + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await check_trial_expirations() + + assert result["expired_count"] == 1 + assert result["activated_count"] == 0 + + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + assert updated.status == SubscriptionStatus.EXPIRED.value + + async def test_activates_trial_with_payment( + self, db, trial_subscription_expired_with_payment + ): + """Test that trials with payment method are activated.""" + subscription_id = trial_subscription_expired_with_payment.id + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await check_trial_expirations() + + assert result["expired_count"] == 0 + assert result["activated_count"] == 1 + + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + assert updated.status == SubscriptionStatus.ACTIVE.value + + async def test_does_not_affect_active_trial(self, db, test_vendor): + """Test that active trials are not affected.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.TRIAL.value, + trial_ends_at=datetime.now(UTC) + timedelta(days=7), # Still active + period_start=datetime.now(UTC), + period_end=datetime.now(UTC) + timedelta(days=30), + ) + db.add(subscription) + db.commit() + subscription_id = subscription.id + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await check_trial_expirations() + + assert result["expired_count"] == 0 + assert result["activated_count"] == 0 + + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + assert updated.status == SubscriptionStatus.TRIAL.value + + +@pytest.mark.integration +@pytest.mark.database +@pytest.mark.asyncio +class TestSyncStripeStatus: + """Test sync_stripe_status task.""" + + async def test_skips_when_stripe_not_configured(self, db): + """Test that sync is skipped when Stripe is not configured.""" + mock_stripe = MagicMock() + mock_stripe.is_configured = False + + with ( + patch("app.tasks.subscription_tasks.SessionLocal", return_value=db), + patch("app.tasks.subscription_tasks.stripe_service", mock_stripe), + ): + result = await sync_stripe_status() + + assert result["skipped"] is True + assert result["synced"] == 0 + + async def test_syncs_subscription_status(self, db, test_vendor): + """Test that subscription status is synced from Stripe.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.TRIAL.value, + stripe_subscription_id="sub_test_123", + period_start=datetime.now(UTC), + period_end=datetime.now(UTC) + timedelta(days=30), + ) + db.add(subscription) + db.commit() + subscription_id = subscription.id + + # Mock Stripe response + mock_stripe_sub = MagicMock() + mock_stripe_sub.status = "active" + mock_stripe_sub.current_period_start = int( + (datetime.now(UTC) - timedelta(days=5)).timestamp() + ) + mock_stripe_sub.current_period_end = int( + (datetime.now(UTC) + timedelta(days=25)).timestamp() + ) + mock_stripe_sub.default_payment_method = "pm_test_456" + + mock_stripe = MagicMock() + mock_stripe.is_configured = True + mock_stripe.get_subscription.return_value = mock_stripe_sub + + with ( + patch("app.tasks.subscription_tasks.SessionLocal", return_value=db), + patch("app.tasks.subscription_tasks.stripe_service", mock_stripe), + ): + result = await sync_stripe_status() + + assert result["synced_count"] == 1 + assert result["error_count"] == 0 + + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + assert updated.status == SubscriptionStatus.ACTIVE.value + assert updated.stripe_payment_method_id == "pm_test_456" + + async def test_handles_missing_stripe_subscription(self, db, test_vendor): + """Test handling when Stripe subscription is not found.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.ACTIVE.value, + stripe_subscription_id="sub_deleted_123", + period_start=datetime.now(UTC), + period_end=datetime.now(UTC) + timedelta(days=30), + ) + db.add(subscription) + db.commit() + + mock_stripe = MagicMock() + mock_stripe.is_configured = True + mock_stripe.get_subscription.return_value = None + + with ( + patch("app.tasks.subscription_tasks.SessionLocal", return_value=db), + patch("app.tasks.subscription_tasks.stripe_service", mock_stripe), + ): + result = await sync_stripe_status() + + # Should not count as synced (subscription not found in Stripe) + assert result["synced_count"] == 0 + + +@pytest.mark.integration +@pytest.mark.database +@pytest.mark.asyncio +class TestCleanupStaleSubscriptions: + """Test cleanup_stale_subscriptions task.""" + + async def test_cleans_old_cancelled_subscriptions(self, db, cancelled_subscription): + """Test that old cancelled subscriptions are marked as expired.""" + subscription_id = cancelled_subscription.id + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await cleanup_stale_subscriptions() + + assert result["cleaned_count"] == 1 + + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + assert updated.status == SubscriptionStatus.EXPIRED.value + + async def test_does_not_clean_recent_cancelled(self, db, test_vendor): + """Test that recently cancelled subscriptions are not cleaned.""" + subscription = VendorSubscription( + vendor_id=test_vendor.id, + tier="essential", + status=SubscriptionStatus.CANCELLED.value, + period_start=datetime.now(UTC) - timedelta(days=25), + period_end=datetime.now(UTC) - timedelta(days=5), # Only 5 days ago + ) + db.add(subscription) + db.commit() + subscription_id = subscription.id + + with patch("app.tasks.subscription_tasks.SessionLocal", return_value=db): + result = await cleanup_stale_subscriptions() + + assert result["cleaned_count"] == 0 + + db.expire_all() + updated = db.query(VendorSubscription).filter_by(id=subscription_id).first() + assert updated.status == SubscriptionStatus.CANCELLED.value