style: apply black and isort formatting across entire codebase

- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards

Also updated Makefile to exclude both venv and .venv from code quality checks.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-28 19:30:17 +01:00
parent 13f0094743
commit 21c13ca39b
236 changed files with 8450 additions and 6545 deletions

View File

@@ -7,13 +7,13 @@ from sqlalchemy.pool import StaticPool
from app.core.database import Base, get_db
from main import app
from models.database.inventory import Inventory
# Import all models to ensure they're registered with Base metadata
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor
from models.database.product import Product
from models.database.inventory import Inventory
from models.database.user import User
from models.database.vendor import Vendor
# Use in-memory SQLite database for tests
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:"

View File

@@ -53,6 +53,7 @@ def test_admin(db, auth_manager):
db.expunge(admin)
return admin
@pytest.fixture
def another_admin(db, auth_manager):
"""Create another test admin user for testing admin-to-admin interactions"""
@@ -72,6 +73,7 @@ def another_admin(db, auth_manager):
db.expunge(admin)
return admin
@pytest.fixture
def other_user(db, auth_manager):
"""Create a different user for testing access controls"""

View File

@@ -1,5 +1,6 @@
# tests/fixtures/customer_fixtures.py
import pytest
from models.database.customer import Customer, CustomerAddress

View File

@@ -1,5 +1,6 @@
# tests/fixtures/marketplace_import_job_fixtures.py
import pytest
from models.database.marketplace_import_job import MarketplaceImportJob

View File

@@ -117,7 +117,9 @@ def marketplace_product_factory():
@pytest.fixture
def test_marketplace_product_with_inventory(db, test_marketplace_product, test_inventory):
def test_marketplace_product_with_inventory(
db, test_marketplace_product, test_inventory
):
"""MarketplaceProduct with associated inventory record."""
# Ensure they're linked by GTIN
if test_marketplace_product.gtin != test_inventory.gtin:
@@ -126,6 +128,6 @@ def test_marketplace_product_with_inventory(db, test_marketplace_product, test_i
db.refresh(test_inventory)
return {
'marketplace_product': test_marketplace_product,
'inventory': test_inventory
"marketplace_product": test_marketplace_product,
"inventory": test_inventory,
}

View File

@@ -8,8 +8,9 @@ This module provides fixtures for:
- Additional testing utilities
"""
import pytest
from unittest.mock import Mock
import pytest
from sqlalchemy.exc import SQLAlchemyError
@@ -26,7 +27,7 @@ def empty_db(db):
"inventory", # Fixed: singular not plural
"products", # Referenced by products
"vendors", # Has foreign key to users
"users" # Base table
"users", # Base table
]
for table in tables_to_clear:

View File

@@ -1,10 +1,11 @@
# tests/fixtures/vendor_fixtures.py
import uuid
import pytest
from models.database.vendor import Vendor
from models.database.product import Product
from models.database.inventory import Inventory
from models.database.product import Product
from models.database.vendor import Vendor
@pytest.fixture
@@ -185,6 +186,7 @@ def multiple_inventory_entries(db, multiple_products, test_vendor):
def create_unique_vendor_factory():
"""Factory function to create unique vendors in tests"""
def _create_vendor(db, owner_user_id, **kwargs):
unique_id = str(uuid.uuid4())[:8]
defaults = {

View File

@@ -59,7 +59,9 @@ class TestAdminAPI:
assert response.status_code == 400 # Business logic error
data = response.json()
assert data["error_code"] == "CANNOT_MODIFY_SELF"
assert "Cannot perform 'deactivate account' on your own account" in data["message"]
assert (
"Cannot perform 'deactivate account' on your own account" in data["message"]
)
def test_toggle_user_status_cannot_modify_admin(
self, client, admin_headers, test_admin, another_admin
@@ -85,7 +87,9 @@ class TestAdminAPI:
# Check that test_vendor is in the response
vendor_codes = [
vendor ["vendor_code"] for vendor in data["vendors"] if "vendor_code" in vendor
vendor["vendor_code"]
for vendor in data["vendors"]
if "vendor_code" in vendor
]
assert test_vendor.vendor_code in vendor_codes
@@ -98,7 +102,7 @@ class TestAdminAPI:
assert data["error_code"] == "ADMIN_REQUIRED"
def test_verify_vendor_admin(self, client, admin_headers, test_vendor):
"""Test admin verifying/unverifying vendor """
"""Test admin verifying/unverifying vendor"""
response = client.put(
f"/api/v1/admin/vendors/{test_vendor.id}/verify", headers=admin_headers
)
@@ -109,8 +113,10 @@ class TestAdminAPI:
assert test_vendor.vendor_code in message
def test_verify_vendor_not_found(self, client, admin_headers):
"""Test admin verifying non-existent vendor """
response = client.put("/api/v1/admin/vendors/99999/verify", headers=admin_headers)
"""Test admin verifying non-existent vendor"""
response = client.put(
"/api/v1/admin/vendors/99999/verify", headers=admin_headers
)
assert response.status_code == 404
data = response.json()
@@ -129,8 +135,10 @@ class TestAdminAPI:
assert test_vendor.vendor_code in message
def test_toggle_vendor_status_not_found(self, client, admin_headers):
"""Test admin toggling status for non-existent vendor """
response = client.put("/api/v1/admin/vendors/99999/status", headers=admin_headers)
"""Test admin toggling status for non-existent vendor"""
response = client.put(
"/api/v1/admin/vendors/99999/status", headers=admin_headers
)
assert response.status_code == 404
data = response.json()
@@ -166,7 +174,8 @@ class TestAdminAPI:
data = response.json()
assert len(data) >= 1
assert all(
job["marketplace"] == test_marketplace_import_job.marketplace for job in data
job["marketplace"] == test_marketplace_import_job.marketplace
for job in data
)
def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers):

View File

@@ -1,7 +1,8 @@
# tests/integration/api/v1/test_auth_endpoints.py
from datetime import datetime, timedelta, timezone
import pytest
from jose import jwt
from datetime import datetime, timedelta, timezone
@pytest.mark.integration
@@ -178,8 +179,7 @@ class TestAuthenticationAPI:
def test_get_current_user_invalid_token(self, client):
"""Test getting current user with invalid token"""
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": "Bearer invalid_token_here"}
"/api/v1/auth/me", headers={"Authorization": "Bearer invalid_token_here"}
)
assert response.status_code == 401
@@ -201,14 +201,11 @@ class TestAuthenticationAPI:
}
expired_token = jwt.encode(
expired_payload,
auth_manager.secret_key,
algorithm=auth_manager.algorithm
expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
response = client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {expired_token}"}
"/api/v1/auth/me", headers={"Authorization": f"Bearer {expired_token}"}
)
assert response.status_code == 401
@@ -321,4 +318,3 @@ class TestAuthManager:
user = auth_manager.authenticate_user(db, "nonexistent", "password")
assert user is None

View File

@@ -14,19 +14,34 @@ class TestFiltering:
"""Test filtering products by brand successfully"""
# Create products with different brands using unique IDs
import uuid
unique_suffix = str(uuid.uuid4())[:8]
products = [
MarketplaceProduct(marketplace_product_id=f"BRAND1_{unique_suffix}", title="MarketplaceProduct 1", brand="BrandA"),
MarketplaceProduct(marketplace_product_id=f"BRAND2_{unique_suffix}", title="MarketplaceProduct 2", brand="BrandB"),
MarketplaceProduct(marketplace_product_id=f"BRAND3_{unique_suffix}", title="MarketplaceProduct 3", brand="BrandA"),
MarketplaceProduct(
marketplace_product_id=f"BRAND1_{unique_suffix}",
title="MarketplaceProduct 1",
brand="BrandA",
),
MarketplaceProduct(
marketplace_product_id=f"BRAND2_{unique_suffix}",
title="MarketplaceProduct 2",
brand="BrandB",
),
MarketplaceProduct(
marketplace_product_id=f"BRAND3_{unique_suffix}",
title="MarketplaceProduct 3",
brand="BrandA",
),
]
db.add_all(products)
db.commit()
# Filter by BrandA
response = client.get("/api/v1/marketplace/product?brand=BrandA", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?brand=BrandA", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # At least our test products
@@ -37,7 +52,9 @@ class TestFiltering:
assert product["brand"] == "BrandA"
# Filter by BrandB
response = client.get("/api/v1/marketplace/product?brand=BrandB", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?brand=BrandB", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1 # At least our test product
@@ -45,37 +62,57 @@ class TestFiltering:
def test_product_marketplace_filter_success(self, client, auth_headers, db):
"""Test filtering products by marketplace successfully"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
products = [
MarketplaceProduct(marketplace_product_id=f"MKT1_{unique_suffix}", title="MarketplaceProduct 1", marketplace="Amazon"),
MarketplaceProduct(marketplace_product_id=f"MKT2_{unique_suffix}", title="MarketplaceProduct 2", marketplace="eBay"),
MarketplaceProduct(marketplace_product_id=f"MKT3_{unique_suffix}", title="MarketplaceProduct 3", marketplace="Amazon"),
MarketplaceProduct(
marketplace_product_id=f"MKT1_{unique_suffix}",
title="MarketplaceProduct 1",
marketplace="Amazon",
),
MarketplaceProduct(
marketplace_product_id=f"MKT2_{unique_suffix}",
title="MarketplaceProduct 2",
marketplace="eBay",
),
MarketplaceProduct(
marketplace_product_id=f"MKT3_{unique_suffix}",
title="MarketplaceProduct 3",
marketplace="Amazon",
),
]
db.add_all(products)
db.commit()
response = client.get("/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # At least our test products
# Verify all returned products have Amazon marketplace
amazon_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
amazon_products = [
p
for p in data["products"]
if p["marketplace_product_id"].endswith(unique_suffix)
]
for product in amazon_products:
assert product["marketplace"] == "Amazon"
def test_product_search_filter_success(self, client, auth_headers, db):
"""Test searching products by text successfully"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
products = [
MarketplaceProduct(
marketplace_product_id=f"SEARCH1_{unique_suffix}",
title=f"Apple iPhone {unique_suffix}",
description="Smartphone"
description="Smartphone",
),
MarketplaceProduct(
marketplace_product_id=f"SEARCH2_{unique_suffix}",
@@ -85,7 +122,7 @@ class TestFiltering:
MarketplaceProduct(
marketplace_product_id=f"SEARCH3_{unique_suffix}",
title=f"iPad Tablet {unique_suffix}",
description="Apple tablet"
description="Apple tablet",
),
]
@@ -93,13 +130,17 @@ class TestFiltering:
db.commit()
# Search for "Apple"
response = client.get(f"/api/v1/marketplace/product?search=Apple", headers=auth_headers)
response = client.get(
f"/api/v1/marketplace/product?search=Apple", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # iPhone and iPad
# Search for "phone"
response = client.get(f"/api/v1/marketplace/product?search=phone", headers=auth_headers)
response = client.get(
f"/api/v1/marketplace/product?search=phone", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 2 # iPhone and Galaxy
@@ -107,6 +148,7 @@ class TestFiltering:
def test_combined_filters_success(self, client, auth_headers, db):
"""Test combining multiple filters successfully"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
products = [
@@ -135,14 +177,19 @@ class TestFiltering:
# Filter by brand AND marketplace
response = client.get(
"/api/v1/marketplace/product?brand=Apple&marketplace=Amazon", headers=auth_headers
"/api/v1/marketplace/product?brand=Apple&marketplace=Amazon",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1 # At least iPhone matches both
# Find our specific test product
matching_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
matching_products = [
p
for p in data["products"]
if p["marketplace_product_id"].endswith(unique_suffix)
]
for product in matching_products:
assert product["brand"] == "Apple"
assert product["marketplace"] == "Amazon"
@@ -150,7 +197,8 @@ class TestFiltering:
def test_filter_with_no_results(self, client, auth_headers):
"""Test filtering with criteria that returns no results"""
response = client.get(
"/api/v1/marketplace/product?brand=NonexistentBrand123456", headers=auth_headers
"/api/v1/marketplace/product?brand=NonexistentBrand123456",
headers=auth_headers,
)
assert response.status_code == 200
@@ -161,6 +209,7 @@ class TestFiltering:
def test_filter_case_insensitive(self, client, auth_headers, db):
"""Test that filters are case-insensitive"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
product = MarketplaceProduct(
@@ -174,7 +223,10 @@ class TestFiltering:
# Test different case variations
for brand_filter in ["TestBrand", "testbrand", "TESTBRAND"]:
response = client.get(f"/api/v1/marketplace/product?brand={brand_filter}", headers=auth_headers)
response = client.get(
f"/api/v1/marketplace/product?brand={brand_filter}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
@@ -183,9 +235,14 @@ class TestFiltering:
"""Test behavior with invalid filter parameters"""
# Test with very long filter values
long_brand = "A" * 1000
response = client.get(f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers)
response = client.get(
f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers
)
assert response.status_code == 200 # Should handle gracefully
# Test with special characters
response = client.get("/api/v1/marketplace/product?brand=<script>alert('test')</script>", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?brand=<script>alert('test')</script>",
headers=auth_headers,
)
assert response.status_code == 200 # Should handle gracefully

View File

@@ -17,7 +17,9 @@ class TestInventoryAPI:
"quantity": 100,
}
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
response = client.post(
"/api/v1/inventory", headers=auth_headers, json=inventory_data
)
assert response.status_code == 200
data = response.json()
@@ -38,7 +40,9 @@ class TestInventoryAPI:
"quantity": 75,
}
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
response = client.post(
"/api/v1/inventory", headers=auth_headers, json=inventory_data
)
assert response.status_code == 200
data = response.json()
@@ -52,7 +56,9 @@ class TestInventoryAPI:
"quantity": 100,
}
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
response = client.post(
"/api/v1/inventory", headers=auth_headers, json=inventory_data
)
assert response.status_code == 422
data = response.json()
@@ -60,7 +66,9 @@ class TestInventoryAPI:
assert data["status_code"] == 422
assert "GTIN is required" in data["message"]
def test_set_inventory_invalid_quantity_validation_error(self, client, auth_headers):
def test_set_inventory_invalid_quantity_validation_error(
self, client, auth_headers
):
"""Test setting inventory with invalid quantity returns InvalidQuantityException"""
inventory_data = {
"gtin": "1234567890123",
@@ -68,7 +76,9 @@ class TestInventoryAPI:
"quantity": -10, # Negative quantity
}
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
response = client.post(
"/api/v1/inventory", headers=auth_headers, json=inventory_data
)
assert response.status_code in [400, 422]
data = response.json()
@@ -138,7 +148,9 @@ class TestInventoryAPI:
data = response.json()
assert data["quantity"] == 35 # 50 - 15
def test_remove_inventory_insufficient_returns_business_logic_error(self, client, auth_headers, db):
def test_remove_inventory_insufficient_returns_business_logic_error(
self, client, auth_headers, db
):
"""Test removing more inventory than available returns InsufficientInventoryException"""
# Create initial inventory
inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=10)
@@ -184,7 +196,9 @@ class TestInventoryAPI:
assert data["error_code"] == "INVENTORY_NOT_FOUND"
assert data["status_code"] == 404
def test_negative_inventory_not_allowed_business_logic_error(self, client, auth_headers, db):
def test_negative_inventory_not_allowed_business_logic_error(
self, client, auth_headers, db
):
"""Test operations resulting in negative inventory returns NegativeInventoryException"""
# Create initial inventory
inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=5)
@@ -204,14 +218,21 @@ class TestInventoryAPI:
assert response.status_code == 400
data = response.json()
# This might be caught as INSUFFICIENT_INVENTORY or NEGATIVE_INVENTORY_NOT_ALLOWED
assert data["error_code"] in ["INSUFFICIENT_INVENTORY", "NEGATIVE_INVENTORY_NOT_ALLOWED"]
assert data["error_code"] in [
"INSUFFICIENT_INVENTORY",
"NEGATIVE_INVENTORY_NOT_ALLOWED",
]
assert data["status_code"] == 400
def test_get_inventory_by_gtin_success(self, client, auth_headers, db):
"""Test getting inventory summary for GTIN successfully"""
# Create inventory in multiple locations
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25)
inventory1 = Inventory(
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
)
inventory2 = Inventory(
gtin="1234567890123", location="WAREHOUSE_B", quantity=25
)
db.add_all([inventory1, inventory2])
db.commit()
@@ -238,12 +259,18 @@ class TestInventoryAPI:
def test_get_total_inventory_success(self, client, auth_headers, db):
"""Test getting total inventory for GTIN successfully"""
# Create inventory in multiple locations
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25)
inventory1 = Inventory(
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
)
inventory2 = Inventory(
gtin="1234567890123", location="WAREHOUSE_B", quantity=25
)
db.add_all([inventory1, inventory2])
db.commit()
response = client.get("/api/v1/inventory/1234567890123/total", headers=auth_headers)
response = client.get(
"/api/v1/inventory/1234567890123/total", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
@@ -253,7 +280,9 @@ class TestInventoryAPI:
def test_get_total_inventory_not_found(self, client, auth_headers):
"""Test getting total inventory for nonexistent GTIN returns InventoryNotFoundException"""
response = client.get("/api/v1/inventory/9999999999999/total", headers=auth_headers)
response = client.get(
"/api/v1/inventory/9999999999999/total", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
@@ -263,8 +292,12 @@ class TestInventoryAPI:
def test_get_all_inventory_success(self, client, auth_headers, db):
"""Test getting all inventory entries successfully"""
# Create some inventory entries
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25)
inventory1 = Inventory(
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
)
inventory2 = Inventory(
gtin="9876543210987", location="WAREHOUSE_B", quantity=25
)
db.add_all([inventory1, inventory2])
db.commit()
@@ -277,20 +310,28 @@ class TestInventoryAPI:
def test_get_all_inventory_with_filters(self, client, auth_headers, db):
"""Test getting inventory entries with filtering"""
# Create inventory entries
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25)
inventory1 = Inventory(
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
)
inventory2 = Inventory(
gtin="9876543210987", location="WAREHOUSE_B", quantity=25
)
db.add_all([inventory1, inventory2])
db.commit()
# Filter by location
response = client.get("/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers)
response = client.get(
"/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
for inventory in data:
assert inventory["location"] == "WAREHOUSE_A"
# Filter by GTIN
response = client.get("/api/v1/inventory?gtin=1234567890123", headers=auth_headers)
response = client.get(
"/api/v1/inventory?gtin=1234567890123", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
for inventory in data:
@@ -390,7 +431,9 @@ class TestInventoryAPI:
"quantity": 100,
}
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
response = client.post(
"/api/v1/inventory", headers=auth_headers, json=inventory_data
)
# This depends on whether your service validates locations
if response.status_code == 404:

View File

@@ -8,9 +8,11 @@ import pytest
@pytest.mark.api
@pytest.mark.marketplace
class TestMarketplaceImportJobAPI:
def test_import_from_marketplace(self, client, auth_headers, test_vendor, test_user):
def test_import_from_marketplace(
self, client, auth_headers, test_vendor, test_user
):
"""Test marketplace import endpoint - just test job creation"""
# Ensure user owns the vendor
# Ensure user owns the vendor
test_vendor.owner_user_id = test_user.id
import_data = {
@@ -32,7 +34,7 @@ class TestMarketplaceImportJobAPI:
assert data["vendor_id"] == test_vendor.id
def test_import_from_marketplace_invalid_vendor(self, client, auth_headers):
"""Test marketplace import with invalid vendor """
"""Test marketplace import with invalid vendor"""
import_data = {
"url": "https://example.com/products.csv",
"marketplace": "TestMarket",
@@ -48,7 +50,9 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "VENDOR_NOT_FOUND"
assert "NONEXISTENT" in data["message"]
def test_import_from_marketplace_unauthorized_vendor(self, client, auth_headers, test_vendor, other_user):
def test_import_from_marketplace_unauthorized_vendor(
self, client, auth_headers, test_vendor, other_user
):
"""Test marketplace import with unauthorized vendor access"""
# Set vendor owner to different user
test_vendor.owner_user_id = other_user.id
@@ -85,8 +89,10 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "VALIDATION_ERROR"
assert "Request validation failed" in data["message"]
def test_import_from_marketplace_admin_access(self, client, admin_headers, test_vendor):
"""Test that admin can import for any vendor """
def test_import_from_marketplace_admin_access(
self, client, admin_headers, test_vendor
):
"""Test that admin can import for any vendor"""
import_data = {
"url": "https://example.com/products.csv",
"marketplace": "AdminMarket",
@@ -94,7 +100,9 @@ class TestMarketplaceImportJobAPI:
}
response = client.post(
"/api/v1/marketplace/import-product", headers=admin_headers, json=import_data
"/api/v1/marketplace/import-product",
headers=admin_headers,
json=import_data,
)
assert response.status_code == 200
@@ -102,11 +110,13 @@ class TestMarketplaceImportJobAPI:
assert data["marketplace"] == "AdminMarket"
assert data["vendor_code"] == test_vendor.vendor_code
def test_get_marketplace_import_status(self, client, auth_headers, test_marketplace_import_job):
def test_get_marketplace_import_status(
self, client, auth_headers, test_marketplace_import_job
):
"""Test getting marketplace import status"""
response = client.get(
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
@@ -118,8 +128,7 @@ class TestMarketplaceImportJobAPI:
def test_get_marketplace_import_status_not_found(self, client, auth_headers):
"""Test getting status of non-existent import job"""
response = client.get(
"/api/v1/marketplace/import-status/99999",
headers=auth_headers
"/api/v1/marketplace/import-status/99999", headers=auth_headers
)
assert response.status_code == 404
@@ -127,21 +136,25 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
assert "99999" in data["message"]
def test_get_marketplace_import_status_unauthorized(self, client, auth_headers, test_marketplace_import_job, other_user):
def test_get_marketplace_import_status_unauthorized(
self, client, auth_headers, test_marketplace_import_job, other_user
):
"""Test getting status of unauthorized import job"""
# Change job owner to other user
test_marketplace_import_job.user_id = other_user.id
response = client.get(
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 403
data = response.json()
assert data["error_code"] == "IMPORT_JOB_NOT_OWNED"
def test_get_marketplace_import_jobs(self, client, auth_headers, test_marketplace_import_job):
def test_get_marketplace_import_jobs(
self, client, auth_headers, test_marketplace_import_job
):
"""Test getting marketplace import jobs"""
response = client.get("/api/v1/marketplace/import-jobs", headers=auth_headers)
@@ -154,11 +167,13 @@ class TestMarketplaceImportJobAPI:
job_ids = [job["job_id"] for job in data]
assert test_marketplace_import_job.id in job_ids
def test_get_marketplace_import_jobs_with_filters(self, client, auth_headers, test_marketplace_import_job):
def test_get_marketplace_import_jobs_with_filters(
self, client, auth_headers, test_marketplace_import_job
):
"""Test getting import jobs with filters"""
response = client.get(
f"/api/v1/marketplace/import-jobs?marketplace={test_marketplace_import_job.marketplace}",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
@@ -167,13 +182,15 @@ class TestMarketplaceImportJobAPI:
assert len(data) >= 1
for job in data:
assert test_marketplace_import_job.marketplace.lower() in job["marketplace"].lower()
assert (
test_marketplace_import_job.marketplace.lower()
in job["marketplace"].lower()
)
def test_get_marketplace_import_jobs_pagination(self, client, auth_headers):
"""Test import jobs pagination"""
response = client.get(
"/api/v1/marketplace/import-jobs?skip=0&limit=5",
headers=auth_headers
"/api/v1/marketplace/import-jobs?skip=0&limit=5", headers=auth_headers
)
assert response.status_code == 200
@@ -181,9 +198,13 @@ class TestMarketplaceImportJobAPI:
assert isinstance(data, list)
assert len(data) <= 5
def test_get_marketplace_import_stats(self, client, auth_headers, test_marketplace_import_job):
def test_get_marketplace_import_stats(
self, client, auth_headers, test_marketplace_import_job
):
"""Test getting marketplace import statistics"""
response = client.get("/api/v1/marketplace/marketplace-import-stats", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/marketplace-import-stats", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
@@ -195,12 +216,15 @@ class TestMarketplaceImportJobAPI:
assert isinstance(data["total_jobs"], int)
assert data["total_jobs"] >= 1
def test_cancel_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db):
def test_cancel_marketplace_import_job(
self, client, auth_headers, test_user, test_vendor, db
):
"""Test cancelling a marketplace import job"""
# Create a pending job that can be cancelled
from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
from models.database.marketplace_import_job import MarketplaceImportJob
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="pending",
@@ -219,8 +243,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.put(
f"/api/v1/marketplace/import-jobs/{job.id}/cancel",
headers=auth_headers
f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=auth_headers
)
assert response.status_code == 200
@@ -232,15 +255,16 @@ class TestMarketplaceImportJobAPI:
def test_cancel_marketplace_import_job_not_found(self, client, auth_headers):
"""Test cancelling non-existent import job"""
response = client.put(
"/api/v1/marketplace/import-jobs/99999/cancel",
headers=auth_headers
"/api/v1/marketplace/import-jobs/99999/cancel", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
def test_cancel_marketplace_import_job_cannot_cancel(self, client, auth_headers, test_marketplace_import_job, db):
def test_cancel_marketplace_import_job_cannot_cancel(
self, client, auth_headers, test_marketplace_import_job, db
):
"""Test cancelling a job that cannot be cancelled"""
# Set job to completed status
test_marketplace_import_job.status = "completed"
@@ -248,7 +272,7 @@ class TestMarketplaceImportJobAPI:
response = client.put(
f"/api/v1/marketplace/import-jobs/{test_marketplace_import_job.id}/cancel",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 400
@@ -256,12 +280,15 @@ class TestMarketplaceImportJobAPI:
assert data["error_code"] == "IMPORT_JOB_CANNOT_BE_CANCELLED"
assert "completed" in data["message"]
def test_delete_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db):
def test_delete_marketplace_import_job(
self, client, auth_headers, test_user, test_vendor, db
):
"""Test deleting a marketplace import job"""
# Create a completed job that can be deleted
from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
from models.database.marketplace_import_job import MarketplaceImportJob
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="completed",
@@ -280,8 +307,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.delete(
f"/api/v1/marketplace/import-jobs/{job.id}",
headers=auth_headers
f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers
)
assert response.status_code == 200
@@ -291,20 +317,22 @@ class TestMarketplaceImportJobAPI:
def test_delete_marketplace_import_job_not_found(self, client, auth_headers):
"""Test deleting non-existent import job"""
response = client.delete(
"/api/v1/marketplace/import-jobs/99999",
headers=auth_headers
"/api/v1/marketplace/import-jobs/99999", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
def test_delete_marketplace_import_job_cannot_delete(self, client, auth_headers, test_user, test_vendor, db):
def test_delete_marketplace_import_job_cannot_delete(
self, client, auth_headers, test_user, test_vendor, db
):
"""Test deleting a job that cannot be deleted"""
# Create a pending job that cannot be deleted
from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
from models.database.marketplace_import_job import MarketplaceImportJob
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="pending",
@@ -323,8 +351,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.delete(
f"/api/v1/marketplace/import-jobs/{job.id}",
headers=auth_headers
f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers
)
assert response.status_code == 400
@@ -352,7 +379,9 @@ class TestMarketplaceImportJobAPI:
data = response.json()
assert data["error_code"] == "INVALID_TOKEN"
def test_admin_can_access_all_jobs(self, client, admin_headers, test_marketplace_import_job):
def test_admin_can_access_all_jobs(
self, client, admin_headers, test_marketplace_import_job
):
"""Test that admin can access all import jobs"""
response = client.get("/api/v1/marketplace/import-jobs", headers=admin_headers)
@@ -363,23 +392,28 @@ class TestMarketplaceImportJobAPI:
job_ids = [job["job_id"] for job in data]
assert test_marketplace_import_job.id in job_ids
def test_admin_can_view_any_job_status(self, client, admin_headers, test_marketplace_import_job):
def test_admin_can_view_any_job_status(
self, client, admin_headers, test_marketplace_import_job
):
"""Test that admin can view any job status"""
response = client.get(
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
headers=admin_headers
headers=admin_headers,
)
assert response.status_code == 200
data = response.json()
assert data["job_id"] == test_marketplace_import_job.id
def test_admin_can_cancel_any_job(self, client, admin_headers, test_user, test_vendor, db):
def test_admin_can_cancel_any_job(
self, client, admin_headers, test_user, test_vendor, db
):
"""Test that admin can cancel any job"""
# Create a pending job owned by different user
from models.database.marketplace_import_job import MarketplaceImportJob
import uuid
from models.database.marketplace_import_job import MarketplaceImportJob
unique_id = str(uuid.uuid4())[:8]
job = MarketplaceImportJob(
status="pending",
@@ -398,8 +432,7 @@ class TestMarketplaceImportJobAPI:
db.refresh(job)
response = client.put(
f"/api/v1/marketplace/import-jobs/{job.id}/cancel",
headers=admin_headers
f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=admin_headers
)
assert response.status_code == 200

View File

@@ -1,7 +1,7 @@
# tests/integration/api/v1/test_export.py
import csv
from io import StringIO
import uuid
from io import StringIO
import pytest
@@ -13,9 +13,13 @@ from models.database.marketplace_product import MarketplaceProduct
@pytest.mark.performance # for the performance test
class TestExportFunctionality:
def test_csv_export_basic_success(self, client, auth_headers, test_marketplace_product):
def test_csv_export_basic_success(
self, client, auth_headers, test_marketplace_product
):
"""Test basic CSV export functionality successfully"""
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/export-csv", headers=auth_headers
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
@@ -26,13 +30,22 @@ class TestExportFunctionality:
# Check header row
header = next(csv_reader)
expected_fields = ["marketplace_product_id", "title", "description", "price", "marketplace"]
expected_fields = [
"marketplace_product_id",
"title",
"description",
"price",
"marketplace",
]
for field in expected_fields:
assert field in header
# Verify test product appears in export
csv_lines = csv_content.split('\n')
test_product_found = any(test_marketplace_product.marketplace_product_id in line for line in csv_lines)
csv_lines = csv_content.split("\n")
test_product_found = any(
test_marketplace_product.marketplace_product_id in line
for line in csv_lines
)
assert test_product_found, "Test product should appear in CSV export"
def test_csv_export_with_marketplace_filter_success(self, client, auth_headers, db):
@@ -43,12 +56,12 @@ class TestExportFunctionality:
MarketplaceProduct(
marketplace_product_id=f"EXP1_{unique_suffix}",
title=f"Amazon MarketplaceProduct {unique_suffix}",
marketplace="Amazon"
marketplace="Amazon",
),
MarketplaceProduct(
marketplace_product_id=f"EXP2_{unique_suffix}",
title=f"eBay MarketplaceProduct {unique_suffix}",
marketplace="eBay"
marketplace="eBay",
),
]
@@ -56,7 +69,8 @@ class TestExportFunctionality:
db.commit()
response = client.get(
"/api/v1/marketplace/product/export-csv?marketplace=Amazon", headers=auth_headers
"/api/v1/marketplace/product/export-csv?marketplace=Amazon",
headers=auth_headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
@@ -72,12 +86,12 @@ class TestExportFunctionality:
MarketplaceProduct(
marketplace_product_id=f"VENDOR1_{unique_suffix}",
title=f"Vendor1 MarketplaceProduct {unique_suffix}",
vendor_name="TestVendor1"
vendor_name="TestVendor1",
),
MarketplaceProduct(
marketplace_product_id=f"VENDOR2_{unique_suffix}",
title=f"Vendor2 MarketplaceProduct {unique_suffix}",
vendor_name="TestVendor2"
vendor_name="TestVendor2",
),
]
@@ -101,19 +115,19 @@ class TestExportFunctionality:
marketplace_product_id=f"COMBO1_{unique_suffix}",
title=f"Combo MarketplaceProduct 1 {unique_suffix}",
marketplace="Amazon",
vendor_name="TestVendor"
vendor_name="TestVendor",
),
MarketplaceProduct(
marketplace_product_id=f"COMBO2_{unique_suffix}",
title=f"Combo MarketplaceProduct 2 {unique_suffix}",
marketplace="eBay",
vendor_name="TestVendor"
vendor_name="TestVendor",
),
MarketplaceProduct(
marketplace_product_id=f"COMBO3_{unique_suffix}",
title=f"Combo MarketplaceProduct 3 {unique_suffix}",
marketplace="Amazon",
vendor_name="OtherVendor"
vendor_name="OtherVendor",
),
]
@@ -122,27 +136,27 @@ class TestExportFunctionality:
response = client.get(
"/api/v1/marketplace/product?marketplace=Amazon&name=TestVendor",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
csv_content = response.content.decode("utf-8")
assert f"COMBO1_{unique_suffix}" in csv_content # Matches both filters
assert f"COMBO2_{unique_suffix}" not in csv_content # Wrong marketplace
assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor
assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor
def test_csv_export_no_results(self, client, auth_headers):
"""Test CSV export with filters that return no results"""
response = client.get(
"/api/v1/marketplace/product/export-csv?marketplace=NonexistentMarketplace12345",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
csv_content = response.content.decode("utf-8")
csv_lines = csv_content.strip().split('\n')
csv_lines = csv_content.strip().split("\n")
# Should have header row even with no data
assert len(csv_lines) >= 1
# First line should be headers
@@ -161,27 +175,32 @@ class TestExportFunctionality:
title=f"Performance MarketplaceProduct {i}",
marketplace="Performance",
description=f"Performance test product {i}",
price="10.99"
price="10.99",
)
products.append(product)
# Add in batches to avoid memory issues
for i in range(0, len(products), 50):
batch = products[i:i + 50]
batch = products[i : i + 50]
db.add_all(batch)
db.commit()
import time
start_time = time.time()
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/export-csv", headers=auth_headers
)
end_time = time.time()
execution_time = end_time - start_time
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
assert execution_time < 10.0, f"Export took {execution_time:.2f} seconds, should be under 10s"
assert (
execution_time < 10.0
), f"Export took {execution_time:.2f} seconds, should be under 10s"
# Verify content contains our test data
csv_content = response.content.decode("utf-8")
@@ -197,9 +216,13 @@ class TestExportFunctionality:
assert data["error_code"] == "INVALID_TOKEN"
assert data["status_code"] == 401
def test_csv_export_streaming_response_format(self, client, auth_headers, test_marketplace_product):
def test_csv_export_streaming_response_format(
self, client, auth_headers, test_marketplace_product
):
"""Test that CSV export returns proper streaming response with correct headers"""
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/export-csv", headers=auth_headers
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/csv; charset=utf-8"
@@ -217,23 +240,27 @@ class TestExportFunctionality:
# Create product with special characters that might break CSV
product = MarketplaceProduct(
marketplace_product_id=f"SPECIAL_{unique_suffix}",
title=f'MarketplaceProduct with quotes and commas {unique_suffix}', # Simplified to avoid CSV escaping issues
title=f"MarketplaceProduct with quotes and commas {unique_suffix}", # Simplified to avoid CSV escaping issues
description=f"Description with special chars {unique_suffix}",
marketplace="Test Market",
price="19.99"
price="19.99",
)
db.add(product)
db.commit()
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/export-csv", headers=auth_headers
)
assert response.status_code == 200
csv_content = response.content.decode("utf-8")
# Verify our test product appears in the CSV content
assert f"SPECIAL_{unique_suffix}" in csv_content
assert f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content
assert (
f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content
)
assert "Test Market" in csv_content
assert "19.99" in csv_content
@@ -243,7 +270,12 @@ class TestExportFunctionality:
header = next(csv_reader)
# Verify header contains expected fields
expected_fields = ["marketplace_product_id", "title", "marketplace", "price"]
expected_fields = [
"marketplace_product_id",
"title",
"marketplace",
"price",
]
for field in expected_fields:
assert field in header
@@ -264,12 +296,15 @@ class TestExportFunctionality:
assert parsed_successfully, "CSV should be parseable despite special characters"
def test_csv_export_error_handling_service_failure(self, client, auth_headers, monkeypatch):
def test_csv_export_error_handling_service_failure(
self, client, auth_headers, monkeypatch
):
"""Test CSV export handles service failures gracefully"""
# Mock the service to raise an exception
def mock_generate_csv_export(*args, **kwargs):
from app.exceptions import ValidationException
raise ValidationException("Mocked service failure")
# This would require access to your service instance to mock properly
@@ -293,7 +328,9 @@ class TestExportFunctionality:
def test_csv_export_filename_generation(self, client, auth_headers):
"""Test CSV export generates appropriate filenames based on filters"""
# Test basic export filename
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/export-csv", headers=auth_headers
)
assert response.status_code == 200
content_disposition = response.headers.get("content-disposition", "")
@@ -302,7 +339,7 @@ class TestExportFunctionality:
# Test with marketplace filter
response = client.get(
"/api/v1/marketplace/product/export-csv?marketplace=Amazon",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200

View File

@@ -16,7 +16,9 @@ class TestMarketplaceProductsAPI:
assert data["products"] == []
assert data["total"] == 0
def test_get_products_with_data(self, client, auth_headers, test_marketplace_product):
def test_get_products_with_data(
self, client, auth_headers, test_marketplace_product
):
"""Test getting products with data"""
response = client.get("/api/v1/marketplace/product", headers=auth_headers)
@@ -25,25 +27,39 @@ class TestMarketplaceProductsAPI:
assert len(data["products"]) >= 1
assert data["total"] >= 1
# Find our test product
test_product_found = any(p["marketplace_product_id"] == test_marketplace_product.marketplace_product_id for p in data["products"])
test_product_found = any(
p["marketplace_product_id"]
== test_marketplace_product.marketplace_product_id
for p in data["products"]
)
assert test_product_found
def test_get_products_with_filters(self, client, auth_headers, test_marketplace_product):
def test_get_products_with_filters(
self, client, auth_headers, test_marketplace_product
):
"""Test filtering products"""
# Test brand filter
response = client.get(f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}", headers=auth_headers)
response = client.get(
f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Test marketplace filter
response = client.get(f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}", headers=auth_headers)
response = client.get(
f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Test search
response = client.get("/api/v1/marketplace/product?search=Test", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?search=Test", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
@@ -71,7 +87,9 @@ class TestMarketplaceProductsAPI:
assert data["title"] == "New MarketplaceProduct"
assert data["marketplace"] == "Amazon"
def test_create_product_duplicate_id_returns_conflict(self, client, auth_headers, test_marketplace_product):
def test_create_product_duplicate_id_returns_conflict(
self, client, auth_headers, test_marketplace_product
):
"""Test creating product with duplicate ID returns MarketplaceProductAlreadyExistsException"""
product_data = {
"marketplace_product_id": test_marketplace_product.marketplace_product_id,
@@ -93,7 +111,10 @@ class TestMarketplaceProductsAPI:
assert data["error_code"] == "PRODUCT_ALREADY_EXISTS"
assert data["status_code"] == 409
assert test_marketplace_product.marketplace_product_id in data["message"]
assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
assert (
data["details"]["marketplace_product_id"]
== test_marketplace_product.marketplace_product_id
)
def test_create_product_missing_title_validation_error(self, client, auth_headers):
"""Test creating product without title returns ValidationException"""
@@ -114,7 +135,9 @@ class TestMarketplaceProductsAPI:
assert "MarketplaceProduct title is required" in data["message"]
assert data["details"]["field"] == "title"
def test_create_product_missing_product_id_validation_error(self, client, auth_headers):
def test_create_product_missing_product_id_validation_error(
self, client, auth_headers
):
"""Test creating product without marketplace_product_id returns ValidationException"""
product_data = {
"marketplace_product_id": "", # Empty product ID
@@ -191,20 +214,28 @@ class TestMarketplaceProductsAPI:
assert "Request validation failed" in data["message"]
assert "validation_errors" in data["details"]
def test_get_product_by_id_success(self, client, auth_headers, test_marketplace_product):
def test_get_product_by_id_success(
self, client, auth_headers, test_marketplace_product
):
"""Test getting specific product successfully"""
response = client.get(
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert data["product"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
assert (
data["product"]["marketplace_product_id"]
== test_marketplace_product.marketplace_product_id
)
assert data["product"]["title"] == test_marketplace_product.title
def test_get_nonexistent_product_returns_not_found(self, client, auth_headers):
"""Test getting nonexistent product returns MarketplaceProductNotFoundException"""
response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
@@ -214,7 +245,9 @@ class TestMarketplaceProductsAPI:
assert data["details"]["resource_type"] == "MarketplaceProduct"
assert data["details"]["identifier"] == "NONEXISTENT"
def test_update_product_success(self, client, auth_headers, test_marketplace_product):
def test_update_product_success(
self, client, auth_headers, test_marketplace_product
):
"""Test updating product successfully"""
update_data = {"title": "Updated MarketplaceProduct Title", "price": "25.99"}
@@ -247,7 +280,9 @@ class TestMarketplaceProductsAPI:
assert data["details"]["resource_type"] == "MarketplaceProduct"
assert data["details"]["identifier"] == "NONEXISTENT"
def test_update_product_empty_title_validation_error(self, client, auth_headers, test_marketplace_product):
def test_update_product_empty_title_validation_error(
self, client, auth_headers, test_marketplace_product
):
"""Test updating product with empty title returns MarketplaceProductValidationException"""
update_data = {"title": ""}
@@ -264,7 +299,9 @@ class TestMarketplaceProductsAPI:
assert "MarketplaceProduct title cannot be empty" in data["message"]
assert data["details"]["field"] == "title"
def test_update_product_invalid_gtin_data_error(self, client, auth_headers, test_marketplace_product):
def test_update_product_invalid_gtin_data_error(
self, client, auth_headers, test_marketplace_product
):
"""Test updating product with invalid GTIN returns InvalidMarketplaceProductDataException"""
update_data = {"gtin": "invalid_gtin"}
@@ -281,7 +318,9 @@ class TestMarketplaceProductsAPI:
assert "Invalid GTIN format" in data["message"]
assert data["details"]["field"] == "gtin"
def test_update_product_invalid_price_data_error(self, client, auth_headers, test_marketplace_product):
def test_update_product_invalid_price_data_error(
self, client, auth_headers, test_marketplace_product
):
"""Test updating product with invalid price returns InvalidMarketplaceProductDataException"""
update_data = {"price": "invalid_price"}
@@ -298,10 +337,13 @@ class TestMarketplaceProductsAPI:
assert "Invalid price format" in data["message"]
assert data["details"]["field"] == "price"
def test_delete_product_success(self, client, auth_headers, test_marketplace_product):
def test_delete_product_success(
self, client, auth_headers, test_marketplace_product
):
"""Test deleting product successfully"""
response = client.delete(
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}",
headers=auth_headers,
)
assert response.status_code == 200
@@ -309,7 +351,9 @@ class TestMarketplaceProductsAPI:
def test_delete_nonexistent_product_returns_not_found(self, client, auth_headers):
"""Test deleting nonexistent product returns MarketplaceProductNotFoundException"""
response = client.delete("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
response = client.delete(
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
@@ -331,7 +375,9 @@ class TestMarketplaceProductsAPI:
def test_exception_structure_consistency(self, client, auth_headers):
"""Test that all exceptions follow the consistent WizamartException structure"""
# Test with a known error case
response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
)
assert response.status_code == 404
data = response.json()

View File

@@ -4,6 +4,7 @@ import pytest
from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.database
@@ -13,6 +14,7 @@ class TestPagination:
def test_product_pagination_success(self, client, auth_headers, db):
"""Test pagination for product listing successfully"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
# Create multiple products
@@ -29,7 +31,9 @@ class TestPagination:
db.commit()
# Test first page
response = client.get("/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert len(data["products"]) == 10
@@ -38,21 +42,29 @@ class TestPagination:
assert data["limit"] == 10
# Test second page
response = client.get("/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert len(data["products"]) == 10
assert data["skip"] == 10
# Test last page (should have remaining products)
response = client.get("/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert len(data["products"]) >= 5 # At least 5 remaining from our test set
def test_pagination_boundary_negative_skip_validation_error(self, client, auth_headers):
def test_pagination_boundary_negative_skip_validation_error(
self, client, auth_headers
):
"""Test negative skip parameter returns ValidationException"""
response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?skip=-1", headers=auth_headers
)
assert response.status_code == 422
data = response.json()
@@ -61,9 +73,13 @@ class TestPagination:
assert "Request validation failed" in data["message"]
assert "validation_errors" in data["details"]
def test_pagination_boundary_zero_limit_validation_error(self, client, auth_headers):
def test_pagination_boundary_zero_limit_validation_error(
self, client, auth_headers
):
"""Test zero limit parameter returns ValidationException"""
response = client.get("/api/v1/marketplace/product?limit=0", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=0", headers=auth_headers
)
assert response.status_code == 422
data = response.json()
@@ -71,9 +87,13 @@ class TestPagination:
assert data["status_code"] == 422
assert "Request validation failed" in data["message"]
def test_pagination_boundary_excessive_limit_validation_error(self, client, auth_headers):
def test_pagination_boundary_excessive_limit_validation_error(
self, client, auth_headers
):
"""Test excessive limit parameter returns ValidationException"""
response = client.get("/api/v1/marketplace/product?limit=10000", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=10000", headers=auth_headers
)
assert response.status_code == 422
data = response.json()
@@ -84,7 +104,9 @@ class TestPagination:
def test_pagination_beyond_available_records(self, client, auth_headers, db):
"""Test pagination beyond available records returns empty results"""
# Test skip beyond available records
response = client.get("/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers
)
assert response.status_code == 200
data = response.json()
@@ -96,6 +118,7 @@ class TestPagination:
def test_pagination_with_filters(self, client, auth_headers, db):
"""Test pagination combined with filtering"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
# Create products with same brand for filtering
@@ -115,7 +138,7 @@ class TestPagination:
# Test first page with filter
response = client.get(
"/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=0",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
@@ -124,14 +147,18 @@ class TestPagination:
assert data["total"] >= 15 # At least our test products
# Verify all products have the filtered brand
test_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
test_products = [
p
for p in data["products"]
if p["marketplace_product_id"].endswith(unique_suffix)
]
for product in test_products:
assert product["brand"] == "FilterBrand"
# Test second page with same filter
response = client.get(
"/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=5",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
@@ -152,6 +179,7 @@ class TestPagination:
def test_pagination_consistency(self, client, auth_headers, db):
"""Test pagination consistency across multiple requests"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
# Create products with predictable ordering
@@ -168,14 +196,22 @@ class TestPagination:
db.commit()
# Get first page
response1 = client.get("/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers)
response1 = client.get(
"/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers
)
assert response1.status_code == 200
first_page_ids = [p["marketplace_product_id"] for p in response1.json()["products"]]
first_page_ids = [
p["marketplace_product_id"] for p in response1.json()["products"]
]
# Get second page
response2 = client.get("/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers)
response2 = client.get(
"/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers
)
assert response2.status_code == 200
second_page_ids = [p["marketplace_product_id"] for p in response2.json()["products"]]
second_page_ids = [
p["marketplace_product_id"] for p in response2.json()["products"]
]
# Verify no overlap between pages
overlap = set(first_page_ids) & set(second_page_ids)
@@ -184,11 +220,13 @@ class TestPagination:
def test_vendor_pagination_success(self, client, admin_headers, db, test_user):
"""Test pagination for vendor listing successfully"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
# Create multiple vendors for pagination testing
from models.database.vendor import Vendor
vendors =[]
vendors = []
for i in range(15):
vendor = Vendor(
vendor_code=f"PAGEVENDOR{i:03d}_{unique_suffix}",
@@ -202,9 +240,7 @@ class TestPagination:
db.commit()
# Test first page (assuming admin endpoint exists)
response = client.get(
"/api/v1/vendor?limit=5&skip=0", headers=admin_headers
)
response = client.get("/api/v1/vendor?limit=5&skip=0", headers=admin_headers)
assert response.status_code == 200
data = response.json()
assert len(data["vendors"]) == 5
@@ -215,10 +251,12 @@ class TestPagination:
def test_inventory_pagination_success(self, client, auth_headers, db):
"""Test pagination for inventory listing successfully"""
import uuid
unique_suffix = str(uuid.uuid4())[:8]
# Create multiple inventory entries
from models.database.inventory import Inventory
inventory_entries = []
for i in range(20):
inventory = Inventory(
@@ -249,7 +287,9 @@ class TestPagination:
import time
start_time = time.time()
response = client.get("/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers
)
end_time = time.time()
assert response.status_code == 200
@@ -262,19 +302,25 @@ class TestPagination:
def test_pagination_with_invalid_parameters_types(self, client, auth_headers):
"""Test pagination with invalid parameter types returns ValidationException"""
# Test non-numeric skip
response = client.get("/api/v1/marketplace/product?skip=invalid", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?skip=invalid", headers=auth_headers
)
assert response.status_code == 422
data = response.json()
assert data["error_code"] == "VALIDATION_ERROR"
# Test non-numeric limit
response = client.get("/api/v1/marketplace/product?limit=invalid", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=invalid", headers=auth_headers
)
assert response.status_code == 422
data = response.json()
assert data["error_code"] == "VALIDATION_ERROR"
# Test float values (should be converted or rejected)
response = client.get("/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers
)
assert response.status_code in [200, 422] # Depends on implementation
def test_empty_dataset_pagination(self, client, auth_headers):
@@ -282,7 +328,7 @@ class TestPagination:
# Use a filter that should return no results
response = client.get(
"/api/v1/marketplace/product?brand=NonexistentBrand999&limit=10&skip=0",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
@@ -294,7 +340,9 @@ class TestPagination:
def test_exception_structure_in_pagination_errors(self, client, auth_headers):
"""Test that pagination validation errors follow consistent exception structure"""
response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?skip=-1", headers=auth_headers
)
assert response.status_code == 422
data = response.json()

View File

@@ -1,6 +1,7 @@
# tests/integration/api/v1/test_stats_endpoints.py
import pytest
@pytest.mark.integration
@pytest.mark.api
@pytest.mark.stats
@@ -18,7 +19,9 @@ class TestStatsAPI:
assert "unique_vendors" in data
assert data["total_products"] >= 1
def test_get_marketplace_stats(self, client, auth_headers, test_marketplace_product):
def test_get_marketplace_stats(
self, client, auth_headers, test_marketplace_product
):
"""Test getting marketplace statistics"""
response = client.get("/api/v1/stats/marketplace", headers=auth_headers)

View File

@@ -23,7 +23,9 @@ class TestVendorsAPI:
assert data["name"] == "New Vendor"
assert data["is_active"] is True
def test_create_vendor_duplicate_code_returns_conflict(self, client, auth_headers, test_vendor):
def test_create_vendor_duplicate_code_returns_conflict(
self, client, auth_headers, test_vendor
):
"""Test creating vendor with duplicate code returns VendorAlreadyExistsException"""
vendor_data = {
"vendor_code": test_vendor.vendor_code,
@@ -40,7 +42,9 @@ class TestVendorsAPI:
assert test_vendor.vendor_code in data["message"]
assert data["details"]["vendor_code"] == test_vendor.vendor_code
def test_create_vendor_missing_vendor_code_validation_error(self, client, auth_headers):
def test_create_vendor_missing_vendor_code_validation_error(
self, client, auth_headers
):
"""Test creating vendor without vendor_code returns ValidationException"""
vendor_data = {
"name": "Vendor without Code",
@@ -56,7 +60,9 @@ class TestVendorsAPI:
assert "Request validation failed" in data["message"]
assert "validation_errors" in data["details"]
def test_create_vendor_empty_vendor_name_validation_error(self, client, auth_headers):
def test_create_vendor_empty_vendor_name_validation_error(
self, client, auth_headers
):
"""Test creating vendor with empty name returns VendorValidationException"""
vendor_data = {
"vendor_code": "EMPTYNAME",
@@ -73,7 +79,9 @@ class TestVendorsAPI:
assert "Vendor name is required" in data["message"]
assert data["details"]["field"] == "name"
def test_create_vendor_max_vendors_reached_business_logic_error(self, client, auth_headers, db, test_user):
def test_create_vendor_max_vendors_reached_business_logic_error(
self, client, auth_headers, db, test_user
):
"""Test creating vendor when max vendors reached returns MaxVendorsReachedException"""
# This test would require creating the maximum allowed vendors first
# The exact implementation depends on your business rules
@@ -94,8 +102,10 @@ class TestVendorsAPI:
assert data["total"] >= 1
assert len(data["vendors"]) >= 1
# Find our test vendor
test_vendor_found = any(s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"])
# Find our test vendor
test_vendor_found = any(
s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"]
)
assert test_vendor_found
def test_get_vendors_with_filters(self, client, auth_headers, test_vendor):
@@ -105,7 +115,7 @@ class TestVendorsAPI:
assert response.status_code == 200
data = response.json()
for vendor in data["vendors"]:
assert vendor ["is_active"] is True
assert vendor["is_active"] is True
# Test verified_only filter
response = client.get("/api/v1/vendor?verified_only=true", headers=auth_headers)
@@ -135,7 +145,9 @@ class TestVendorsAPI:
assert data["details"]["resource_type"] == "Vendor"
assert data["details"]["identifier"] == "NONEXISTENT"
def test_get_vendor_unauthorized_access(self, client, auth_headers, test_vendor, other_user, db):
def test_get_vendor_unauthorized_access(
self, client, auth_headers, test_vendor, other_user, db
):
"""Test accessing vendor owned by another user returns UnauthorizedVendorAccessException"""
# Change vendor owner to other user AND make it unverified/inactive
# so that non-owner users cannot access it
@@ -154,7 +166,9 @@ class TestVendorsAPI:
assert test_vendor.vendor_code in data["message"]
assert data["details"]["vendor_code"] == test_vendor.vendor_code
def test_get_vendor_unauthorized_access_with_inactive_vendor(self, client, auth_headers, inactive_vendor):
def test_get_vendor_unauthorized_access_with_inactive_vendor(
self, client, auth_headers, inactive_vendor
):
"""Test accessing inactive vendor owned by another user returns UnauthorizedVendorAccessException"""
# inactive_vendor fixture already creates an unverified, inactive vendor owned by other_user
response = client.get(
@@ -168,7 +182,9 @@ class TestVendorsAPI:
assert inactive_vendor.vendor_code in data["message"]
assert data["details"]["vendor_code"] == inactive_vendor.vendor_code
def test_get_vendor_public_access_allowed(self, client, auth_headers, verified_vendor):
def test_get_vendor_public_access_allowed(
self, client, auth_headers, verified_vendor
):
"""Test accessing verified vendor owned by another user is allowed (public access)"""
# verified_vendor fixture creates a verified, active vendor owned by other_user
# This should allow public access per your business logic
@@ -181,7 +197,9 @@ class TestVendorsAPI:
assert data["vendor_code"] == verified_vendor.vendor_code
assert data["name"] == verified_vendor.name
def test_add_product_to_vendor_success(self, client, auth_headers, test_vendor, unique_product):
def test_add_product_to_vendor_success(
self, client, auth_headers, test_vendor, unique_product
):
"""Test adding product to vendor successfully"""
product_data = {
"marketplace_product_id": unique_product.marketplace_product_id, # Use string marketplace_product_id, not database id
@@ -193,7 +211,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
json=product_data
json=product_data,
)
assert response.status_code == 200
@@ -207,10 +225,15 @@ class TestVendorsAPI:
# MarketplaceProduct details are nested in the 'marketplace_product' field
assert "marketplace_product" in data
assert data["marketplace_product"]["marketplace_product_id"] == unique_product.marketplace_product_id
assert (
data["marketplace_product"]["marketplace_product_id"]
== unique_product.marketplace_product_id
)
assert data["marketplace_product"]["id"] == unique_product.id
def test_add_product_to_vendor_already_exists_conflict(self, client, auth_headers, test_vendor, test_product):
def test_add_product_to_vendor_already_exists_conflict(
self, client, auth_headers, test_vendor, test_product
):
"""Test adding product that already exists in vendor returns ProductAlreadyExistsException"""
# test_product fixture already creates a relationship, get the marketplace_product_id string
existing_product = test_product.marketplace_product
@@ -223,7 +246,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
json=product_data
json=product_data,
)
assert response.status_code == 409
@@ -233,7 +256,9 @@ class TestVendorsAPI:
assert test_vendor.vendor_code in data["message"]
assert existing_product.marketplace_product_id in data["message"]
def test_add_nonexistent_product_to_vendor_not_found(self, client, auth_headers, test_vendor):
def test_add_nonexistent_product_to_vendor_not_found(
self, client, auth_headers, test_vendor
):
"""Test adding nonexistent product to vendor returns MarketplaceProductNotFoundException"""
product_data = {
"marketplace_product_id": "NONEXISTENT_PRODUCT", # Use string marketplace_product_id that doesn't exist
@@ -243,7 +268,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
json=product_data
json=product_data,
)
assert response.status_code == 404
@@ -252,11 +277,12 @@ class TestVendorsAPI:
assert data["status_code"] == 404
assert "NONEXISTENT_PRODUCT" in data["message"]
def test_get_products_success(self, client, auth_headers, test_vendor, test_product):
def test_get_products_success(
self, client, auth_headers, test_vendor, test_product
):
"""Test getting vendor products successfully"""
response = client.get(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers
f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers
)
assert response.status_code == 200
@@ -271,22 +297,21 @@ class TestVendorsAPI:
# Test active_only filter
response = client.get(
f"/api/v1/vendor/{test_vendor.vendor_code}/products?active_only=true",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
# Test featured_only filter
response = client.get(
f"/api/v1/vendor/{test_vendor.vendor_code}/products?featured_only=true",
headers=auth_headers
headers=auth_headers,
)
assert response.status_code == 200
def test_get_products_from_nonexistent_vendor_not_found(self, client, auth_headers):
"""Test getting products from nonexistent vendor returns VendorNotFoundException"""
response = client.get(
"/api/v1/vendor/NONEXISTENT/products",
headers=auth_headers
"/api/v1/vendor/NONEXISTENT/products", headers=auth_headers
)
assert response.status_code == 404
@@ -295,7 +320,9 @@ class TestVendorsAPI:
assert data["status_code"] == 404
assert "NONEXISTENT" in data["message"]
def test_vendor_not_active_business_logic_error(self, client, auth_headers, test_vendor, db):
def test_vendor_not_active_business_logic_error(
self, client, auth_headers, test_vendor, db
):
"""Test accessing inactive vendor returns VendorNotActiveException (if enforced)"""
# Set vendor to inactive
test_vendor.is_active = False
@@ -313,7 +340,9 @@ class TestVendorsAPI:
assert data["status_code"] == 400
assert test_vendor.vendor_code in data["message"]
def test_vendor_not_verified_business_logic_error(self, client, auth_headers, test_vendor, db):
def test_vendor_not_verified_business_logic_error(
self, client, auth_headers, test_vendor, db
):
"""Test operations requiring verification returns VendorNotVerifiedException (if enforced)"""
# Set vendor to unverified
test_vendor.is_verified = False
@@ -328,7 +357,7 @@ class TestVendorsAPI:
response = client.post(
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
headers=auth_headers,
json=product_data
json=product_data,
)
# If your service requires verification for adding products

View File

@@ -10,8 +10,9 @@ These tests verify that:
5. Vendor context middleware works correctly with API authentication
"""
import pytest
from datetime import datetime, timedelta, timezone
import pytest
from jose import jwt
@@ -26,7 +27,9 @@ class TestVendorAPIAuthentication:
# Authentication Tests - /api/v1/vendor/auth/me
# ========================================================================
def test_vendor_auth_me_success(self, client, vendor_user_headers, test_vendor_user):
def test_vendor_auth_me_success(
self, client, vendor_user_headers, test_vendor_user
):
"""Test /auth/me endpoint with valid vendor user token"""
response = client.get("/api/v1/vendor/auth/me", headers=vendor_user_headers)
@@ -50,7 +53,7 @@ class TestVendorAPIAuthentication:
"""Test /auth/me endpoint with invalid token format"""
response = client.get(
"/api/v1/vendor/auth/me",
headers={"Authorization": "Bearer invalid_token_here"}
headers={"Authorization": "Bearer invalid_token_here"},
)
assert response.status_code == 401
@@ -66,7 +69,9 @@ class TestVendorAPIAuthentication:
assert data["error_code"] == "FORBIDDEN"
assert "Admin users cannot access vendor API" in data["message"]
def test_vendor_auth_me_with_regular_user_token(self, client, auth_headers, test_user):
def test_vendor_auth_me_with_regular_user_token(
self, client, auth_headers, test_user
):
"""Test /auth/me endpoint rejects regular users"""
response = client.get("/api/v1/vendor/auth/me", headers=auth_headers)
@@ -88,14 +93,12 @@ class TestVendorAPIAuthentication:
}
expired_token = jwt.encode(
expired_payload,
auth_manager.secret_key,
algorithm=auth_manager.algorithm
expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
response = client.get(
"/api/v1/vendor/auth/me",
headers={"Authorization": f"Bearer {expired_token}"}
headers={"Authorization": f"Bearer {expired_token}"},
)
assert response.status_code == 401
@@ -111,8 +114,7 @@ class TestVendorAPIAuthentication:
):
"""Test dashboard stats with valid vendor authentication"""
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=vendor_user_headers
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -131,10 +133,7 @@ class TestVendorAPIAuthentication:
def test_vendor_dashboard_stats_with_admin(self, client, admin_headers):
"""Test dashboard stats rejects admin users"""
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=admin_headers
)
response = client.get("/api/v1/vendor/dashboard/stats", headers=admin_headers)
assert response.status_code == 403
data = response.json()
@@ -145,10 +144,7 @@ class TestVendorAPIAuthentication:
# Login to get session cookie
login_response = client.post(
"/api/v1/vendor/auth/login",
json={
"username": test_vendor_user.username,
"password": "vendorpass123"
}
json={"username": test_vendor_user.username, "password": "vendorpass123"},
)
assert login_response.status_code == 200
@@ -169,10 +165,7 @@ class TestVendorAPIAuthentication:
# Get a valid session by logging in
login_response = client.post(
"/api/v1/vendor/auth/login",
json={
"username": test_vendor_user.username,
"password": "vendorpass123"
}
json={"username": test_vendor_user.username, "password": "vendorpass123"},
)
assert login_response.status_code == 200
@@ -191,8 +184,9 @@ class TestVendorAPIAuthentication:
response = client.get(endpoint)
# All should fail with 401 (header required)
assert response.status_code == 401, \
f"Endpoint {endpoint} should reject cookie-only auth"
assert (
response.status_code == 401
), f"Endpoint {endpoint} should reject cookie-only auth"
# ========================================================================
# Role-Based Access Control Tests
@@ -211,13 +205,15 @@ class TestVendorAPIAuthentication:
for endpoint in endpoints:
# Test with regular user token
response = client.get(endpoint, headers=auth_headers)
assert response.status_code == 403, \
f"Endpoint {endpoint} should reject regular users"
assert (
response.status_code == 403
), f"Endpoint {endpoint} should reject regular users"
# Test with admin token
response = client.get(endpoint, headers=admin_headers)
assert response.status_code == 403, \
f"Endpoint {endpoint} should reject admin users"
assert (
response.status_code == 403
), f"Endpoint {endpoint} should reject admin users"
def test_vendor_api_accepts_only_vendor_role(
self, client, vendor_user_headers, test_vendor_user
@@ -229,8 +225,10 @@ class TestVendorAPIAuthentication:
for endpoint in endpoints:
response = client.get(endpoint, headers=vendor_user_headers)
assert response.status_code in [200, 404], \
f"Endpoint {endpoint} should accept vendor users (got {response.status_code})"
assert response.status_code in [
200,
404,
], f"Endpoint {endpoint} should accept vendor users (got {response.status_code})"
# ========================================================================
# Token Validation Tests
@@ -248,8 +246,9 @@ class TestVendorAPIAuthentication:
for headers in malformed_headers:
response = client.get("/api/v1/vendor/auth/me", headers=headers)
assert response.status_code == 401, \
f"Should reject malformed header: {headers}"
assert (
response.status_code == 401
), f"Should reject malformed header: {headers}"
def test_token_with_missing_claims(self, client, auth_manager):
"""Test token missing required claims"""
@@ -261,14 +260,12 @@ class TestVendorAPIAuthentication:
}
invalid_token = jwt.encode(
invalid_payload,
auth_manager.secret_key,
algorithm=auth_manager.algorithm
invalid_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
response = client.get(
"/api/v1/vendor/auth/me",
headers={"Authorization": f"Bearer {invalid_token}"}
headers={"Authorization": f"Bearer {invalid_token}"},
)
assert response.status_code == 401
@@ -298,9 +295,7 @@ class TestVendorAPIAuthentication:
db.add(test_vendor_user)
db.commit()
def test_concurrent_requests_with_same_token(
self, client, vendor_user_headers
):
def test_concurrent_requests_with_same_token(self, client, vendor_user_headers):
"""Test that the same token can be used for multiple concurrent requests"""
# Make multiple requests with the same token
responses = []
@@ -314,10 +309,7 @@ class TestVendorAPIAuthentication:
def test_vendor_api_with_empty_authorization_header(self, client):
"""Test vendor API with empty Authorization header value"""
response = client.get(
"/api/v1/vendor/auth/me",
headers={"Authorization": ""}
)
response = client.get("/api/v1/vendor/auth/me", headers={"Authorization": ""})
assert response.status_code == 401
@@ -328,17 +320,12 @@ class TestVendorAPIAuthentication:
class TestVendorAPIConsistency:
"""Test that all vendor API endpoints use consistent authentication"""
def test_all_vendor_endpoints_require_header_auth(
self, client, test_vendor_user
):
def test_all_vendor_endpoints_require_header_auth(self, client, test_vendor_user):
"""Verify all vendor API endpoints require Authorization header"""
# Login to establish session
client.post(
"/api/v1/vendor/auth/login",
json={
"username": test_vendor_user.username,
"password": "vendorpass123"
}
json={"username": test_vendor_user.username, "password": "vendorpass123"},
)
# All vendor API endpoints (excluding public endpoints like /info)
@@ -361,8 +348,9 @@ class TestVendorAPIConsistency:
response = client.post(endpoint, json={})
# All should reject cookie-only auth with 401
assert response.status_code == 401, \
f"Endpoint {endpoint} should require Authorization header (got {response.status_code})"
assert (
response.status_code == 401
), f"Endpoint {endpoint} should require Authorization header (got {response.status_code})"
def test_vendor_endpoints_accept_vendor_token(
self, client, vendor_user_headers, test_vendor_with_vendor_user
@@ -380,5 +368,7 @@ class TestVendorAPIConsistency:
response = client.get(endpoint, headers=vendor_user_headers)
# Should not be authentication/authorization errors
assert response.status_code not in [401, 403], \
f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})"
assert response.status_code not in [
401,
403,
], f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})"

View File

@@ -23,8 +23,7 @@ class TestVendorDashboardAPI:
):
"""Test dashboard stats returns correct data structure"""
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=vendor_user_headers
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -66,9 +65,9 @@ class TestVendorDashboardAPI:
self, client, db, test_vendor_user, auth_manager
):
"""Test that dashboard stats only show data for the authenticated vendor"""
from models.database.vendor import Vendor, VendorUser
from models.database.product import Product
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
from models.database.vendor import Vendor, VendorUser
# Create two separate vendors with different data
vendor1 = Vendor(
@@ -118,10 +117,7 @@ class TestVendorDashboardAPI:
vendor1_headers = {"Authorization": f"Bearer {token_data['access_token']}"}
# Get stats for vendor1
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=vendor1_headers
)
response = client.get("/api/v1/vendor/dashboard/stats", headers=vendor1_headers)
assert response.status_code == 200
data = response.json()
@@ -130,9 +126,7 @@ class TestVendorDashboardAPI:
assert data["vendor"]["id"] == vendor1.id
assert data["products"]["total"] == 3
def test_dashboard_stats_without_vendor_association(
self, client, db, auth_manager
):
def test_dashboard_stats_without_vendor_association(self, client, db, auth_manager):
"""Test dashboard stats for user not associated with any vendor"""
from models.database.user import User
@@ -206,8 +200,7 @@ class TestVendorDashboardAPI:
):
"""Test dashboard stats for vendor with no data"""
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=vendor_user_headers
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -224,8 +217,8 @@ class TestVendorDashboardAPI:
self, client, db, vendor_user_headers, test_vendor_with_vendor_user
):
"""Test dashboard stats accuracy with actual products"""
from models.database.product import Product
from models.database.marketplace_product import MarketplaceProduct
from models.database.product import Product
# Create marketplace products
mp = MarketplaceProduct(
@@ -249,8 +242,7 @@ class TestVendorDashboardAPI:
# Get stats
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=vendor_user_headers
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
assert response.status_code == 200
@@ -267,8 +259,7 @@ class TestVendorDashboardAPI:
start_time = time.time()
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=vendor_user_headers
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
end_time = time.time()
@@ -284,8 +275,7 @@ class TestVendorDashboardAPI:
responses = []
for _ in range(3):
response = client.get(
"/api/v1/vendor/dashboard/stats",
headers=vendor_user_headers
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
)
responses.append(response)

View File

@@ -3,6 +3,7 @@
Fixtures specific to middleware integration tests.
"""
import pytest
from models.database.vendor import Vendor
from models.database.vendor_domain import VendorDomain
from models.database.vendor_theme import VendorTheme
@@ -12,10 +13,7 @@ from models.database.vendor_theme import VendorTheme
def vendor_with_subdomain(db):
"""Create a vendor with subdomain for testing."""
vendor = Vendor(
name="Test Vendor",
code="testvendor",
subdomain="testvendor",
is_active=True
name="Test Vendor", code="testvendor", subdomain="testvendor", is_active=True
)
db.add(vendor)
db.commit()
@@ -30,7 +28,7 @@ def vendor_with_custom_domain(db):
name="Custom Domain Vendor",
code="customvendor",
subdomain="customvendor",
is_active=True
is_active=True,
)
db.add(vendor)
db.commit()
@@ -38,10 +36,7 @@ def vendor_with_custom_domain(db):
# Add custom domain
domain = VendorDomain(
vendor_id=vendor.id,
domain="customdomain.com",
is_active=True,
is_primary=True
vendor_id=vendor.id, domain="customdomain.com", is_active=True, is_primary=True
)
db.add(domain)
db.commit()
@@ -56,7 +51,7 @@ def vendor_with_theme(db):
name="Themed Vendor",
code="themedvendor",
subdomain="themedvendor",
is_active=True
is_active=True,
)
db.add(vendor)
db.commit()
@@ -69,7 +64,7 @@ def vendor_with_theme(db):
secondary_color="#33FF57",
logo_url="/static/vendors/themedvendor/logo.png",
favicon_url="/static/vendors/themedvendor/favicon.ico",
custom_css="body { background: #FF5733; }"
custom_css="body { background: #FF5733; }",
)
db.add(theme)
db.commit()
@@ -81,10 +76,7 @@ def vendor_with_theme(db):
def inactive_vendor(db):
"""Create an inactive vendor for testing."""
vendor = Vendor(
name="Inactive Vendor",
code="inactive",
subdomain="inactive",
is_active=False
name="Inactive Vendor", code="inactive", subdomain="inactive", is_active=False
)
db.add(vendor)
db.commit()

View File

@@ -5,8 +5,10 @@ Integration tests for request context detection end-to-end flow.
These tests verify that context type (API, ADMIN, VENDOR_DASHBOARD, SHOP, FALLBACK)
is correctly detected through real HTTP requests.
"""
import pytest
from unittest.mock import patch
import pytest
from middleware.context import RequestContext
@@ -23,13 +25,22 @@ class TestContextDetectionFlow:
def test_api_path_detected_as_api_context(self, client):
"""Test that /api/* paths are detected as API context."""
from fastapi import Request
from main import app
@app.get("/api/test-api-context")
async def test_api(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"context_enum": (
request.state.context_type.name
if hasattr(request.state, "context_type")
else None
),
}
response = client.get("/api/test-api-context")
@@ -42,12 +53,17 @@ class TestContextDetectionFlow:
def test_nested_api_path_detected_as_api_context(self, client):
"""Test that nested /api/ paths are detected as API context."""
from fastapi import Request
from main import app
@app.get("/api/v1/vendor/products")
async def test_nested_api(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
response = client.get("/api/v1/vendor/products")
@@ -63,13 +79,22 @@ class TestContextDetectionFlow:
def test_admin_path_detected_as_admin_context(self, client):
"""Test that /admin/* paths are detected as ADMIN context."""
from fastapi import Request
from main import app
@app.get("/admin/test-admin-context")
async def test_admin(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"context_enum": (
request.state.context_type.name
if hasattr(request.state, "context_type")
else None
),
}
response = client.get("/admin/test-admin-context")
@@ -82,19 +107,23 @@ class TestContextDetectionFlow:
def test_admin_subdomain_detected_as_admin_context(self, client):
"""Test that admin.* subdomain is detected as ADMIN context."""
from fastapi import Request
from main import app
@app.get("/test-admin-subdomain-context")
async def test_admin_subdomain(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-admin-subdomain-context",
headers={"host": "admin.platform.com"}
"/test-admin-subdomain-context", headers={"host": "admin.platform.com"}
)
assert response.status_code == 200
@@ -104,12 +133,17 @@ class TestContextDetectionFlow:
def test_nested_admin_path_detected_as_admin_context(self, client):
"""Test that nested /admin/ paths are detected as ADMIN context."""
from fastapi import Request
from main import app
@app.get("/admin/vendors/123/edit")
async def test_nested_admin(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
response = client.get("/admin/vendors/123/edit")
@@ -125,21 +159,31 @@ class TestContextDetectionFlow:
def test_vendor_dashboard_path_detected(self, client, vendor_with_subdomain):
"""Test that /vendor/* paths are detected as VENDOR_DASHBOARD context."""
from fastapi import Request
from main import app
@app.get("/vendor/test-vendor-dashboard")
async def test_vendor_dashboard(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"context_enum": (
request.state.context_type.name
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-vendor-dashboard",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -151,19 +195,24 @@ class TestContextDetectionFlow:
def test_nested_vendor_dashboard_path_detected(self, client, vendor_with_subdomain):
"""Test that nested /vendor/ paths are detected as VENDOR_DASHBOARD context."""
from fastapi import Request
from main import app
@app.get("/vendor/products/123/edit")
async def test_nested_vendor(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/products/123/edit",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -174,24 +223,36 @@ class TestContextDetectionFlow:
# Shop Context Detection Tests
# ========================================================================
def test_shop_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
def test_shop_path_with_vendor_detected_as_shop(
self, client, vendor_with_subdomain
):
"""Test that /shop/* paths with vendor are detected as SHOP context."""
from fastapi import Request
from main import app
@app.get("/shop/test-shop-context")
async def test_shop(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"context_enum": (
request.state.context_type.name
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-shop-context",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -200,23 +261,31 @@ class TestContextDetectionFlow:
assert data["context_enum"] == "SHOP"
assert data["has_vendor"] is True
def test_root_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
def test_root_path_with_vendor_detected_as_shop(
self, client, vendor_with_subdomain
):
"""Test that root path with vendor is detected as SHOP context."""
from fastapi import Request
from main import app
@app.get("/test-root-shop")
async def test_root_shop(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-root-shop",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -228,21 +297,27 @@ class TestContextDetectionFlow:
def test_custom_domain_shop_detected(self, client, vendor_with_custom_domain):
"""Test that custom domain shop is detected as SHOP context."""
from fastapi import Request
from main import app
@app.get("/products")
async def test_custom_domain_shop(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/products",
headers={"host": "customdomain.com"}
)
response = client.get("/products", headers={"host": "customdomain.com"})
assert response.status_code == 200
data = response.json()
@@ -256,21 +331,30 @@ class TestContextDetectionFlow:
def test_unknown_path_without_vendor_fallback_context(self, client):
"""Test that unknown paths without vendor get FALLBACK context."""
from fastapi import Request
from main import app
@app.get("/test-fallback-context")
async def test_fallback(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"context_enum": (
request.state.context_type.name
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-fallback-context",
headers={"host": "platform.com"}
"/test-fallback-context", headers={"host": "platform.com"}
)
assert response.status_code == 200
@@ -286,20 +370,26 @@ class TestContextDetectionFlow:
def test_api_path_overrides_vendor_context(self, client, vendor_with_subdomain):
"""Test that /api/* path sets API context even with vendor."""
from fastapi import Request
from main import app
@app.get("/api/test-api-priority")
async def test_api_priority(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/api/test-api-priority",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -312,20 +402,26 @@ class TestContextDetectionFlow:
def test_admin_path_overrides_vendor_context(self, client, vendor_with_subdomain):
"""Test that /admin/* path sets ADMIN context even with vendor."""
from fastapi import Request
from main import app
@app.get("/admin/test-admin-priority")
async def test_admin_priority(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/admin/test-admin-priority",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -333,22 +429,29 @@ class TestContextDetectionFlow:
# Admin path should override vendor context
assert data["context_type"] == "admin"
def test_vendor_dashboard_overrides_shop_context(self, client, vendor_with_subdomain):
def test_vendor_dashboard_overrides_shop_context(
self, client, vendor_with_subdomain
):
"""Test that /vendor/* path sets VENDOR_DASHBOARD, not SHOP."""
from fastapi import Request
from main import app
@app.get("/vendor/test-priority")
async def test_vendor_priority(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-priority",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -363,21 +466,30 @@ class TestContextDetectionFlow:
def test_context_uses_clean_path_for_detection(self, client, vendor_with_subdomain):
"""Test that context detection uses clean_path, not original path."""
from fastapi import Request
from main import app
@app.get("/vendors/{vendor_code}/shop/products")
async def test_clean_path_context(vendor_code: str, request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
"original_path": request.url.path
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"clean_path": (
request.state.clean_path
if hasattr(request.state, "clean_path")
else None
),
"original_path": request.url.path,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendors/{vendor_with_subdomain.code}/shop/products",
headers={"host": "localhost:8000"}
headers={"host": "localhost:8000"},
)
assert response.status_code == 200
@@ -394,15 +506,20 @@ class TestContextDetectionFlow:
def test_context_type_is_enum_instance(self, client):
"""Test that context_type is a RequestContext enum instance."""
from fastapi import Request
from main import app
@app.get("/api/test-enum")
async def test_enum(request: Request):
context = request.state.context_type if hasattr(request.state, 'context_type') else None
context = (
request.state.context_type
if hasattr(request.state, "context_type")
else None
)
return {
"is_enum": isinstance(context, RequestContext) if context else False,
"enum_name": context.name if context else None,
"enum_value": context.value if context else None
"enum_value": context.value if context else None,
}
response = client.get("/api/test-enum")
@@ -417,23 +534,30 @@ class TestContextDetectionFlow:
# Edge Cases
# ========================================================================
def test_empty_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
def test_empty_path_with_vendor_detected_as_shop(
self, client, vendor_with_subdomain
):
"""Test that empty/root path with vendor is detected as SHOP."""
from fastapi import Request
from main import app
@app.get("/")
async def test_root(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
"/", headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
)
assert response.status_code in [200, 404] # Might not have root handler
@@ -445,13 +569,18 @@ class TestContextDetectionFlow:
def test_case_insensitive_context_detection(self, client):
"""Test that context detection is case insensitive for paths."""
from fastapi import Request
from main import app
@app.get("/API/test-case")
@app.get("/api/test-case")
async def test_case(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
# Test uppercase

View File

@@ -5,8 +5,10 @@ Integration tests for the complete middleware stack.
These tests verify that all middleware components work together correctly
through real HTTP requests, ensuring proper execution order and state injection.
"""
import pytest
from unittest.mock import patch
import pytest
from middleware.context import RequestContext
@@ -23,14 +25,20 @@ class TestMiddlewareStackIntegration:
"""Test that /admin/* paths set ADMIN context type."""
# Create a simple endpoint to inspect request state
from fastapi import Request
from main import app
@app.get("/admin/test-context")
async def test_admin_context(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"has_theme": hasattr(request.state, 'theme')
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"has_theme": hasattr(request.state, "theme"),
}
response = client.get("/admin/test-context")
@@ -44,20 +52,24 @@ class TestMiddlewareStackIntegration:
def test_admin_subdomain_sets_admin_context(self, client):
"""Test that admin.* subdomain sets ADMIN context type."""
from fastapi import Request
from main import app
@app.get("/test-admin-subdomain")
async def test_admin_subdomain(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
# Simulate request with admin subdomain
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-admin-subdomain",
headers={"host": "admin.platform.com"}
"/test-admin-subdomain", headers={"host": "admin.platform.com"}
)
assert response.status_code == 200
@@ -71,12 +83,17 @@ class TestMiddlewareStackIntegration:
def test_api_path_sets_api_context(self, client):
"""Test that /api/* paths set API context type."""
from fastapi import Request
from main import app
@app.get("/api/test-context")
async def test_api_context(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
)
}
response = client.get("/api/test-context")
@@ -89,25 +106,40 @@ class TestMiddlewareStackIntegration:
# Vendor Dashboard Context Tests
# ========================================================================
def test_vendor_dashboard_path_sets_vendor_context(self, client, vendor_with_subdomain):
def test_vendor_dashboard_path_sets_vendor_context(
self, client, vendor_with_subdomain
):
"""Test that /vendor/* paths with vendor set VENDOR_DASHBOARD context."""
from fastapi import Request
from main import app
@app.get("/vendor/test-context")
async def test_vendor_context(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"vendor_id": (
request.state.vendor_id
if hasattr(request.state, "vendor_id")
else None
),
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor")
else None
),
}
# Request with vendor subdomain
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-context",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -120,24 +152,35 @@ class TestMiddlewareStackIntegration:
# Shop Context Tests
# ========================================================================
def test_shop_path_with_subdomain_sets_shop_context(self, client, vendor_with_subdomain):
def test_shop_path_with_subdomain_sets_shop_context(
self, client, vendor_with_subdomain
):
"""Test that /shop/* paths with vendor subdomain set SHOP context."""
from fastapi import Request
from main import app
@app.get("/shop/test-context")
async def test_shop_context(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
"has_theme": hasattr(request.state, 'theme')
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"vendor_id": (
request.state.vendor_id
if hasattr(request.state, "vendor_id")
else None
),
"has_theme": hasattr(request.state, "theme"),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-context",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -146,23 +189,33 @@ class TestMiddlewareStackIntegration:
assert data["vendor_id"] == vendor_with_subdomain.id
assert data["has_theme"] is True
def test_shop_path_with_custom_domain_sets_shop_context(self, client, vendor_with_custom_domain):
def test_shop_path_with_custom_domain_sets_shop_context(
self, client, vendor_with_custom_domain
):
"""Test that /shop/* paths with custom domain set SHOP context."""
from fastapi import Request
from main import app
@app.get("/shop/test-custom-domain")
async def test_shop_custom_domain(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"vendor_id": (
request.state.vendor_id
if hasattr(request.state, "vendor_id")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-custom-domain",
headers={"host": "customdomain.com"}
"/shop/test-custom-domain", headers={"host": "customdomain.com"}
)
assert response.status_code == 200
@@ -174,9 +227,12 @@ class TestMiddlewareStackIntegration:
# Middleware Execution Order Tests
# ========================================================================
def test_vendor_context_runs_before_context_detection(self, client, vendor_with_subdomain):
def test_vendor_context_runs_before_context_detection(
self, client, vendor_with_subdomain
):
"""Test that VendorContextMiddleware runs before ContextDetectionMiddleware."""
from fastapi import Request
from main import app
@app.get("/test-execution-order")
@@ -184,16 +240,16 @@ class TestMiddlewareStackIntegration:
# If vendor context runs first, clean_path should be available
# before context detection uses it
return {
"has_vendor": hasattr(request.state, 'vendor'),
"has_clean_path": hasattr(request.state, 'clean_path'),
"has_context_type": hasattr(request.state, 'context_type')
"has_vendor": hasattr(request.state, "vendor"),
"has_clean_path": hasattr(request.state, "clean_path"),
"has_context_type": hasattr(request.state, "context_type"),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-execution-order",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -206,21 +262,26 @@ class TestMiddlewareStackIntegration:
def test_theme_context_runs_after_vendor_context(self, client, vendor_with_theme):
"""Test that ThemeContextMiddleware runs after VendorContextMiddleware."""
from fastapi import Request
from main import app
@app.get("/test-theme-loading")
async def test_theme_loading(request: Request):
return {
"has_vendor": hasattr(request.state, 'vendor'),
"has_theme": hasattr(request.state, 'theme'),
"theme_primary_color": request.state.theme.get('primary_color') if hasattr(request.state, 'theme') else None
"has_vendor": hasattr(request.state, "vendor"),
"has_theme": hasattr(request.state, "theme"),
"theme_primary_color": (
request.state.theme.get("primary_color")
if hasattr(request.state, "theme")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-loading",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -249,21 +310,27 @@ class TestMiddlewareStackIntegration:
def test_missing_vendor_graceful_handling(self, client):
"""Test that missing vendor is handled gracefully."""
from fastapi import Request
from main import app
@app.get("/test-missing-vendor")
async def test_missing_vendor(request: Request):
return {
"has_vendor": hasattr(request.state, 'vendor'),
"vendor": request.state.vendor if hasattr(request.state, 'vendor') else None,
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
"has_vendor": hasattr(request.state, "vendor"),
"vendor": (
request.state.vendor if hasattr(request.state, "vendor") else None
),
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-missing-vendor",
headers={"host": "nonexistent.platform.com"}
"/test-missing-vendor", headers={"host": "nonexistent.platform.com"}
)
assert response.status_code == 200
@@ -276,20 +343,23 @@ class TestMiddlewareStackIntegration:
def test_inactive_vendor_not_loaded(self, client, inactive_vendor):
"""Test that inactive vendors are not loaded."""
from fastapi import Request
from main import app
@app.get("/test-inactive-vendor")
async def test_inactive_vendor_endpoint(request: Request):
return {
"has_vendor": hasattr(request.state, 'vendor'),
"vendor": request.state.vendor if hasattr(request.state, 'vendor') else None
"has_vendor": hasattr(request.state, "vendor"),
"vendor": (
request.state.vendor if hasattr(request.state, "vendor") else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-inactive-vendor",
headers={"host": f"{inactive_vendor.subdomain}.platform.com"}
headers={"host": f"{inactive_vendor.subdomain}.platform.com"},
)
assert response.status_code == 200

View File

@@ -5,9 +5,10 @@ Integration tests for theme loading end-to-end flow.
These tests verify that vendor themes are correctly loaded and injected
into request.state through real HTTP requests.
"""
import pytest
from unittest.mock import patch
import pytest
@pytest.mark.integration
@pytest.mark.middleware
@@ -22,21 +23,22 @@ class TestThemeLoadingFlow:
def test_theme_loaded_for_vendor_with_custom_theme(self, client, vendor_with_theme):
"""Test that custom theme is loaded for vendor with theme."""
from fastapi import Request
from main import app
@app.get("/test-theme-loading")
async def test_theme(request: Request):
theme = request.state.theme if hasattr(request.state, 'theme') else None
theme = request.state.theme if hasattr(request.state, "theme") else None
return {
"has_theme": theme is not None,
"theme_data": theme if theme else None
"theme_data": theme if theme else None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-loading",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -46,24 +48,27 @@ class TestThemeLoadingFlow:
assert data["theme_data"]["primary_color"] == "#FF5733"
assert data["theme_data"]["secondary_color"] == "#33FF57"
def test_default_theme_loaded_for_vendor_without_theme(self, client, vendor_with_subdomain):
def test_default_theme_loaded_for_vendor_without_theme(
self, client, vendor_with_subdomain
):
"""Test that default theme is loaded for vendor without custom theme."""
from fastapi import Request
from main import app
@app.get("/test-default-theme")
async def test_default_theme(request: Request):
theme = request.state.theme if hasattr(request.state, 'theme') else None
theme = request.state.theme if hasattr(request.state, "theme") else None
return {
"has_theme": theme is not None,
"theme_data": theme if theme else None
"theme_data": theme if theme else None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-default-theme",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -76,21 +81,20 @@ class TestThemeLoadingFlow:
def test_no_theme_loaded_without_vendor(self, client):
"""Test that no theme is loaded when there's no vendor."""
from fastapi import Request
from main import app
@app.get("/test-no-theme")
async def test_no_theme(request: Request):
return {
"has_theme": hasattr(request.state, 'theme'),
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
"has_theme": hasattr(request.state, "theme"),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-no-theme",
headers={"host": "platform.com"}
)
response = client.get("/test-no-theme", headers={"host": "platform.com"})
assert response.status_code == 200
data = response.json()
@@ -105,24 +109,25 @@ class TestThemeLoadingFlow:
def test_custom_theme_contains_all_fields(self, client, vendor_with_theme):
"""Test that custom theme contains all expected fields."""
from fastapi import Request
from main import app
@app.get("/test-theme-fields")
async def test_theme_fields(request: Request):
theme = request.state.theme if hasattr(request.state, 'theme') else {}
theme = request.state.theme if hasattr(request.state, "theme") else {}
return {
"primary_color": theme.get("primary_color"),
"secondary_color": theme.get("secondary_color"),
"logo_url": theme.get("logo_url"),
"favicon_url": theme.get("favicon_url"),
"custom_css": theme.get("custom_css")
"custom_css": theme.get("custom_css"),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-fields",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -136,24 +141,25 @@ class TestThemeLoadingFlow:
def test_default_theme_structure(self, client, vendor_with_subdomain):
"""Test that default theme has expected structure."""
from fastapi import Request
from main import app
@app.get("/test-default-theme-structure")
async def test_default_structure(request: Request):
theme = request.state.theme if hasattr(request.state, 'theme') else {}
theme = request.state.theme if hasattr(request.state, "theme") else {}
return {
"has_primary_color": "primary_color" in theme,
"has_secondary_color": "secondary_color" in theme,
"has_logo_url": "logo_url" in theme,
"has_favicon_url": "favicon_url" in theme,
"has_custom_css": "custom_css" in theme
"has_custom_css": "custom_css" in theme,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-default-theme-structure",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -169,21 +175,30 @@ class TestThemeLoadingFlow:
def test_theme_loaded_in_shop_context(self, client, vendor_with_theme):
"""Test that theme is loaded in SHOP context."""
from fastapi import Request
from main import app
@app.get("/shop/test-shop-theme")
async def test_shop_theme(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_theme": hasattr(request.state, 'theme'),
"theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_theme": hasattr(request.state, "theme"),
"theme_primary": (
request.state.theme.get("primary_color")
if hasattr(request.state, "theme")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-shop-theme",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -195,21 +210,30 @@ class TestThemeLoadingFlow:
def test_theme_loaded_in_vendor_dashboard_context(self, client, vendor_with_theme):
"""Test that theme is loaded in VENDOR_DASHBOARD context."""
from fastapi import Request
from main import app
@app.get("/vendor/test-dashboard-theme")
async def test_dashboard_theme(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_theme": hasattr(request.state, 'theme'),
"theme_secondary": request.state.theme.get("secondary_color") if hasattr(request.state, 'theme') else None
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_theme": hasattr(request.state, "theme"),
"theme_secondary": (
request.state.theme.get("secondary_color")
if hasattr(request.state, "theme")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/vendor/test-dashboard-theme",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -221,21 +245,27 @@ class TestThemeLoadingFlow:
def test_theme_loaded_in_api_context_with_vendor(self, client, vendor_with_theme):
"""Test that theme is loaded in API context when vendor is present."""
from fastapi import Request
from main import app
@app.get("/api/test-api-theme")
async def test_api_theme(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"has_theme": hasattr(request.state, 'theme')
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"has_theme": hasattr(request.state, "theme"),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/api/test-api-theme",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -248,13 +278,18 @@ class TestThemeLoadingFlow:
def test_no_theme_in_admin_context(self, client):
"""Test that theme is not loaded in ADMIN context (no vendor)."""
from fastapi import Request
from main import app
@app.get("/admin/test-admin-no-theme")
async def test_admin_no_theme(request: Request):
return {
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
"has_theme": hasattr(request.state, 'theme')
"context_type": (
request.state.context_type.value
if hasattr(request.state, "context_type")
else None
),
"has_theme": hasattr(request.state, "theme"),
}
response = client.get("/admin/test-admin-no-theme")
@@ -272,20 +307,29 @@ class TestThemeLoadingFlow:
def test_theme_loaded_with_subdomain_routing(self, client, vendor_with_theme):
"""Test theme loading with subdomain routing."""
from fastapi import Request
from main import app
@app.get("/test-subdomain-theme")
async def test_subdomain_theme(request: Request):
return {
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
"theme_logo": request.state.theme.get("logo_url") if hasattr(request.state, 'theme') else None
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
"theme_logo": (
request.state.theme.get("logo_url")
if hasattr(request.state, "theme")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-subdomain-theme",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -293,33 +337,40 @@ class TestThemeLoadingFlow:
assert data["vendor_code"] == vendor_with_theme.code
assert data["theme_logo"] == "/static/vendors/themedvendor/logo.png"
def test_theme_loaded_with_custom_domain_routing(self, client, vendor_with_custom_domain, db):
def test_theme_loaded_with_custom_domain_routing(
self, client, vendor_with_custom_domain, db
):
"""Test theme loading with custom domain routing."""
# Add theme to custom domain vendor
from models.database.vendor_theme import VendorTheme
theme = VendorTheme(
vendor_id=vendor_with_custom_domain.id,
primary_color="#123456",
secondary_color="#654321"
secondary_color="#654321",
)
db.add(theme)
db.commit()
from fastapi import Request
from main import app
@app.get("/test-custom-domain-theme")
async def test_custom_domain_theme(request: Request):
return {
"has_theme": hasattr(request.state, 'theme'),
"theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
"has_theme": hasattr(request.state, "theme"),
"theme_primary": (
request.state.theme.get("primary_color")
if hasattr(request.state, "theme")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-custom-domain-theme",
headers={"host": "customdomain.com"}
"/test-custom-domain-theme", headers={"host": "customdomain.com"}
)
assert response.status_code == 200
@@ -331,28 +382,37 @@ class TestThemeLoadingFlow:
# Theme Dependency on Vendor Context Tests
# ========================================================================
def test_theme_middleware_depends_on_vendor_middleware(self, client, vendor_with_theme):
def test_theme_middleware_depends_on_vendor_middleware(
self, client, vendor_with_theme
):
"""Test that theme loading depends on vendor being detected first."""
from fastapi import Request
from main import app
@app.get("/test-theme-vendor-dependency")
async def test_dependency(request: Request):
return {
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
"has_theme": hasattr(request.state, 'theme'),
"has_vendor": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor_id": (
request.state.vendor_id
if hasattr(request.state, "vendor_id")
else None
),
"has_theme": hasattr(request.state, "theme"),
"vendor_matches_theme": (
request.state.vendor_id == vendor_with_theme.id
if hasattr(request.state, 'vendor_id') else False
)
if hasattr(request.state, "vendor_id")
else False
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-vendor-dependency",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -369,15 +429,20 @@ class TestThemeLoadingFlow:
def test_theme_loaded_consistently_across_requests(self, client, vendor_with_theme):
"""Test that theme is loaded consistently across multiple requests."""
from fastapi import Request
from main import app
@app.get("/test-theme-consistency")
async def test_consistency(request: Request):
return {
"theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
"theme_primary": (
request.state.theme.get("primary_color")
if hasattr(request.state, "theme")
else None
)
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
# Make multiple requests
@@ -385,7 +450,7 @@ class TestThemeLoadingFlow:
for _ in range(3):
response = client.get(
"/test-theme-consistency",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
responses.append(response.json())
@@ -396,25 +461,28 @@ class TestThemeLoadingFlow:
# Edge Cases and Error Handling Tests
# ========================================================================
def test_theme_gracefully_handles_missing_theme_fields(self, client, vendor_with_subdomain):
def test_theme_gracefully_handles_missing_theme_fields(
self, client, vendor_with_subdomain
):
"""Test that missing theme fields are handled gracefully."""
from fastapi import Request
from main import app
@app.get("/test-partial-theme")
async def test_partial_theme(request: Request):
theme = request.state.theme if hasattr(request.state, 'theme') else {}
theme = request.state.theme if hasattr(request.state, "theme") else {}
return {
"has_theme": bool(theme),
"primary_color": theme.get("primary_color", "default"),
"logo_url": theme.get("logo_url", "default")
"logo_url": theme.get("logo_url", "default"),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-partial-theme",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -427,23 +495,21 @@ class TestThemeLoadingFlow:
def test_theme_dict_is_mutable(self, client, vendor_with_theme):
"""Test that theme dict can be accessed and read from."""
from fastapi import Request
from main import app
@app.get("/test-theme-mutable")
async def test_mutable(request: Request):
theme = request.state.theme if hasattr(request.state, 'theme') else {}
theme = request.state.theme if hasattr(request.state, "theme") else {}
# Try to access theme values
primary = theme.get("primary_color")
return {
"can_read": primary is not None,
"value": primary
}
return {"can_read": primary is not None, "value": primary}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-theme-mutable",
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
)
assert response.status_code == 200

View File

@@ -5,9 +5,10 @@ Integration tests for vendor context detection end-to-end flow.
These tests verify that vendor detection works correctly through real HTTP requests
for all routing modes: subdomain, custom domain, and path-based.
"""
import pytest
from unittest.mock import patch
import pytest
@pytest.mark.integration
@pytest.mark.middleware
@@ -22,23 +23,37 @@ class TestVendorContextFlow:
def test_subdomain_vendor_detection(self, client, vendor_with_subdomain):
"""Test vendor detection via subdomain routing."""
from fastapi import Request
from main import app
@app.get("/test-subdomain-detection")
async def test_subdomain(request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
"vendor_name": request.state.vendor.name if hasattr(request.state, 'vendor') and request.state.vendor else None,
"detection_method": "subdomain"
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor_id": (
request.state.vendor_id
if hasattr(request.state, "vendor_id")
else None
),
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
"vendor_name": (
request.state.vendor.name
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
"detection_method": "subdomain",
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-subdomain-detection",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -51,20 +66,28 @@ class TestVendorContextFlow:
def test_subdomain_with_port_detection(self, client, vendor_with_subdomain):
"""Test vendor detection via subdomain with port number."""
from fastapi import Request
from main import app
@app.get("/test-subdomain-port")
async def test_subdomain_port(request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-subdomain-port",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com:8000"}
headers={
"host": f"{vendor_with_subdomain.subdomain}.platform.com:8000"
},
)
assert response.status_code == 200
@@ -75,20 +98,24 @@ class TestVendorContextFlow:
def test_nonexistent_subdomain_returns_no_vendor(self, client):
"""Test that nonexistent subdomain doesn't crash and returns no vendor."""
from fastapi import Request
from main import app
@app.get("/test-nonexistent-subdomain")
async def test_nonexistent(request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor": request.state.vendor if hasattr(request.state, 'vendor') else None
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor": (
request.state.vendor if hasattr(request.state, "vendor") else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-nonexistent-subdomain",
headers={"host": "nonexistent.platform.com"}
headers={"host": "nonexistent.platform.com"},
)
assert response.status_code == 200
@@ -102,22 +129,31 @@ class TestVendorContextFlow:
def test_custom_domain_vendor_detection(self, client, vendor_with_custom_domain):
"""Test vendor detection via custom domain."""
from fastapi import Request
from main import app
@app.get("/test-custom-domain")
async def test_custom_domain(request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
"detection_method": "custom_domain"
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor_id": (
request.state.vendor_id
if hasattr(request.state, "vendor_id")
else None
),
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
"detection_method": "custom_domain",
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-custom-domain",
headers={"host": "customdomain.com"}
"/test-custom-domain", headers={"host": "customdomain.com"}
)
assert response.status_code == 200
@@ -129,21 +165,26 @@ class TestVendorContextFlow:
def test_custom_domain_with_www_detection(self, client, vendor_with_custom_domain):
"""Test vendor detection via custom domain with www prefix."""
from fastapi import Request
from main import app
@app.get("/test-custom-domain-www")
async def test_custom_domain_www(request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
# Test with www prefix - should still detect vendor
response = client.get(
"/test-custom-domain-www",
headers={"host": "www.customdomain.com"}
"/test-custom-domain-www", headers={"host": "www.customdomain.com"}
)
# This might fail if your implementation doesn't strip www
@@ -154,25 +195,37 @@ class TestVendorContextFlow:
# Path-Based Detection Tests (Development Mode)
# ========================================================================
def test_path_based_vendor_detection_vendors_prefix(self, client, vendor_with_subdomain):
def test_path_based_vendor_detection_vendors_prefix(
self, client, vendor_with_subdomain
):
"""Test vendor detection via path-based routing with /vendors/ prefix."""
from fastapi import Request
from main import app
@app.get("/vendors/{vendor_code}/test-path")
async def test_path_based(vendor_code: str, request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor_code_param": vendor_code,
"vendor_code_state": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None
"vendor_code_state": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
"clean_path": (
request.state.clean_path
if hasattr(request.state, "clean_path")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendors/{vendor_with_subdomain.code}/test-path",
headers={"host": "localhost:8000"}
headers={"host": "localhost:8000"},
)
assert response.status_code == 200
@@ -181,23 +234,31 @@ class TestVendorContextFlow:
assert data["vendor_code_param"] == vendor_with_subdomain.code
assert data["vendor_code_state"] == vendor_with_subdomain.code
def test_path_based_vendor_detection_vendor_prefix(self, client, vendor_with_subdomain):
def test_path_based_vendor_detection_vendor_prefix(
self, client, vendor_with_subdomain
):
"""Test vendor detection via path-based routing with /vendor/ prefix."""
from fastapi import Request
from main import app
@app.get("/vendor/{vendor_code}/test")
async def test_vendor_path(vendor_code: str, request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None,
"vendor_code": (
request.state.vendor.code
if hasattr(request.state, "vendor") and request.state.vendor
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendor/{vendor_with_subdomain.code}/test",
headers={"host": "localhost:8000"}
headers={"host": "localhost:8000"},
)
assert response.status_code == 200
@@ -209,48 +270,63 @@ class TestVendorContextFlow:
# Clean Path Extraction Tests
# ========================================================================
def test_clean_path_extracted_from_vendor_prefix(self, client, vendor_with_subdomain):
def test_clean_path_extracted_from_vendor_prefix(
self, client, vendor_with_subdomain
):
"""Test that clean_path is correctly extracted from path-based routing."""
from fastapi import Request
from main import app
@app.get("/vendors/{vendor_code}/shop/products")
async def test_clean_path(vendor_code: str, request: Request):
return {
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
"original_path": request.url.path
"clean_path": (
request.state.clean_path
if hasattr(request.state, "clean_path")
else None
),
"original_path": request.url.path,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
f"/vendors/{vendor_with_subdomain.code}/shop/products",
headers={"host": "localhost:8000"}
headers={"host": "localhost:8000"},
)
assert response.status_code == 200
data = response.json()
# Clean path should have vendor prefix removed
assert data["clean_path"] == "/shop/products"
assert f"/vendors/{vendor_with_subdomain.code}/shop/products" in data["original_path"]
assert (
f"/vendors/{vendor_with_subdomain.code}/shop/products"
in data["original_path"]
)
def test_clean_path_unchanged_for_subdomain(self, client, vendor_with_subdomain):
"""Test that clean_path equals original path for subdomain routing."""
from fastapi import Request
from main import app
@app.get("/shop/test-clean-path")
async def test_subdomain_clean_path(request: Request):
return {
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
"original_path": request.url.path
"clean_path": (
request.state.clean_path
if hasattr(request.state, "clean_path")
else None
),
"original_path": request.url.path,
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/shop/test-clean-path",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -266,21 +342,30 @@ class TestVendorContextFlow:
def test_vendor_id_injected_into_request_state(self, client, vendor_with_subdomain):
"""Test that vendor_id is correctly injected into request.state."""
from fastapi import Request
from main import app
@app.get("/test-vendor-id-injection")
async def test_vendor_id(request: Request):
return {
"has_vendor_id": hasattr(request.state, 'vendor_id'),
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
"vendor_id_type": type(request.state.vendor_id).__name__ if hasattr(request.state, 'vendor_id') else None
"has_vendor_id": hasattr(request.state, "vendor_id"),
"vendor_id": (
request.state.vendor_id
if hasattr(request.state, "vendor_id")
else None
),
"vendor_id_type": (
type(request.state.vendor_id).__name__
if hasattr(request.state, "vendor_id")
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-vendor-id-injection",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -289,30 +374,41 @@ class TestVendorContextFlow:
assert data["vendor_id"] == vendor_with_subdomain.id
assert data["vendor_id_type"] == "int"
def test_vendor_object_injected_into_request_state(self, client, vendor_with_subdomain):
def test_vendor_object_injected_into_request_state(
self, client, vendor_with_subdomain
):
"""Test that full vendor object is injected into request.state."""
from fastapi import Request
from main import app
@app.get("/test-vendor-object-injection")
async def test_vendor_object(request: Request):
vendor = request.state.vendor if hasattr(request.state, 'vendor') and request.state.vendor else None
vendor = (
request.state.vendor
if hasattr(request.state, "vendor") and request.state.vendor
else None
)
return {
"has_vendor": vendor is not None,
"vendor_attributes": {
"id": vendor.id if vendor else None,
"name": vendor.name if vendor else None,
"code": vendor.code if vendor else None,
"subdomain": vendor.subdomain if vendor else None,
"is_active": vendor.is_active if vendor else None
} if vendor else None
"vendor_attributes": (
{
"id": vendor.id if vendor else None,
"name": vendor.name if vendor else None,
"code": vendor.code if vendor else None,
"subdomain": vendor.subdomain if vendor else None,
"is_active": vendor.is_active if vendor else None,
}
if vendor
else None
),
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-vendor-object-injection",
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -330,19 +426,21 @@ class TestVendorContextFlow:
def test_inactive_vendor_not_detected(self, client, inactive_vendor):
"""Test that inactive vendors are not detected."""
from fastapi import Request
from main import app
@app.get("/test-inactive-vendor-detection")
async def test_inactive(request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-inactive-vendor-detection",
headers={"host": f"{inactive_vendor.subdomain}.platform.com"}
headers={"host": f"{inactive_vendor.subdomain}.platform.com"},
)
assert response.status_code == 200
@@ -352,19 +450,20 @@ class TestVendorContextFlow:
def test_platform_domain_without_subdomain_no_vendor(self, client):
"""Test that platform domain without subdomain doesn't detect vendor."""
from fastapi import Request
from main import app
@app.get("/test-platform-domain")
async def test_platform(request: Request):
return {
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None
"vendor_detected": hasattr(request.state, "vendor")
and request.state.vendor is not None
}
with patch('app.core.config.settings') as mock_settings:
with patch("app.core.config.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
response = client.get(
"/test-platform-domain",
headers={"host": "platform.com"}
"/test-platform-domain", headers={"host": "platform.com"}
)
assert response.status_code == 200

View File

@@ -11,7 +11,8 @@ class TestInputValidation:
malicious_search = "'; DROP TABLE products; --"
response = client.get(
f"/api/v1/marketplace/product?search={malicious_search}", headers=auth_headers
f"/api/v1/marketplace/product?search={malicious_search}",
headers=auth_headers,
)
# Should not crash and should return normal response
@@ -40,17 +41,23 @@ class TestInputValidation:
def test_parameter_validation(self, client, auth_headers):
"""Test parameter validation for API endpoints"""
# Test invalid pagination parameters
response = client.get("/api/v1/marketplace/product?limit=-1", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=-1", headers=auth_headers
)
assert response.status_code == 422 # Validation error
response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?skip=-1", headers=auth_headers
)
assert response.status_code == 422 # Validation error
def test_json_validation(self, client, auth_headers):
"""Test JSON validation for POST requests"""
# Test invalid JSON structure
response = client.post(
"/api/v1/marketplace/product", headers=auth_headers, content="invalid json content"
"/api/v1/marketplace/product",
headers=auth_headers,
content="invalid json content",
)
assert response.status_code == 422 # JSON decode error
@@ -58,6 +65,8 @@ class TestInputValidation:
response = client.post(
"/api/v1/marketplace/product",
headers=auth_headers,
json={"title": "Test MarketplaceProduct"}, # Missing required marketplace_product_id
json={
"title": "Test MarketplaceProduct"
}, # Missing required marketplace_product_id
)
assert response.status_code == 422 # Validation error

View File

@@ -33,12 +33,15 @@ class TestIntegrationFlows:
"quantity": 50,
}
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
response = client.post(
"/api/v1/inventory", headers=auth_headers, json=inventory_data
)
assert response.status_code == 200
# 3. Get product with inventory info
response = client.get(
f"/api/v1/marketplace/product/{product['marketplace_product_id']}", headers=auth_headers
f"/api/v1/marketplace/product/{product['marketplace_product_id']}",
headers=auth_headers,
)
assert response.status_code == 200
product_detail = response.json()
@@ -55,14 +58,15 @@ class TestIntegrationFlows:
# 5. Search for product
response = client.get(
"/api/v1/marketplace/product?search=Updated Integration", headers=auth_headers
"/api/v1/marketplace/product?search=Updated Integration",
headers=auth_headers,
)
assert response.status_code == 200
assert response.json()["total"] == 1
def test_product_workflow(self, client, auth_headers):
"""Test vendor creation and product management workflow"""
# 1. Create a vendor
# 1. Create a vendor
vendor_data = {
"vendor_code": "FLOWVENDOR",
"name": "Integration Flow Vendor",
@@ -91,7 +95,9 @@ class TestIntegrationFlows:
# This would test the vendor -product association
# 4. Get vendor details
response = client.get(f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers)
response = client.get(
f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers
)
assert response.status_code == 200
def test_inventory_operations_workflow(self, client, auth_headers):

View File

@@ -28,7 +28,9 @@ class TestPerformance:
# Time the request
start_time = time.time()
response = client.get("/api/v1/marketplace/product?limit=100", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?limit=100", headers=auth_headers
)
end_time = time.time()
assert response.status_code == 200
@@ -54,7 +56,9 @@ class TestPerformance:
# Time search request
start_time = time.time()
response = client.get("/api/v1/marketplace/product?search=Searchable", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product?search=Searchable", headers=auth_headers
)
end_time = time.time()
assert response.status_code == 200
@@ -115,7 +119,8 @@ class TestPerformance:
for offset in offsets:
start_time = time.time()
response = client.get(
f"/api/v1/marketplace/product?skip={offset}&limit=20", headers=auth_headers
f"/api/v1/marketplace/product?skip={offset}&limit=20",
headers=auth_headers,
)
end_time = time.time()

View File

@@ -5,9 +5,10 @@ System tests for error handling across the LetzVendor API.
Tests the complete error handling flow from FastAPI through custom exception handlers
to ensure proper HTTP status codes, error structures, and client-friendly responses.
"""
import pytest
import json
import pytest
@pytest.mark.system
class TestErrorHandling:
@@ -16,9 +17,7 @@ class TestErrorHandling:
def test_invalid_json_request(self, client, auth_headers):
"""Test handling of malformed JSON requests"""
response = client.post(
"/api/v1/vendor",
headers=auth_headers,
content="{ invalid json syntax"
"/api/v1/vendor", headers=auth_headers, content="{ invalid json syntax"
)
assert response.status_code == 422
@@ -31,9 +30,7 @@ class TestErrorHandling:
"""Test validation errors for missing required fields"""
# Missing name
response = client.post(
"/api/v1/vendor",
headers=auth_headers,
json={"vendor_code": "TESTVENDOR"}
"/api/v1/vendor", headers=auth_headers, json={"vendor_code": "TESTVENDOR"}
)
assert response.status_code == 422
@@ -48,10 +45,7 @@ class TestErrorHandling:
response = client.post(
"/api/v1/vendor",
headers=auth_headers,
json={
"vendor_code": "INVALID@VENDOR!",
"name": "Test Vendor"
}
json={"vendor_code": "INVALID@VENDOR!", "name": "Test Vendor"},
)
assert response.status_code == 422
@@ -92,7 +86,7 @@ class TestErrorHandling:
assert data["status_code"] == 401
def test_vendor_not_found(self, client, auth_headers):
"""Test accessing non-existent vendor """
"""Test accessing non-existent vendor"""
response = client.get("/api/v1/vendor/NONEXISTENT", headers=auth_headers)
assert response.status_code == 404
@@ -104,7 +98,9 @@ class TestErrorHandling:
def test_product_not_found(self, client, auth_headers):
"""Test accessing non-existent product"""
response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
response = client.get(
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
)
assert response.status_code == 404
data = response.json()
@@ -117,7 +113,7 @@ class TestErrorHandling:
"""Test creating vendor with duplicate vendor code"""
vendor_data = {
"vendor_code": test_vendor.vendor_code,
"name": "Duplicate Vendor"
"name": "Duplicate Vendor",
}
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
@@ -128,25 +124,34 @@ class TestErrorHandling:
assert data["status_code"] == 409
assert data["details"]["vendor_code"] == test_vendor.vendor_code.upper()
def test_duplicate_product_creation(self, client, auth_headers, test_marketplace_product):
def test_duplicate_product_creation(
self, client, auth_headers, test_marketplace_product
):
"""Test creating product with duplicate product ID"""
product_data = {
"marketplace_product_id": test_marketplace_product.marketplace_product_id,
"title": "Duplicate MarketplaceProduct",
"gtin": "1234567890123"
"gtin": "1234567890123",
}
response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data)
response = client.post(
"/api/v1/marketplace/product", headers=auth_headers, json=product_data
)
assert response.status_code == 409
data = response.json()
assert data["error_code"] == "PRODUCT_ALREADY_EXISTS"
assert data["status_code"] == 409
assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
assert (
data["details"]["marketplace_product_id"]
== test_marketplace_product.marketplace_product_id
)
def test_unauthorized_vendor_access(self, client, auth_headers, inactive_vendor):
"""Test accessing vendor without proper permissions"""
response = client.get(f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers)
response = client.get(
f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers
)
assert response.status_code == 403
data = response.json()
@@ -154,27 +159,33 @@ class TestErrorHandling:
assert data["status_code"] == 403
assert data["details"]["vendor_code"] == inactive_vendor.vendor_code
def test_insufficient_permissions(self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users"):
def test_insufficient_permissions(
self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users"
):
"""Test accessing admin endpoints with regular user"""
response = client.get(admin_only_endpoint, headers=auth_headers)
assert response.status_code in [403, 404] # 403 for permission denied, 404 if endpoint doesn't exist
assert response.status_code in [
403,
404,
] # 403 for permission denied, 404 if endpoint doesn't exist
if response.status_code == 403:
data = response.json()
assert data["error_code"] in ["ADMIN_REQUIRED", "INSUFFICIENT_PERMISSIONS"]
assert data["status_code"] == 403
def test_business_logic_violation_max_vendors(self, client, auth_headers, monkeypatch):
def test_business_logic_violation_max_vendors(
self, client, auth_headers, monkeypatch
):
"""Test business logic violation - creating too many vendors"""
# This test would require mocking the vendor limit check
# For now, test the error structure when creating multiple vendors
vendors_created = []
for i in range(6): # Assume limit is 5
vendor_data = {
"vendor_code": f"VENDOR{i:03d}",
"name": f"Test Vendor {i}"
}
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
vendor_data = {"vendor_code": f"VENDOR{i:03d}", "name": f"Test Vendor {i}"}
response = client.post(
"/api/v1/vendor", headers=auth_headers, json=vendor_data
)
vendors_created.append(response)
# At least one should succeed, and if limit is enforced, later ones should fail
@@ -193,10 +204,12 @@ class TestErrorHandling:
product_data = {
"marketplace_product_id": "TESTPROD001",
"title": "Test MarketplaceProduct",
"gtin": "invalid_gtin_format"
"gtin": "invalid_gtin_format",
}
response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data)
response = client.post(
"/api/v1/marketplace/product", headers=auth_headers, json=product_data
)
assert response.status_code == 422
data = response.json()
@@ -204,13 +217,15 @@ class TestErrorHandling:
assert data["status_code"] == 422
assert data["details"]["field"] == "gtin"
def test_inventory_insufficient_quantity(self, client, auth_headers, test_vendor, test_marketplace_product):
def test_inventory_insufficient_quantity(
self, client, auth_headers, test_vendor, test_marketplace_product
):
"""Test business logic error for insufficient inventory"""
# First create some inventory
inventory_data = {
"gtin": test_marketplace_product.gtin,
"location": "WAREHOUSE_A",
"quantity": 5
"quantity": 5,
}
client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
@@ -218,9 +233,11 @@ class TestErrorHandling:
remove_data = {
"gtin": test_marketplace_product.gtin,
"location": "WAREHOUSE_A",
"quantity": 10 # More than the 5 we added
"quantity": 10, # More than the 5 we added
}
response = client.post("/api/v1/inventory/remove", headers=auth_headers, json=remove_data)
response = client.post(
"/api/v1/inventory/remove", headers=auth_headers, json=remove_data
)
# This should ALWAYS fail with insufficient inventory error
assert response.status_code == 400
@@ -257,7 +274,7 @@ class TestErrorHandling:
response = client.post(
"/api/v1/vendor",
headers=headers,
content="<vendor ><code>TEST</code></vendor >"
content="<vendor ><code>TEST</code></vendor >",
)
assert response.status_code in [400, 415, 422]
@@ -268,7 +285,7 @@ class TestErrorHandling:
vendor_data = {
"vendor_code": "LARGEVENDOR",
"name": "Large Vendor",
"description": large_description
"description": large_description,
}
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
@@ -310,10 +327,12 @@ class TestErrorHandling:
# Test invalid marketplace
import_data = {
"marketplace": "INVALID_MARKETPLACE",
"vendor_code": test_vendor.vendor_code
"vendor_code": test_vendor.vendor_code,
}
response = client.post("/api/v1/imports", headers=auth_headers, json=import_data)
response = client.post(
"/api/v1/imports", headers=auth_headers, json=import_data
)
if response.status_code == 422:
data = response.json()
@@ -329,10 +348,12 @@ class TestErrorHandling:
# Test with potentially problematic external data
import_data = {
"marketplace": "LETZSHOP",
"external_url": "https://nonexistent-marketplace.com/api"
"external_url": "https://nonexistent-marketplace.com/api",
}
response = client.post("/api/v1/imports", headers=auth_headers, json=import_data)
response = client.post(
"/api/v1/imports", headers=auth_headers, json=import_data
)
# If it's a real external service error, check structure
if response.status_code == 502:
@@ -356,7 +377,9 @@ class TestErrorHandling:
# All error responses should have these fields
required_fields = ["error_code", "message", "status_code"]
for field in required_fields:
assert field in data, f"Missing {field} in error response for {endpoint}"
assert (
field in data
), f"Missing {field} in error response for {endpoint}"
# Details field should be present (can be empty dict)
assert "details" in data
@@ -420,5 +443,7 @@ class TestErrorRecovery:
# Check that error was logged (if your app logs 404s as errors)
# Adjust based on your logging configuration
error_logs = [record for record in caplog.records if record.levelno >= logging.ERROR]
error_logs = [
record for record in caplog.records if record.levelno >= logging.ERROR
]
# May or may not have logs depending on whether 404s are logged as errors

View File

@@ -12,21 +12,18 @@ Tests cover:
- Error handling and edge cases
"""
import pytest
from unittest.mock import Mock, MagicMock, patch
from datetime import datetime, timedelta, timezone
from jose import jwt
from fastapi import HTTPException
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi import HTTPException
from jose import jwt
from app.exceptions import (AdminRequiredException,
InsufficientPermissionsException,
InvalidCredentialsException, InvalidTokenException,
TokenExpiredException, UserNotActiveException)
from middleware.auth import AuthManager
from app.exceptions import (
InvalidTokenException,
TokenExpiredException,
UserNotActiveException,
InvalidCredentialsException,
AdminRequiredException,
InsufficientPermissionsException,
)
from models.database.user import User
@@ -124,7 +121,9 @@ class TestUserAuthentication:
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
result = auth_manager.authenticate_user(
mock_db, "test@example.com", "password123"
)
assert result is mock_user
@@ -192,7 +191,9 @@ class TestJWTTokenCreation:
token = token_data["access_token"]
# Decode without verification to check payload
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
payload = jwt.decode(
token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
)
assert payload["sub"] == "42"
assert payload["username"] == "testuser"
@@ -205,8 +206,12 @@ class TestJWTTokenCreation:
"""Test tokens are different for different users."""
auth_manager = AuthManager()
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
user1 = Mock(
spec=User, id=1, username="user1", email="user1@test.com", role="customer"
)
user2 = Mock(
spec=User, id=2, username="user2", email="user2@test.com", role="vendor"
)
token1 = auth_manager.create_access_token(user1)["access_token"]
token2 = auth_manager.create_access_token(user2)["access_token"]
@@ -227,7 +232,7 @@ class TestJWTTokenCreation:
payload = jwt.decode(
token_data["access_token"],
auth_manager.secret_key,
algorithms=[auth_manager.algorithm]
algorithms=[auth_manager.algorithm],
)
assert payload["role"] == "admin"
@@ -311,9 +316,11 @@ class TestJWTTokenVerification:
# Create token without 'sub' field
payload = {
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -325,11 +332,10 @@ class TestJWTTokenVerification:
auth_manager = AuthManager()
# Create token without 'exp' field
payload = {
"sub": "1",
"username": "testuser"
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
payload = {"sub": "1", "username": "testuser"}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -343,7 +349,7 @@ class TestJWTTokenVerification:
payload = {
"sub": "1",
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
# Create token with different algorithm
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
@@ -357,15 +363,13 @@ class TestJWTTokenVerification:
# Create a token with expiration in the past
past_time = datetime.now(timezone.utc) - timedelta(minutes=1)
payload = {
"sub": "1",
"username": "testuser",
"exp": past_time.timestamp()
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
# Mock jwt.decode to bypass its expiration check and test line 205
with patch('middleware.auth.jwt.decode') as mock_decode:
with patch("middleware.auth.jwt.decode") as mock_decode:
mock_decode.return_value = payload
with pytest.raises(TokenExpiredException):
@@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser:
# Existing admin user
existing_admin = Mock(spec=User)
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
mock_db.query.return_value.filter.return_value.first.return_value = (
existing_admin
)
result = auth_manager.create_default_admin_user(mock_db)
@@ -599,19 +605,21 @@ class TestAuthManagerConfiguration:
def test_default_configuration(self):
"""Test AuthManager uses default configuration."""
with patch.dict('os.environ', {}, clear=True):
with patch.dict("os.environ", {}, clear=True):
auth_manager = AuthManager()
assert auth_manager.algorithm == "HS256"
assert auth_manager.token_expire_minutes == 30
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
assert (
auth_manager.secret_key == "your-secret-key-change-in-production-please"
)
def test_custom_configuration(self):
"""Test AuthManager uses environment variables."""
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'custom-secret-key',
'JWT_EXPIRE_MINUTES': '60'
}):
with patch.dict(
"os.environ",
{"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"},
):
auth_manager = AuthManager()
assert auth_manager.secret_key == "custom-secret-key"
@@ -619,9 +627,7 @@ class TestAuthManagerConfiguration:
def test_partial_custom_configuration(self):
"""Test AuthManager with partial environment configuration."""
with patch.dict('os.environ', {
'JWT_EXPIRE_MINUTES': '120'
}, clear=False):
with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False):
auth_manager = AuthManager()
assert auth_manager.token_expire_minutes == 120
@@ -656,9 +662,11 @@ class TestEdgeCases:
"sub": "1",
"username": "testuser",
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
"exp": datetime.now(timezone.utc) + timedelta(hours=2),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
# Should still verify successfully (JWT doesn't validate iat by default)
result = auth_manager.verify_token(token)
@@ -698,7 +706,9 @@ class TestEdgeCases:
token = token_data["access_token"]
# Mock jose.jwt.decode to raise an unexpected exception
with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")):
with patch(
"middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error")
):
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)

View File

@@ -10,16 +10,13 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from unittest.mock import Mock, AsyncMock, patch
from fastapi import Request
from middleware.context import (
ContextManager,
ContextMiddleware,
RequestContext,
get_request_context,
)
from middleware.context import (ContextManager, ContextMiddleware,
RequestContext, get_request_context)
@pytest.mark.unit
@@ -321,22 +318,38 @@ class TestContextManagerHelpers:
def test_is_admin_context_from_subdomain(self):
"""Test _is_admin_context with admin subdomain."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
assert (
ContextManager._is_admin_context(
request, "admin.platform.com", "/dashboard"
)
is True
)
def test_is_admin_context_from_path(self):
"""Test _is_admin_context with admin path."""
request = Mock()
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
assert (
ContextManager._is_admin_context(request, "localhost", "/admin/users")
is True
)
def test_is_admin_context_both(self):
"""Test _is_admin_context with both subdomain and path."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
assert (
ContextManager._is_admin_context(
request, "admin.platform.com", "/admin/users"
)
is True
)
def test_is_not_admin_context(self):
"""Test _is_admin_context returns False for non-admin."""
request = Mock()
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
assert (
ContextManager._is_admin_context(request, "vendor.platform.com", "/shop")
is False
)
def test_is_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context with /vendor/ path."""
@@ -344,11 +357,16 @@ class TestContextManagerHelpers:
def test_is_vendor_dashboard_context_nested(self):
"""Test _is_vendor_dashboard_context with nested vendor path."""
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
assert (
ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
)
def test_is_not_vendor_dashboard_context_vendors_plural(self):
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
assert (
ContextManager._is_vendor_dashboard_context("/vendors/shop123/products")
is False
)
def test_is_not_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
@@ -373,7 +391,7 @@ class TestContextMiddleware:
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'context_type')
assert hasattr(request.state, "context_type")
assert request.state.context_type == RequestContext.API
call_next.assert_called_once_with(request)
@@ -565,7 +583,7 @@ class TestEdgeCases:
request.url = Mock(path="/api/vendors")
request.headers = {"host": "localhost"}
# No state attribute at all
delattr(request, 'state')
delattr(request, "state")
# Should still work, falling back to url.path
with pytest.raises(AttributeError):

View File

@@ -12,16 +12,18 @@ Tests cover:
- Edge cases and isolation
"""
import pytest
from unittest.mock import Mock
from middleware.decorators import rate_limit, rate_limiter
from app.exceptions.base import RateLimitException
import pytest
from app.exceptions.base import RateLimitException
from middleware.decorators import rate_limit, rate_limiter
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(autouse=True)
def reset_rate_limiter():
"""Reset rate limiter state before each test to ensure isolation."""
@@ -34,6 +36,7 @@ def reset_rate_limiter():
# Rate Limit Decorator Tests
# =============================================================================
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimitDecorator:
@@ -42,6 +45,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_allows_within_limit(self):
"""Test decorator allows requests within rate limit."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
return {"status": "ok"}
@@ -53,6 +57,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_blocks_exceeding_limit(self):
"""Test decorator blocks requests exceeding rate limit."""
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_blocked():
return {"status": "ok"}
@@ -71,6 +76,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_preserves_function_metadata(self):
"""Test decorator preserves original function metadata."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
"""Test endpoint docstring."""
@@ -82,21 +88,19 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_with_args_and_kwargs(self):
"""Test decorator works with function arguments."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint(arg1, arg2, kwarg1=None):
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
result = await test_endpoint("value1", "value2", kwarg1="value3")
assert result == {
"arg1": "value1",
"arg2": "value2",
"kwarg1": "value3"
}
assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"}
@pytest.mark.asyncio
async def test_decorator_default_parameters(self):
"""Test decorator uses default parameters."""
@rate_limit() # Use defaults
async def test_endpoint():
return {"status": "ok"}
@@ -108,6 +112,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_exception_includes_retry_after(self):
"""Test rate limit exception includes retry_after."""
@rate_limit(max_requests=1, window_seconds=60)
async def test_endpoint_retry():
return {"status": "ok"}
@@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_zero_max_requests(self):
"""Test decorator with max_requests=0 blocks all requests."""
@rate_limit(max_requests=0, window_seconds=3600)
async def test_endpoint_zero():
return {"status": "ok"}
@@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_very_short_window(self):
"""Test decorator with very short time window."""
@rate_limit(max_requests=1, window_seconds=1)
async def test_endpoint_short():
return {"status": "ok"}
@@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_multiple_functions_separate_limits(self):
"""Test that different functions have separate rate limits."""
@rate_limit(max_requests=1, window_seconds=3600)
async def endpoint1():
return {"endpoint": "1"}
@@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_exception_in_function(self):
"""Test decorator handles exceptions from wrapped function."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint_error():
raise ValueError("Function error")
@@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_isolation_between_tests(self):
"""Test that rate limiter state is properly isolated between tests."""
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_isolation():
return {"status": "ok"}
@@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_dict(self):
"""Test decorator correctly returns dictionary."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_dict():
return {"key": "value", "number": 42}
@@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_list(self):
"""Test decorator correctly returns list."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_list():
return [1, 2, 3, 4, 5]
@@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_none(self):
"""Test decorator correctly returns None."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_none():
return None
@@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_object(self):
"""Test decorator correctly returns custom objects."""
class TestObject:
def __init__(self):
self.name = "test_object"

View File

@@ -11,9 +11,10 @@ Tests cover:
- Edge cases (missing client info, etc.)
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request
from middleware.logging import LoggingMiddleware
@@ -40,7 +41,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify request was logged
@@ -65,7 +66,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
result = await middleware.dispatch(request, call_next)
# Verify response was logged
@@ -89,7 +90,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
assert "X-Process-Time" in response.headers
@@ -113,11 +114,13 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Should log "unknown" for client
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
assert any(
"unknown" in str(call) for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_middleware_logs_exceptions(self):
@@ -131,8 +134,9 @@ class TestLoggingMiddleware:
call_next = AsyncMock(side_effect=Exception("Test error"))
with patch('middleware.logging.logger') as mock_logger, \
pytest.raises(Exception):
with patch("middleware.logging.logger") as mock_logger, pytest.raises(
Exception
):
await middleware.dispatch(request, call_next)
# Verify error was logged
@@ -156,7 +160,7 @@ class TestLoggingMiddleware:
call_next = slow_call_next
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
process_time = float(result.headers["X-Process-Time"])
@@ -181,7 +185,7 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
# Should still have process time, even if very small
@@ -205,11 +209,13 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify method was logged
assert any(method in str(call) for call in mock_logger.info.call_args_list)
assert any(
method in str(call) for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_middleware_logs_different_status_codes(self):
@@ -227,8 +233,11 @@ class TestLoggingEdgeCases:
response = Mock(status_code=status_code, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify status code was logged
assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list)
assert any(
str(status_code) in str(call)
for call in mock_logger.info.call_args_list
)

View File

@@ -11,10 +11,11 @@ Tests cover:
- Edge cases and concurrency scenarios
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from collections import deque
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch
import pytest
from middleware.rate_limiter import RateLimiter
@@ -306,8 +307,8 @@ class TestRateLimiterStatistics:
# Add requests at different times
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
stats = limiter.get_client_stats(client_id)
@@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases:
limiter = RateLimiter()
client_id = "long_window_client"
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
result = limiter.allow_request(
client_id, max_requests=10, window_seconds=86400 * 365
)
assert result is True
@@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases:
client_id = "same_client"
# Allow with one limit
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
assert (
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
is True
)
# Check with stricter limit
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
assert (
limiter.allow_request(client_id, max_requests=1, window_seconds=3600)
is False
)
def test_rate_limiter_unicode_client_id(self):
"""Test rate limiter with unicode client ID."""

View File

@@ -11,15 +11,14 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from unittest.mock import Mock, AsyncMock, MagicMock, patch
from fastapi import Request
from middleware.theme_context import (
ThemeContextManager,
ThemeContextMiddleware,
get_current_theme,
)
from middleware.theme_context import (ThemeContextManager,
ThemeContextMiddleware,
get_current_theme)
@pytest.mark.unit
@@ -42,7 +41,14 @@ class TestThemeContextManager:
"""Test default theme has all required colors."""
theme = ThemeContextManager.get_default_theme()
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
required_colors = [
"primary",
"secondary",
"accent",
"background",
"text",
"border",
]
for color in required_colors:
assert color in theme["colors"]
assert theme["colors"][color].startswith("#")
@@ -79,10 +85,7 @@ class TestThemeContextManager:
mock_theme = Mock()
# Mock to_dict to return actual dictionary
custom_theme_dict = {
"theme_name": "custom",
"colors": {"primary": "#ff0000"}
}
custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}}
mock_theme.to_dict.return_value = custom_theme_dict
# Correct filter chain: query().filter().first()
@@ -141,8 +144,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
mock_theme = {"theme_name": "test_theme"}
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
with patch(
"middleware.theme_context.get_db", return_value=iter([mock_db])
), patch.object(
ThemeContextManager, "get_vendor_theme", return_value=mock_theme
):
await middleware.dispatch(request, call_next)
@@ -161,7 +167,7 @@ class TestThemeContextMiddleware:
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'theme')
assert hasattr(request.state, "theme")
assert request.state.theme["theme_name"] == "default"
call_next.assert_called_once()
@@ -178,8 +184,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
with patch(
"middleware.theme_context.get_db", return_value=iter([mock_db])
), patch.object(
ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error")
):
await middleware.dispatch(request, call_next)
@@ -224,7 +233,7 @@ class TestThemeEdgeCases:
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
with patch("middleware.theme_context.get_db", return_value=iter([mock_db])):
await middleware.dispatch(request, call_next)
# Verify database was closed

View File

@@ -11,17 +11,16 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from fastapi import Request, HTTPException
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session
from middleware.vendor_context import (
VendorContextManager,
VendorContextMiddleware,
get_current_vendor,
require_vendor_context,
)
from middleware.vendor_context import (VendorContextManager,
VendorContextMiddleware,
get_current_vendor,
require_vendor_context)
@pytest.mark.unit
@@ -39,7 +38,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -55,7 +54,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -71,7 +70,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -87,7 +86,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -140,7 +139,7 @@ class TestVendorContextManager:
request.headers = {"host": "admin.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -153,7 +152,7 @@ class TestVendorContextManager:
request.headers = {"host": "www.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -166,7 +165,7 @@ class TestVendorContextManager:
request.headers = {"host": "api.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -179,7 +178,7 @@ class TestVendorContextManager:
request.headers = {"host": "localhost"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -198,12 +197,11 @@ class TestVendorContextManager:
mock_vendor.is_active = True
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor_domain
)
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -218,12 +216,11 @@ class TestVendorContextManager:
mock_vendor.is_active = False
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor_domain
)
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -232,12 +229,11 @@ class TestVendorContextManager:
def test_get_vendor_from_custom_domain_not_found(self):
"""Test custom domain not found in database."""
mock_db = Mock(spec=Session)
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
None
)
context = {
"detection_method": "custom_domain",
"domain": "nonexistent.com"
}
context = {"detection_method": "custom_domain", "domain": "nonexistent.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -249,12 +245,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
context = {"detection_method": "subdomain", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -266,12 +261,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "path",
"subdomain": "vendor1"
}
context = {"detection_method": "path", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -291,12 +285,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "subdomain",
"subdomain": "VENDOR1" # Uppercase
}
context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -311,10 +304,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -325,10 +315,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendors/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendors/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -339,10 +326,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -353,10 +337,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -425,21 +406,24 @@ class TestVendorContextManager:
# Static File Detection Tests
# ========================================================================
@pytest.mark.parametrize("path", [
"/static/css/style.css",
"/static/js/app.js",
"/media/images/product.png",
"/assets/logo.svg",
"/.well-known/security.txt",
"/favicon.ico",
"/image.jpg",
"/style.css",
"/app.webmanifest",
"/static/", # Path starting with /static/ but no extension
"/media/uploads", # Path starting with /media/ but no extension
"/subfolder/favicon.ico", # favicon.ico in subfolder
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
])
@pytest.mark.parametrize(
"path",
[
"/static/css/style.css",
"/static/js/app.js",
"/media/images/product.png",
"/assets/logo.svg",
"/.well-known/security.txt",
"/favicon.ico",
"/image.jpg",
"/style.css",
"/app.webmanifest",
"/static/", # Path starting with /static/ but no extension
"/media/uploads", # Path starting with /media/ but no extension
"/subfolder/favicon.ico", # favicon.ico in subfolder
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
],
)
def test_is_static_file_request(self, path):
"""Test static file detection for various paths and extensions."""
request = Mock(spec=Request)
@@ -447,12 +431,15 @@ class TestVendorContextManager:
assert VendorContextManager.is_static_file_request(request) is True
@pytest.mark.parametrize("path", [
"/shop/products",
"/admin/dashboard",
"/api/vendors",
"/about",
])
@pytest.mark.parametrize(
"path",
[
"/shop/products",
"/admin/dashboard",
"/api/vendors",
"/about",
],
)
def test_is_not_static_file_request(self, path):
"""Test non-static file paths."""
request = Mock(spec=Request)
@@ -478,7 +465,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
with patch.object(VendorContextManager, "is_admin_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -498,7 +485,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
with patch.object(VendorContextManager, "is_api_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -517,7 +504,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
with patch.object(
VendorContextManager, "is_static_file_request", return_value=True
):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -540,17 +529,19 @@ class TestVendorContextMiddleware:
mock_vendor.name = "Test Vendor"
mock_vendor.subdomain = "vendor1"
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=vendor_context
), patch.object(
VendorContextManager, "get_vendor_from_context", return_value=mock_vendor
), patch.object(
VendorContextManager, "extract_clean_path", return_value="/shop/products"
), patch(
"middleware.vendor_context.get_db", return_value=iter([mock_db])
):
await middleware.dispatch(request, call_next)
@@ -571,16 +562,17 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
vendor_context = {
"detection_method": "subdomain",
"subdomain": "nonexistent"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=vendor_context
), patch.object(
VendorContextManager, "get_vendor_from_context", return_value=None
), patch(
"middleware.vendor_context.get_db", return_value=iter([mock_db])
):
await middleware.dispatch(request, call_next)
@@ -601,7 +593,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=None
):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -714,7 +708,7 @@ class TestEdgeCases:
request.headers = {"host": "shop.vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -735,11 +729,14 @@ class TestEdgeCases:
context = {"subdomain": "nonexistent", "detection_method": "subdomain"}
with patch('middleware.vendor_context.logger') as mock_logger:
with patch("middleware.vendor_context.logger") as mock_logger:
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is None
# Verify warning was logged
mock_logger.warning.assert_called()
warning_message = str(mock_logger.warning.call_args)
assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message
assert (
"No active vendor found for subdomain" in warning_message
and "nonexistent" in warning_message
)

View File

@@ -1,16 +1,17 @@
# tests/unit/models/test_database_models.py
import pytest
from datetime import datetime, timezone
import pytest
from sqlalchemy.exc import IntegrityError
from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor, VendorUser, Role
from models.database.inventory import Inventory
from models.database.user import User
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.product import Product
from models.database.customer import Customer, CustomerAddress
from models.database.inventory import Inventory
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.order import Order, OrderItem
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
@pytest.mark.unit
@@ -277,7 +278,7 @@ class TestMarketplaceProductModel:
vendor_id=test_vendor.id,
marketplace_product_id="UNIQUE_001",
title="Product 1",
marketplace="Letzshop"
marketplace="Letzshop",
)
db.add(product1)
db.commit()
@@ -288,7 +289,7 @@ class TestMarketplaceProductModel:
vendor_id=test_vendor.id,
marketplace_product_id="UNIQUE_001",
title="Product 2",
marketplace="Letzshop"
marketplace="Letzshop",
)
db.add(product2)
db.commit()
@@ -515,7 +516,9 @@ class TestCustomerModel:
class TestOrderModel:
"""Test Order model"""
def test_order_creation(self, db, test_vendor, test_customer, test_customer_address):
def test_order_creation(
self, db, test_vendor, test_customer, test_customer_address
):
"""Test Order model with customer relationship"""
order = Order(
vendor_id=test_vendor.id,
@@ -563,7 +566,9 @@ class TestOrderModel:
assert float(order_item.unit_price) == 49.99
assert float(order_item.total_price) == 99.98
def test_order_number_uniqueness(self, db, test_vendor, test_customer, test_customer_address):
def test_order_number_uniqueness(
self, db, test_vendor, test_customer, test_customer_address
):
"""Test order_number unique constraint"""
order1 = Order(
vendor_id=test_vendor.id,

View File

@@ -1,14 +1,10 @@
# tests/unit/services/test_admin_service.py
import pytest
from app.exceptions import (
UserNotFoundException,
UserStatusChangeException,
CannotModifySelfException,
VendorNotFoundException,
VendorVerificationException,
AdminOperationException,
)
from app.exceptions import (AdminOperationException, CannotModifySelfException,
UserNotFoundException, UserStatusChangeException,
VendorNotFoundException,
VendorVerificationException)
from app.services.admin_service import AdminService
from app.services.stats_service import stats_service
from models.database.marketplace_import_job import MarketplaceImportJob
@@ -85,7 +81,9 @@ class TestAdminService:
assert exception.error_code == "CANNOT_MODIFY_SELF"
assert "deactivate account" in exception.message
def test_toggle_user_status_cannot_modify_admin(self, db, test_admin, another_admin):
def test_toggle_user_status_cannot_modify_admin(
self, db, test_admin, another_admin
):
"""Test that admin cannot modify another admin"""
with pytest.raises(UserStatusChangeException) as exc_info:
self.service.toggle_user_status(db, another_admin.id, test_admin.id)
@@ -148,7 +146,7 @@ class TestAdminService:
assert "99999" in exception.message
def test_toggle_vendor_status_deactivate(self, db, test_vendor):
"""Test deactivating a vendor """
"""Test deactivating a vendor"""
original_status = test_vendor.is_active
vendor, message = self.service.toggle_vendor_status(db, test_vendor.id)
@@ -170,21 +168,26 @@ class TestAdminService:
assert exception.error_code == "VENDOR_NOT_FOUND"
# Marketplace Import Jobs Tests
def test_get_marketplace_import_jobs_no_filters(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_no_filters(
self, db, test_marketplace_import_job
):
"""Test getting marketplace import jobs without filters"""
result = self.service.get_marketplace_import_jobs(db, skip=0, limit=10)
assert len(result) >= 1
# Find our test job in the results
test_job = next(
(job for job in result if job.job_id == test_marketplace_import_job.id), None
(job for job in result if job.job_id == test_marketplace_import_job.id),
None,
)
assert test_job is not None
assert test_job.marketplace == test_marketplace_import_job.marketplace
assert test_job.vendor_name == test_marketplace_import_job.name
assert test_job.status == test_marketplace_import_job.status
def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_with_marketplace_filter(
self, db, test_marketplace_import_job
):
"""Test filtering marketplace import jobs by marketplace"""
result = self.service.get_marketplace_import_jobs(
db, marketplace=test_marketplace_import_job.marketplace, skip=0, limit=10
@@ -192,9 +195,14 @@ class TestAdminService:
assert len(result) >= 1
for job in result:
assert test_marketplace_import_job.marketplace.lower() in job.marketplace.lower()
assert (
test_marketplace_import_job.marketplace.lower()
in job.marketplace.lower()
)
def test_get_marketplace_import_jobs_with_vendor_filter(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_with_vendor_filter(
self, db, test_marketplace_import_job
):
"""Test filtering marketplace import jobs by vendor name"""
result = self.service.get_marketplace_import_jobs(
db, vendor_name=test_marketplace_import_job.name, skip=0, limit=10
@@ -204,7 +212,9 @@ class TestAdminService:
for job in result:
assert test_marketplace_import_job.name.lower() in job.vendor_name.lower()
def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_with_status_filter(
self, db, test_marketplace_import_job
):
"""Test filtering marketplace import jobs by status"""
result = self.service.get_marketplace_import_jobs(
db, status=test_marketplace_import_job.status, skip=0, limit=10
@@ -214,7 +224,9 @@ class TestAdminService:
for job in result:
assert job.status == test_marketplace_import_job.status
def test_get_marketplace_import_jobs_pagination(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_pagination(
self, db, test_marketplace_import_job
):
"""Test marketplace import jobs pagination"""
result_page1 = self.service.get_marketplace_import_jobs(db, skip=0, limit=1)
result_page2 = self.service.get_marketplace_import_jobs(db, skip=1, limit=1)

View File

@@ -1,11 +1,9 @@
# tests/test_auth_service.py
import pytest
from app.exceptions.auth import (
UserAlreadyExistsException,
InvalidCredentialsException,
UserNotActiveException,
)
from app.exceptions.auth import (InvalidCredentialsException,
UserAlreadyExistsException,
UserNotActiveException)
from app.exceptions.base import ValidationException
from app.services.auth_service import AuthService
from models.schema.auth import UserLogin, UserRegister
@@ -218,11 +216,14 @@ class TestAuthService:
def test_create_access_token_failure(self, test_user, monkeypatch):
"""Test creating access token handles failures"""
# Mock the auth_manager to raise an exception
def mock_create_token(*args, **kwargs):
raise Exception("Token creation failed")
monkeypatch.setattr(self.service.auth_manager, "create_access_token", mock_create_token)
monkeypatch.setattr(
self.service.auth_manager, "create_access_token", mock_create_token
)
with pytest.raises(ValidationException) as exc_info:
self.service.create_access_token(test_user)
@@ -250,11 +251,14 @@ class TestAuthService:
def test_hash_password_failure(self, monkeypatch):
"""Test password hashing handles failures"""
# Mock the auth_manager to raise an exception
def mock_hash_password(*args, **kwargs):
raise Exception("Hashing failed")
monkeypatch.setattr(self.service.auth_manager, "hash_password", mock_hash_password)
monkeypatch.setattr(
self.service.auth_manager, "hash_password", mock_hash_password
)
with pytest.raises(ValidationException) as exc_info:
self.service.hash_password("testpassword")
@@ -267,9 +271,7 @@ class TestAuthService:
def test_register_user_database_error(self, db_with_error):
"""Test user registration handles database errors"""
user_data = UserRegister(
email="test@example.com",
username="testuser",
password="password123"
email="test@example.com", username="testuser", password="password123"
)
with pytest.raises(ValidationException) as exc_info:

View File

@@ -3,19 +3,17 @@ import uuid
import pytest
from app.exceptions import (InsufficientInventoryException,
InvalidInventoryOperationException,
InvalidQuantityException,
InventoryNotFoundException,
InventoryValidationException,
NegativeInventoryException, ValidationException)
from app.services.inventory_service import InventoryService
from app.exceptions import (
InventoryNotFoundException,
InsufficientInventoryException,
InvalidInventoryOperationException,
InventoryValidationException,
NegativeInventoryException,
InvalidQuantityException,
ValidationException,
)
from models.schema.inventory import InventoryAdd, InventoryCreate, InventoryUpdate
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
from models.schema.inventory import (InventoryAdd, InventoryCreate,
InventoryUpdate)
@pytest.mark.unit
@@ -40,10 +38,14 @@ class TestInventoryService:
def test_normalize_gtin_valid(self):
"""Test GTIN normalization with valid GTINs."""
# Test various valid GTIN formats - these should remain unchanged
assert self.service._normalize_gtin("1234567890123") == "1234567890123" # EAN-13
assert (
self.service._normalize_gtin("1234567890123") == "1234567890123"
) # EAN-13
assert self.service._normalize_gtin("123456789012") == "123456789012" # UPC-A
assert self.service._normalize_gtin("12345678") == "12345678" # EAN-8
assert self.service._normalize_gtin("12345678901234") == "12345678901234" # GTIN-14
assert (
self.service._normalize_gtin("12345678901234") == "12345678901234"
) # GTIN-14
# Test with decimal points (should be removed)
assert self.service._normalize_gtin("1234567890123.0") == "1234567890123"
@@ -52,11 +54,17 @@ class TestInventoryService:
assert self.service._normalize_gtin(" 1234567890123 ") == "1234567890123"
# Test short GTINs being padded
assert self.service._normalize_gtin("123") == "0000000000123" # Padded to EAN-13
assert self.service._normalize_gtin("12345") == "0000000012345" # Padded to EAN-13
assert (
self.service._normalize_gtin("123") == "0000000000123"
) # Padded to EAN-13
assert (
self.service._normalize_gtin("12345") == "0000000012345"
) # Padded to EAN-13
# Test long GTINs being truncated
assert self.service._normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13
assert (
self.service._normalize_gtin("123456789012345") == "3456789012345"
) # Truncated to 13
def test_normalize_gtin_edge_cases(self):
"""Test GTIN normalization edge cases."""
@@ -65,9 +73,15 @@ class TestInventoryService:
assert self.service._normalize_gtin(123) == "0000000000123"
# Test mixed valid/invalid characters
assert self.service._normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed
assert self.service._normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed
assert self.service._normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed
assert (
self.service._normalize_gtin("123-456-789-012") == "123456789012"
) # Dashes removed
assert (
self.service._normalize_gtin("123 456 789 012") == "123456789012"
) # Spaces removed
assert (
self.service._normalize_gtin("ABC123456789012DEF") == "123456789012"
) # Letters removed
def test_set_inventory_new_entry_success(self, db):
"""Test setting inventory for a new GTIN/location combination successfully."""
@@ -162,7 +176,9 @@ class TestInventoryService:
def test_add_inventory_invalid_gtin_validation_error(self, db):
"""Test adding inventory with invalid GTIN returns InventoryValidationException."""
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50)
inventory_data = InventoryAdd(
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50
)
with pytest.raises(InventoryValidationException) as exc_info:
self.service.add_inventory(db, inventory_data)
@@ -180,11 +196,12 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVALID_QUANTITY"
assert "Quantity must be positive" in str(exc_info.value)
def test_remove_inventory_success(self, db, test_inventory):
"""Test removing inventory successfully."""
original_quantity = test_inventory.quantity
remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available
remove_quantity = min(
10, original_quantity
) # Ensure we don't remove more than available
inventory_data = InventoryAdd(
gtin=test_inventory.gtin,
@@ -212,7 +229,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
assert exc_info.value.details["gtin"] == test_inventory.gtin
assert exc_info.value.details["location"] == test_inventory.location
assert exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
assert (
exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
)
assert exc_info.value.details["available_quantity"] == test_inventory.quantity
def test_remove_inventory_nonexistent_entry_not_found(self, db):
@@ -231,7 +250,9 @@ class TestInventoryService:
def test_remove_inventory_invalid_gtin_validation_error(self, db):
"""Test removing inventory with invalid GTIN returns InventoryValidationException."""
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10)
inventory_data = InventoryAdd(
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10
)
with pytest.raises(InventoryValidationException) as exc_info:
self.service.remove_inventory(db, inventory_data)
@@ -254,7 +275,9 @@ class TestInventoryService:
# The service prevents negative inventory through InsufficientInventoryException
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
def test_get_inventory_by_gtin_success(self, db, test_inventory, test_marketplace_product):
def test_get_inventory_by_gtin_success(
self, db, test_inventory, test_marketplace_product
):
"""Test getting inventory summary by GTIN successfully."""
result = self.service.get_inventory_by_gtin(db, test_inventory.gtin)
@@ -265,14 +288,20 @@ class TestInventoryService:
assert result.locations[0].quantity == test_inventory.quantity
assert result.product_title == test_marketplace_product.title
def test_get_inventory_by_gtin_multiple_locations_success(self, db, test_marketplace_product):
def test_get_inventory_by_gtin_multiple_locations_success(
self, db, test_marketplace_product
):
"""Test getting inventory summary with multiple locations successfully."""
unique_gtin = test_marketplace_product.gtin
unique_id = str(uuid.uuid4())[:8]
# Create multiple inventory entries for the same GTIN with unique locations
inventory1 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50)
inventory2 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30)
inventory1 = Inventory(
gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50
)
inventory2 = Inventory(
gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30
)
db.add(inventory1)
db.add(inventory2)
@@ -301,7 +330,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVENTORY_VALIDATION_FAILED"
assert "Invalid GTIN format" in str(exc_info.value)
def test_get_total_inventory_success(self, db, test_inventory, test_marketplace_product):
def test_get_total_inventory_success(
self, db, test_inventory, test_marketplace_product
):
"""Test getting total inventory for a GTIN successfully."""
result = self.service.get_total_inventory(db, test_inventory.gtin)
@@ -364,7 +395,9 @@ class TestInventoryService:
result = self.service.get_all_inventory(db, skip=2, limit=2)
assert len(result) <= 2 # Should be at most 2, might be less if other records exist
assert (
len(result) <= 2
) # Should be at most 2, might be less if other records exist
def test_update_inventory_success(self, db, test_inventory):
"""Test updating inventory quantity successfully."""
@@ -404,7 +437,9 @@ class TestInventoryService:
assert result is True
# Verify the inventory is actually deleted
deleted_inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
deleted_inventory = (
db.query(Inventory).filter(Inventory.id == inventory_id).first()
)
assert deleted_inventory is None
def test_delete_inventory_not_found_error(self, db):
@@ -415,7 +450,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVENTORY_NOT_FOUND"
assert "99999" in str(exc_info.value)
def test_get_low_inventory_items_success(self, db, test_inventory, test_marketplace_product):
def test_get_low_inventory_items_success(
self, db, test_inventory, test_marketplace_product
):
"""Test getting low inventory items successfully."""
# Set inventory to a low value
test_inventory.quantity = 5
@@ -424,7 +461,9 @@ class TestInventoryService:
result = self.service.get_low_inventory_items(db, threshold=10)
assert len(result) >= 1
low_inventory_item = next((item for item in result if item["gtin"] == test_inventory.gtin), None)
low_inventory_item = next(
(item for item in result if item["gtin"] == test_inventory.gtin), None
)
assert low_inventory_item is not None
assert low_inventory_item["current_quantity"] == 5
assert low_inventory_item["location"] == test_inventory.location
@@ -440,9 +479,13 @@ class TestInventoryService:
def test_get_inventory_summary_by_location_success(self, db, test_inventory):
"""Test getting inventory summary by location successfully."""
result = self.service.get_inventory_summary_by_location(db, test_inventory.location)
result = self.service.get_inventory_summary_by_location(
db, test_inventory.location
)
assert result["location"] == test_inventory.location.upper() # Service normalizes to uppercase
assert (
result["location"] == test_inventory.location.upper()
) # Service normalizes to uppercase
assert result["total_items"] >= 1
assert result["total_quantity"] >= test_inventory.quantity
assert result["unique_gtins"] >= 1
@@ -450,7 +493,9 @@ class TestInventoryService:
def test_get_inventory_summary_by_location_empty_result(self, db):
"""Test getting inventory summary for location with no inventory."""
unique_id = str(uuid.uuid4())[:8]
result = self.service.get_inventory_summary_by_location(db, f"EMPTY_LOCATION_{unique_id}")
result = self.service.get_inventory_summary_by_location(
db, f"EMPTY_LOCATION_{unique_id}"
)
assert result["total_items"] == 0
assert result["total_quantity"] == 0
@@ -459,12 +504,16 @@ class TestInventoryService:
def test_validate_quantity_edge_cases(self, db):
"""Test quantity validation with edge cases."""
# Test zero quantity with allow_zero=True (should succeed)
inventory_data = InventoryCreate(gtin="1234567890123", location="WAREHOUSE_A", quantity=0)
inventory_data = InventoryCreate(
gtin="1234567890123", location="WAREHOUSE_A", quantity=0
)
result = self.service.set_inventory(db, inventory_data)
assert result.quantity == 0
# Test zero quantity with add_inventory (should fail - doesn't allow zero)
inventory_data_add = InventoryAdd(gtin="1234567890123", location="WAREHOUSE_B", quantity=0)
inventory_data_add = InventoryAdd(
gtin="1234567890123", location="WAREHOUSE_B", quantity=0
)
with pytest.raises(InvalidQuantityException):
self.service.add_inventory(db, inventory_data_add)
@@ -477,10 +526,10 @@ class TestInventoryService:
exception = exc_info.value
# Verify exception structure matches WizamartException.to_dict()
assert hasattr(exception, 'error_code')
assert hasattr(exception, 'message')
assert hasattr(exception, 'status_code')
assert hasattr(exception, 'details')
assert hasattr(exception, "error_code")
assert hasattr(exception, "message")
assert hasattr(exception, "status_code")
assert hasattr(exception, "details")
assert isinstance(exception.error_code, str)
assert isinstance(exception.message, str)

View File

@@ -4,19 +4,18 @@ from datetime import datetime
import pytest
from app.exceptions.marketplace_import_job import (
ImportJobNotFoundException,
ImportJobNotOwnedException,
ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException,
)
from app.exceptions.vendor import VendorNotFoundException, UnauthorizedVendorAccessException
from app.exceptions.base import ValidationException
from app.services.marketplace_import_job_service import MarketplaceImportJobService
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
from app.exceptions.marketplace_import_job import (
ImportJobCannotBeCancelledException, ImportJobCannotBeDeletedException,
ImportJobNotFoundException, ImportJobNotOwnedException)
from app.exceptions.vendor import (UnauthorizedVendorAccessException,
VendorNotFoundException)
from app.services.marketplace_import_job_service import \
MarketplaceImportJobService
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
@pytest.mark.unit
@@ -31,7 +30,9 @@ class TestMarketplaceService:
test_vendor.owner_user_id = test_user.id
db.commit()
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_user)
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code, test_user
)
assert result.vendor_code == test_vendor.vendor_code
assert result.owner_user_id == test_user.id
@@ -39,8 +40,10 @@ class TestMarketplaceService:
def test_validate_vendor_access_admin_can_access_any_vendor(
self, db, test_vendor, test_admin
):
"""Test that admin users can access any vendor """
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_admin)
"""Test that admin users can access any vendor"""
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code, test_admin
)
assert result.vendor_code == test_vendor.vendor_code
@@ -57,7 +60,7 @@ class TestMarketplaceService:
def test_validate_vendor_access_permission_denied(
self, db, test_vendor, test_user, other_user
):
"""Test vendor access validation when user doesn't own the vendor """
"""Test vendor access validation when user doesn't own the vendor"""
# Set the vendor owner to a different user
test_vendor.owner_user_id = other_user.id
db.commit()
@@ -93,7 +96,7 @@ class TestMarketplaceService:
assert result.vendor_name == test_vendor.name
def test_create_import_job_invalid_vendor(self, db, test_user):
"""Test import job creation with invalid vendor """
"""Test import job creation with invalid vendor"""
request = MarketplaceImportJobRequest(
url="https://example.com/products.csv",
marketplace="Amazon",
@@ -108,7 +111,9 @@ class TestMarketplaceService:
assert exception.error_code == "VENDOR_NOT_FOUND"
assert "INVALID_VENDOR" in exception.message
def test_create_import_job_unauthorized_access(self, db, test_vendor, test_user, other_user):
def test_create_import_job_unauthorized_access(
self, db, test_vendor, test_user, other_user
):
"""Test import job creation with unauthorized vendor access"""
# Set the vendor owner to a different user
test_vendor.owner_user_id = other_user.id
@@ -127,7 +132,9 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "UNAUTHORIZED_VENDOR_ACCESS"
def test_get_import_job_by_id_success(self, db, test_marketplace_import_job, test_user):
def test_get_import_job_by_id_success(
self, db, test_marketplace_import_job, test_user
):
"""Test getting import job by ID for job owner"""
result = self.service.get_import_job_by_id(
db, test_marketplace_import_job.id, test_user
@@ -161,14 +168,18 @@ class TestMarketplaceService:
):
"""Test access denied when user doesn't own the job"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
self.service.get_import_job_by_id(db, test_marketplace_import_job.id, other_user)
self.service.get_import_job_by_id(
db, test_marketplace_import_job.id, other_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
assert exception.status_code == 403
assert str(test_marketplace_import_job.id) in exception.message
def test_get_import_jobs_user_filter(self, db, test_marketplace_import_job, test_user):
def test_get_import_jobs_user_filter(
self, db, test_marketplace_import_job, test_user
):
"""Test getting import jobs filtered by user"""
jobs = self.service.get_import_jobs(db, test_user)
@@ -176,7 +187,9 @@ class TestMarketplaceService:
assert any(job.id == test_marketplace_import_job.id for job in jobs)
assert test_marketplace_import_job.user_id == test_user.id
def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_import_job, test_admin):
def test_get_import_jobs_admin_sees_all(
self, db, test_marketplace_import_job, test_admin
):
"""Test that admin sees all import jobs"""
jobs = self.service.get_import_jobs(db, test_admin)
@@ -192,7 +205,9 @@ class TestMarketplaceService:
)
assert len(jobs) >= 1
assert any(job.marketplace == test_marketplace_import_job.marketplace for job in jobs)
assert any(
job.marketplace == test_marketplace_import_job.marketplace for job in jobs
)
def test_get_import_jobs_with_pagination(self, db, test_user, test_vendor):
"""Test getting import jobs with pagination"""
@@ -330,10 +345,14 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
def test_cancel_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
def test_cancel_import_job_access_denied(
self, db, test_marketplace_import_job, other_user
):
"""Test cancelling import job without access"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
self.service.cancel_import_job(db, test_marketplace_import_job.id, other_user)
self.service.cancel_import_job(
db, test_marketplace_import_job.id, other_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
@@ -347,7 +366,9 @@ class TestMarketplaceService:
db.commit()
with pytest.raises(ImportJobCannotBeCancelledException) as exc_info:
self.service.cancel_import_job(db, test_marketplace_import_job.id, test_user)
self.service.cancel_import_job(
db, test_marketplace_import_job.id, test_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_CANNOT_BE_CANCELLED"
@@ -396,10 +417,14 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
def test_delete_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
def test_delete_import_job_access_denied(
self, db, test_marketplace_import_job, other_user
):
"""Test deleting import job without access"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
self.service.delete_import_job(db, test_marketplace_import_job.id, other_user)
self.service.delete_import_job(
db, test_marketplace_import_job.id, other_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
@@ -440,11 +465,15 @@ class TestMarketplaceService:
db.commit()
# Test with lowercase vendor code
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.lower(), test_user)
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code.lower(), test_user
)
assert result.vendor_code == test_vendor.vendor_code
# Test with uppercase vendor code
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.upper(), test_user)
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code.upper(), test_user
)
assert result.vendor_code == test_vendor.vendor_code
def test_create_import_job_database_error(self, db_with_error, test_user):

View File

@@ -1,16 +1,15 @@
# tests/test_product_service.py
import pytest
from app.exceptions import (InvalidMarketplaceProductDataException,
MarketplaceProductAlreadyExistsException,
MarketplaceProductNotFoundException,
MarketplaceProductValidationException,
ValidationException)
from app.services.marketplace_product_service import MarketplaceProductService
from app.exceptions import (
MarketplaceProductNotFoundException,
MarketplaceProductAlreadyExistsException,
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
ValidationException,
)
from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
from models.database.marketplace_product import MarketplaceProduct
from models.schema.marketplace_product import (MarketplaceProductCreate,
MarketplaceProductUpdate)
@pytest.mark.unit
@@ -98,7 +97,10 @@ class TestProductService:
assert exc_info.value.error_code == "PRODUCT_ALREADY_EXISTS"
assert test_marketplace_product.marketplace_product_id in str(exc_info.value)
assert exc_info.value.status_code == 409
assert exc_info.value.details.get("marketplace_product_id") == test_marketplace_product.marketplace_product_id
assert (
exc_info.value.details.get("marketplace_product_id")
== test_marketplace_product.marketplace_product_id
)
def test_create_product_invalid_price(self, db):
"""Test product creation with invalid price raises InvalidMarketplaceProductDataException"""
@@ -117,9 +119,14 @@ class TestProductService:
def test_get_product_by_id_or_raise_success(self, db, test_marketplace_product):
"""Test successful product retrieval by ID"""
product = self.service.get_product_by_id_or_raise(db, test_marketplace_product.marketplace_product_id)
product = self.service.get_product_by_id_or_raise(
db, test_marketplace_product.marketplace_product_id
)
assert product.marketplace_product_id == test_marketplace_product.marketplace_product_id
assert (
product.marketplace_product_id
== test_marketplace_product.marketplace_product_id
)
assert product.title == test_marketplace_product.title
def test_get_product_by_id_or_raise_not_found(self, db):
@@ -152,21 +159,35 @@ class TestProductService:
assert total >= 1
assert len(products) >= 1
# Verify search worked by checking that title contains search term
found_product = next((p for p in products if p.marketplace_product_id == test_marketplace_product.marketplace_product_id), None)
found_product = next(
(
p
for p in products
if p.marketplace_product_id
== test_marketplace_product.marketplace_product_id
),
None,
)
assert found_product is not None
def test_update_product_success(self, db, test_marketplace_product):
"""Test successful product update"""
update_data = MarketplaceProductUpdate(
title="Updated MarketplaceProduct Title",
price="39.99"
title="Updated MarketplaceProduct Title", price="39.99"
)
updated_product = self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
updated_product = self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert updated_product.title == "Updated MarketplaceProduct Title"
assert updated_product.price == "39.99" # Price is stored as string after processing
assert updated_product.marketplace_product_id == test_marketplace_product.marketplace_product_id # ID unchanged
assert (
updated_product.price == "39.99"
) # Price is stored as string after processing
assert (
updated_product.marketplace_product_id
== test_marketplace_product.marketplace_product_id
) # ID unchanged
def test_update_product_not_found(self, db):
"""Test updating non-existent product raises MarketplaceProductNotFoundException"""
@@ -183,7 +204,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(gtin="invalid_gtin")
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
assert "Invalid GTIN format" in str(exc_info.value)
@@ -194,7 +217,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(title="")
with pytest.raises(MarketplaceProductValidationException) as exc_info:
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert exc_info.value.error_code == "PRODUCT_VALIDATION_FAILED"
assert "MarketplaceProduct title cannot be empty" in str(exc_info.value)
@@ -205,7 +230,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(price="invalid_price")
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
assert "Invalid price format" in str(exc_info.value)
@@ -213,12 +240,16 @@ class TestProductService:
def test_delete_product_success(self, db, test_marketplace_product):
"""Test successful product deletion"""
result = self.service.delete_product(db, test_marketplace_product.marketplace_product_id)
result = self.service.delete_product(
db, test_marketplace_product.marketplace_product_id
)
assert result is True
# Verify product is deleted
deleted_product = self.service.get_product_by_id(db, test_marketplace_product.marketplace_product_id)
deleted_product = self.service.get_product_by_id(
db, test_marketplace_product.marketplace_product_id
)
assert deleted_product is None
def test_delete_product_not_found(self, db):
@@ -229,10 +260,14 @@ class TestProductService:
assert exc_info.value.error_code == "PRODUCT_NOT_FOUND"
assert "NONEXISTENT" in str(exc_info.value)
def test_get_inventory_info_success(self, db, test_marketplace_product_with_inventory):
def test_get_inventory_info_success(
self, db, test_marketplace_product_with_inventory
):
"""Test getting inventory info for product with inventory"""
# Extract the product from the dictionary
marketplace_product = test_marketplace_product_with_inventory['marketplace_product']
marketplace_product = test_marketplace_product_with_inventory[
"marketplace_product"
]
inventory_info = self.service.get_inventory_info(db, marketplace_product.gtin)
@@ -243,13 +278,17 @@ class TestProductService:
def test_get_inventory_info_no_inventory(self, db, test_marketplace_product):
"""Test getting inventory info for product without inventory"""
inventory_info = self.service.get_inventory_info(db, test_marketplace_product.gtin or "1234567890123")
inventory_info = self.service.get_inventory_info(
db, test_marketplace_product.gtin or "1234567890123"
)
assert inventory_info is None
def test_product_exists_true(self, db, test_marketplace_product):
"""Test product_exists returns True for existing product"""
exists = self.service.product_exists(db, test_marketplace_product.marketplace_product_id)
exists = self.service.product_exists(
db, test_marketplace_product.marketplace_product_id
)
assert exists is True
def test_product_exists_false(self, db):
@@ -265,7 +304,9 @@ class TestProductService:
csv_lines = list(csv_generator)
assert len(csv_lines) > 1 # Header + at least one data row
assert csv_lines[0].startswith("marketplace_product_id,title,description") # Check header
assert csv_lines[0].startswith(
"marketplace_product_id,title,description"
) # Check header
# Check that test product appears in CSV
csv_content = "".join(csv_lines)
@@ -274,8 +315,7 @@ class TestProductService:
def test_generate_csv_export_with_filters(self, db, test_marketplace_product):
"""Test CSV export with marketplace filter"""
csv_generator = self.service.generate_csv_export(
db,
marketplace=test_marketplace_product.marketplace
db, marketplace=test_marketplace_product.marketplace
)
csv_lines = list(csv_generator)

View File

@@ -2,8 +2,8 @@
import pytest
from app.services.stats_service import StatsService
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
@pytest.mark.unit
@@ -15,7 +15,9 @@ class TestStatsService:
"""Setup method following the same pattern as other service tests"""
self.service = StatsService()
def test_get_comprehensive_stats_basic(self, db, test_marketplace_product, test_inventory):
def test_get_comprehensive_stats_basic(
self, db, test_marketplace_product, test_inventory
):
"""Test getting comprehensive stats with basic data"""
stats = self.service.get_comprehensive_stats(db)
@@ -31,7 +33,9 @@ class TestStatsService:
assert stats["total_inventory_entries"] >= 1
assert stats["total_inventory_quantity"] >= 10 # test_inventory has quantity 10
def test_get_comprehensive_stats_multiple_products(self, db, test_marketplace_product):
def test_get_comprehensive_stats_multiple_products(
self, db, test_marketplace_product
):
"""Test comprehensive stats with multiple products across different dimensions"""
# Create products with different brands, categories, marketplaces
additional_products = [
@@ -87,7 +91,7 @@ class TestStatsService:
brand=None, # Null brand
google_product_category=None, # Null category
marketplace=None, # Null marketplace
vendor_name=None, # Null vendor
vendor_name=None, # Null vendor
price="10.00",
currency="EUR",
),
@@ -97,7 +101,7 @@ class TestStatsService:
brand="", # Empty brand
google_product_category="", # Empty category
marketplace="", # Empty marketplace
vendor_name="", # Empty vendor
vendor_name="", # Empty vendor
price="15.00",
currency="EUR",
),
@@ -124,7 +128,11 @@ class TestStatsService:
# Find our test marketplace in the results
test_marketplace_stat = next(
(stat for stat in stats if stat["marketplace"] == test_marketplace_product.marketplace),
(
stat
for stat in stats
if stat["marketplace"] == test_marketplace_product.marketplace
),
None,
)
assert test_marketplace_stat is not None
@@ -309,7 +317,9 @@ class TestStatsService:
count = self.service._get_unique_marketplaces_count(db)
assert count >= 2 # At least Amazon and eBay, plus test_marketplace_product marketplace
assert (
count >= 2
) # At least Amazon and eBay, plus test_marketplace_product marketplace
assert isinstance(count, int)
def test_get_unique_vendors_count(self, db, test_marketplace_product):
@@ -338,7 +348,9 @@ class TestStatsService:
count = self.service._get_unique_vendors_count(db)
assert count >= 2 # At least VendorA and VendorB, plus test_marketplace_product vendor
assert (
count >= 2
) # At least VendorA and VendorB, plus test_marketplace_product vendor
assert isinstance(count, int)
def test_get_inventory_statistics(self, db, test_inventory):
@@ -438,7 +450,7 @@ class TestStatsService:
db.add_all(marketplace_products)
db.commit()
vendors =self.service._get_vendors_by_marketplace(db, "TestMarketplace")
vendors = self.service._get_vendors_by_marketplace(db, "TestMarketplace")
assert len(vendors) == 2
assert "TestVendor1" in vendors
@@ -482,7 +494,9 @@ class TestStatsService:
def test_get_products_by_marketplace_not_found(self, db):
"""Test getting product count for non-existent marketplace"""
count = self.service._get_products_by_marketplace_count(db, "NonExistentMarketplace")
count = self.service._get_products_by_marketplace_count(
db, "NonExistentMarketplace"
)
assert count == 0

View File

@@ -1,19 +1,16 @@
# tests/test_vendor_service.py (updated to use custom exceptions)
import pytest
from app.exceptions import (InvalidVendorDataException,
MarketplaceProductNotFoundException,
MaxVendorsReachedException,
ProductAlreadyExistsException,
UnauthorizedVendorAccessException,
ValidationException, VendorAlreadyExistsException,
VendorNotFoundException)
from app.services.vendor_service import VendorService
from app.exceptions import (
VendorNotFoundException,
VendorAlreadyExistsException,
UnauthorizedVendorAccessException,
InvalidVendorDataException,
MarketplaceProductNotFoundException,
ProductAlreadyExistsException,
MaxVendorsReachedException,
ValidationException,
)
from models.schema.vendor import VendorCreate
from models.schema.product import ProductCreate
from models.schema.vendor import VendorCreate
@pytest.mark.unit
@@ -38,15 +35,17 @@ class TestVendorService:
assert vendor is not None
assert vendor.vendor_code == "NEWVENDOR"
assert vendor.owner_user_id == test_user.id
assert vendor.is_verified is False # Regular user creates unverified vendor
assert vendor.is_verified is False # Regular user creates unverified vendor
def test_create_vendor_admin_auto_verify(self, db, test_admin, vendor_factory):
"""Test admin creates verified vendor automatically"""
vendor_data = VendorCreate(vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor")
vendor_data = VendorCreate(
vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor"
)
vendor = self.service.create_vendor(db, vendor_data, test_admin)
assert vendor.is_verified is True # Admin creates verified vendor
assert vendor.is_verified is True # Admin creates verified vendor
def test_create_vendor_duplicate_code(self, db, test_user, test_vendor):
"""Test vendor creation fails with duplicate vendor code"""
@@ -88,7 +87,9 @@ class TestVendorService:
def test_create_vendor_invalid_code_format(self, db, test_user):
"""Test vendor creation fails with invalid vendor code format"""
vendor_data = VendorCreate(vendor_code="INVALID@CODE!", vendor_name="Test Vendor")
vendor_data = VendorCreate(
vendor_code="INVALID@CODE!", vendor_name="Test Vendor"
)
with pytest.raises(InvalidVendorDataException) as exc_info:
self.service.create_vendor(db, vendor_data, test_user)
@@ -105,7 +106,9 @@ class TestVendorService:
def mock_check_vendor_limit(self, db, user):
raise MaxVendorsReachedException(max_vendors=5, user_id=user.id)
monkeypatch.setattr(VendorService, "_check_vendor_limit", mock_check_vendor_limit)
monkeypatch.setattr(
VendorService, "_check_vendor_limit", mock_check_vendor_limit
)
vendor_data = VendorCreate(vendor_code="NEWVENDOR", vendor_name="New Vendor")
@@ -118,7 +121,9 @@ class TestVendorService:
assert exception.details["max_vendors"] == 5
assert exception.details["user_id"] == test_user.id
def test_get_vendors_regular_user(self, db, test_user, test_vendor, inactive_vendor):
def test_get_vendors_regular_user(
self, db, test_user, test_vendor, inactive_vendor
):
"""Test regular user can only see active verified vendors and own vendors"""
vendors, total = self.service.get_vendors(db, test_user, skip=0, limit=10)
@@ -127,7 +132,7 @@ class TestVendorService:
assert inactive_vendor.vendor_code not in vendor_codes
def test_get_vendors_admin_user(
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
):
"""Test admin user can see all vendors with filters"""
vendors, total = self.service.get_vendors(
@@ -140,14 +145,16 @@ class TestVendorService:
assert verified_vendor.vendor_code in vendor_codes
def test_get_vendor_by_code_owner_access(self, db, test_user, test_vendor):
"""Test vendor owner can access their own vendor """
vendor = self.service.get_vendor_by_code(db, test_vendor.vendor_code.lower(), test_user)
"""Test vendor owner can access their own vendor"""
vendor = self.service.get_vendor_by_code(
db, test_vendor.vendor_code.lower(), test_user
)
assert vendor is not None
assert vendor.id == test_vendor.id
def test_get_vendor_by_code_admin_access(self, db, test_admin, test_vendor):
"""Test admin can access any vendor """
"""Test admin can access any vendor"""
vendor = self.service.get_vendor_by_code(
db, test_vendor.vendor_code.lower(), test_admin
)
@@ -178,16 +185,14 @@ class TestVendorService:
assert exception.details["user_id"] == test_user.id
def test_add_product_to_vendor_success(self, db, test_vendor, unique_product):
"""Test successfully adding product to vendor """
"""Test successfully adding product to vendor"""
product_data = ProductCreate(
marketplace_product_id=unique_product.marketplace_product_id,
price="15.99",
is_featured=True,
)
product = self.service.add_product_to_catalog(
db, test_vendor, product_data
)
product = self.service.add_product_to_catalog(db, test_vendor, product_data)
assert product is not None
assert product.vendor_id == test_vendor.id
@@ -195,7 +200,9 @@ class TestVendorService:
def test_add_product_to_vendor_product_not_found(self, db, test_vendor):
"""Test adding non-existent product to vendor fails"""
product_data = ProductCreate(marketplace_product_id="NONEXISTENT", price="15.99")
product_data = ProductCreate(
marketplace_product_id="NONEXISTENT", price="15.99"
)
with pytest.raises(MarketplaceProductNotFoundException) as exc_info:
self.service.add_product_to_catalog(db, test_vendor, product_data)
@@ -209,7 +216,8 @@ class TestVendorService:
def test_add_product_to_vendor_already_exists(self, db, test_vendor, test_product):
"""Test adding product that's already in vendor fails"""
product_data = ProductCreate(
marketplace_product_id=test_product.marketplace_product.marketplace_product_id, price="15.99"
marketplace_product_id=test_product.marketplace_product.marketplace_product_id,
price="15.99",
)
with pytest.raises(ProductAlreadyExistsException) as exc_info:
@@ -219,11 +227,12 @@ class TestVendorService:
assert exception.status_code == 409
assert exception.error_code == "PRODUCT_ALREADY_EXISTS"
assert exception.details["vendor_code"] == test_vendor.vendor_code
assert exception.details["marketplace_product_id"] == test_product.marketplace_product.marketplace_product_id
assert (
exception.details["marketplace_product_id"]
== test_product.marketplace_product.marketplace_product_id
)
def test_get_products_owner_access(
self, db, test_user, test_vendor, test_product
):
def test_get_products_owner_access(self, db, test_user, test_vendor, test_product):
"""Test vendor owner can get vendor products"""
products, total = self.service.get_products(db, test_vendor, test_user)
@@ -291,7 +300,9 @@ class TestVendorService:
assert exception.error_code == "VALIDATION_ERROR"
assert "Failed to retrieve vendors" in exception.message
def test_add_product_database_error(self, db, test_vendor, unique_product, monkeypatch):
def test_add_product_database_error(
self, db, test_vendor, unique_product, monkeypatch
):
"""Test add product handles database errors gracefully"""
def mock_commit():

View File

@@ -18,7 +18,9 @@ class TestCSVProcessor:
def test_download_csv_encoding_fallback(self, mock_get):
"""Test CSV download with encoding fallback"""
# Create content with special characters that would fail UTF-8 if not properly encoded
special_content = "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
special_content = (
"marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
)
mock_response = Mock()
mock_response.status_code = 200
@@ -40,9 +42,7 @@ class TestCSVProcessor:
mock_response = Mock()
mock_response.status_code = 200
# Create bytes that will fail most encodings
mock_response.content = (
b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
)
mock_response.content = b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response