style: apply black and isort formatting across entire codebase
- Standardize quote style (single to double quotes) - Reorder and group imports alphabetically - Fix line breaks and indentation for consistency - Apply PEP 8 formatting standards Also updated Makefile to exclude both venv and .venv from code quality checks. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -7,13 +7,13 @@ from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.core.database import Base, get_db
|
||||
from main import app
|
||||
from models.database.inventory import Inventory
|
||||
# Import all models to ensure they're registered with Base metadata
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.vendor import Vendor
|
||||
from models.database.product import Product
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
# Use in-memory SQLite database for tests
|
||||
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///:memory:"
|
||||
|
||||
2
tests/fixtures/auth_fixtures.py
vendored
2
tests/fixtures/auth_fixtures.py
vendored
@@ -53,6 +53,7 @@ def test_admin(db, auth_manager):
|
||||
db.expunge(admin)
|
||||
return admin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def another_admin(db, auth_manager):
|
||||
"""Create another test admin user for testing admin-to-admin interactions"""
|
||||
@@ -72,6 +73,7 @@ def another_admin(db, auth_manager):
|
||||
db.expunge(admin)
|
||||
return admin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user(db, auth_manager):
|
||||
"""Create a different user for testing access controls"""
|
||||
|
||||
1
tests/fixtures/customer_fixtures.py
vendored
1
tests/fixtures/customer_fixtures.py
vendored
@@ -1,5 +1,6 @@
|
||||
# tests/fixtures/customer_fixtures.py
|
||||
import pytest
|
||||
|
||||
from models.database.customer import Customer, CustomerAddress
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# tests/fixtures/marketplace_import_job_fixtures.py
|
||||
import pytest
|
||||
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
|
||||
|
||||
|
||||
@@ -117,7 +117,9 @@ def marketplace_product_factory():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_marketplace_product_with_inventory(db, test_marketplace_product, test_inventory):
|
||||
def test_marketplace_product_with_inventory(
|
||||
db, test_marketplace_product, test_inventory
|
||||
):
|
||||
"""MarketplaceProduct with associated inventory record."""
|
||||
# Ensure they're linked by GTIN
|
||||
if test_marketplace_product.gtin != test_inventory.gtin:
|
||||
@@ -126,6 +128,6 @@ def test_marketplace_product_with_inventory(db, test_marketplace_product, test_i
|
||||
db.refresh(test_inventory)
|
||||
|
||||
return {
|
||||
'marketplace_product': test_marketplace_product,
|
||||
'inventory': test_inventory
|
||||
"marketplace_product": test_marketplace_product,
|
||||
"inventory": test_inventory,
|
||||
}
|
||||
|
||||
5
tests/fixtures/testing_fixtures.py
vendored
5
tests/fixtures/testing_fixtures.py
vendored
@@ -8,8 +8,9 @@ This module provides fixtures for:
|
||||
- Additional testing utilities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
|
||||
@@ -26,7 +27,7 @@ def empty_db(db):
|
||||
"inventory", # Fixed: singular not plural
|
||||
"products", # Referenced by products
|
||||
"vendors", # Has foreign key to users
|
||||
"users" # Base table
|
||||
"users", # Base table
|
||||
]
|
||||
|
||||
for table in tables_to_clear:
|
||||
|
||||
6
tests/fixtures/vendor_fixtures.py
vendored
6
tests/fixtures/vendor_fixtures.py
vendored
@@ -1,10 +1,11 @@
|
||||
# tests/fixtures/vendor_fixtures.py
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from models.database.vendor import Vendor
|
||||
from models.database.product import Product
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -185,6 +186,7 @@ def multiple_inventory_entries(db, multiple_products, test_vendor):
|
||||
|
||||
def create_unique_vendor_factory():
|
||||
"""Factory function to create unique vendors in tests"""
|
||||
|
||||
def _create_vendor(db, owner_user_id, **kwargs):
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
defaults = {
|
||||
|
||||
@@ -59,7 +59,9 @@ class TestAdminAPI:
|
||||
assert response.status_code == 400 # Business logic error
|
||||
data = response.json()
|
||||
assert data["error_code"] == "CANNOT_MODIFY_SELF"
|
||||
assert "Cannot perform 'deactivate account' on your own account" in data["message"]
|
||||
assert (
|
||||
"Cannot perform 'deactivate account' on your own account" in data["message"]
|
||||
)
|
||||
|
||||
def test_toggle_user_status_cannot_modify_admin(
|
||||
self, client, admin_headers, test_admin, another_admin
|
||||
@@ -85,7 +87,9 @@ class TestAdminAPI:
|
||||
|
||||
# Check that test_vendor is in the response
|
||||
vendor_codes = [
|
||||
vendor ["vendor_code"] for vendor in data["vendors"] if "vendor_code" in vendor
|
||||
vendor["vendor_code"]
|
||||
for vendor in data["vendors"]
|
||||
if "vendor_code" in vendor
|
||||
]
|
||||
assert test_vendor.vendor_code in vendor_codes
|
||||
|
||||
@@ -98,7 +102,7 @@ class TestAdminAPI:
|
||||
assert data["error_code"] == "ADMIN_REQUIRED"
|
||||
|
||||
def test_verify_vendor_admin(self, client, admin_headers, test_vendor):
|
||||
"""Test admin verifying/unverifying vendor """
|
||||
"""Test admin verifying/unverifying vendor"""
|
||||
response = client.put(
|
||||
f"/api/v1/admin/vendors/{test_vendor.id}/verify", headers=admin_headers
|
||||
)
|
||||
@@ -109,8 +113,10 @@ class TestAdminAPI:
|
||||
assert test_vendor.vendor_code in message
|
||||
|
||||
def test_verify_vendor_not_found(self, client, admin_headers):
|
||||
"""Test admin verifying non-existent vendor """
|
||||
response = client.put("/api/v1/admin/vendors/99999/verify", headers=admin_headers)
|
||||
"""Test admin verifying non-existent vendor"""
|
||||
response = client.put(
|
||||
"/api/v1/admin/vendors/99999/verify", headers=admin_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
@@ -129,8 +135,10 @@ class TestAdminAPI:
|
||||
assert test_vendor.vendor_code in message
|
||||
|
||||
def test_toggle_vendor_status_not_found(self, client, admin_headers):
|
||||
"""Test admin toggling status for non-existent vendor """
|
||||
response = client.put("/api/v1/admin/vendors/99999/status", headers=admin_headers)
|
||||
"""Test admin toggling status for non-existent vendor"""
|
||||
response = client.put(
|
||||
"/api/v1/admin/vendors/99999/status", headers=admin_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
@@ -166,7 +174,8 @@ class TestAdminAPI:
|
||||
data = response.json()
|
||||
assert len(data) >= 1
|
||||
assert all(
|
||||
job["marketplace"] == test_marketplace_import_job.marketplace for job in data
|
||||
job["marketplace"] == test_marketplace_import_job.marketplace
|
||||
for job in data
|
||||
)
|
||||
|
||||
def test_get_marketplace_import_jobs_non_admin(self, client, auth_headers):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# tests/integration/api/v1/test_auth_endpoints.py
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from jose import jwt
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -178,8 +179,7 @@ class TestAuthenticationAPI:
|
||||
def test_get_current_user_invalid_token(self, client):
|
||||
"""Test getting current user with invalid token"""
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": "Bearer invalid_token_here"}
|
||||
"/api/v1/auth/me", headers={"Authorization": "Bearer invalid_token_here"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -201,14 +201,11 @@ class TestAuthenticationAPI:
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(
|
||||
expired_payload,
|
||||
auth_manager.secret_key,
|
||||
algorithm=auth_manager.algorithm
|
||||
expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {expired_token}"}
|
||||
"/api/v1/auth/me", headers={"Authorization": f"Bearer {expired_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -321,4 +318,3 @@ class TestAuthManager:
|
||||
user = auth_manager.authenticate_user(db, "nonexistent", "password")
|
||||
|
||||
assert user is None
|
||||
|
||||
|
||||
@@ -14,19 +14,34 @@ class TestFiltering:
|
||||
"""Test filtering products by brand successfully"""
|
||||
# Create products with different brands using unique IDs
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
products = [
|
||||
MarketplaceProduct(marketplace_product_id=f"BRAND1_{unique_suffix}", title="MarketplaceProduct 1", brand="BrandA"),
|
||||
MarketplaceProduct(marketplace_product_id=f"BRAND2_{unique_suffix}", title="MarketplaceProduct 2", brand="BrandB"),
|
||||
MarketplaceProduct(marketplace_product_id=f"BRAND3_{unique_suffix}", title="MarketplaceProduct 3", brand="BrandA"),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"BRAND1_{unique_suffix}",
|
||||
title="MarketplaceProduct 1",
|
||||
brand="BrandA",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"BRAND2_{unique_suffix}",
|
||||
title="MarketplaceProduct 2",
|
||||
brand="BrandB",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"BRAND3_{unique_suffix}",
|
||||
title="MarketplaceProduct 3",
|
||||
brand="BrandA",
|
||||
),
|
||||
]
|
||||
|
||||
db.add_all(products)
|
||||
db.commit()
|
||||
|
||||
# Filter by BrandA
|
||||
response = client.get("/api/v1/marketplace/product?brand=BrandA", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=BrandA", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 2 # At least our test products
|
||||
@@ -37,7 +52,9 @@ class TestFiltering:
|
||||
assert product["brand"] == "BrandA"
|
||||
|
||||
# Filter by BrandB
|
||||
response = client.get("/api/v1/marketplace/product?brand=BrandB", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=BrandB", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1 # At least our test product
|
||||
@@ -45,37 +62,57 @@ class TestFiltering:
|
||||
def test_product_marketplace_filter_success(self, client, auth_headers, db):
|
||||
"""Test filtering products by marketplace successfully"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
products = [
|
||||
MarketplaceProduct(marketplace_product_id=f"MKT1_{unique_suffix}", title="MarketplaceProduct 1", marketplace="Amazon"),
|
||||
MarketplaceProduct(marketplace_product_id=f"MKT2_{unique_suffix}", title="MarketplaceProduct 2", marketplace="eBay"),
|
||||
MarketplaceProduct(marketplace_product_id=f"MKT3_{unique_suffix}", title="MarketplaceProduct 3", marketplace="Amazon"),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"MKT1_{unique_suffix}",
|
||||
title="MarketplaceProduct 1",
|
||||
marketplace="Amazon",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"MKT2_{unique_suffix}",
|
||||
title="MarketplaceProduct 2",
|
||||
marketplace="eBay",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"MKT3_{unique_suffix}",
|
||||
title="MarketplaceProduct 3",
|
||||
marketplace="Amazon",
|
||||
),
|
||||
]
|
||||
|
||||
db.add_all(products)
|
||||
db.commit()
|
||||
|
||||
response = client.get("/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?marketplace=Amazon", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 2 # At least our test products
|
||||
|
||||
# Verify all returned products have Amazon marketplace
|
||||
amazon_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
|
||||
amazon_products = [
|
||||
p
|
||||
for p in data["products"]
|
||||
if p["marketplace_product_id"].endswith(unique_suffix)
|
||||
]
|
||||
for product in amazon_products:
|
||||
assert product["marketplace"] == "Amazon"
|
||||
|
||||
def test_product_search_filter_success(self, client, auth_headers, db):
|
||||
"""Test searching products by text successfully"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
products = [
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"SEARCH1_{unique_suffix}",
|
||||
title=f"Apple iPhone {unique_suffix}",
|
||||
description="Smartphone"
|
||||
description="Smartphone",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"SEARCH2_{unique_suffix}",
|
||||
@@ -85,7 +122,7 @@ class TestFiltering:
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"SEARCH3_{unique_suffix}",
|
||||
title=f"iPad Tablet {unique_suffix}",
|
||||
description="Apple tablet"
|
||||
description="Apple tablet",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -93,13 +130,17 @@ class TestFiltering:
|
||||
db.commit()
|
||||
|
||||
# Search for "Apple"
|
||||
response = client.get(f"/api/v1/marketplace/product?search=Apple", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?search=Apple", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 2 # iPhone and iPad
|
||||
|
||||
# Search for "phone"
|
||||
response = client.get(f"/api/v1/marketplace/product?search=phone", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?search=phone", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 2 # iPhone and Galaxy
|
||||
@@ -107,6 +148,7 @@ class TestFiltering:
|
||||
def test_combined_filters_success(self, client, auth_headers, db):
|
||||
"""Test combining multiple filters successfully"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
products = [
|
||||
@@ -135,14 +177,19 @@ class TestFiltering:
|
||||
|
||||
# Filter by brand AND marketplace
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=Apple&marketplace=Amazon", headers=auth_headers
|
||||
"/api/v1/marketplace/product?brand=Apple&marketplace=Amazon",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1 # At least iPhone matches both
|
||||
|
||||
# Find our specific test product
|
||||
matching_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
|
||||
matching_products = [
|
||||
p
|
||||
for p in data["products"]
|
||||
if p["marketplace_product_id"].endswith(unique_suffix)
|
||||
]
|
||||
for product in matching_products:
|
||||
assert product["brand"] == "Apple"
|
||||
assert product["marketplace"] == "Amazon"
|
||||
@@ -150,7 +197,8 @@ class TestFiltering:
|
||||
def test_filter_with_no_results(self, client, auth_headers):
|
||||
"""Test filtering with criteria that returns no results"""
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=NonexistentBrand123456", headers=auth_headers
|
||||
"/api/v1/marketplace/product?brand=NonexistentBrand123456",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -161,6 +209,7 @@ class TestFiltering:
|
||||
def test_filter_case_insensitive(self, client, auth_headers, db):
|
||||
"""Test that filters are case-insensitive"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
product = MarketplaceProduct(
|
||||
@@ -174,7 +223,10 @@ class TestFiltering:
|
||||
|
||||
# Test different case variations
|
||||
for brand_filter in ["TestBrand", "testbrand", "TESTBRAND"]:
|
||||
response = client.get(f"/api/v1/marketplace/product?brand={brand_filter}", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?brand={brand_filter}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
@@ -183,9 +235,14 @@ class TestFiltering:
|
||||
"""Test behavior with invalid filter parameters"""
|
||||
# Test with very long filter values
|
||||
long_brand = "A" * 1000
|
||||
response = client.get(f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?brand={long_brand}", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200 # Should handle gracefully
|
||||
|
||||
# Test with special characters
|
||||
response = client.get("/api/v1/marketplace/product?brand=<script>alert('test')</script>", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=<script>alert('test')</script>",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200 # Should handle gracefully
|
||||
|
||||
@@ -17,7 +17,9 @@ class TestInventoryAPI:
|
||||
"quantity": 100,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
|
||||
response = client.post(
|
||||
"/api/v1/inventory", headers=auth_headers, json=inventory_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -38,7 +40,9 @@ class TestInventoryAPI:
|
||||
"quantity": 75,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
|
||||
response = client.post(
|
||||
"/api/v1/inventory", headers=auth_headers, json=inventory_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -52,7 +56,9 @@ class TestInventoryAPI:
|
||||
"quantity": 100,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
|
||||
response = client.post(
|
||||
"/api/v1/inventory", headers=auth_headers, json=inventory_data
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
@@ -60,7 +66,9 @@ class TestInventoryAPI:
|
||||
assert data["status_code"] == 422
|
||||
assert "GTIN is required" in data["message"]
|
||||
|
||||
def test_set_inventory_invalid_quantity_validation_error(self, client, auth_headers):
|
||||
def test_set_inventory_invalid_quantity_validation_error(
|
||||
self, client, auth_headers
|
||||
):
|
||||
"""Test setting inventory with invalid quantity returns InvalidQuantityException"""
|
||||
inventory_data = {
|
||||
"gtin": "1234567890123",
|
||||
@@ -68,7 +76,9 @@ class TestInventoryAPI:
|
||||
"quantity": -10, # Negative quantity
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
|
||||
response = client.post(
|
||||
"/api/v1/inventory", headers=auth_headers, json=inventory_data
|
||||
)
|
||||
|
||||
assert response.status_code in [400, 422]
|
||||
data = response.json()
|
||||
@@ -138,7 +148,9 @@ class TestInventoryAPI:
|
||||
data = response.json()
|
||||
assert data["quantity"] == 35 # 50 - 15
|
||||
|
||||
def test_remove_inventory_insufficient_returns_business_logic_error(self, client, auth_headers, db):
|
||||
def test_remove_inventory_insufficient_returns_business_logic_error(
|
||||
self, client, auth_headers, db
|
||||
):
|
||||
"""Test removing more inventory than available returns InsufficientInventoryException"""
|
||||
# Create initial inventory
|
||||
inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=10)
|
||||
@@ -184,7 +196,9 @@ class TestInventoryAPI:
|
||||
assert data["error_code"] == "INVENTORY_NOT_FOUND"
|
||||
assert data["status_code"] == 404
|
||||
|
||||
def test_negative_inventory_not_allowed_business_logic_error(self, client, auth_headers, db):
|
||||
def test_negative_inventory_not_allowed_business_logic_error(
|
||||
self, client, auth_headers, db
|
||||
):
|
||||
"""Test operations resulting in negative inventory returns NegativeInventoryException"""
|
||||
# Create initial inventory
|
||||
inventory = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=5)
|
||||
@@ -204,14 +218,21 @@ class TestInventoryAPI:
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
# This might be caught as INSUFFICIENT_INVENTORY or NEGATIVE_INVENTORY_NOT_ALLOWED
|
||||
assert data["error_code"] in ["INSUFFICIENT_INVENTORY", "NEGATIVE_INVENTORY_NOT_ALLOWED"]
|
||||
assert data["error_code"] in [
|
||||
"INSUFFICIENT_INVENTORY",
|
||||
"NEGATIVE_INVENTORY_NOT_ALLOWED",
|
||||
]
|
||||
assert data["status_code"] == 400
|
||||
|
||||
def test_get_inventory_by_gtin_success(self, client, auth_headers, db):
|
||||
"""Test getting inventory summary for GTIN successfully"""
|
||||
# Create inventory in multiple locations
|
||||
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
|
||||
inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25)
|
||||
inventory1 = Inventory(
|
||||
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
|
||||
)
|
||||
inventory2 = Inventory(
|
||||
gtin="1234567890123", location="WAREHOUSE_B", quantity=25
|
||||
)
|
||||
db.add_all([inventory1, inventory2])
|
||||
db.commit()
|
||||
|
||||
@@ -238,12 +259,18 @@ class TestInventoryAPI:
|
||||
def test_get_total_inventory_success(self, client, auth_headers, db):
|
||||
"""Test getting total inventory for GTIN successfully"""
|
||||
# Create inventory in multiple locations
|
||||
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
|
||||
inventory2 = Inventory(gtin="1234567890123", location="WAREHOUSE_B", quantity=25)
|
||||
inventory1 = Inventory(
|
||||
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
|
||||
)
|
||||
inventory2 = Inventory(
|
||||
gtin="1234567890123", location="WAREHOUSE_B", quantity=25
|
||||
)
|
||||
db.add_all([inventory1, inventory2])
|
||||
db.commit()
|
||||
|
||||
response = client.get("/api/v1/inventory/1234567890123/total", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/inventory/1234567890123/total", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -253,7 +280,9 @@ class TestInventoryAPI:
|
||||
|
||||
def test_get_total_inventory_not_found(self, client, auth_headers):
|
||||
"""Test getting total inventory for nonexistent GTIN returns InventoryNotFoundException"""
|
||||
response = client.get("/api/v1/inventory/9999999999999/total", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/inventory/9999999999999/total", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
@@ -263,8 +292,12 @@ class TestInventoryAPI:
|
||||
def test_get_all_inventory_success(self, client, auth_headers, db):
|
||||
"""Test getting all inventory entries successfully"""
|
||||
# Create some inventory entries
|
||||
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
|
||||
inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25)
|
||||
inventory1 = Inventory(
|
||||
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
|
||||
)
|
||||
inventory2 = Inventory(
|
||||
gtin="9876543210987", location="WAREHOUSE_B", quantity=25
|
||||
)
|
||||
db.add_all([inventory1, inventory2])
|
||||
db.commit()
|
||||
|
||||
@@ -277,20 +310,28 @@ class TestInventoryAPI:
|
||||
def test_get_all_inventory_with_filters(self, client, auth_headers, db):
|
||||
"""Test getting inventory entries with filtering"""
|
||||
# Create inventory entries
|
||||
inventory1 = Inventory(gtin="1234567890123", location="WAREHOUSE_A", quantity=50)
|
||||
inventory2 = Inventory(gtin="9876543210987", location="WAREHOUSE_B", quantity=25)
|
||||
inventory1 = Inventory(
|
||||
gtin="1234567890123", location="WAREHOUSE_A", quantity=50
|
||||
)
|
||||
inventory2 = Inventory(
|
||||
gtin="9876543210987", location="WAREHOUSE_B", quantity=25
|
||||
)
|
||||
db.add_all([inventory1, inventory2])
|
||||
db.commit()
|
||||
|
||||
# Filter by location
|
||||
response = client.get("/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/inventory?location=WAREHOUSE_A", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for inventory in data:
|
||||
assert inventory["location"] == "WAREHOUSE_A"
|
||||
|
||||
# Filter by GTIN
|
||||
response = client.get("/api/v1/inventory?gtin=1234567890123", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/inventory?gtin=1234567890123", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for inventory in data:
|
||||
@@ -390,7 +431,9 @@ class TestInventoryAPI:
|
||||
"quantity": 100,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
|
||||
response = client.post(
|
||||
"/api/v1/inventory", headers=auth_headers, json=inventory_data
|
||||
)
|
||||
|
||||
# This depends on whether your service validates locations
|
||||
if response.status_code == 404:
|
||||
|
||||
@@ -8,9 +8,11 @@ import pytest
|
||||
@pytest.mark.api
|
||||
@pytest.mark.marketplace
|
||||
class TestMarketplaceImportJobAPI:
|
||||
def test_import_from_marketplace(self, client, auth_headers, test_vendor, test_user):
|
||||
def test_import_from_marketplace(
|
||||
self, client, auth_headers, test_vendor, test_user
|
||||
):
|
||||
"""Test marketplace import endpoint - just test job creation"""
|
||||
# Ensure user owns the vendor
|
||||
# Ensure user owns the vendor
|
||||
test_vendor.owner_user_id = test_user.id
|
||||
|
||||
import_data = {
|
||||
@@ -32,7 +34,7 @@ class TestMarketplaceImportJobAPI:
|
||||
assert data["vendor_id"] == test_vendor.id
|
||||
|
||||
def test_import_from_marketplace_invalid_vendor(self, client, auth_headers):
|
||||
"""Test marketplace import with invalid vendor """
|
||||
"""Test marketplace import with invalid vendor"""
|
||||
import_data = {
|
||||
"url": "https://example.com/products.csv",
|
||||
"marketplace": "TestMarket",
|
||||
@@ -48,7 +50,9 @@ class TestMarketplaceImportJobAPI:
|
||||
assert data["error_code"] == "VENDOR_NOT_FOUND"
|
||||
assert "NONEXISTENT" in data["message"]
|
||||
|
||||
def test_import_from_marketplace_unauthorized_vendor(self, client, auth_headers, test_vendor, other_user):
|
||||
def test_import_from_marketplace_unauthorized_vendor(
|
||||
self, client, auth_headers, test_vendor, other_user
|
||||
):
|
||||
"""Test marketplace import with unauthorized vendor access"""
|
||||
# Set vendor owner to different user
|
||||
test_vendor.owner_user_id = other_user.id
|
||||
@@ -85,8 +89,10 @@ class TestMarketplaceImportJobAPI:
|
||||
assert data["error_code"] == "VALIDATION_ERROR"
|
||||
assert "Request validation failed" in data["message"]
|
||||
|
||||
def test_import_from_marketplace_admin_access(self, client, admin_headers, test_vendor):
|
||||
"""Test that admin can import for any vendor """
|
||||
def test_import_from_marketplace_admin_access(
|
||||
self, client, admin_headers, test_vendor
|
||||
):
|
||||
"""Test that admin can import for any vendor"""
|
||||
import_data = {
|
||||
"url": "https://example.com/products.csv",
|
||||
"marketplace": "AdminMarket",
|
||||
@@ -94,7 +100,9 @@ class TestMarketplaceImportJobAPI:
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/marketplace/import-product", headers=admin_headers, json=import_data
|
||||
"/api/v1/marketplace/import-product",
|
||||
headers=admin_headers,
|
||||
json=import_data,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -102,11 +110,13 @@ class TestMarketplaceImportJobAPI:
|
||||
assert data["marketplace"] == "AdminMarket"
|
||||
assert data["vendor_code"] == test_vendor.vendor_code
|
||||
|
||||
def test_get_marketplace_import_status(self, client, auth_headers, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_status(
|
||||
self, client, auth_headers, test_marketplace_import_job
|
||||
):
|
||||
"""Test getting marketplace import status"""
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -118,8 +128,7 @@ class TestMarketplaceImportJobAPI:
|
||||
def test_get_marketplace_import_status_not_found(self, client, auth_headers):
|
||||
"""Test getting status of non-existent import job"""
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/import-status/99999",
|
||||
headers=auth_headers
|
||||
"/api/v1/marketplace/import-status/99999", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -127,21 +136,25 @@ class TestMarketplaceImportJobAPI:
|
||||
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
|
||||
assert "99999" in data["message"]
|
||||
|
||||
def test_get_marketplace_import_status_unauthorized(self, client, auth_headers, test_marketplace_import_job, other_user):
|
||||
def test_get_marketplace_import_status_unauthorized(
|
||||
self, client, auth_headers, test_marketplace_import_job, other_user
|
||||
):
|
||||
"""Test getting status of unauthorized import job"""
|
||||
# Change job owner to other user
|
||||
test_marketplace_import_job.user_id = other_user.id
|
||||
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
data = response.json()
|
||||
assert data["error_code"] == "IMPORT_JOB_NOT_OWNED"
|
||||
|
||||
def test_get_marketplace_import_jobs(self, client, auth_headers, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs(
|
||||
self, client, auth_headers, test_marketplace_import_job
|
||||
):
|
||||
"""Test getting marketplace import jobs"""
|
||||
response = client.get("/api/v1/marketplace/import-jobs", headers=auth_headers)
|
||||
|
||||
@@ -154,11 +167,13 @@ class TestMarketplaceImportJobAPI:
|
||||
job_ids = [job["job_id"] for job in data]
|
||||
assert test_marketplace_import_job.id in job_ids
|
||||
|
||||
def test_get_marketplace_import_jobs_with_filters(self, client, auth_headers, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_with_filters(
|
||||
self, client, auth_headers, test_marketplace_import_job
|
||||
):
|
||||
"""Test getting import jobs with filters"""
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/import-jobs?marketplace={test_marketplace_import_job.marketplace}",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -167,13 +182,15 @@ class TestMarketplaceImportJobAPI:
|
||||
assert len(data) >= 1
|
||||
|
||||
for job in data:
|
||||
assert test_marketplace_import_job.marketplace.lower() in job["marketplace"].lower()
|
||||
assert (
|
||||
test_marketplace_import_job.marketplace.lower()
|
||||
in job["marketplace"].lower()
|
||||
)
|
||||
|
||||
def test_get_marketplace_import_jobs_pagination(self, client, auth_headers):
|
||||
"""Test import jobs pagination"""
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/import-jobs?skip=0&limit=5",
|
||||
headers=auth_headers
|
||||
"/api/v1/marketplace/import-jobs?skip=0&limit=5", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -181,9 +198,13 @@ class TestMarketplaceImportJobAPI:
|
||||
assert isinstance(data, list)
|
||||
assert len(data) <= 5
|
||||
|
||||
def test_get_marketplace_import_stats(self, client, auth_headers, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_stats(
|
||||
self, client, auth_headers, test_marketplace_import_job
|
||||
):
|
||||
"""Test getting marketplace import statistics"""
|
||||
response = client.get("/api/v1/marketplace/marketplace-import-stats", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/marketplace-import-stats", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -195,12 +216,15 @@ class TestMarketplaceImportJobAPI:
|
||||
assert isinstance(data["total_jobs"], int)
|
||||
assert data["total_jobs"] >= 1
|
||||
|
||||
def test_cancel_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db):
|
||||
def test_cancel_marketplace_import_job(
|
||||
self, client, auth_headers, test_user, test_vendor, db
|
||||
):
|
||||
"""Test cancelling a marketplace import job"""
|
||||
# Create a pending job that can be cancelled
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
import uuid
|
||||
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
job = MarketplaceImportJob(
|
||||
status="pending",
|
||||
@@ -219,8 +243,7 @@ class TestMarketplaceImportJobAPI:
|
||||
db.refresh(job)
|
||||
|
||||
response = client.put(
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}/cancel",
|
||||
headers=auth_headers
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -232,15 +255,16 @@ class TestMarketplaceImportJobAPI:
|
||||
def test_cancel_marketplace_import_job_not_found(self, client, auth_headers):
|
||||
"""Test cancelling non-existent import job"""
|
||||
response = client.put(
|
||||
"/api/v1/marketplace/import-jobs/99999/cancel",
|
||||
headers=auth_headers
|
||||
"/api/v1/marketplace/import-jobs/99999/cancel", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
|
||||
|
||||
def test_cancel_marketplace_import_job_cannot_cancel(self, client, auth_headers, test_marketplace_import_job, db):
|
||||
def test_cancel_marketplace_import_job_cannot_cancel(
|
||||
self, client, auth_headers, test_marketplace_import_job, db
|
||||
):
|
||||
"""Test cancelling a job that cannot be cancelled"""
|
||||
# Set job to completed status
|
||||
test_marketplace_import_job.status = "completed"
|
||||
@@ -248,7 +272,7 @@ class TestMarketplaceImportJobAPI:
|
||||
|
||||
response = client.put(
|
||||
f"/api/v1/marketplace/import-jobs/{test_marketplace_import_job.id}/cancel",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -256,12 +280,15 @@ class TestMarketplaceImportJobAPI:
|
||||
assert data["error_code"] == "IMPORT_JOB_CANNOT_BE_CANCELLED"
|
||||
assert "completed" in data["message"]
|
||||
|
||||
def test_delete_marketplace_import_job(self, client, auth_headers, test_user, test_vendor, db):
|
||||
def test_delete_marketplace_import_job(
|
||||
self, client, auth_headers, test_user, test_vendor, db
|
||||
):
|
||||
"""Test deleting a marketplace import job"""
|
||||
# Create a completed job that can be deleted
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
import uuid
|
||||
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
job = MarketplaceImportJob(
|
||||
status="completed",
|
||||
@@ -280,8 +307,7 @@ class TestMarketplaceImportJobAPI:
|
||||
db.refresh(job)
|
||||
|
||||
response = client.delete(
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}",
|
||||
headers=auth_headers
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -291,20 +317,22 @@ class TestMarketplaceImportJobAPI:
|
||||
def test_delete_marketplace_import_job_not_found(self, client, auth_headers):
|
||||
"""Test deleting non-existent import job"""
|
||||
response = client.delete(
|
||||
"/api/v1/marketplace/import-jobs/99999",
|
||||
headers=auth_headers
|
||||
"/api/v1/marketplace/import-jobs/99999", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["error_code"] == "IMPORT_JOB_NOT_FOUND"
|
||||
|
||||
def test_delete_marketplace_import_job_cannot_delete(self, client, auth_headers, test_user, test_vendor, db):
|
||||
def test_delete_marketplace_import_job_cannot_delete(
|
||||
self, client, auth_headers, test_user, test_vendor, db
|
||||
):
|
||||
"""Test deleting a job that cannot be deleted"""
|
||||
# Create a pending job that cannot be deleted
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
import uuid
|
||||
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
job = MarketplaceImportJob(
|
||||
status="pending",
|
||||
@@ -323,8 +351,7 @@ class TestMarketplaceImportJobAPI:
|
||||
db.refresh(job)
|
||||
|
||||
response = client.delete(
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}",
|
||||
headers=auth_headers
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -352,7 +379,9 @@ class TestMarketplaceImportJobAPI:
|
||||
data = response.json()
|
||||
assert data["error_code"] == "INVALID_TOKEN"
|
||||
|
||||
def test_admin_can_access_all_jobs(self, client, admin_headers, test_marketplace_import_job):
|
||||
def test_admin_can_access_all_jobs(
|
||||
self, client, admin_headers, test_marketplace_import_job
|
||||
):
|
||||
"""Test that admin can access all import jobs"""
|
||||
response = client.get("/api/v1/marketplace/import-jobs", headers=admin_headers)
|
||||
|
||||
@@ -363,23 +392,28 @@ class TestMarketplaceImportJobAPI:
|
||||
job_ids = [job["job_id"] for job in data]
|
||||
assert test_marketplace_import_job.id in job_ids
|
||||
|
||||
def test_admin_can_view_any_job_status(self, client, admin_headers, test_marketplace_import_job):
|
||||
def test_admin_can_view_any_job_status(
|
||||
self, client, admin_headers, test_marketplace_import_job
|
||||
):
|
||||
"""Test that admin can view any job status"""
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/import-status/{test_marketplace_import_job.id}",
|
||||
headers=admin_headers
|
||||
headers=admin_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["job_id"] == test_marketplace_import_job.id
|
||||
|
||||
def test_admin_can_cancel_any_job(self, client, admin_headers, test_user, test_vendor, db):
|
||||
def test_admin_can_cancel_any_job(
|
||||
self, client, admin_headers, test_user, test_vendor, db
|
||||
):
|
||||
"""Test that admin can cancel any job"""
|
||||
# Create a pending job owned by different user
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
import uuid
|
||||
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
job = MarketplaceImportJob(
|
||||
status="pending",
|
||||
@@ -398,8 +432,7 @@ class TestMarketplaceImportJobAPI:
|
||||
db.refresh(job)
|
||||
|
||||
response = client.put(
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}/cancel",
|
||||
headers=admin_headers
|
||||
f"/api/v1/marketplace/import-jobs/{job.id}/cancel", headers=admin_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# tests/integration/api/v1/test_export.py
|
||||
import csv
|
||||
from io import StringIO
|
||||
import uuid
|
||||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,9 +13,13 @@ from models.database.marketplace_product import MarketplaceProduct
|
||||
@pytest.mark.performance # for the performance test
|
||||
class TestExportFunctionality:
|
||||
|
||||
def test_csv_export_basic_success(self, client, auth_headers, test_marketplace_product):
|
||||
def test_csv_export_basic_success(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test basic CSV export functionality successfully"""
|
||||
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv; charset=utf-8"
|
||||
@@ -26,13 +30,22 @@ class TestExportFunctionality:
|
||||
|
||||
# Check header row
|
||||
header = next(csv_reader)
|
||||
expected_fields = ["marketplace_product_id", "title", "description", "price", "marketplace"]
|
||||
expected_fields = [
|
||||
"marketplace_product_id",
|
||||
"title",
|
||||
"description",
|
||||
"price",
|
||||
"marketplace",
|
||||
]
|
||||
for field in expected_fields:
|
||||
assert field in header
|
||||
|
||||
# Verify test product appears in export
|
||||
csv_lines = csv_content.split('\n')
|
||||
test_product_found = any(test_marketplace_product.marketplace_product_id in line for line in csv_lines)
|
||||
csv_lines = csv_content.split("\n")
|
||||
test_product_found = any(
|
||||
test_marketplace_product.marketplace_product_id in line
|
||||
for line in csv_lines
|
||||
)
|
||||
assert test_product_found, "Test product should appear in CSV export"
|
||||
|
||||
def test_csv_export_with_marketplace_filter_success(self, client, auth_headers, db):
|
||||
@@ -43,12 +56,12 @@ class TestExportFunctionality:
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"EXP1_{unique_suffix}",
|
||||
title=f"Amazon MarketplaceProduct {unique_suffix}",
|
||||
marketplace="Amazon"
|
||||
marketplace="Amazon",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"EXP2_{unique_suffix}",
|
||||
title=f"eBay MarketplaceProduct {unique_suffix}",
|
||||
marketplace="eBay"
|
||||
marketplace="eBay",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -56,7 +69,8 @@ class TestExportFunctionality:
|
||||
db.commit()
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv?marketplace=Amazon", headers=auth_headers
|
||||
"/api/v1/marketplace/product/export-csv?marketplace=Amazon",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv; charset=utf-8"
|
||||
@@ -72,12 +86,12 @@ class TestExportFunctionality:
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"VENDOR1_{unique_suffix}",
|
||||
title=f"Vendor1 MarketplaceProduct {unique_suffix}",
|
||||
vendor_name="TestVendor1"
|
||||
vendor_name="TestVendor1",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"VENDOR2_{unique_suffix}",
|
||||
title=f"Vendor2 MarketplaceProduct {unique_suffix}",
|
||||
vendor_name="TestVendor2"
|
||||
vendor_name="TestVendor2",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -101,19 +115,19 @@ class TestExportFunctionality:
|
||||
marketplace_product_id=f"COMBO1_{unique_suffix}",
|
||||
title=f"Combo MarketplaceProduct 1 {unique_suffix}",
|
||||
marketplace="Amazon",
|
||||
vendor_name="TestVendor"
|
||||
vendor_name="TestVendor",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"COMBO2_{unique_suffix}",
|
||||
title=f"Combo MarketplaceProduct 2 {unique_suffix}",
|
||||
marketplace="eBay",
|
||||
vendor_name="TestVendor"
|
||||
vendor_name="TestVendor",
|
||||
),
|
||||
MarketplaceProduct(
|
||||
marketplace_product_id=f"COMBO3_{unique_suffix}",
|
||||
title=f"Combo MarketplaceProduct 3 {unique_suffix}",
|
||||
marketplace="Amazon",
|
||||
vendor_name="OtherVendor"
|
||||
vendor_name="OtherVendor",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -122,27 +136,27 @@ class TestExportFunctionality:
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?marketplace=Amazon&name=TestVendor",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
csv_content = response.content.decode("utf-8")
|
||||
assert f"COMBO1_{unique_suffix}" in csv_content # Matches both filters
|
||||
assert f"COMBO2_{unique_suffix}" not in csv_content # Wrong marketplace
|
||||
assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor
|
||||
assert f"COMBO3_{unique_suffix}" not in csv_content # Wrong vendor
|
||||
|
||||
def test_csv_export_no_results(self, client, auth_headers):
|
||||
"""Test CSV export with filters that return no results"""
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv?marketplace=NonexistentMarketplace12345",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv; charset=utf-8"
|
||||
|
||||
csv_content = response.content.decode("utf-8")
|
||||
csv_lines = csv_content.strip().split('\n')
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
# Should have header row even with no data
|
||||
assert len(csv_lines) >= 1
|
||||
# First line should be headers
|
||||
@@ -161,27 +175,32 @@ class TestExportFunctionality:
|
||||
title=f"Performance MarketplaceProduct {i}",
|
||||
marketplace="Performance",
|
||||
description=f"Performance test product {i}",
|
||||
price="10.99"
|
||||
price="10.99",
|
||||
)
|
||||
products.append(product)
|
||||
|
||||
# Add in batches to avoid memory issues
|
||||
for i in range(0, len(products), 50):
|
||||
batch = products[i:i + 50]
|
||||
batch = products[i : i + 50]
|
||||
db.add_all(batch)
|
||||
db.commit()
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv", headers=auth_headers
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv; charset=utf-8"
|
||||
assert execution_time < 10.0, f"Export took {execution_time:.2f} seconds, should be under 10s"
|
||||
assert (
|
||||
execution_time < 10.0
|
||||
), f"Export took {execution_time:.2f} seconds, should be under 10s"
|
||||
|
||||
# Verify content contains our test data
|
||||
csv_content = response.content.decode("utf-8")
|
||||
@@ -197,9 +216,13 @@ class TestExportFunctionality:
|
||||
assert data["error_code"] == "INVALID_TOKEN"
|
||||
assert data["status_code"] == 401
|
||||
|
||||
def test_csv_export_streaming_response_format(self, client, auth_headers, test_marketplace_product):
|
||||
def test_csv_export_streaming_response_format(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test that CSV export returns proper streaming response with correct headers"""
|
||||
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/csv; charset=utf-8"
|
||||
@@ -217,23 +240,27 @@ class TestExportFunctionality:
|
||||
# Create product with special characters that might break CSV
|
||||
product = MarketplaceProduct(
|
||||
marketplace_product_id=f"SPECIAL_{unique_suffix}",
|
||||
title=f'MarketplaceProduct with quotes and commas {unique_suffix}', # Simplified to avoid CSV escaping issues
|
||||
title=f"MarketplaceProduct with quotes and commas {unique_suffix}", # Simplified to avoid CSV escaping issues
|
||||
description=f"Description with special chars {unique_suffix}",
|
||||
marketplace="Test Market",
|
||||
price="19.99"
|
||||
price="19.99",
|
||||
)
|
||||
|
||||
db.add(product)
|
||||
db.commit()
|
||||
|
||||
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
csv_content = response.content.decode("utf-8")
|
||||
|
||||
# Verify our test product appears in the CSV content
|
||||
assert f"SPECIAL_{unique_suffix}" in csv_content
|
||||
assert f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content
|
||||
assert (
|
||||
f"MarketplaceProduct with quotes and commas {unique_suffix}" in csv_content
|
||||
)
|
||||
assert "Test Market" in csv_content
|
||||
assert "19.99" in csv_content
|
||||
|
||||
@@ -243,7 +270,12 @@ class TestExportFunctionality:
|
||||
header = next(csv_reader)
|
||||
|
||||
# Verify header contains expected fields
|
||||
expected_fields = ["marketplace_product_id", "title", "marketplace", "price"]
|
||||
expected_fields = [
|
||||
"marketplace_product_id",
|
||||
"title",
|
||||
"marketplace",
|
||||
"price",
|
||||
]
|
||||
for field in expected_fields:
|
||||
assert field in header
|
||||
|
||||
@@ -264,12 +296,15 @@ class TestExportFunctionality:
|
||||
|
||||
assert parsed_successfully, "CSV should be parseable despite special characters"
|
||||
|
||||
def test_csv_export_error_handling_service_failure(self, client, auth_headers, monkeypatch):
|
||||
def test_csv_export_error_handling_service_failure(
|
||||
self, client, auth_headers, monkeypatch
|
||||
):
|
||||
"""Test CSV export handles service failures gracefully"""
|
||||
|
||||
# Mock the service to raise an exception
|
||||
def mock_generate_csv_export(*args, **kwargs):
|
||||
from app.exceptions import ValidationException
|
||||
|
||||
raise ValidationException("Mocked service failure")
|
||||
|
||||
# This would require access to your service instance to mock properly
|
||||
@@ -293,7 +328,9 @@ class TestExportFunctionality:
|
||||
def test_csv_export_filename_generation(self, client, auth_headers):
|
||||
"""Test CSV export generates appropriate filenames based on filters"""
|
||||
# Test basic export filename
|
||||
response = client.get("/api/v1/marketplace/product/export-csv", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
content_disposition = response.headers.get("content-disposition", "")
|
||||
@@ -302,7 +339,7 @@ class TestExportFunctionality:
|
||||
# Test with marketplace filter
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/export-csv?marketplace=Amazon",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ class TestMarketplaceProductsAPI:
|
||||
assert data["products"] == []
|
||||
assert data["total"] == 0
|
||||
|
||||
def test_get_products_with_data(self, client, auth_headers, test_marketplace_product):
|
||||
def test_get_products_with_data(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test getting products with data"""
|
||||
response = client.get("/api/v1/marketplace/product", headers=auth_headers)
|
||||
|
||||
@@ -25,25 +27,39 @@ class TestMarketplaceProductsAPI:
|
||||
assert len(data["products"]) >= 1
|
||||
assert data["total"] >= 1
|
||||
# Find our test product
|
||||
test_product_found = any(p["marketplace_product_id"] == test_marketplace_product.marketplace_product_id for p in data["products"])
|
||||
test_product_found = any(
|
||||
p["marketplace_product_id"]
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
for p in data["products"]
|
||||
)
|
||||
assert test_product_found
|
||||
|
||||
def test_get_products_with_filters(self, client, auth_headers, test_marketplace_product):
|
||||
def test_get_products_with_filters(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test filtering products"""
|
||||
# Test brand filter
|
||||
response = client.get(f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?brand={test_marketplace_product.brand}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
# Test marketplace filter
|
||||
response = client.get(f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?marketplace={test_marketplace_product.marketplace}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
# Test search
|
||||
response = client.get("/api/v1/marketplace/product?search=Test", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?search=Test", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
@@ -71,7 +87,9 @@ class TestMarketplaceProductsAPI:
|
||||
assert data["title"] == "New MarketplaceProduct"
|
||||
assert data["marketplace"] == "Amazon"
|
||||
|
||||
def test_create_product_duplicate_id_returns_conflict(self, client, auth_headers, test_marketplace_product):
|
||||
def test_create_product_duplicate_id_returns_conflict(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test creating product with duplicate ID returns MarketplaceProductAlreadyExistsException"""
|
||||
product_data = {
|
||||
"marketplace_product_id": test_marketplace_product.marketplace_product_id,
|
||||
@@ -93,7 +111,10 @@ class TestMarketplaceProductsAPI:
|
||||
assert data["error_code"] == "PRODUCT_ALREADY_EXISTS"
|
||||
assert data["status_code"] == 409
|
||||
assert test_marketplace_product.marketplace_product_id in data["message"]
|
||||
assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
data["details"]["marketplace_product_id"]
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
def test_create_product_missing_title_validation_error(self, client, auth_headers):
|
||||
"""Test creating product without title returns ValidationException"""
|
||||
@@ -114,7 +135,9 @@ class TestMarketplaceProductsAPI:
|
||||
assert "MarketplaceProduct title is required" in data["message"]
|
||||
assert data["details"]["field"] == "title"
|
||||
|
||||
def test_create_product_missing_product_id_validation_error(self, client, auth_headers):
|
||||
def test_create_product_missing_product_id_validation_error(
|
||||
self, client, auth_headers
|
||||
):
|
||||
"""Test creating product without marketplace_product_id returns ValidationException"""
|
||||
product_data = {
|
||||
"marketplace_product_id": "", # Empty product ID
|
||||
@@ -191,20 +214,28 @@ class TestMarketplaceProductsAPI:
|
||||
assert "Request validation failed" in data["message"]
|
||||
assert "validation_errors" in data["details"]
|
||||
|
||||
def test_get_product_by_id_success(self, client, auth_headers, test_marketplace_product):
|
||||
def test_get_product_by_id_success(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test getting specific product successfully"""
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers
|
||||
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["product"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
data["product"]["marketplace_product_id"]
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
assert data["product"]["title"] == test_marketplace_product.title
|
||||
|
||||
def test_get_nonexistent_product_returns_not_found(self, client, auth_headers):
|
||||
"""Test getting nonexistent product returns MarketplaceProductNotFoundException"""
|
||||
response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
@@ -214,7 +245,9 @@ class TestMarketplaceProductsAPI:
|
||||
assert data["details"]["resource_type"] == "MarketplaceProduct"
|
||||
assert data["details"]["identifier"] == "NONEXISTENT"
|
||||
|
||||
def test_update_product_success(self, client, auth_headers, test_marketplace_product):
|
||||
def test_update_product_success(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test updating product successfully"""
|
||||
update_data = {"title": "Updated MarketplaceProduct Title", "price": "25.99"}
|
||||
|
||||
@@ -247,7 +280,9 @@ class TestMarketplaceProductsAPI:
|
||||
assert data["details"]["resource_type"] == "MarketplaceProduct"
|
||||
assert data["details"]["identifier"] == "NONEXISTENT"
|
||||
|
||||
def test_update_product_empty_title_validation_error(self, client, auth_headers, test_marketplace_product):
|
||||
def test_update_product_empty_title_validation_error(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test updating product with empty title returns MarketplaceProductValidationException"""
|
||||
update_data = {"title": ""}
|
||||
|
||||
@@ -264,7 +299,9 @@ class TestMarketplaceProductsAPI:
|
||||
assert "MarketplaceProduct title cannot be empty" in data["message"]
|
||||
assert data["details"]["field"] == "title"
|
||||
|
||||
def test_update_product_invalid_gtin_data_error(self, client, auth_headers, test_marketplace_product):
|
||||
def test_update_product_invalid_gtin_data_error(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test updating product with invalid GTIN returns InvalidMarketplaceProductDataException"""
|
||||
update_data = {"gtin": "invalid_gtin"}
|
||||
|
||||
@@ -281,7 +318,9 @@ class TestMarketplaceProductsAPI:
|
||||
assert "Invalid GTIN format" in data["message"]
|
||||
assert data["details"]["field"] == "gtin"
|
||||
|
||||
def test_update_product_invalid_price_data_error(self, client, auth_headers, test_marketplace_product):
|
||||
def test_update_product_invalid_price_data_error(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test updating product with invalid price returns InvalidMarketplaceProductDataException"""
|
||||
update_data = {"price": "invalid_price"}
|
||||
|
||||
@@ -298,10 +337,13 @@ class TestMarketplaceProductsAPI:
|
||||
assert "Invalid price format" in data["message"]
|
||||
assert data["details"]["field"] == "price"
|
||||
|
||||
def test_delete_product_success(self, client, auth_headers, test_marketplace_product):
|
||||
def test_delete_product_success(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test deleting product successfully"""
|
||||
response = client.delete(
|
||||
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}", headers=auth_headers
|
||||
f"/api/v1/marketplace/product/{test_marketplace_product.marketplace_product_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -309,7 +351,9 @@ class TestMarketplaceProductsAPI:
|
||||
|
||||
def test_delete_nonexistent_product_returns_not_found(self, client, auth_headers):
|
||||
"""Test deleting nonexistent product returns MarketplaceProductNotFoundException"""
|
||||
response = client.delete("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
|
||||
response = client.delete(
|
||||
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
@@ -331,7 +375,9 @@ class TestMarketplaceProductsAPI:
|
||||
def test_exception_structure_consistency(self, client, auth_headers):
|
||||
"""Test that all exceptions follow the consistent WizamartException structure"""
|
||||
# Test with a known error case
|
||||
response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.vendor import Vendor
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.api
|
||||
@pytest.mark.database
|
||||
@@ -13,6 +14,7 @@ class TestPagination:
|
||||
def test_product_pagination_success(self, client, auth_headers, db):
|
||||
"""Test pagination for product listing successfully"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create multiple products
|
||||
@@ -29,7 +31,9 @@ class TestPagination:
|
||||
db.commit()
|
||||
|
||||
# Test first page
|
||||
response = client.get("/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=10&skip=0", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["products"]) == 10
|
||||
@@ -38,21 +42,29 @@ class TestPagination:
|
||||
assert data["limit"] == 10
|
||||
|
||||
# Test second page
|
||||
response = client.get("/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=10&skip=10", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["products"]) == 10
|
||||
assert data["skip"] == 10
|
||||
|
||||
# Test last page (should have remaining products)
|
||||
response = client.get("/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=10&skip=20", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["products"]) >= 5 # At least 5 remaining from our test set
|
||||
|
||||
def test_pagination_boundary_negative_skip_validation_error(self, client, auth_headers):
|
||||
def test_pagination_boundary_negative_skip_validation_error(
|
||||
self, client, auth_headers
|
||||
):
|
||||
"""Test negative skip parameter returns ValidationException"""
|
||||
response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?skip=-1", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
@@ -61,9 +73,13 @@ class TestPagination:
|
||||
assert "Request validation failed" in data["message"]
|
||||
assert "validation_errors" in data["details"]
|
||||
|
||||
def test_pagination_boundary_zero_limit_validation_error(self, client, auth_headers):
|
||||
def test_pagination_boundary_zero_limit_validation_error(
|
||||
self, client, auth_headers
|
||||
):
|
||||
"""Test zero limit parameter returns ValidationException"""
|
||||
response = client.get("/api/v1/marketplace/product?limit=0", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=0", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
@@ -71,9 +87,13 @@ class TestPagination:
|
||||
assert data["status_code"] == 422
|
||||
assert "Request validation failed" in data["message"]
|
||||
|
||||
def test_pagination_boundary_excessive_limit_validation_error(self, client, auth_headers):
|
||||
def test_pagination_boundary_excessive_limit_validation_error(
|
||||
self, client, auth_headers
|
||||
):
|
||||
"""Test excessive limit parameter returns ValidationException"""
|
||||
response = client.get("/api/v1/marketplace/product?limit=10000", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=10000", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
@@ -84,7 +104,9 @@ class TestPagination:
|
||||
def test_pagination_beyond_available_records(self, client, auth_headers, db):
|
||||
"""Test pagination beyond available records returns empty results"""
|
||||
# Test skip beyond available records
|
||||
response = client.get("/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?skip=10000&limit=10", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -96,6 +118,7 @@ class TestPagination:
|
||||
def test_pagination_with_filters(self, client, auth_headers, db):
|
||||
"""Test pagination combined with filtering"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create products with same brand for filtering
|
||||
@@ -115,7 +138,7 @@ class TestPagination:
|
||||
# Test first page with filter
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=0",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -124,14 +147,18 @@ class TestPagination:
|
||||
assert data["total"] >= 15 # At least our test products
|
||||
|
||||
# Verify all products have the filtered brand
|
||||
test_products = [p for p in data["products"] if p["marketplace_product_id"].endswith(unique_suffix)]
|
||||
test_products = [
|
||||
p
|
||||
for p in data["products"]
|
||||
if p["marketplace_product_id"].endswith(unique_suffix)
|
||||
]
|
||||
for product in test_products:
|
||||
assert product["brand"] == "FilterBrand"
|
||||
|
||||
# Test second page with same filter
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=FilterBrand&limit=5&skip=5",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -152,6 +179,7 @@ class TestPagination:
|
||||
def test_pagination_consistency(self, client, auth_headers, db):
|
||||
"""Test pagination consistency across multiple requests"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create products with predictable ordering
|
||||
@@ -168,14 +196,22 @@ class TestPagination:
|
||||
db.commit()
|
||||
|
||||
# Get first page
|
||||
response1 = client.get("/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers)
|
||||
response1 = client.get(
|
||||
"/api/v1/marketplace/product?limit=5&skip=0", headers=auth_headers
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
first_page_ids = [p["marketplace_product_id"] for p in response1.json()["products"]]
|
||||
first_page_ids = [
|
||||
p["marketplace_product_id"] for p in response1.json()["products"]
|
||||
]
|
||||
|
||||
# Get second page
|
||||
response2 = client.get("/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers)
|
||||
response2 = client.get(
|
||||
"/api/v1/marketplace/product?limit=5&skip=5", headers=auth_headers
|
||||
)
|
||||
assert response2.status_code == 200
|
||||
second_page_ids = [p["marketplace_product_id"] for p in response2.json()["products"]]
|
||||
second_page_ids = [
|
||||
p["marketplace_product_id"] for p in response2.json()["products"]
|
||||
]
|
||||
|
||||
# Verify no overlap between pages
|
||||
overlap = set(first_page_ids) & set(second_page_ids)
|
||||
@@ -184,11 +220,13 @@ class TestPagination:
|
||||
def test_vendor_pagination_success(self, client, admin_headers, db, test_user):
|
||||
"""Test pagination for vendor listing successfully"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create multiple vendors for pagination testing
|
||||
from models.database.vendor import Vendor
|
||||
vendors =[]
|
||||
|
||||
vendors = []
|
||||
for i in range(15):
|
||||
vendor = Vendor(
|
||||
vendor_code=f"PAGEVENDOR{i:03d}_{unique_suffix}",
|
||||
@@ -202,9 +240,7 @@ class TestPagination:
|
||||
db.commit()
|
||||
|
||||
# Test first page (assuming admin endpoint exists)
|
||||
response = client.get(
|
||||
"/api/v1/vendor?limit=5&skip=0", headers=admin_headers
|
||||
)
|
||||
response = client.get("/api/v1/vendor?limit=5&skip=0", headers=admin_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["vendors"]) == 5
|
||||
@@ -215,10 +251,12 @@ class TestPagination:
|
||||
def test_inventory_pagination_success(self, client, auth_headers, db):
|
||||
"""Test pagination for inventory listing successfully"""
|
||||
import uuid
|
||||
|
||||
unique_suffix = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create multiple inventory entries
|
||||
from models.database.inventory import Inventory
|
||||
|
||||
inventory_entries = []
|
||||
for i in range(20):
|
||||
inventory = Inventory(
|
||||
@@ -249,7 +287,9 @@ class TestPagination:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
response = client.get("/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?skip=1000&limit=10", headers=auth_headers
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -262,19 +302,25 @@ class TestPagination:
|
||||
def test_pagination_with_invalid_parameters_types(self, client, auth_headers):
|
||||
"""Test pagination with invalid parameter types returns ValidationException"""
|
||||
# Test non-numeric skip
|
||||
response = client.get("/api/v1/marketplace/product?skip=invalid", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?skip=invalid", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["error_code"] == "VALIDATION_ERROR"
|
||||
|
||||
# Test non-numeric limit
|
||||
response = client.get("/api/v1/marketplace/product?limit=invalid", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=invalid", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert data["error_code"] == "VALIDATION_ERROR"
|
||||
|
||||
# Test float values (should be converted or rejected)
|
||||
response = client.get("/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?skip=10.5&limit=5.5", headers=auth_headers
|
||||
)
|
||||
assert response.status_code in [200, 422] # Depends on implementation
|
||||
|
||||
def test_empty_dataset_pagination(self, client, auth_headers):
|
||||
@@ -282,7 +328,7 @@ class TestPagination:
|
||||
# Use a filter that should return no results
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?brand=NonexistentBrand999&limit=10&skip=0",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -294,7 +340,9 @@ class TestPagination:
|
||||
|
||||
def test_exception_structure_in_pagination_errors(self, client, auth_headers):
|
||||
"""Test that pagination validation errors follow consistent exception structure"""
|
||||
response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?skip=-1", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# tests/integration/api/v1/test_stats_endpoints.py
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.api
|
||||
@pytest.mark.stats
|
||||
@@ -18,7 +19,9 @@ class TestStatsAPI:
|
||||
assert "unique_vendors" in data
|
||||
assert data["total_products"] >= 1
|
||||
|
||||
def test_get_marketplace_stats(self, client, auth_headers, test_marketplace_product):
|
||||
def test_get_marketplace_stats(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test getting marketplace statistics"""
|
||||
response = client.get("/api/v1/stats/marketplace", headers=auth_headers)
|
||||
|
||||
|
||||
@@ -23,7 +23,9 @@ class TestVendorsAPI:
|
||||
assert data["name"] == "New Vendor"
|
||||
assert data["is_active"] is True
|
||||
|
||||
def test_create_vendor_duplicate_code_returns_conflict(self, client, auth_headers, test_vendor):
|
||||
def test_create_vendor_duplicate_code_returns_conflict(
|
||||
self, client, auth_headers, test_vendor
|
||||
):
|
||||
"""Test creating vendor with duplicate code returns VendorAlreadyExistsException"""
|
||||
vendor_data = {
|
||||
"vendor_code": test_vendor.vendor_code,
|
||||
@@ -40,7 +42,9 @@ class TestVendorsAPI:
|
||||
assert test_vendor.vendor_code in data["message"]
|
||||
assert data["details"]["vendor_code"] == test_vendor.vendor_code
|
||||
|
||||
def test_create_vendor_missing_vendor_code_validation_error(self, client, auth_headers):
|
||||
def test_create_vendor_missing_vendor_code_validation_error(
|
||||
self, client, auth_headers
|
||||
):
|
||||
"""Test creating vendor without vendor_code returns ValidationException"""
|
||||
vendor_data = {
|
||||
"name": "Vendor without Code",
|
||||
@@ -56,7 +60,9 @@ class TestVendorsAPI:
|
||||
assert "Request validation failed" in data["message"]
|
||||
assert "validation_errors" in data["details"]
|
||||
|
||||
def test_create_vendor_empty_vendor_name_validation_error(self, client, auth_headers):
|
||||
def test_create_vendor_empty_vendor_name_validation_error(
|
||||
self, client, auth_headers
|
||||
):
|
||||
"""Test creating vendor with empty name returns VendorValidationException"""
|
||||
vendor_data = {
|
||||
"vendor_code": "EMPTYNAME",
|
||||
@@ -73,7 +79,9 @@ class TestVendorsAPI:
|
||||
assert "Vendor name is required" in data["message"]
|
||||
assert data["details"]["field"] == "name"
|
||||
|
||||
def test_create_vendor_max_vendors_reached_business_logic_error(self, client, auth_headers, db, test_user):
|
||||
def test_create_vendor_max_vendors_reached_business_logic_error(
|
||||
self, client, auth_headers, db, test_user
|
||||
):
|
||||
"""Test creating vendor when max vendors reached returns MaxVendorsReachedException"""
|
||||
# This test would require creating the maximum allowed vendors first
|
||||
# The exact implementation depends on your business rules
|
||||
@@ -94,8 +102,10 @@ class TestVendorsAPI:
|
||||
assert data["total"] >= 1
|
||||
assert len(data["vendors"]) >= 1
|
||||
|
||||
# Find our test vendor
|
||||
test_vendor_found = any(s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"])
|
||||
# Find our test vendor
|
||||
test_vendor_found = any(
|
||||
s["vendor_code"] == test_vendor.vendor_code for s in data["vendors"]
|
||||
)
|
||||
assert test_vendor_found
|
||||
|
||||
def test_get_vendors_with_filters(self, client, auth_headers, test_vendor):
|
||||
@@ -105,7 +115,7 @@ class TestVendorsAPI:
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for vendor in data["vendors"]:
|
||||
assert vendor ["is_active"] is True
|
||||
assert vendor["is_active"] is True
|
||||
|
||||
# Test verified_only filter
|
||||
response = client.get("/api/v1/vendor?verified_only=true", headers=auth_headers)
|
||||
@@ -135,7 +145,9 @@ class TestVendorsAPI:
|
||||
assert data["details"]["resource_type"] == "Vendor"
|
||||
assert data["details"]["identifier"] == "NONEXISTENT"
|
||||
|
||||
def test_get_vendor_unauthorized_access(self, client, auth_headers, test_vendor, other_user, db):
|
||||
def test_get_vendor_unauthorized_access(
|
||||
self, client, auth_headers, test_vendor, other_user, db
|
||||
):
|
||||
"""Test accessing vendor owned by another user returns UnauthorizedVendorAccessException"""
|
||||
# Change vendor owner to other user AND make it unverified/inactive
|
||||
# so that non-owner users cannot access it
|
||||
@@ -154,7 +166,9 @@ class TestVendorsAPI:
|
||||
assert test_vendor.vendor_code in data["message"]
|
||||
assert data["details"]["vendor_code"] == test_vendor.vendor_code
|
||||
|
||||
def test_get_vendor_unauthorized_access_with_inactive_vendor(self, client, auth_headers, inactive_vendor):
|
||||
def test_get_vendor_unauthorized_access_with_inactive_vendor(
|
||||
self, client, auth_headers, inactive_vendor
|
||||
):
|
||||
"""Test accessing inactive vendor owned by another user returns UnauthorizedVendorAccessException"""
|
||||
# inactive_vendor fixture already creates an unverified, inactive vendor owned by other_user
|
||||
response = client.get(
|
||||
@@ -168,7 +182,9 @@ class TestVendorsAPI:
|
||||
assert inactive_vendor.vendor_code in data["message"]
|
||||
assert data["details"]["vendor_code"] == inactive_vendor.vendor_code
|
||||
|
||||
def test_get_vendor_public_access_allowed(self, client, auth_headers, verified_vendor):
|
||||
def test_get_vendor_public_access_allowed(
|
||||
self, client, auth_headers, verified_vendor
|
||||
):
|
||||
"""Test accessing verified vendor owned by another user is allowed (public access)"""
|
||||
# verified_vendor fixture creates a verified, active vendor owned by other_user
|
||||
# This should allow public access per your business logic
|
||||
@@ -181,7 +197,9 @@ class TestVendorsAPI:
|
||||
assert data["vendor_code"] == verified_vendor.vendor_code
|
||||
assert data["name"] == verified_vendor.name
|
||||
|
||||
def test_add_product_to_vendor_success(self, client, auth_headers, test_vendor, unique_product):
|
||||
def test_add_product_to_vendor_success(
|
||||
self, client, auth_headers, test_vendor, unique_product
|
||||
):
|
||||
"""Test adding product to vendor successfully"""
|
||||
product_data = {
|
||||
"marketplace_product_id": unique_product.marketplace_product_id, # Use string marketplace_product_id, not database id
|
||||
@@ -193,7 +211,7 @@ class TestVendorsAPI:
|
||||
response = client.post(
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
|
||||
headers=auth_headers,
|
||||
json=product_data
|
||||
json=product_data,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -207,10 +225,15 @@ class TestVendorsAPI:
|
||||
|
||||
# MarketplaceProduct details are nested in the 'marketplace_product' field
|
||||
assert "marketplace_product" in data
|
||||
assert data["marketplace_product"]["marketplace_product_id"] == unique_product.marketplace_product_id
|
||||
assert (
|
||||
data["marketplace_product"]["marketplace_product_id"]
|
||||
== unique_product.marketplace_product_id
|
||||
)
|
||||
assert data["marketplace_product"]["id"] == unique_product.id
|
||||
|
||||
def test_add_product_to_vendor_already_exists_conflict(self, client, auth_headers, test_vendor, test_product):
|
||||
def test_add_product_to_vendor_already_exists_conflict(
|
||||
self, client, auth_headers, test_vendor, test_product
|
||||
):
|
||||
"""Test adding product that already exists in vendor returns ProductAlreadyExistsException"""
|
||||
# test_product fixture already creates a relationship, get the marketplace_product_id string
|
||||
existing_product = test_product.marketplace_product
|
||||
@@ -223,7 +246,7 @@ class TestVendorsAPI:
|
||||
response = client.post(
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
|
||||
headers=auth_headers,
|
||||
json=product_data
|
||||
json=product_data,
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
@@ -233,7 +256,9 @@ class TestVendorsAPI:
|
||||
assert test_vendor.vendor_code in data["message"]
|
||||
assert existing_product.marketplace_product_id in data["message"]
|
||||
|
||||
def test_add_nonexistent_product_to_vendor_not_found(self, client, auth_headers, test_vendor):
|
||||
def test_add_nonexistent_product_to_vendor_not_found(
|
||||
self, client, auth_headers, test_vendor
|
||||
):
|
||||
"""Test adding nonexistent product to vendor returns MarketplaceProductNotFoundException"""
|
||||
product_data = {
|
||||
"marketplace_product_id": "NONEXISTENT_PRODUCT", # Use string marketplace_product_id that doesn't exist
|
||||
@@ -243,7 +268,7 @@ class TestVendorsAPI:
|
||||
response = client.post(
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
|
||||
headers=auth_headers,
|
||||
json=product_data
|
||||
json=product_data,
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -252,11 +277,12 @@ class TestVendorsAPI:
|
||||
assert data["status_code"] == 404
|
||||
assert "NONEXISTENT_PRODUCT" in data["message"]
|
||||
|
||||
def test_get_products_success(self, client, auth_headers, test_vendor, test_product):
|
||||
def test_get_products_success(
|
||||
self, client, auth_headers, test_vendor, test_product
|
||||
):
|
||||
"""Test getting vendor products successfully"""
|
||||
response = client.get(
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
|
||||
headers=auth_headers
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -271,22 +297,21 @@ class TestVendorsAPI:
|
||||
# Test active_only filter
|
||||
response = client.get(
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products?active_only=true",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test featured_only filter
|
||||
response = client.get(
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products?featured_only=true",
|
||||
headers=auth_headers
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_products_from_nonexistent_vendor_not_found(self, client, auth_headers):
|
||||
"""Test getting products from nonexistent vendor returns VendorNotFoundException"""
|
||||
response = client.get(
|
||||
"/api/v1/vendor/NONEXISTENT/products",
|
||||
headers=auth_headers
|
||||
"/api/v1/vendor/NONEXISTENT/products", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -295,7 +320,9 @@ class TestVendorsAPI:
|
||||
assert data["status_code"] == 404
|
||||
assert "NONEXISTENT" in data["message"]
|
||||
|
||||
def test_vendor_not_active_business_logic_error(self, client, auth_headers, test_vendor, db):
|
||||
def test_vendor_not_active_business_logic_error(
|
||||
self, client, auth_headers, test_vendor, db
|
||||
):
|
||||
"""Test accessing inactive vendor returns VendorNotActiveException (if enforced)"""
|
||||
# Set vendor to inactive
|
||||
test_vendor.is_active = False
|
||||
@@ -313,7 +340,9 @@ class TestVendorsAPI:
|
||||
assert data["status_code"] == 400
|
||||
assert test_vendor.vendor_code in data["message"]
|
||||
|
||||
def test_vendor_not_verified_business_logic_error(self, client, auth_headers, test_vendor, db):
|
||||
def test_vendor_not_verified_business_logic_error(
|
||||
self, client, auth_headers, test_vendor, db
|
||||
):
|
||||
"""Test operations requiring verification returns VendorNotVerifiedException (if enforced)"""
|
||||
# Set vendor to unverified
|
||||
test_vendor.is_verified = False
|
||||
@@ -328,7 +357,7 @@ class TestVendorsAPI:
|
||||
response = client.post(
|
||||
f"/api/v1/vendor/{test_vendor.vendor_code}/products",
|
||||
headers=auth_headers,
|
||||
json=product_data
|
||||
json=product_data,
|
||||
)
|
||||
|
||||
# If your service requires verification for adding products
|
||||
|
||||
@@ -10,8 +10,9 @@ These tests verify that:
|
||||
5. Vendor context middleware works correctly with API authentication
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from jose import jwt
|
||||
|
||||
|
||||
@@ -26,7 +27,9 @@ class TestVendorAPIAuthentication:
|
||||
# Authentication Tests - /api/v1/vendor/auth/me
|
||||
# ========================================================================
|
||||
|
||||
def test_vendor_auth_me_success(self, client, vendor_user_headers, test_vendor_user):
|
||||
def test_vendor_auth_me_success(
|
||||
self, client, vendor_user_headers, test_vendor_user
|
||||
):
|
||||
"""Test /auth/me endpoint with valid vendor user token"""
|
||||
response = client.get("/api/v1/vendor/auth/me", headers=vendor_user_headers)
|
||||
|
||||
@@ -50,7 +53,7 @@ class TestVendorAPIAuthentication:
|
||||
"""Test /auth/me endpoint with invalid token format"""
|
||||
response = client.get(
|
||||
"/api/v1/vendor/auth/me",
|
||||
headers={"Authorization": "Bearer invalid_token_here"}
|
||||
headers={"Authorization": "Bearer invalid_token_here"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -66,7 +69,9 @@ class TestVendorAPIAuthentication:
|
||||
assert data["error_code"] == "FORBIDDEN"
|
||||
assert "Admin users cannot access vendor API" in data["message"]
|
||||
|
||||
def test_vendor_auth_me_with_regular_user_token(self, client, auth_headers, test_user):
|
||||
def test_vendor_auth_me_with_regular_user_token(
|
||||
self, client, auth_headers, test_user
|
||||
):
|
||||
"""Test /auth/me endpoint rejects regular users"""
|
||||
response = client.get("/api/v1/vendor/auth/me", headers=auth_headers)
|
||||
|
||||
@@ -88,14 +93,12 @@ class TestVendorAPIAuthentication:
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(
|
||||
expired_payload,
|
||||
auth_manager.secret_key,
|
||||
algorithm=auth_manager.algorithm
|
||||
expired_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/vendor/auth/me",
|
||||
headers={"Authorization": f"Bearer {expired_token}"}
|
||||
headers={"Authorization": f"Bearer {expired_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -111,8 +114,7 @@ class TestVendorAPIAuthentication:
|
||||
):
|
||||
"""Test dashboard stats with valid vendor authentication"""
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=vendor_user_headers
|
||||
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -131,10 +133,7 @@ class TestVendorAPIAuthentication:
|
||||
|
||||
def test_vendor_dashboard_stats_with_admin(self, client, admin_headers):
|
||||
"""Test dashboard stats rejects admin users"""
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=admin_headers
|
||||
)
|
||||
response = client.get("/api/v1/vendor/dashboard/stats", headers=admin_headers)
|
||||
|
||||
assert response.status_code == 403
|
||||
data = response.json()
|
||||
@@ -145,10 +144,7 @@ class TestVendorAPIAuthentication:
|
||||
# Login to get session cookie
|
||||
login_response = client.post(
|
||||
"/api/v1/vendor/auth/login",
|
||||
json={
|
||||
"username": test_vendor_user.username,
|
||||
"password": "vendorpass123"
|
||||
}
|
||||
json={"username": test_vendor_user.username, "password": "vendorpass123"},
|
||||
)
|
||||
assert login_response.status_code == 200
|
||||
|
||||
@@ -169,10 +165,7 @@ class TestVendorAPIAuthentication:
|
||||
# Get a valid session by logging in
|
||||
login_response = client.post(
|
||||
"/api/v1/vendor/auth/login",
|
||||
json={
|
||||
"username": test_vendor_user.username,
|
||||
"password": "vendorpass123"
|
||||
}
|
||||
json={"username": test_vendor_user.username, "password": "vendorpass123"},
|
||||
)
|
||||
assert login_response.status_code == 200
|
||||
|
||||
@@ -191,8 +184,9 @@ class TestVendorAPIAuthentication:
|
||||
response = client.get(endpoint)
|
||||
|
||||
# All should fail with 401 (header required)
|
||||
assert response.status_code == 401, \
|
||||
f"Endpoint {endpoint} should reject cookie-only auth"
|
||||
assert (
|
||||
response.status_code == 401
|
||||
), f"Endpoint {endpoint} should reject cookie-only auth"
|
||||
|
||||
# ========================================================================
|
||||
# Role-Based Access Control Tests
|
||||
@@ -211,13 +205,15 @@ class TestVendorAPIAuthentication:
|
||||
for endpoint in endpoints:
|
||||
# Test with regular user token
|
||||
response = client.get(endpoint, headers=auth_headers)
|
||||
assert response.status_code == 403, \
|
||||
f"Endpoint {endpoint} should reject regular users"
|
||||
assert (
|
||||
response.status_code == 403
|
||||
), f"Endpoint {endpoint} should reject regular users"
|
||||
|
||||
# Test with admin token
|
||||
response = client.get(endpoint, headers=admin_headers)
|
||||
assert response.status_code == 403, \
|
||||
f"Endpoint {endpoint} should reject admin users"
|
||||
assert (
|
||||
response.status_code == 403
|
||||
), f"Endpoint {endpoint} should reject admin users"
|
||||
|
||||
def test_vendor_api_accepts_only_vendor_role(
|
||||
self, client, vendor_user_headers, test_vendor_user
|
||||
@@ -229,8 +225,10 @@ class TestVendorAPIAuthentication:
|
||||
|
||||
for endpoint in endpoints:
|
||||
response = client.get(endpoint, headers=vendor_user_headers)
|
||||
assert response.status_code in [200, 404], \
|
||||
f"Endpoint {endpoint} should accept vendor users (got {response.status_code})"
|
||||
assert response.status_code in [
|
||||
200,
|
||||
404,
|
||||
], f"Endpoint {endpoint} should accept vendor users (got {response.status_code})"
|
||||
|
||||
# ========================================================================
|
||||
# Token Validation Tests
|
||||
@@ -248,8 +246,9 @@ class TestVendorAPIAuthentication:
|
||||
|
||||
for headers in malformed_headers:
|
||||
response = client.get("/api/v1/vendor/auth/me", headers=headers)
|
||||
assert response.status_code == 401, \
|
||||
f"Should reject malformed header: {headers}"
|
||||
assert (
|
||||
response.status_code == 401
|
||||
), f"Should reject malformed header: {headers}"
|
||||
|
||||
def test_token_with_missing_claims(self, client, auth_manager):
|
||||
"""Test token missing required claims"""
|
||||
@@ -261,14 +260,12 @@ class TestVendorAPIAuthentication:
|
||||
}
|
||||
|
||||
invalid_token = jwt.encode(
|
||||
invalid_payload,
|
||||
auth_manager.secret_key,
|
||||
algorithm=auth_manager.algorithm
|
||||
invalid_payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/vendor/auth/me",
|
||||
headers={"Authorization": f"Bearer {invalid_token}"}
|
||||
headers={"Authorization": f"Bearer {invalid_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -298,9 +295,7 @@ class TestVendorAPIAuthentication:
|
||||
db.add(test_vendor_user)
|
||||
db.commit()
|
||||
|
||||
def test_concurrent_requests_with_same_token(
|
||||
self, client, vendor_user_headers
|
||||
):
|
||||
def test_concurrent_requests_with_same_token(self, client, vendor_user_headers):
|
||||
"""Test that the same token can be used for multiple concurrent requests"""
|
||||
# Make multiple requests with the same token
|
||||
responses = []
|
||||
@@ -314,10 +309,7 @@ class TestVendorAPIAuthentication:
|
||||
|
||||
def test_vendor_api_with_empty_authorization_header(self, client):
|
||||
"""Test vendor API with empty Authorization header value"""
|
||||
response = client.get(
|
||||
"/api/v1/vendor/auth/me",
|
||||
headers={"Authorization": ""}
|
||||
)
|
||||
response = client.get("/api/v1/vendor/auth/me", headers={"Authorization": ""})
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -328,17 +320,12 @@ class TestVendorAPIAuthentication:
|
||||
class TestVendorAPIConsistency:
|
||||
"""Test that all vendor API endpoints use consistent authentication"""
|
||||
|
||||
def test_all_vendor_endpoints_require_header_auth(
|
||||
self, client, test_vendor_user
|
||||
):
|
||||
def test_all_vendor_endpoints_require_header_auth(self, client, test_vendor_user):
|
||||
"""Verify all vendor API endpoints require Authorization header"""
|
||||
# Login to establish session
|
||||
client.post(
|
||||
"/api/v1/vendor/auth/login",
|
||||
json={
|
||||
"username": test_vendor_user.username,
|
||||
"password": "vendorpass123"
|
||||
}
|
||||
json={"username": test_vendor_user.username, "password": "vendorpass123"},
|
||||
)
|
||||
|
||||
# All vendor API endpoints (excluding public endpoints like /info)
|
||||
@@ -361,8 +348,9 @@ class TestVendorAPIConsistency:
|
||||
response = client.post(endpoint, json={})
|
||||
|
||||
# All should reject cookie-only auth with 401
|
||||
assert response.status_code == 401, \
|
||||
f"Endpoint {endpoint} should require Authorization header (got {response.status_code})"
|
||||
assert (
|
||||
response.status_code == 401
|
||||
), f"Endpoint {endpoint} should require Authorization header (got {response.status_code})"
|
||||
|
||||
def test_vendor_endpoints_accept_vendor_token(
|
||||
self, client, vendor_user_headers, test_vendor_with_vendor_user
|
||||
@@ -380,5 +368,7 @@ class TestVendorAPIConsistency:
|
||||
response = client.get(endpoint, headers=vendor_user_headers)
|
||||
|
||||
# Should not be authentication/authorization errors
|
||||
assert response.status_code not in [401, 403], \
|
||||
f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})"
|
||||
assert response.status_code not in [
|
||||
401,
|
||||
403,
|
||||
], f"Endpoint {endpoint} should accept vendor token (got {response.status_code}: {response.text})"
|
||||
|
||||
@@ -23,8 +23,7 @@ class TestVendorDashboardAPI:
|
||||
):
|
||||
"""Test dashboard stats returns correct data structure"""
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=vendor_user_headers
|
||||
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -66,9 +65,9 @@ class TestVendorDashboardAPI:
|
||||
self, client, db, test_vendor_user, auth_manager
|
||||
):
|
||||
"""Test that dashboard stats only show data for the authenticated vendor"""
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
from models.database.product import Product
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.product import Product
|
||||
from models.database.vendor import Vendor, VendorUser
|
||||
|
||||
# Create two separate vendors with different data
|
||||
vendor1 = Vendor(
|
||||
@@ -118,10 +117,7 @@ class TestVendorDashboardAPI:
|
||||
vendor1_headers = {"Authorization": f"Bearer {token_data['access_token']}"}
|
||||
|
||||
# Get stats for vendor1
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=vendor1_headers
|
||||
)
|
||||
response = client.get("/api/v1/vendor/dashboard/stats", headers=vendor1_headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -130,9 +126,7 @@ class TestVendorDashboardAPI:
|
||||
assert data["vendor"]["id"] == vendor1.id
|
||||
assert data["products"]["total"] == 3
|
||||
|
||||
def test_dashboard_stats_without_vendor_association(
|
||||
self, client, db, auth_manager
|
||||
):
|
||||
def test_dashboard_stats_without_vendor_association(self, client, db, auth_manager):
|
||||
"""Test dashboard stats for user not associated with any vendor"""
|
||||
from models.database.user import User
|
||||
|
||||
@@ -206,8 +200,7 @@ class TestVendorDashboardAPI:
|
||||
):
|
||||
"""Test dashboard stats for vendor with no data"""
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=vendor_user_headers
|
||||
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -224,8 +217,8 @@ class TestVendorDashboardAPI:
|
||||
self, client, db, vendor_user_headers, test_vendor_with_vendor_user
|
||||
):
|
||||
"""Test dashboard stats accuracy with actual products"""
|
||||
from models.database.product import Product
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.product import Product
|
||||
|
||||
# Create marketplace products
|
||||
mp = MarketplaceProduct(
|
||||
@@ -249,8 +242,7 @@ class TestVendorDashboardAPI:
|
||||
|
||||
# Get stats
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=vendor_user_headers
|
||||
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -267,8 +259,7 @@ class TestVendorDashboardAPI:
|
||||
|
||||
start_time = time.time()
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=vendor_user_headers
|
||||
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
@@ -284,8 +275,7 @@ class TestVendorDashboardAPI:
|
||||
responses = []
|
||||
for _ in range(3):
|
||||
response = client.get(
|
||||
"/api/v1/vendor/dashboard/stats",
|
||||
headers=vendor_user_headers
|
||||
"/api/v1/vendor/dashboard/stats", headers=vendor_user_headers
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Fixtures specific to middleware integration tests.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from models.database.vendor import Vendor
|
||||
from models.database.vendor_domain import VendorDomain
|
||||
from models.database.vendor_theme import VendorTheme
|
||||
@@ -12,10 +13,7 @@ from models.database.vendor_theme import VendorTheme
|
||||
def vendor_with_subdomain(db):
|
||||
"""Create a vendor with subdomain for testing."""
|
||||
vendor = Vendor(
|
||||
name="Test Vendor",
|
||||
code="testvendor",
|
||||
subdomain="testvendor",
|
||||
is_active=True
|
||||
name="Test Vendor", code="testvendor", subdomain="testvendor", is_active=True
|
||||
)
|
||||
db.add(vendor)
|
||||
db.commit()
|
||||
@@ -30,7 +28,7 @@ def vendor_with_custom_domain(db):
|
||||
name="Custom Domain Vendor",
|
||||
code="customvendor",
|
||||
subdomain="customvendor",
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
db.add(vendor)
|
||||
db.commit()
|
||||
@@ -38,10 +36,7 @@ def vendor_with_custom_domain(db):
|
||||
|
||||
# Add custom domain
|
||||
domain = VendorDomain(
|
||||
vendor_id=vendor.id,
|
||||
domain="customdomain.com",
|
||||
is_active=True,
|
||||
is_primary=True
|
||||
vendor_id=vendor.id, domain="customdomain.com", is_active=True, is_primary=True
|
||||
)
|
||||
db.add(domain)
|
||||
db.commit()
|
||||
@@ -56,7 +51,7 @@ def vendor_with_theme(db):
|
||||
name="Themed Vendor",
|
||||
code="themedvendor",
|
||||
subdomain="themedvendor",
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
db.add(vendor)
|
||||
db.commit()
|
||||
@@ -69,7 +64,7 @@ def vendor_with_theme(db):
|
||||
secondary_color="#33FF57",
|
||||
logo_url="/static/vendors/themedvendor/logo.png",
|
||||
favicon_url="/static/vendors/themedvendor/favicon.ico",
|
||||
custom_css="body { background: #FF5733; }"
|
||||
custom_css="body { background: #FF5733; }",
|
||||
)
|
||||
db.add(theme)
|
||||
db.commit()
|
||||
@@ -81,10 +76,7 @@ def vendor_with_theme(db):
|
||||
def inactive_vendor(db):
|
||||
"""Create an inactive vendor for testing."""
|
||||
vendor = Vendor(
|
||||
name="Inactive Vendor",
|
||||
code="inactive",
|
||||
subdomain="inactive",
|
||||
is_active=False
|
||||
name="Inactive Vendor", code="inactive", subdomain="inactive", is_active=False
|
||||
)
|
||||
db.add(vendor)
|
||||
db.commit()
|
||||
|
||||
@@ -5,8 +5,10 @@ Integration tests for request context detection end-to-end flow.
|
||||
These tests verify that context type (API, ADMIN, VENDOR_DASHBOARD, SHOP, FALLBACK)
|
||||
is correctly detected through real HTTP requests.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from middleware.context import RequestContext
|
||||
|
||||
|
||||
@@ -23,13 +25,22 @@ class TestContextDetectionFlow:
|
||||
def test_api_path_detected_as_api_context(self, client):
|
||||
"""Test that /api/* paths are detected as API context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/api/test-api-context")
|
||||
async def test_api(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"context_enum": (
|
||||
request.state.context_type.name
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
response = client.get("/api/test-api-context")
|
||||
@@ -42,12 +53,17 @@ class TestContextDetectionFlow:
|
||||
def test_nested_api_path_detected_as_api_context(self, client):
|
||||
"""Test that nested /api/ paths are detected as API context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/api/v1/vendor/products")
|
||||
async def test_nested_api(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
response = client.get("/api/v1/vendor/products")
|
||||
@@ -63,13 +79,22 @@ class TestContextDetectionFlow:
|
||||
def test_admin_path_detected_as_admin_context(self, client):
|
||||
"""Test that /admin/* paths are detected as ADMIN context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/admin/test-admin-context")
|
||||
async def test_admin(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"context_enum": (
|
||||
request.state.context_type.name
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
response = client.get("/admin/test-admin-context")
|
||||
@@ -82,19 +107,23 @@ class TestContextDetectionFlow:
|
||||
def test_admin_subdomain_detected_as_admin_context(self, client):
|
||||
"""Test that admin.* subdomain is detected as ADMIN context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-admin-subdomain-context")
|
||||
async def test_admin_subdomain(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-admin-subdomain-context",
|
||||
headers={"host": "admin.platform.com"}
|
||||
"/test-admin-subdomain-context", headers={"host": "admin.platform.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -104,12 +133,17 @@ class TestContextDetectionFlow:
|
||||
def test_nested_admin_path_detected_as_admin_context(self, client):
|
||||
"""Test that nested /admin/ paths are detected as ADMIN context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/admin/vendors/123/edit")
|
||||
async def test_nested_admin(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
response = client.get("/admin/vendors/123/edit")
|
||||
@@ -125,21 +159,31 @@ class TestContextDetectionFlow:
|
||||
def test_vendor_dashboard_path_detected(self, client, vendor_with_subdomain):
|
||||
"""Test that /vendor/* paths are detected as VENDOR_DASHBOARD context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendor/test-vendor-dashboard")
|
||||
async def test_vendor_dashboard(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"context_enum": (
|
||||
request.state.context_type.name
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/vendor/test-vendor-dashboard",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -151,19 +195,24 @@ class TestContextDetectionFlow:
|
||||
def test_nested_vendor_dashboard_path_detected(self, client, vendor_with_subdomain):
|
||||
"""Test that nested /vendor/ paths are detected as VENDOR_DASHBOARD context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendor/products/123/edit")
|
||||
async def test_nested_vendor(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/vendor/products/123/edit",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -174,24 +223,36 @@ class TestContextDetectionFlow:
|
||||
# Shop Context Detection Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_shop_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
|
||||
def test_shop_path_with_vendor_detected_as_shop(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that /shop/* paths with vendor are detected as SHOP context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/shop/test-shop-context")
|
||||
async def test_shop(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"context_enum": (
|
||||
request.state.context_type.name
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/shop/test-shop-context",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -200,23 +261,31 @@ class TestContextDetectionFlow:
|
||||
assert data["context_enum"] == "SHOP"
|
||||
assert data["has_vendor"] is True
|
||||
|
||||
def test_root_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
|
||||
def test_root_path_with_vendor_detected_as_shop(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that root path with vendor is detected as SHOP context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-root-shop")
|
||||
async def test_root_shop(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-root-shop",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -228,21 +297,27 @@ class TestContextDetectionFlow:
|
||||
def test_custom_domain_shop_detected(self, client, vendor_with_custom_domain):
|
||||
"""Test that custom domain shop is detected as SHOP context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/products")
|
||||
async def test_custom_domain_shop(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/products",
|
||||
headers={"host": "customdomain.com"}
|
||||
)
|
||||
response = client.get("/products", headers={"host": "customdomain.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -256,21 +331,30 @@ class TestContextDetectionFlow:
|
||||
def test_unknown_path_without_vendor_fallback_context(self, client):
|
||||
"""Test that unknown paths without vendor get FALLBACK context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-fallback-context")
|
||||
async def test_fallback(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"context_enum": request.state.context_type.name if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"context_enum": (
|
||||
request.state.context_type.name
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-fallback-context",
|
||||
headers={"host": "platform.com"}
|
||||
"/test-fallback-context", headers={"host": "platform.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -286,20 +370,26 @@ class TestContextDetectionFlow:
|
||||
def test_api_path_overrides_vendor_context(self, client, vendor_with_subdomain):
|
||||
"""Test that /api/* path sets API context even with vendor."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/api/test-api-priority")
|
||||
async def test_api_priority(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/api/test-api-priority",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -312,20 +402,26 @@ class TestContextDetectionFlow:
|
||||
def test_admin_path_overrides_vendor_context(self, client, vendor_with_subdomain):
|
||||
"""Test that /admin/* path sets ADMIN context even with vendor."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/admin/test-admin-priority")
|
||||
async def test_admin_priority(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/admin/test-admin-priority",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -333,22 +429,29 @@ class TestContextDetectionFlow:
|
||||
# Admin path should override vendor context
|
||||
assert data["context_type"] == "admin"
|
||||
|
||||
def test_vendor_dashboard_overrides_shop_context(self, client, vendor_with_subdomain):
|
||||
def test_vendor_dashboard_overrides_shop_context(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that /vendor/* path sets VENDOR_DASHBOARD, not SHOP."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendor/test-priority")
|
||||
async def test_vendor_priority(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/vendor/test-priority",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -363,21 +466,30 @@ class TestContextDetectionFlow:
|
||||
def test_context_uses_clean_path_for_detection(self, client, vendor_with_subdomain):
|
||||
"""Test that context detection uses clean_path, not original path."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendors/{vendor_code}/shop/products")
|
||||
async def test_clean_path_context(vendor_code: str, request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
|
||||
"original_path": request.url.path
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"clean_path": (
|
||||
request.state.clean_path
|
||||
if hasattr(request.state, "clean_path")
|
||||
else None
|
||||
),
|
||||
"original_path": request.url.path,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
f"/vendors/{vendor_with_subdomain.code}/shop/products",
|
||||
headers={"host": "localhost:8000"}
|
||||
headers={"host": "localhost:8000"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -394,15 +506,20 @@ class TestContextDetectionFlow:
|
||||
def test_context_type_is_enum_instance(self, client):
|
||||
"""Test that context_type is a RequestContext enum instance."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/api/test-enum")
|
||||
async def test_enum(request: Request):
|
||||
context = request.state.context_type if hasattr(request.state, 'context_type') else None
|
||||
context = (
|
||||
request.state.context_type
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"is_enum": isinstance(context, RequestContext) if context else False,
|
||||
"enum_name": context.name if context else None,
|
||||
"enum_value": context.value if context else None
|
||||
"enum_value": context.value if context else None,
|
||||
}
|
||||
|
||||
response = client.get("/api/test-enum")
|
||||
@@ -417,23 +534,30 @@ class TestContextDetectionFlow:
|
||||
# Edge Cases
|
||||
# ========================================================================
|
||||
|
||||
def test_empty_path_with_vendor_detected_as_shop(self, client, vendor_with_subdomain):
|
||||
def test_empty_path_with_vendor_detected_as_shop(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that empty/root path with vendor is detected as SHOP."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/")
|
||||
async def test_root(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
"/", headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 404] # Might not have root handler
|
||||
@@ -445,13 +569,18 @@ class TestContextDetectionFlow:
|
||||
def test_case_insensitive_context_detection(self, client):
|
||||
"""Test that context detection is case insensitive for paths."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/API/test-case")
|
||||
@app.get("/api/test-case")
|
||||
async def test_case(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
# Test uppercase
|
||||
|
||||
@@ -5,8 +5,10 @@ Integration tests for the complete middleware stack.
|
||||
These tests verify that all middleware components work together correctly
|
||||
through real HTTP requests, ensuring proper execution order and state injection.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from middleware.context import RequestContext
|
||||
|
||||
|
||||
@@ -23,14 +25,20 @@ class TestMiddlewareStackIntegration:
|
||||
"""Test that /admin/* paths set ADMIN context type."""
|
||||
# Create a simple endpoint to inspect request state
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/admin/test-context")
|
||||
async def test_admin_context(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"has_theme": hasattr(request.state, 'theme')
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
}
|
||||
|
||||
response = client.get("/admin/test-context")
|
||||
@@ -44,20 +52,24 @@ class TestMiddlewareStackIntegration:
|
||||
def test_admin_subdomain_sets_admin_context(self, client):
|
||||
"""Test that admin.* subdomain sets ADMIN context type."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-admin-subdomain")
|
||||
async def test_admin_subdomain(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
# Simulate request with admin subdomain
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-admin-subdomain",
|
||||
headers={"host": "admin.platform.com"}
|
||||
"/test-admin-subdomain", headers={"host": "admin.platform.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -71,12 +83,17 @@ class TestMiddlewareStackIntegration:
|
||||
def test_api_path_sets_api_context(self, client):
|
||||
"""Test that /api/* paths set API context type."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/api/test-context")
|
||||
async def test_api_context(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
response = client.get("/api/test-context")
|
||||
@@ -89,25 +106,40 @@ class TestMiddlewareStackIntegration:
|
||||
# Vendor Dashboard Context Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_vendor_dashboard_path_sets_vendor_context(self, client, vendor_with_subdomain):
|
||||
def test_vendor_dashboard_path_sets_vendor_context(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that /vendor/* paths with vendor set VENDOR_DASHBOARD context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendor/test-context")
|
||||
async def test_vendor_context(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"vendor_id": (
|
||||
request.state.vendor_id
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
# Request with vendor subdomain
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/vendor/test-context",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -120,24 +152,35 @@ class TestMiddlewareStackIntegration:
|
||||
# Shop Context Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_shop_path_with_subdomain_sets_shop_context(self, client, vendor_with_subdomain):
|
||||
def test_shop_path_with_subdomain_sets_shop_context(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that /shop/* paths with vendor subdomain set SHOP context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/shop/test-context")
|
||||
async def test_shop_context(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
|
||||
"has_theme": hasattr(request.state, 'theme')
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"vendor_id": (
|
||||
request.state.vendor_id
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/shop/test-context",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -146,23 +189,33 @@ class TestMiddlewareStackIntegration:
|
||||
assert data["vendor_id"] == vendor_with_subdomain.id
|
||||
assert data["has_theme"] is True
|
||||
|
||||
def test_shop_path_with_custom_domain_sets_shop_context(self, client, vendor_with_custom_domain):
|
||||
def test_shop_path_with_custom_domain_sets_shop_context(
|
||||
self, client, vendor_with_custom_domain
|
||||
):
|
||||
"""Test that /shop/* paths with custom domain set SHOP context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/shop/test-custom-domain")
|
||||
async def test_shop_custom_domain(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"vendor_id": (
|
||||
request.state.vendor_id
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/shop/test-custom-domain",
|
||||
headers={"host": "customdomain.com"}
|
||||
"/shop/test-custom-domain", headers={"host": "customdomain.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -174,9 +227,12 @@ class TestMiddlewareStackIntegration:
|
||||
# Middleware Execution Order Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_vendor_context_runs_before_context_detection(self, client, vendor_with_subdomain):
|
||||
def test_vendor_context_runs_before_context_detection(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that VendorContextMiddleware runs before ContextDetectionMiddleware."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-execution-order")
|
||||
@@ -184,16 +240,16 @@ class TestMiddlewareStackIntegration:
|
||||
# If vendor context runs first, clean_path should be available
|
||||
# before context detection uses it
|
||||
return {
|
||||
"has_vendor": hasattr(request.state, 'vendor'),
|
||||
"has_clean_path": hasattr(request.state, 'clean_path'),
|
||||
"has_context_type": hasattr(request.state, 'context_type')
|
||||
"has_vendor": hasattr(request.state, "vendor"),
|
||||
"has_clean_path": hasattr(request.state, "clean_path"),
|
||||
"has_context_type": hasattr(request.state, "context_type"),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-execution-order",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -206,21 +262,26 @@ class TestMiddlewareStackIntegration:
|
||||
def test_theme_context_runs_after_vendor_context(self, client, vendor_with_theme):
|
||||
"""Test that ThemeContextMiddleware runs after VendorContextMiddleware."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-theme-loading")
|
||||
async def test_theme_loading(request: Request):
|
||||
return {
|
||||
"has_vendor": hasattr(request.state, 'vendor'),
|
||||
"has_theme": hasattr(request.state, 'theme'),
|
||||
"theme_primary_color": request.state.theme.get('primary_color') if hasattr(request.state, 'theme') else None
|
||||
"has_vendor": hasattr(request.state, "vendor"),
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
"theme_primary_color": (
|
||||
request.state.theme.get("primary_color")
|
||||
if hasattr(request.state, "theme")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-theme-loading",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -249,21 +310,27 @@ class TestMiddlewareStackIntegration:
|
||||
def test_missing_vendor_graceful_handling(self, client):
|
||||
"""Test that missing vendor is handled gracefully."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-missing-vendor")
|
||||
async def test_missing_vendor(request: Request):
|
||||
return {
|
||||
"has_vendor": hasattr(request.state, 'vendor'),
|
||||
"vendor": request.state.vendor if hasattr(request.state, 'vendor') else None,
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None
|
||||
"has_vendor": hasattr(request.state, "vendor"),
|
||||
"vendor": (
|
||||
request.state.vendor if hasattr(request.state, "vendor") else None
|
||||
),
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-missing-vendor",
|
||||
headers={"host": "nonexistent.platform.com"}
|
||||
"/test-missing-vendor", headers={"host": "nonexistent.platform.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -276,20 +343,23 @@ class TestMiddlewareStackIntegration:
|
||||
def test_inactive_vendor_not_loaded(self, client, inactive_vendor):
|
||||
"""Test that inactive vendors are not loaded."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-inactive-vendor")
|
||||
async def test_inactive_vendor_endpoint(request: Request):
|
||||
return {
|
||||
"has_vendor": hasattr(request.state, 'vendor'),
|
||||
"vendor": request.state.vendor if hasattr(request.state, 'vendor') else None
|
||||
"has_vendor": hasattr(request.state, "vendor"),
|
||||
"vendor": (
|
||||
request.state.vendor if hasattr(request.state, "vendor") else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-inactive-vendor",
|
||||
headers={"host": f"{inactive_vendor.subdomain}.platform.com"}
|
||||
headers={"host": f"{inactive_vendor.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -5,9 +5,10 @@ Integration tests for theme loading end-to-end flow.
|
||||
These tests verify that vendor themes are correctly loaded and injected
|
||||
into request.state through real HTTP requests.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.middleware
|
||||
@@ -22,21 +23,22 @@ class TestThemeLoadingFlow:
|
||||
def test_theme_loaded_for_vendor_with_custom_theme(self, client, vendor_with_theme):
|
||||
"""Test that custom theme is loaded for vendor with theme."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-theme-loading")
|
||||
async def test_theme(request: Request):
|
||||
theme = request.state.theme if hasattr(request.state, 'theme') else None
|
||||
theme = request.state.theme if hasattr(request.state, "theme") else None
|
||||
return {
|
||||
"has_theme": theme is not None,
|
||||
"theme_data": theme if theme else None
|
||||
"theme_data": theme if theme else None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-theme-loading",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -46,24 +48,27 @@ class TestThemeLoadingFlow:
|
||||
assert data["theme_data"]["primary_color"] == "#FF5733"
|
||||
assert data["theme_data"]["secondary_color"] == "#33FF57"
|
||||
|
||||
def test_default_theme_loaded_for_vendor_without_theme(self, client, vendor_with_subdomain):
|
||||
def test_default_theme_loaded_for_vendor_without_theme(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that default theme is loaded for vendor without custom theme."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-default-theme")
|
||||
async def test_default_theme(request: Request):
|
||||
theme = request.state.theme if hasattr(request.state, 'theme') else None
|
||||
theme = request.state.theme if hasattr(request.state, "theme") else None
|
||||
return {
|
||||
"has_theme": theme is not None,
|
||||
"theme_data": theme if theme else None
|
||||
"theme_data": theme if theme else None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-default-theme",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -76,21 +81,20 @@ class TestThemeLoadingFlow:
|
||||
def test_no_theme_loaded_without_vendor(self, client):
|
||||
"""Test that no theme is loaded when there's no vendor."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-no-theme")
|
||||
async def test_no_theme(request: Request):
|
||||
return {
|
||||
"has_theme": hasattr(request.state, 'theme'),
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-no-theme",
|
||||
headers={"host": "platform.com"}
|
||||
)
|
||||
response = client.get("/test-no-theme", headers={"host": "platform.com"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
@@ -105,24 +109,25 @@ class TestThemeLoadingFlow:
|
||||
def test_custom_theme_contains_all_fields(self, client, vendor_with_theme):
|
||||
"""Test that custom theme contains all expected fields."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-theme-fields")
|
||||
async def test_theme_fields(request: Request):
|
||||
theme = request.state.theme if hasattr(request.state, 'theme') else {}
|
||||
theme = request.state.theme if hasattr(request.state, "theme") else {}
|
||||
return {
|
||||
"primary_color": theme.get("primary_color"),
|
||||
"secondary_color": theme.get("secondary_color"),
|
||||
"logo_url": theme.get("logo_url"),
|
||||
"favicon_url": theme.get("favicon_url"),
|
||||
"custom_css": theme.get("custom_css")
|
||||
"custom_css": theme.get("custom_css"),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-theme-fields",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -136,24 +141,25 @@ class TestThemeLoadingFlow:
|
||||
def test_default_theme_structure(self, client, vendor_with_subdomain):
|
||||
"""Test that default theme has expected structure."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-default-theme-structure")
|
||||
async def test_default_structure(request: Request):
|
||||
theme = request.state.theme if hasattr(request.state, 'theme') else {}
|
||||
theme = request.state.theme if hasattr(request.state, "theme") else {}
|
||||
return {
|
||||
"has_primary_color": "primary_color" in theme,
|
||||
"has_secondary_color": "secondary_color" in theme,
|
||||
"has_logo_url": "logo_url" in theme,
|
||||
"has_favicon_url": "favicon_url" in theme,
|
||||
"has_custom_css": "custom_css" in theme
|
||||
"has_custom_css": "custom_css" in theme,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-default-theme-structure",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -169,21 +175,30 @@ class TestThemeLoadingFlow:
|
||||
def test_theme_loaded_in_shop_context(self, client, vendor_with_theme):
|
||||
"""Test that theme is loaded in SHOP context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/shop/test-shop-theme")
|
||||
async def test_shop_theme(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_theme": hasattr(request.state, 'theme'),
|
||||
"theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
"theme_primary": (
|
||||
request.state.theme.get("primary_color")
|
||||
if hasattr(request.state, "theme")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/shop/test-shop-theme",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -195,21 +210,30 @@ class TestThemeLoadingFlow:
|
||||
def test_theme_loaded_in_vendor_dashboard_context(self, client, vendor_with_theme):
|
||||
"""Test that theme is loaded in VENDOR_DASHBOARD context."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendor/test-dashboard-theme")
|
||||
async def test_dashboard_theme(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_theme": hasattr(request.state, 'theme'),
|
||||
"theme_secondary": request.state.theme.get("secondary_color") if hasattr(request.state, 'theme') else None
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
"theme_secondary": (
|
||||
request.state.theme.get("secondary_color")
|
||||
if hasattr(request.state, "theme")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/vendor/test-dashboard-theme",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -221,21 +245,27 @@ class TestThemeLoadingFlow:
|
||||
def test_theme_loaded_in_api_context_with_vendor(self, client, vendor_with_theme):
|
||||
"""Test that theme is loaded in API context when vendor is present."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/api/test-api-theme")
|
||||
async def test_api_theme(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"has_theme": hasattr(request.state, 'theme')
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/api/test-api-theme",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -248,13 +278,18 @@ class TestThemeLoadingFlow:
|
||||
def test_no_theme_in_admin_context(self, client):
|
||||
"""Test that theme is not loaded in ADMIN context (no vendor)."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/admin/test-admin-no-theme")
|
||||
async def test_admin_no_theme(request: Request):
|
||||
return {
|
||||
"context_type": request.state.context_type.value if hasattr(request.state, 'context_type') else None,
|
||||
"has_theme": hasattr(request.state, 'theme')
|
||||
"context_type": (
|
||||
request.state.context_type.value
|
||||
if hasattr(request.state, "context_type")
|
||||
else None
|
||||
),
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
}
|
||||
|
||||
response = client.get("/admin/test-admin-no-theme")
|
||||
@@ -272,20 +307,29 @@ class TestThemeLoadingFlow:
|
||||
def test_theme_loaded_with_subdomain_routing(self, client, vendor_with_theme):
|
||||
"""Test theme loading with subdomain routing."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-subdomain-theme")
|
||||
async def test_subdomain_theme(request: Request):
|
||||
return {
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
|
||||
"theme_logo": request.state.theme.get("logo_url") if hasattr(request.state, 'theme') else None
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
"theme_logo": (
|
||||
request.state.theme.get("logo_url")
|
||||
if hasattr(request.state, "theme")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-subdomain-theme",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -293,33 +337,40 @@ class TestThemeLoadingFlow:
|
||||
assert data["vendor_code"] == vendor_with_theme.code
|
||||
assert data["theme_logo"] == "/static/vendors/themedvendor/logo.png"
|
||||
|
||||
def test_theme_loaded_with_custom_domain_routing(self, client, vendor_with_custom_domain, db):
|
||||
def test_theme_loaded_with_custom_domain_routing(
|
||||
self, client, vendor_with_custom_domain, db
|
||||
):
|
||||
"""Test theme loading with custom domain routing."""
|
||||
# Add theme to custom domain vendor
|
||||
from models.database.vendor_theme import VendorTheme
|
||||
|
||||
theme = VendorTheme(
|
||||
vendor_id=vendor_with_custom_domain.id,
|
||||
primary_color="#123456",
|
||||
secondary_color="#654321"
|
||||
secondary_color="#654321",
|
||||
)
|
||||
db.add(theme)
|
||||
db.commit()
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-custom-domain-theme")
|
||||
async def test_custom_domain_theme(request: Request):
|
||||
return {
|
||||
"has_theme": hasattr(request.state, 'theme'),
|
||||
"theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
"theme_primary": (
|
||||
request.state.theme.get("primary_color")
|
||||
if hasattr(request.state, "theme")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-custom-domain-theme",
|
||||
headers={"host": "customdomain.com"}
|
||||
"/test-custom-domain-theme", headers={"host": "customdomain.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -331,28 +382,37 @@ class TestThemeLoadingFlow:
|
||||
# Theme Dependency on Vendor Context Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_theme_middleware_depends_on_vendor_middleware(self, client, vendor_with_theme):
|
||||
def test_theme_middleware_depends_on_vendor_middleware(
|
||||
self, client, vendor_with_theme
|
||||
):
|
||||
"""Test that theme loading depends on vendor being detected first."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-theme-vendor-dependency")
|
||||
async def test_dependency(request: Request):
|
||||
return {
|
||||
"has_vendor": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
|
||||
"has_theme": hasattr(request.state, 'theme'),
|
||||
"has_vendor": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor_id": (
|
||||
request.state.vendor_id
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
"has_theme": hasattr(request.state, "theme"),
|
||||
"vendor_matches_theme": (
|
||||
request.state.vendor_id == vendor_with_theme.id
|
||||
if hasattr(request.state, 'vendor_id') else False
|
||||
)
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-theme-vendor-dependency",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -369,15 +429,20 @@ class TestThemeLoadingFlow:
|
||||
def test_theme_loaded_consistently_across_requests(self, client, vendor_with_theme):
|
||||
"""Test that theme is loaded consistently across multiple requests."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-theme-consistency")
|
||||
async def test_consistency(request: Request):
|
||||
return {
|
||||
"theme_primary": request.state.theme.get("primary_color") if hasattr(request.state, 'theme') else None
|
||||
"theme_primary": (
|
||||
request.state.theme.get("primary_color")
|
||||
if hasattr(request.state, "theme")
|
||||
else None
|
||||
)
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
# Make multiple requests
|
||||
@@ -385,7 +450,7 @@ class TestThemeLoadingFlow:
|
||||
for _ in range(3):
|
||||
response = client.get(
|
||||
"/test-theme-consistency",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
responses.append(response.json())
|
||||
|
||||
@@ -396,25 +461,28 @@ class TestThemeLoadingFlow:
|
||||
# Edge Cases and Error Handling Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_theme_gracefully_handles_missing_theme_fields(self, client, vendor_with_subdomain):
|
||||
def test_theme_gracefully_handles_missing_theme_fields(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that missing theme fields are handled gracefully."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-partial-theme")
|
||||
async def test_partial_theme(request: Request):
|
||||
theme = request.state.theme if hasattr(request.state, 'theme') else {}
|
||||
theme = request.state.theme if hasattr(request.state, "theme") else {}
|
||||
return {
|
||||
"has_theme": bool(theme),
|
||||
"primary_color": theme.get("primary_color", "default"),
|
||||
"logo_url": theme.get("logo_url", "default")
|
||||
"logo_url": theme.get("logo_url", "default"),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-partial-theme",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -427,23 +495,21 @@ class TestThemeLoadingFlow:
|
||||
def test_theme_dict_is_mutable(self, client, vendor_with_theme):
|
||||
"""Test that theme dict can be accessed and read from."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-theme-mutable")
|
||||
async def test_mutable(request: Request):
|
||||
theme = request.state.theme if hasattr(request.state, 'theme') else {}
|
||||
theme = request.state.theme if hasattr(request.state, "theme") else {}
|
||||
# Try to access theme values
|
||||
primary = theme.get("primary_color")
|
||||
return {
|
||||
"can_read": primary is not None,
|
||||
"value": primary
|
||||
}
|
||||
return {"can_read": primary is not None, "value": primary}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-theme-mutable",
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_theme.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -5,9 +5,10 @@ Integration tests for vendor context detection end-to-end flow.
|
||||
These tests verify that vendor detection works correctly through real HTTP requests
|
||||
for all routing modes: subdomain, custom domain, and path-based.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.middleware
|
||||
@@ -22,23 +23,37 @@ class TestVendorContextFlow:
|
||||
def test_subdomain_vendor_detection(self, client, vendor_with_subdomain):
|
||||
"""Test vendor detection via subdomain routing."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-subdomain-detection")
|
||||
async def test_subdomain(request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
|
||||
"vendor_name": request.state.vendor.name if hasattr(request.state, 'vendor') and request.state.vendor else None,
|
||||
"detection_method": "subdomain"
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor_id": (
|
||||
request.state.vendor_id
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
"vendor_name": (
|
||||
request.state.vendor.name
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
"detection_method": "subdomain",
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-subdomain-detection",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -51,20 +66,28 @@ class TestVendorContextFlow:
|
||||
def test_subdomain_with_port_detection(self, client, vendor_with_subdomain):
|
||||
"""Test vendor detection via subdomain with port number."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-subdomain-port")
|
||||
async def test_subdomain_port(request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-subdomain-port",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com:8000"}
|
||||
headers={
|
||||
"host": f"{vendor_with_subdomain.subdomain}.platform.com:8000"
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -75,20 +98,24 @@ class TestVendorContextFlow:
|
||||
def test_nonexistent_subdomain_returns_no_vendor(self, client):
|
||||
"""Test that nonexistent subdomain doesn't crash and returns no vendor."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-nonexistent-subdomain")
|
||||
async def test_nonexistent(request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor": request.state.vendor if hasattr(request.state, 'vendor') else None
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor": (
|
||||
request.state.vendor if hasattr(request.state, "vendor") else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-nonexistent-subdomain",
|
||||
headers={"host": "nonexistent.platform.com"}
|
||||
headers={"host": "nonexistent.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -102,22 +129,31 @@ class TestVendorContextFlow:
|
||||
def test_custom_domain_vendor_detection(self, client, vendor_with_custom_domain):
|
||||
"""Test vendor detection via custom domain."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-custom-domain")
|
||||
async def test_custom_domain(request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
|
||||
"detection_method": "custom_domain"
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor_id": (
|
||||
request.state.vendor_id
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
"detection_method": "custom_domain",
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-custom-domain",
|
||||
headers={"host": "customdomain.com"}
|
||||
"/test-custom-domain", headers={"host": "customdomain.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -129,21 +165,26 @@ class TestVendorContextFlow:
|
||||
def test_custom_domain_with_www_detection(self, client, vendor_with_custom_domain):
|
||||
"""Test vendor detection via custom domain with www prefix."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-custom-domain-www")
|
||||
async def test_custom_domain_www(request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
# Test with www prefix - should still detect vendor
|
||||
response = client.get(
|
||||
"/test-custom-domain-www",
|
||||
headers={"host": "www.customdomain.com"}
|
||||
"/test-custom-domain-www", headers={"host": "www.customdomain.com"}
|
||||
)
|
||||
|
||||
# This might fail if your implementation doesn't strip www
|
||||
@@ -154,25 +195,37 @@ class TestVendorContextFlow:
|
||||
# Path-Based Detection Tests (Development Mode)
|
||||
# ========================================================================
|
||||
|
||||
def test_path_based_vendor_detection_vendors_prefix(self, client, vendor_with_subdomain):
|
||||
def test_path_based_vendor_detection_vendors_prefix(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test vendor detection via path-based routing with /vendors/ prefix."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendors/{vendor_code}/test-path")
|
||||
async def test_path_based(vendor_code: str, request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor_code_param": vendor_code,
|
||||
"vendor_code_state": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None,
|
||||
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None
|
||||
"vendor_code_state": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
"clean_path": (
|
||||
request.state.clean_path
|
||||
if hasattr(request.state, "clean_path")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
f"/vendors/{vendor_with_subdomain.code}/test-path",
|
||||
headers={"host": "localhost:8000"}
|
||||
headers={"host": "localhost:8000"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -181,23 +234,31 @@ class TestVendorContextFlow:
|
||||
assert data["vendor_code_param"] == vendor_with_subdomain.code
|
||||
assert data["vendor_code_state"] == vendor_with_subdomain.code
|
||||
|
||||
def test_path_based_vendor_detection_vendor_prefix(self, client, vendor_with_subdomain):
|
||||
def test_path_based_vendor_detection_vendor_prefix(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test vendor detection via path-based routing with /vendor/ prefix."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendor/{vendor_code}/test")
|
||||
async def test_vendor_path(vendor_code: str, request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None,
|
||||
"vendor_code": request.state.vendor.code if hasattr(request.state, 'vendor') and request.state.vendor else None
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None,
|
||||
"vendor_code": (
|
||||
request.state.vendor.code
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
f"/vendor/{vendor_with_subdomain.code}/test",
|
||||
headers={"host": "localhost:8000"}
|
||||
headers={"host": "localhost:8000"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -209,48 +270,63 @@ class TestVendorContextFlow:
|
||||
# Clean Path Extraction Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_clean_path_extracted_from_vendor_prefix(self, client, vendor_with_subdomain):
|
||||
def test_clean_path_extracted_from_vendor_prefix(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that clean_path is correctly extracted from path-based routing."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/vendors/{vendor_code}/shop/products")
|
||||
async def test_clean_path(vendor_code: str, request: Request):
|
||||
return {
|
||||
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
|
||||
"original_path": request.url.path
|
||||
"clean_path": (
|
||||
request.state.clean_path
|
||||
if hasattr(request.state, "clean_path")
|
||||
else None
|
||||
),
|
||||
"original_path": request.url.path,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
f"/vendors/{vendor_with_subdomain.code}/shop/products",
|
||||
headers={"host": "localhost:8000"}
|
||||
headers={"host": "localhost:8000"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Clean path should have vendor prefix removed
|
||||
assert data["clean_path"] == "/shop/products"
|
||||
assert f"/vendors/{vendor_with_subdomain.code}/shop/products" in data["original_path"]
|
||||
assert (
|
||||
f"/vendors/{vendor_with_subdomain.code}/shop/products"
|
||||
in data["original_path"]
|
||||
)
|
||||
|
||||
def test_clean_path_unchanged_for_subdomain(self, client, vendor_with_subdomain):
|
||||
"""Test that clean_path equals original path for subdomain routing."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/shop/test-clean-path")
|
||||
async def test_subdomain_clean_path(request: Request):
|
||||
return {
|
||||
"clean_path": request.state.clean_path if hasattr(request.state, 'clean_path') else None,
|
||||
"original_path": request.url.path
|
||||
"clean_path": (
|
||||
request.state.clean_path
|
||||
if hasattr(request.state, "clean_path")
|
||||
else None
|
||||
),
|
||||
"original_path": request.url.path,
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/shop/test-clean-path",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -266,21 +342,30 @@ class TestVendorContextFlow:
|
||||
def test_vendor_id_injected_into_request_state(self, client, vendor_with_subdomain):
|
||||
"""Test that vendor_id is correctly injected into request.state."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-vendor-id-injection")
|
||||
async def test_vendor_id(request: Request):
|
||||
return {
|
||||
"has_vendor_id": hasattr(request.state, 'vendor_id'),
|
||||
"vendor_id": request.state.vendor_id if hasattr(request.state, 'vendor_id') else None,
|
||||
"vendor_id_type": type(request.state.vendor_id).__name__ if hasattr(request.state, 'vendor_id') else None
|
||||
"has_vendor_id": hasattr(request.state, "vendor_id"),
|
||||
"vendor_id": (
|
||||
request.state.vendor_id
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
"vendor_id_type": (
|
||||
type(request.state.vendor_id).__name__
|
||||
if hasattr(request.state, "vendor_id")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-vendor-id-injection",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -289,30 +374,41 @@ class TestVendorContextFlow:
|
||||
assert data["vendor_id"] == vendor_with_subdomain.id
|
||||
assert data["vendor_id_type"] == "int"
|
||||
|
||||
def test_vendor_object_injected_into_request_state(self, client, vendor_with_subdomain):
|
||||
def test_vendor_object_injected_into_request_state(
|
||||
self, client, vendor_with_subdomain
|
||||
):
|
||||
"""Test that full vendor object is injected into request.state."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-vendor-object-injection")
|
||||
async def test_vendor_object(request: Request):
|
||||
vendor = request.state.vendor if hasattr(request.state, 'vendor') and request.state.vendor else None
|
||||
vendor = (
|
||||
request.state.vendor
|
||||
if hasattr(request.state, "vendor") and request.state.vendor
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"has_vendor": vendor is not None,
|
||||
"vendor_attributes": {
|
||||
"id": vendor.id if vendor else None,
|
||||
"name": vendor.name if vendor else None,
|
||||
"code": vendor.code if vendor else None,
|
||||
"subdomain": vendor.subdomain if vendor else None,
|
||||
"is_active": vendor.is_active if vendor else None
|
||||
} if vendor else None
|
||||
"vendor_attributes": (
|
||||
{
|
||||
"id": vendor.id if vendor else None,
|
||||
"name": vendor.name if vendor else None,
|
||||
"code": vendor.code if vendor else None,
|
||||
"subdomain": vendor.subdomain if vendor else None,
|
||||
"is_active": vendor.is_active if vendor else None,
|
||||
}
|
||||
if vendor
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-vendor-object-injection",
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"}
|
||||
headers={"host": f"{vendor_with_subdomain.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -330,19 +426,21 @@ class TestVendorContextFlow:
|
||||
def test_inactive_vendor_not_detected(self, client, inactive_vendor):
|
||||
"""Test that inactive vendors are not detected."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-inactive-vendor-detection")
|
||||
async def test_inactive(request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-inactive-vendor-detection",
|
||||
headers={"host": f"{inactive_vendor.subdomain}.platform.com"}
|
||||
headers={"host": f"{inactive_vendor.subdomain}.platform.com"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -352,19 +450,20 @@ class TestVendorContextFlow:
|
||||
def test_platform_domain_without_subdomain_no_vendor(self, client):
|
||||
"""Test that platform domain without subdomain doesn't detect vendor."""
|
||||
from fastapi import Request
|
||||
|
||||
from main import app
|
||||
|
||||
@app.get("/test-platform-domain")
|
||||
async def test_platform(request: Request):
|
||||
return {
|
||||
"vendor_detected": hasattr(request.state, 'vendor') and request.state.vendor is not None
|
||||
"vendor_detected": hasattr(request.state, "vendor")
|
||||
and request.state.vendor is not None
|
||||
}
|
||||
|
||||
with patch('app.core.config.settings') as mock_settings:
|
||||
with patch("app.core.config.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
response = client.get(
|
||||
"/test-platform-domain",
|
||||
headers={"host": "platform.com"}
|
||||
"/test-platform-domain", headers={"host": "platform.com"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -11,7 +11,8 @@ class TestInputValidation:
|
||||
malicious_search = "'; DROP TABLE products; --"
|
||||
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?search={malicious_search}", headers=auth_headers
|
||||
f"/api/v1/marketplace/product?search={malicious_search}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Should not crash and should return normal response
|
||||
@@ -40,17 +41,23 @@ class TestInputValidation:
|
||||
def test_parameter_validation(self, client, auth_headers):
|
||||
"""Test parameter validation for API endpoints"""
|
||||
# Test invalid pagination parameters
|
||||
response = client.get("/api/v1/marketplace/product?limit=-1", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=-1", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
response = client.get("/api/v1/marketplace/product?skip=-1", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?skip=-1", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_json_validation(self, client, auth_headers):
|
||||
"""Test JSON validation for POST requests"""
|
||||
# Test invalid JSON structure
|
||||
response = client.post(
|
||||
"/api/v1/marketplace/product", headers=auth_headers, content="invalid json content"
|
||||
"/api/v1/marketplace/product",
|
||||
headers=auth_headers,
|
||||
content="invalid json content",
|
||||
)
|
||||
assert response.status_code == 422 # JSON decode error
|
||||
|
||||
@@ -58,6 +65,8 @@ class TestInputValidation:
|
||||
response = client.post(
|
||||
"/api/v1/marketplace/product",
|
||||
headers=auth_headers,
|
||||
json={"title": "Test MarketplaceProduct"}, # Missing required marketplace_product_id
|
||||
json={
|
||||
"title": "Test MarketplaceProduct"
|
||||
}, # Missing required marketplace_product_id
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
@@ -33,12 +33,15 @@ class TestIntegrationFlows:
|
||||
"quantity": 50,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
|
||||
response = client.post(
|
||||
"/api/v1/inventory", headers=auth_headers, json=inventory_data
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 3. Get product with inventory info
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product/{product['marketplace_product_id']}", headers=auth_headers
|
||||
f"/api/v1/marketplace/product/{product['marketplace_product_id']}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
product_detail = response.json()
|
||||
@@ -55,14 +58,15 @@ class TestIntegrationFlows:
|
||||
|
||||
# 5. Search for product
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?search=Updated Integration", headers=auth_headers
|
||||
"/api/v1/marketplace/product?search=Updated Integration",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["total"] == 1
|
||||
|
||||
def test_product_workflow(self, client, auth_headers):
|
||||
"""Test vendor creation and product management workflow"""
|
||||
# 1. Create a vendor
|
||||
# 1. Create a vendor
|
||||
vendor_data = {
|
||||
"vendor_code": "FLOWVENDOR",
|
||||
"name": "Integration Flow Vendor",
|
||||
@@ -91,7 +95,9 @@ class TestIntegrationFlows:
|
||||
# This would test the vendor -product association
|
||||
|
||||
# 4. Get vendor details
|
||||
response = client.get(f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/vendor/{vendor ['vendor_code']}", headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_inventory_operations_workflow(self, client, auth_headers):
|
||||
|
||||
@@ -28,7 +28,9 @@ class TestPerformance:
|
||||
|
||||
# Time the request
|
||||
start_time = time.time()
|
||||
response = client.get("/api/v1/marketplace/product?limit=100", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?limit=100", headers=auth_headers
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -54,7 +56,9 @@ class TestPerformance:
|
||||
|
||||
# Time search request
|
||||
start_time = time.time()
|
||||
response = client.get("/api/v1/marketplace/product?search=Searchable", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product?search=Searchable", headers=auth_headers
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -115,7 +119,8 @@ class TestPerformance:
|
||||
for offset in offsets:
|
||||
start_time = time.time()
|
||||
response = client.get(
|
||||
f"/api/v1/marketplace/product?skip={offset}&limit=20", headers=auth_headers
|
||||
f"/api/v1/marketplace/product?skip={offset}&limit=20",
|
||||
headers=auth_headers,
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
@@ -5,9 +5,10 @@ System tests for error handling across the LetzVendor API.
|
||||
Tests the complete error handling flow from FastAPI through custom exception handlers
|
||||
to ensure proper HTTP status codes, error structures, and client-friendly responses.
|
||||
"""
|
||||
import pytest
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.system
|
||||
class TestErrorHandling:
|
||||
@@ -16,9 +17,7 @@ class TestErrorHandling:
|
||||
def test_invalid_json_request(self, client, auth_headers):
|
||||
"""Test handling of malformed JSON requests"""
|
||||
response = client.post(
|
||||
"/api/v1/vendor",
|
||||
headers=auth_headers,
|
||||
content="{ invalid json syntax"
|
||||
"/api/v1/vendor", headers=auth_headers, content="{ invalid json syntax"
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
@@ -31,9 +30,7 @@ class TestErrorHandling:
|
||||
"""Test validation errors for missing required fields"""
|
||||
# Missing name
|
||||
response = client.post(
|
||||
"/api/v1/vendor",
|
||||
headers=auth_headers,
|
||||
json={"vendor_code": "TESTVENDOR"}
|
||||
"/api/v1/vendor", headers=auth_headers, json={"vendor_code": "TESTVENDOR"}
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
@@ -48,10 +45,7 @@ class TestErrorHandling:
|
||||
response = client.post(
|
||||
"/api/v1/vendor",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"vendor_code": "INVALID@VENDOR!",
|
||||
"name": "Test Vendor"
|
||||
}
|
||||
json={"vendor_code": "INVALID@VENDOR!", "name": "Test Vendor"},
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
@@ -92,7 +86,7 @@ class TestErrorHandling:
|
||||
assert data["status_code"] == 401
|
||||
|
||||
def test_vendor_not_found(self, client, auth_headers):
|
||||
"""Test accessing non-existent vendor """
|
||||
"""Test accessing non-existent vendor"""
|
||||
response = client.get("/api/v1/vendor/NONEXISTENT", headers=auth_headers)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -104,7 +98,9 @@ class TestErrorHandling:
|
||||
|
||||
def test_product_not_found(self, client, auth_headers):
|
||||
"""Test accessing non-existent product"""
|
||||
response = client.get("/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers)
|
||||
response = client.get(
|
||||
"/api/v1/marketplace/product/NONEXISTENT", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
@@ -117,7 +113,7 @@ class TestErrorHandling:
|
||||
"""Test creating vendor with duplicate vendor code"""
|
||||
vendor_data = {
|
||||
"vendor_code": test_vendor.vendor_code,
|
||||
"name": "Duplicate Vendor"
|
||||
"name": "Duplicate Vendor",
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
|
||||
@@ -128,25 +124,34 @@ class TestErrorHandling:
|
||||
assert data["status_code"] == 409
|
||||
assert data["details"]["vendor_code"] == test_vendor.vendor_code.upper()
|
||||
|
||||
def test_duplicate_product_creation(self, client, auth_headers, test_marketplace_product):
|
||||
def test_duplicate_product_creation(
|
||||
self, client, auth_headers, test_marketplace_product
|
||||
):
|
||||
"""Test creating product with duplicate product ID"""
|
||||
product_data = {
|
||||
"marketplace_product_id": test_marketplace_product.marketplace_product_id,
|
||||
"title": "Duplicate MarketplaceProduct",
|
||||
"gtin": "1234567890123"
|
||||
"gtin": "1234567890123",
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data)
|
||||
response = client.post(
|
||||
"/api/v1/marketplace/product", headers=auth_headers, json=product_data
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
data = response.json()
|
||||
assert data["error_code"] == "PRODUCT_ALREADY_EXISTS"
|
||||
assert data["status_code"] == 409
|
||||
assert data["details"]["marketplace_product_id"] == test_marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
data["details"]["marketplace_product_id"]
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
def test_unauthorized_vendor_access(self, client, auth_headers, inactive_vendor):
|
||||
"""Test accessing vendor without proper permissions"""
|
||||
response = client.get(f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers)
|
||||
response = client.get(
|
||||
f"/api/v1/vendor/{inactive_vendor.vendor_code}", headers=auth_headers
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
data = response.json()
|
||||
@@ -154,27 +159,33 @@ class TestErrorHandling:
|
||||
assert data["status_code"] == 403
|
||||
assert data["details"]["vendor_code"] == inactive_vendor.vendor_code
|
||||
|
||||
def test_insufficient_permissions(self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users"):
|
||||
def test_insufficient_permissions(
|
||||
self, client, auth_headers, admin_only_endpoint="/api/v1/admin/users"
|
||||
):
|
||||
"""Test accessing admin endpoints with regular user"""
|
||||
response = client.get(admin_only_endpoint, headers=auth_headers)
|
||||
|
||||
assert response.status_code in [403, 404] # 403 for permission denied, 404 if endpoint doesn't exist
|
||||
assert response.status_code in [
|
||||
403,
|
||||
404,
|
||||
] # 403 for permission denied, 404 if endpoint doesn't exist
|
||||
if response.status_code == 403:
|
||||
data = response.json()
|
||||
assert data["error_code"] in ["ADMIN_REQUIRED", "INSUFFICIENT_PERMISSIONS"]
|
||||
assert data["status_code"] == 403
|
||||
|
||||
def test_business_logic_violation_max_vendors(self, client, auth_headers, monkeypatch):
|
||||
def test_business_logic_violation_max_vendors(
|
||||
self, client, auth_headers, monkeypatch
|
||||
):
|
||||
"""Test business logic violation - creating too many vendors"""
|
||||
# This test would require mocking the vendor limit check
|
||||
# For now, test the error structure when creating multiple vendors
|
||||
vendors_created = []
|
||||
for i in range(6): # Assume limit is 5
|
||||
vendor_data = {
|
||||
"vendor_code": f"VENDOR{i:03d}",
|
||||
"name": f"Test Vendor {i}"
|
||||
}
|
||||
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
|
||||
vendor_data = {"vendor_code": f"VENDOR{i:03d}", "name": f"Test Vendor {i}"}
|
||||
response = client.post(
|
||||
"/api/v1/vendor", headers=auth_headers, json=vendor_data
|
||||
)
|
||||
vendors_created.append(response)
|
||||
|
||||
# At least one should succeed, and if limit is enforced, later ones should fail
|
||||
@@ -193,10 +204,12 @@ class TestErrorHandling:
|
||||
product_data = {
|
||||
"marketplace_product_id": "TESTPROD001",
|
||||
"title": "Test MarketplaceProduct",
|
||||
"gtin": "invalid_gtin_format"
|
||||
"gtin": "invalid_gtin_format",
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/marketplace/product", headers=auth_headers, json=product_data)
|
||||
response = client.post(
|
||||
"/api/v1/marketplace/product", headers=auth_headers, json=product_data
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
@@ -204,13 +217,15 @@ class TestErrorHandling:
|
||||
assert data["status_code"] == 422
|
||||
assert data["details"]["field"] == "gtin"
|
||||
|
||||
def test_inventory_insufficient_quantity(self, client, auth_headers, test_vendor, test_marketplace_product):
|
||||
def test_inventory_insufficient_quantity(
|
||||
self, client, auth_headers, test_vendor, test_marketplace_product
|
||||
):
|
||||
"""Test business logic error for insufficient inventory"""
|
||||
# First create some inventory
|
||||
inventory_data = {
|
||||
"gtin": test_marketplace_product.gtin,
|
||||
"location": "WAREHOUSE_A",
|
||||
"quantity": 5
|
||||
"quantity": 5,
|
||||
}
|
||||
client.post("/api/v1/inventory", headers=auth_headers, json=inventory_data)
|
||||
|
||||
@@ -218,9 +233,11 @@ class TestErrorHandling:
|
||||
remove_data = {
|
||||
"gtin": test_marketplace_product.gtin,
|
||||
"location": "WAREHOUSE_A",
|
||||
"quantity": 10 # More than the 5 we added
|
||||
"quantity": 10, # More than the 5 we added
|
||||
}
|
||||
response = client.post("/api/v1/inventory/remove", headers=auth_headers, json=remove_data)
|
||||
response = client.post(
|
||||
"/api/v1/inventory/remove", headers=auth_headers, json=remove_data
|
||||
)
|
||||
|
||||
# This should ALWAYS fail with insufficient inventory error
|
||||
assert response.status_code == 400
|
||||
@@ -257,7 +274,7 @@ class TestErrorHandling:
|
||||
response = client.post(
|
||||
"/api/v1/vendor",
|
||||
headers=headers,
|
||||
content="<vendor ><code>TEST</code></vendor >"
|
||||
content="<vendor ><code>TEST</code></vendor >",
|
||||
)
|
||||
|
||||
assert response.status_code in [400, 415, 422]
|
||||
@@ -268,7 +285,7 @@ class TestErrorHandling:
|
||||
vendor_data = {
|
||||
"vendor_code": "LARGEVENDOR",
|
||||
"name": "Large Vendor",
|
||||
"description": large_description
|
||||
"description": large_description,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/vendor", headers=auth_headers, json=vendor_data)
|
||||
@@ -310,10 +327,12 @@ class TestErrorHandling:
|
||||
# Test invalid marketplace
|
||||
import_data = {
|
||||
"marketplace": "INVALID_MARKETPLACE",
|
||||
"vendor_code": test_vendor.vendor_code
|
||||
"vendor_code": test_vendor.vendor_code,
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/imports", headers=auth_headers, json=import_data)
|
||||
response = client.post(
|
||||
"/api/v1/imports", headers=auth_headers, json=import_data
|
||||
)
|
||||
|
||||
if response.status_code == 422:
|
||||
data = response.json()
|
||||
@@ -329,10 +348,12 @@ class TestErrorHandling:
|
||||
# Test with potentially problematic external data
|
||||
import_data = {
|
||||
"marketplace": "LETZSHOP",
|
||||
"external_url": "https://nonexistent-marketplace.com/api"
|
||||
"external_url": "https://nonexistent-marketplace.com/api",
|
||||
}
|
||||
|
||||
response = client.post("/api/v1/imports", headers=auth_headers, json=import_data)
|
||||
response = client.post(
|
||||
"/api/v1/imports", headers=auth_headers, json=import_data
|
||||
)
|
||||
|
||||
# If it's a real external service error, check structure
|
||||
if response.status_code == 502:
|
||||
@@ -356,7 +377,9 @@ class TestErrorHandling:
|
||||
# All error responses should have these fields
|
||||
required_fields = ["error_code", "message", "status_code"]
|
||||
for field in required_fields:
|
||||
assert field in data, f"Missing {field} in error response for {endpoint}"
|
||||
assert (
|
||||
field in data
|
||||
), f"Missing {field} in error response for {endpoint}"
|
||||
|
||||
# Details field should be present (can be empty dict)
|
||||
assert "details" in data
|
||||
@@ -420,5 +443,7 @@ class TestErrorRecovery:
|
||||
|
||||
# Check that error was logged (if your app logs 404s as errors)
|
||||
# Adjust based on your logging configuration
|
||||
error_logs = [record for record in caplog.records if record.levelno >= logging.ERROR]
|
||||
error_logs = [
|
||||
record for record in caplog.records if record.levelno >= logging.ERROR
|
||||
]
|
||||
# May or may not have logs depending on whether 404s are logged as errors
|
||||
|
||||
@@ -12,21 +12,18 @@ Tests cover:
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from fastapi import HTTPException
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from jose import jwt
|
||||
|
||||
from app.exceptions import (AdminRequiredException,
|
||||
InsufficientPermissionsException,
|
||||
InvalidCredentialsException, InvalidTokenException,
|
||||
TokenExpiredException, UserNotActiveException)
|
||||
from middleware.auth import AuthManager
|
||||
from app.exceptions import (
|
||||
InvalidTokenException,
|
||||
TokenExpiredException,
|
||||
UserNotActiveException,
|
||||
InvalidCredentialsException,
|
||||
AdminRequiredException,
|
||||
InsufficientPermissionsException,
|
||||
)
|
||||
from models.database.user import User
|
||||
|
||||
|
||||
@@ -124,7 +121,9 @@ class TestUserAuthentication:
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
result = auth_manager.authenticate_user(
|
||||
mock_db, "test@example.com", "password123"
|
||||
)
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
@@ -192,7 +191,9 @@ class TestJWTTokenCreation:
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Decode without verification to check payload
|
||||
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
|
||||
payload = jwt.decode(
|
||||
token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
|
||||
)
|
||||
|
||||
assert payload["sub"] == "42"
|
||||
assert payload["username"] == "testuser"
|
||||
@@ -205,8 +206,12 @@ class TestJWTTokenCreation:
|
||||
"""Test tokens are different for different users."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
|
||||
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
|
||||
user1 = Mock(
|
||||
spec=User, id=1, username="user1", email="user1@test.com", role="customer"
|
||||
)
|
||||
user2 = Mock(
|
||||
spec=User, id=2, username="user2", email="user2@test.com", role="vendor"
|
||||
)
|
||||
|
||||
token1 = auth_manager.create_access_token(user1)["access_token"]
|
||||
token2 = auth_manager.create_access_token(user2)["access_token"]
|
||||
@@ -227,7 +232,7 @@ class TestJWTTokenCreation:
|
||||
payload = jwt.decode(
|
||||
token_data["access_token"],
|
||||
auth_manager.secret_key,
|
||||
algorithms=[auth_manager.algorithm]
|
||||
algorithms=[auth_manager.algorithm],
|
||||
)
|
||||
|
||||
assert payload["role"] == "admin"
|
||||
@@ -311,9 +316,11 @@ class TestJWTTokenVerification:
|
||||
# Create token without 'sub' field
|
||||
payload = {
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
@@ -325,11 +332,10 @@ class TestJWTTokenVerification:
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# Create token without 'exp' field
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser"
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
payload = {"sub": "1", "username": "testuser"}
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
@@ -343,7 +349,7 @@ class TestJWTTokenVerification:
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
# Create token with different algorithm
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
|
||||
@@ -357,15 +363,13 @@ class TestJWTTokenVerification:
|
||||
|
||||
# Create a token with expiration in the past
|
||||
past_time = datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"exp": past_time.timestamp()
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()}
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
# Mock jwt.decode to bypass its expiration check and test line 205
|
||||
with patch('middleware.auth.jwt.decode') as mock_decode:
|
||||
with patch("middleware.auth.jwt.decode") as mock_decode:
|
||||
mock_decode.return_value = payload
|
||||
|
||||
with pytest.raises(TokenExpiredException):
|
||||
@@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser:
|
||||
|
||||
# Existing admin user
|
||||
existing_admin = Mock(spec=User)
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = (
|
||||
existing_admin
|
||||
)
|
||||
|
||||
result = auth_manager.create_default_admin_user(mock_db)
|
||||
|
||||
@@ -599,19 +605,21 @@ class TestAuthManagerConfiguration:
|
||||
|
||||
def test_default_configuration(self):
|
||||
"""Test AuthManager uses default configuration."""
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.algorithm == "HS256"
|
||||
assert auth_manager.token_expire_minutes == 30
|
||||
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
|
||||
assert (
|
||||
auth_manager.secret_key == "your-secret-key-change-in-production-please"
|
||||
)
|
||||
|
||||
def test_custom_configuration(self):
|
||||
"""Test AuthManager uses environment variables."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'custom-secret-key',
|
||||
'JWT_EXPIRE_MINUTES': '60'
|
||||
}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"},
|
||||
):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.secret_key == "custom-secret-key"
|
||||
@@ -619,9 +627,7 @@ class TestAuthManagerConfiguration:
|
||||
|
||||
def test_partial_custom_configuration(self):
|
||||
"""Test AuthManager with partial environment configuration."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_EXPIRE_MINUTES': '120'
|
||||
}, clear=False):
|
||||
with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.token_expire_minutes == 120
|
||||
@@ -656,9 +662,11 @@ class TestEdgeCases:
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=2),
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
# Should still verify successfully (JWT doesn't validate iat by default)
|
||||
result = auth_manager.verify_token(token)
|
||||
@@ -698,7 +706,9 @@ class TestEdgeCases:
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Mock jose.jwt.decode to raise an unexpected exception
|
||||
with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")):
|
||||
with patch(
|
||||
"middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error")
|
||||
):
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
|
||||
|
||||
@@ -10,16 +10,13 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.context import (
|
||||
ContextManager,
|
||||
ContextMiddleware,
|
||||
RequestContext,
|
||||
get_request_context,
|
||||
)
|
||||
from middleware.context import (ContextManager, ContextMiddleware,
|
||||
RequestContext, get_request_context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -321,22 +318,38 @@ class TestContextManagerHelpers:
|
||||
def test_is_admin_context_from_subdomain(self):
|
||||
"""Test _is_admin_context with admin subdomain."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(
|
||||
request, "admin.platform.com", "/dashboard"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_admin_context_from_path(self):
|
||||
"""Test _is_admin_context with admin path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(request, "localhost", "/admin/users")
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_admin_context_both(self):
|
||||
"""Test _is_admin_context with both subdomain and path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(
|
||||
request, "admin.platform.com", "/admin/users"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_not_admin_context(self):
|
||||
"""Test _is_admin_context returns False for non-admin."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
|
||||
assert (
|
||||
ContextManager._is_admin_context(request, "vendor.platform.com", "/shop")
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context with /vendor/ path."""
|
||||
@@ -344,11 +357,16 @@ class TestContextManagerHelpers:
|
||||
|
||||
def test_is_vendor_dashboard_context_nested(self):
|
||||
"""Test _is_vendor_dashboard_context with nested vendor path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
|
||||
assert (
|
||||
ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
|
||||
)
|
||||
|
||||
def test_is_not_vendor_dashboard_context_vendors_plural(self):
|
||||
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
|
||||
assert (
|
||||
ContextManager._is_vendor_dashboard_context("/vendors/shop123/products")
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_not_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
|
||||
@@ -373,7 +391,7 @@ class TestContextMiddleware:
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'context_type')
|
||||
assert hasattr(request.state, "context_type")
|
||||
assert request.state.context_type == RequestContext.API
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@@ -565,7 +583,7 @@ class TestEdgeCases:
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.headers = {"host": "localhost"}
|
||||
# No state attribute at all
|
||||
delattr(request, 'state')
|
||||
delattr(request, "state")
|
||||
|
||||
# Should still work, falling back to url.path
|
||||
with pytest.raises(AttributeError):
|
||||
|
||||
@@ -12,16 +12,18 @@ Tests cover:
|
||||
- Edge cases and isolation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from middleware.decorators import rate_limit, rate_limiter
|
||||
from app.exceptions.base import RateLimitException
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions.base import RateLimitException
|
||||
from middleware.decorators import rate_limit, rate_limiter
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limiter():
|
||||
"""Reset rate limiter state before each test to ensure isolation."""
|
||||
@@ -34,6 +36,7 @@ def reset_rate_limiter():
|
||||
# Rate Limit Decorator Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimitDecorator:
|
||||
@@ -42,6 +45,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_allows_within_limit(self):
|
||||
"""Test decorator allows requests within rate limit."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
@@ -53,6 +57,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_blocks_exceeding_limit(self):
|
||||
"""Test decorator blocks requests exceeding rate limit."""
|
||||
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint_blocked():
|
||||
return {"status": "ok"}
|
||||
@@ -71,6 +76,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_preserves_function_metadata(self):
|
||||
"""Test decorator preserves original function metadata."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
"""Test endpoint docstring."""
|
||||
@@ -82,21 +88,19 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_args_and_kwargs(self):
|
||||
"""Test decorator works with function arguments."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint(arg1, arg2, kwarg1=None):
|
||||
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
|
||||
|
||||
result = await test_endpoint("value1", "value2", kwarg1="value3")
|
||||
|
||||
assert result == {
|
||||
"arg1": "value1",
|
||||
"arg2": "value2",
|
||||
"kwarg1": "value3"
|
||||
}
|
||||
assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_default_parameters(self):
|
||||
"""Test decorator uses default parameters."""
|
||||
|
||||
@rate_limit() # Use defaults
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
@@ -108,6 +112,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_exception_includes_retry_after(self):
|
||||
"""Test rate limit exception includes retry_after."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=60)
|
||||
async def test_endpoint_retry():
|
||||
return {"status": "ok"}
|
||||
@@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_zero_max_requests(self):
|
||||
"""Test decorator with max_requests=0 blocks all requests."""
|
||||
|
||||
@rate_limit(max_requests=0, window_seconds=3600)
|
||||
async def test_endpoint_zero():
|
||||
return {"status": "ok"}
|
||||
@@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_very_short_window(self):
|
||||
"""Test decorator with very short time window."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=1)
|
||||
async def test_endpoint_short():
|
||||
return {"status": "ok"}
|
||||
@@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_multiple_functions_separate_limits(self):
|
||||
"""Test that different functions have separate rate limits."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=3600)
|
||||
async def endpoint1():
|
||||
return {"endpoint": "1"}
|
||||
@@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_exception_in_function(self):
|
||||
"""Test decorator handles exceptions from wrapped function."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint_error():
|
||||
raise ValueError("Function error")
|
||||
@@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_isolation_between_tests(self):
|
||||
"""Test that rate limiter state is properly isolated between tests."""
|
||||
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint_isolation():
|
||||
return {"status": "ok"}
|
||||
@@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_dict(self):
|
||||
"""Test decorator correctly returns dictionary."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_dict():
|
||||
return {"key": "value", "number": 42}
|
||||
@@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_list(self):
|
||||
"""Test decorator correctly returns list."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_list():
|
||||
return [1, 2, 3, 4, 5]
|
||||
@@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_none(self):
|
||||
"""Test decorator correctly returns None."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_none():
|
||||
return None
|
||||
@@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_object(self):
|
||||
"""Test decorator correctly returns custom objects."""
|
||||
|
||||
class TestObject:
|
||||
def __init__(self):
|
||||
self.name = "test_object"
|
||||
|
||||
@@ -11,9 +11,10 @@ Tests cover:
|
||||
- Edge cases (missing client info, etc.)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.logging import LoggingMiddleware
|
||||
@@ -40,7 +41,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify request was logged
|
||||
@@ -65,7 +66,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify response was logged
|
||||
@@ -89,7 +90,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert "X-Process-Time" in response.headers
|
||||
@@ -113,11 +114,13 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should log "unknown" for client
|
||||
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
"unknown" in str(call) for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_exceptions(self):
|
||||
@@ -131,8 +134,9 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger, \
|
||||
pytest.raises(Exception):
|
||||
with patch("middleware.logging.logger") as mock_logger, pytest.raises(
|
||||
Exception
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify error was logged
|
||||
@@ -156,7 +160,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = slow_call_next
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
process_time = float(result.headers["X-Process-Time"])
|
||||
@@ -181,7 +185,7 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=200, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should still have process time, even if very small
|
||||
@@ -205,11 +209,13 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=200, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify method was logged
|
||||
assert any(method in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
method in str(call) for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_different_status_codes(self):
|
||||
@@ -227,8 +233,11 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=status_code, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify status code was logged
|
||||
assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
str(status_code) in str(call)
|
||||
for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@@ -11,10 +11,11 @@ Tests cover:
|
||||
- Edge cases and concurrency scenarios
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from middleware.rate_limiter import RateLimiter
|
||||
|
||||
@@ -306,8 +307,8 @@ class TestRateLimiterStatistics:
|
||||
# Add requests at different times
|
||||
now = datetime.now(timezone.utc)
|
||||
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
|
||||
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
|
||||
|
||||
stats = limiter.get_client_stats(client_id)
|
||||
|
||||
@@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases:
|
||||
limiter = RateLimiter()
|
||||
client_id = "long_window_client"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
|
||||
result = limiter.allow_request(
|
||||
client_id, max_requests=10, window_seconds=86400 * 365
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases:
|
||||
client_id = "same_client"
|
||||
|
||||
# Allow with one limit
|
||||
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||
is True
|
||||
)
|
||||
|
||||
# Check with stricter limit
|
||||
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=1, window_seconds=3600)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_rate_limiter_unicode_client_id(self):
|
||||
"""Test rate limiter with unicode client ID."""
|
||||
|
||||
@@ -11,15 +11,14 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.theme_context import (
|
||||
ThemeContextManager,
|
||||
ThemeContextMiddleware,
|
||||
get_current_theme,
|
||||
)
|
||||
from middleware.theme_context import (ThemeContextManager,
|
||||
ThemeContextMiddleware,
|
||||
get_current_theme)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -42,7 +41,14 @@ class TestThemeContextManager:
|
||||
"""Test default theme has all required colors."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
|
||||
required_colors = [
|
||||
"primary",
|
||||
"secondary",
|
||||
"accent",
|
||||
"background",
|
||||
"text",
|
||||
"border",
|
||||
]
|
||||
for color in required_colors:
|
||||
assert color in theme["colors"]
|
||||
assert theme["colors"][color].startswith("#")
|
||||
@@ -79,10 +85,7 @@ class TestThemeContextManager:
|
||||
mock_theme = Mock()
|
||||
|
||||
# Mock to_dict to return actual dictionary
|
||||
custom_theme_dict = {
|
||||
"theme_name": "custom",
|
||||
"colors": {"primary": "#ff0000"}
|
||||
}
|
||||
custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}}
|
||||
mock_theme.to_dict.return_value = custom_theme_dict
|
||||
|
||||
# Correct filter chain: query().filter().first()
|
||||
@@ -141,8 +144,11 @@ class TestThemeContextMiddleware:
|
||||
mock_db = MagicMock()
|
||||
mock_theme = {"theme_name": "test_theme"}
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
|
||||
with patch(
|
||||
"middleware.theme_context.get_db", return_value=iter([mock_db])
|
||||
), patch.object(
|
||||
ThemeContextManager, "get_vendor_theme", return_value=mock_theme
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -161,7 +167,7 @@ class TestThemeContextMiddleware:
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'theme')
|
||||
assert hasattr(request.state, "theme")
|
||||
assert request.state.theme["theme_name"] == "default"
|
||||
call_next.assert_called_once()
|
||||
|
||||
@@ -178,8 +184,11 @@ class TestThemeContextMiddleware:
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
|
||||
with patch(
|
||||
"middleware.theme_context.get_db", return_value=iter([mock_db])
|
||||
), patch.object(
|
||||
ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error")
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -224,7 +233,7 @@ class TestThemeEdgeCases:
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
|
||||
with patch("middleware.theme_context.get_db", return_value=iter([mock_db])):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify database was closed
|
||||
|
||||
@@ -11,17 +11,16 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.vendor_context import (
|
||||
VendorContextManager,
|
||||
VendorContextMiddleware,
|
||||
get_current_vendor,
|
||||
require_vendor_context,
|
||||
)
|
||||
from middleware.vendor_context import (VendorContextManager,
|
||||
VendorContextMiddleware,
|
||||
get_current_vendor,
|
||||
require_vendor_context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -39,7 +38,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "customdomain1.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -55,7 +54,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "customdomain1.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -71,7 +70,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -87,7 +86,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "vendor1.platform.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -140,7 +139,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -153,7 +152,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "www.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -166,7 +165,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "api.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -179,7 +178,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -198,12 +197,11 @@ class TestVendorContextManager:
|
||||
mock_vendor.is_active = True
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor_domain
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -218,12 +216,11 @@ class TestVendorContextManager:
|
||||
mock_vendor.is_active = False
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor_domain
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -232,12 +229,11 @@ class TestVendorContextManager:
|
||||
def test_get_vendor_from_custom_domain_not_found(self):
|
||||
"""Test custom domain not found in database."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "nonexistent.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "nonexistent.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -249,12 +245,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -266,12 +261,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "path",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
context = {"detection_method": "path", "subdomain": "vendor1"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -291,12 +285,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "VENDOR1" # Uppercase
|
||||
}
|
||||
context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -311,10 +304,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -325,10 +315,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendors/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendors/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -339,10 +326,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -353,10 +337,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -425,21 +406,24 @@ class TestVendorContextManager:
|
||||
# Static File Detection Tests
|
||||
# ========================================================================
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/static/css/style.css",
|
||||
"/static/js/app.js",
|
||||
"/media/images/product.png",
|
||||
"/assets/logo.svg",
|
||||
"/.well-known/security.txt",
|
||||
"/favicon.ico",
|
||||
"/image.jpg",
|
||||
"/style.css",
|
||||
"/app.webmanifest",
|
||||
"/static/", # Path starting with /static/ but no extension
|
||||
"/media/uploads", # Path starting with /media/ but no extension
|
||||
"/subfolder/favicon.ico", # favicon.ico in subfolder
|
||||
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/static/css/style.css",
|
||||
"/static/js/app.js",
|
||||
"/media/images/product.png",
|
||||
"/assets/logo.svg",
|
||||
"/.well-known/security.txt",
|
||||
"/favicon.ico",
|
||||
"/image.jpg",
|
||||
"/style.css",
|
||||
"/app.webmanifest",
|
||||
"/static/", # Path starting with /static/ but no extension
|
||||
"/media/uploads", # Path starting with /media/ but no extension
|
||||
"/subfolder/favicon.ico", # favicon.ico in subfolder
|
||||
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
|
||||
],
|
||||
)
|
||||
def test_is_static_file_request(self, path):
|
||||
"""Test static file detection for various paths and extensions."""
|
||||
request = Mock(spec=Request)
|
||||
@@ -447,12 +431,15 @@ class TestVendorContextManager:
|
||||
|
||||
assert VendorContextManager.is_static_file_request(request) is True
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/shop/products",
|
||||
"/admin/dashboard",
|
||||
"/api/vendors",
|
||||
"/about",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/shop/products",
|
||||
"/admin/dashboard",
|
||||
"/api/vendors",
|
||||
"/about",
|
||||
],
|
||||
)
|
||||
def test_is_not_static_file_request(self, path):
|
||||
"""Test non-static file paths."""
|
||||
request = Mock(spec=Request)
|
||||
@@ -478,7 +465,7 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
|
||||
with patch.object(VendorContextManager, "is_admin_request", return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -498,7 +485,7 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
|
||||
with patch.object(VendorContextManager, "is_api_request", return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -517,7 +504,9 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
|
||||
with patch.object(
|
||||
VendorContextManager, "is_static_file_request", return_value=True
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -540,17 +529,19 @@ class TestVendorContextMiddleware:
|
||||
mock_vendor.name = "Test Vendor"
|
||||
mock_vendor.subdomain = "vendor1"
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
|
||||
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=vendor_context
|
||||
), patch.object(
|
||||
VendorContextManager, "get_vendor_from_context", return_value=mock_vendor
|
||||
), patch.object(
|
||||
VendorContextManager, "extract_clean_path", return_value="/shop/products"
|
||||
), patch(
|
||||
"middleware.vendor_context.get_db", return_value=iter([mock_db])
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -571,16 +562,17 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "nonexistent"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=vendor_context
|
||||
), patch.object(
|
||||
VendorContextManager, "get_vendor_from_context", return_value=None
|
||||
), patch(
|
||||
"middleware.vendor_context.get_db", return_value=iter([mock_db])
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -601,7 +593,9 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=None
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -714,7 +708,7 @@ class TestEdgeCases:
|
||||
request.headers = {"host": "shop.vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -735,11 +729,14 @@ class TestEdgeCases:
|
||||
|
||||
context = {"subdomain": "nonexistent", "detection_method": "subdomain"}
|
||||
|
||||
with patch('middleware.vendor_context.logger') as mock_logger:
|
||||
with patch("middleware.vendor_context.logger") as mock_logger:
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is None
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called()
|
||||
warning_message = str(mock_logger.warning.call_args)
|
||||
assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message
|
||||
assert (
|
||||
"No active vendor found for subdomain" in warning_message
|
||||
and "nonexistent" in warning_message
|
||||
)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# tests/unit/models/test_database_models.py
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.vendor import Vendor, VendorUser, Role
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.user import User
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.product import Product
|
||||
from models.database.customer import Customer, CustomerAddress
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.order import Order, OrderItem
|
||||
from models.database.product import Product
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Role, Vendor, VendorUser
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -277,7 +278,7 @@ class TestMarketplaceProductModel:
|
||||
vendor_id=test_vendor.id,
|
||||
marketplace_product_id="UNIQUE_001",
|
||||
title="Product 1",
|
||||
marketplace="Letzshop"
|
||||
marketplace="Letzshop",
|
||||
)
|
||||
db.add(product1)
|
||||
db.commit()
|
||||
@@ -288,7 +289,7 @@ class TestMarketplaceProductModel:
|
||||
vendor_id=test_vendor.id,
|
||||
marketplace_product_id="UNIQUE_001",
|
||||
title="Product 2",
|
||||
marketplace="Letzshop"
|
||||
marketplace="Letzshop",
|
||||
)
|
||||
db.add(product2)
|
||||
db.commit()
|
||||
@@ -515,7 +516,9 @@ class TestCustomerModel:
|
||||
class TestOrderModel:
|
||||
"""Test Order model"""
|
||||
|
||||
def test_order_creation(self, db, test_vendor, test_customer, test_customer_address):
|
||||
def test_order_creation(
|
||||
self, db, test_vendor, test_customer, test_customer_address
|
||||
):
|
||||
"""Test Order model with customer relationship"""
|
||||
order = Order(
|
||||
vendor_id=test_vendor.id,
|
||||
@@ -563,7 +566,9 @@ class TestOrderModel:
|
||||
assert float(order_item.unit_price) == 49.99
|
||||
assert float(order_item.total_price) == 99.98
|
||||
|
||||
def test_order_number_uniqueness(self, db, test_vendor, test_customer, test_customer_address):
|
||||
def test_order_number_uniqueness(
|
||||
self, db, test_vendor, test_customer, test_customer_address
|
||||
):
|
||||
"""Test order_number unique constraint"""
|
||||
order1 = Order(
|
||||
vendor_id=test_vendor.id,
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
# tests/unit/services/test_admin_service.py
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (
|
||||
UserNotFoundException,
|
||||
UserStatusChangeException,
|
||||
CannotModifySelfException,
|
||||
VendorNotFoundException,
|
||||
VendorVerificationException,
|
||||
AdminOperationException,
|
||||
)
|
||||
from app.exceptions import (AdminOperationException, CannotModifySelfException,
|
||||
UserNotFoundException, UserStatusChangeException,
|
||||
VendorNotFoundException,
|
||||
VendorVerificationException)
|
||||
from app.services.admin_service import AdminService
|
||||
from app.services.stats_service import stats_service
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
@@ -85,7 +81,9 @@ class TestAdminService:
|
||||
assert exception.error_code == "CANNOT_MODIFY_SELF"
|
||||
assert "deactivate account" in exception.message
|
||||
|
||||
def test_toggle_user_status_cannot_modify_admin(self, db, test_admin, another_admin):
|
||||
def test_toggle_user_status_cannot_modify_admin(
|
||||
self, db, test_admin, another_admin
|
||||
):
|
||||
"""Test that admin cannot modify another admin"""
|
||||
with pytest.raises(UserStatusChangeException) as exc_info:
|
||||
self.service.toggle_user_status(db, another_admin.id, test_admin.id)
|
||||
@@ -148,7 +146,7 @@ class TestAdminService:
|
||||
assert "99999" in exception.message
|
||||
|
||||
def test_toggle_vendor_status_deactivate(self, db, test_vendor):
|
||||
"""Test deactivating a vendor """
|
||||
"""Test deactivating a vendor"""
|
||||
original_status = test_vendor.is_active
|
||||
|
||||
vendor, message = self.service.toggle_vendor_status(db, test_vendor.id)
|
||||
@@ -170,21 +168,26 @@ class TestAdminService:
|
||||
assert exception.error_code == "VENDOR_NOT_FOUND"
|
||||
|
||||
# Marketplace Import Jobs Tests
|
||||
def test_get_marketplace_import_jobs_no_filters(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_no_filters(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test getting marketplace import jobs without filters"""
|
||||
result = self.service.get_marketplace_import_jobs(db, skip=0, limit=10)
|
||||
|
||||
assert len(result) >= 1
|
||||
# Find our test job in the results
|
||||
test_job = next(
|
||||
(job for job in result if job.job_id == test_marketplace_import_job.id), None
|
||||
(job for job in result if job.job_id == test_marketplace_import_job.id),
|
||||
None,
|
||||
)
|
||||
assert test_job is not None
|
||||
assert test_job.marketplace == test_marketplace_import_job.marketplace
|
||||
assert test_job.vendor_name == test_marketplace_import_job.name
|
||||
assert test_job.status == test_marketplace_import_job.status
|
||||
|
||||
def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_with_marketplace_filter(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test filtering marketplace import jobs by marketplace"""
|
||||
result = self.service.get_marketplace_import_jobs(
|
||||
db, marketplace=test_marketplace_import_job.marketplace, skip=0, limit=10
|
||||
@@ -192,9 +195,14 @@ class TestAdminService:
|
||||
|
||||
assert len(result) >= 1
|
||||
for job in result:
|
||||
assert test_marketplace_import_job.marketplace.lower() in job.marketplace.lower()
|
||||
assert (
|
||||
test_marketplace_import_job.marketplace.lower()
|
||||
in job.marketplace.lower()
|
||||
)
|
||||
|
||||
def test_get_marketplace_import_jobs_with_vendor_filter(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_with_vendor_filter(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test filtering marketplace import jobs by vendor name"""
|
||||
result = self.service.get_marketplace_import_jobs(
|
||||
db, vendor_name=test_marketplace_import_job.name, skip=0, limit=10
|
||||
@@ -204,7 +212,9 @@ class TestAdminService:
|
||||
for job in result:
|
||||
assert test_marketplace_import_job.name.lower() in job.vendor_name.lower()
|
||||
|
||||
def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_with_status_filter(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test filtering marketplace import jobs by status"""
|
||||
result = self.service.get_marketplace_import_jobs(
|
||||
db, status=test_marketplace_import_job.status, skip=0, limit=10
|
||||
@@ -214,7 +224,9 @@ class TestAdminService:
|
||||
for job in result:
|
||||
assert job.status == test_marketplace_import_job.status
|
||||
|
||||
def test_get_marketplace_import_jobs_pagination(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_pagination(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test marketplace import jobs pagination"""
|
||||
result_page1 = self.service.get_marketplace_import_jobs(db, skip=0, limit=1)
|
||||
result_page2 = self.service.get_marketplace_import_jobs(db, skip=1, limit=1)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# tests/test_auth_service.py
|
||||
import pytest
|
||||
|
||||
from app.exceptions.auth import (
|
||||
UserAlreadyExistsException,
|
||||
InvalidCredentialsException,
|
||||
UserNotActiveException,
|
||||
)
|
||||
from app.exceptions.auth import (InvalidCredentialsException,
|
||||
UserAlreadyExistsException,
|
||||
UserNotActiveException)
|
||||
from app.exceptions.base import ValidationException
|
||||
from app.services.auth_service import AuthService
|
||||
from models.schema.auth import UserLogin, UserRegister
|
||||
@@ -218,11 +216,14 @@ class TestAuthService:
|
||||
|
||||
def test_create_access_token_failure(self, test_user, monkeypatch):
|
||||
"""Test creating access token handles failures"""
|
||||
|
||||
# Mock the auth_manager to raise an exception
|
||||
def mock_create_token(*args, **kwargs):
|
||||
raise Exception("Token creation failed")
|
||||
|
||||
monkeypatch.setattr(self.service.auth_manager, "create_access_token", mock_create_token)
|
||||
monkeypatch.setattr(
|
||||
self.service.auth_manager, "create_access_token", mock_create_token
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationException) as exc_info:
|
||||
self.service.create_access_token(test_user)
|
||||
@@ -250,11 +251,14 @@ class TestAuthService:
|
||||
|
||||
def test_hash_password_failure(self, monkeypatch):
|
||||
"""Test password hashing handles failures"""
|
||||
|
||||
# Mock the auth_manager to raise an exception
|
||||
def mock_hash_password(*args, **kwargs):
|
||||
raise Exception("Hashing failed")
|
||||
|
||||
monkeypatch.setattr(self.service.auth_manager, "hash_password", mock_hash_password)
|
||||
monkeypatch.setattr(
|
||||
self.service.auth_manager, "hash_password", mock_hash_password
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationException) as exc_info:
|
||||
self.service.hash_password("testpassword")
|
||||
@@ -267,9 +271,7 @@ class TestAuthService:
|
||||
def test_register_user_database_error(self, db_with_error):
|
||||
"""Test user registration handles database errors"""
|
||||
user_data = UserRegister(
|
||||
email="test@example.com",
|
||||
username="testuser",
|
||||
password="password123"
|
||||
email="test@example.com", username="testuser", password="password123"
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationException) as exc_info:
|
||||
|
||||
@@ -3,19 +3,17 @@ import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (InsufficientInventoryException,
|
||||
InvalidInventoryOperationException,
|
||||
InvalidQuantityException,
|
||||
InventoryNotFoundException,
|
||||
InventoryValidationException,
|
||||
NegativeInventoryException, ValidationException)
|
||||
from app.services.inventory_service import InventoryService
|
||||
from app.exceptions import (
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InvalidInventoryOperationException,
|
||||
InventoryValidationException,
|
||||
NegativeInventoryException,
|
||||
InvalidQuantityException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.schema.inventory import InventoryAdd, InventoryCreate, InventoryUpdate
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.schema.inventory import (InventoryAdd, InventoryCreate,
|
||||
InventoryUpdate)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -40,10 +38,14 @@ class TestInventoryService:
|
||||
def test_normalize_gtin_valid(self):
|
||||
"""Test GTIN normalization with valid GTINs."""
|
||||
# Test various valid GTIN formats - these should remain unchanged
|
||||
assert self.service._normalize_gtin("1234567890123") == "1234567890123" # EAN-13
|
||||
assert (
|
||||
self.service._normalize_gtin("1234567890123") == "1234567890123"
|
||||
) # EAN-13
|
||||
assert self.service._normalize_gtin("123456789012") == "123456789012" # UPC-A
|
||||
assert self.service._normalize_gtin("12345678") == "12345678" # EAN-8
|
||||
assert self.service._normalize_gtin("12345678901234") == "12345678901234" # GTIN-14
|
||||
assert (
|
||||
self.service._normalize_gtin("12345678901234") == "12345678901234"
|
||||
) # GTIN-14
|
||||
|
||||
# Test with decimal points (should be removed)
|
||||
assert self.service._normalize_gtin("1234567890123.0") == "1234567890123"
|
||||
@@ -52,11 +54,17 @@ class TestInventoryService:
|
||||
assert self.service._normalize_gtin(" 1234567890123 ") == "1234567890123"
|
||||
|
||||
# Test short GTINs being padded
|
||||
assert self.service._normalize_gtin("123") == "0000000000123" # Padded to EAN-13
|
||||
assert self.service._normalize_gtin("12345") == "0000000012345" # Padded to EAN-13
|
||||
assert (
|
||||
self.service._normalize_gtin("123") == "0000000000123"
|
||||
) # Padded to EAN-13
|
||||
assert (
|
||||
self.service._normalize_gtin("12345") == "0000000012345"
|
||||
) # Padded to EAN-13
|
||||
|
||||
# Test long GTINs being truncated
|
||||
assert self.service._normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13
|
||||
assert (
|
||||
self.service._normalize_gtin("123456789012345") == "3456789012345"
|
||||
) # Truncated to 13
|
||||
|
||||
def test_normalize_gtin_edge_cases(self):
|
||||
"""Test GTIN normalization edge cases."""
|
||||
@@ -65,9 +73,15 @@ class TestInventoryService:
|
||||
assert self.service._normalize_gtin(123) == "0000000000123"
|
||||
|
||||
# Test mixed valid/invalid characters
|
||||
assert self.service._normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed
|
||||
assert self.service._normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed
|
||||
assert self.service._normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed
|
||||
assert (
|
||||
self.service._normalize_gtin("123-456-789-012") == "123456789012"
|
||||
) # Dashes removed
|
||||
assert (
|
||||
self.service._normalize_gtin("123 456 789 012") == "123456789012"
|
||||
) # Spaces removed
|
||||
assert (
|
||||
self.service._normalize_gtin("ABC123456789012DEF") == "123456789012"
|
||||
) # Letters removed
|
||||
|
||||
def test_set_inventory_new_entry_success(self, db):
|
||||
"""Test setting inventory for a new GTIN/location combination successfully."""
|
||||
@@ -162,7 +176,9 @@ class TestInventoryService:
|
||||
|
||||
def test_add_inventory_invalid_gtin_validation_error(self, db):
|
||||
"""Test adding inventory with invalid GTIN returns InventoryValidationException."""
|
||||
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50)
|
||||
inventory_data = InventoryAdd(
|
||||
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50
|
||||
)
|
||||
|
||||
with pytest.raises(InventoryValidationException) as exc_info:
|
||||
self.service.add_inventory(db, inventory_data)
|
||||
@@ -180,11 +196,12 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INVALID_QUANTITY"
|
||||
assert "Quantity must be positive" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_remove_inventory_success(self, db, test_inventory):
|
||||
"""Test removing inventory successfully."""
|
||||
original_quantity = test_inventory.quantity
|
||||
remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available
|
||||
remove_quantity = min(
|
||||
10, original_quantity
|
||||
) # Ensure we don't remove more than available
|
||||
|
||||
inventory_data = InventoryAdd(
|
||||
gtin=test_inventory.gtin,
|
||||
@@ -212,7 +229,9 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
|
||||
assert exc_info.value.details["gtin"] == test_inventory.gtin
|
||||
assert exc_info.value.details["location"] == test_inventory.location
|
||||
assert exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
|
||||
assert (
|
||||
exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
|
||||
)
|
||||
assert exc_info.value.details["available_quantity"] == test_inventory.quantity
|
||||
|
||||
def test_remove_inventory_nonexistent_entry_not_found(self, db):
|
||||
@@ -231,7 +250,9 @@ class TestInventoryService:
|
||||
|
||||
def test_remove_inventory_invalid_gtin_validation_error(self, db):
|
||||
"""Test removing inventory with invalid GTIN returns InventoryValidationException."""
|
||||
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10)
|
||||
inventory_data = InventoryAdd(
|
||||
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10
|
||||
)
|
||||
|
||||
with pytest.raises(InventoryValidationException) as exc_info:
|
||||
self.service.remove_inventory(db, inventory_data)
|
||||
@@ -254,7 +275,9 @@ class TestInventoryService:
|
||||
# The service prevents negative inventory through InsufficientInventoryException
|
||||
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
|
||||
|
||||
def test_get_inventory_by_gtin_success(self, db, test_inventory, test_marketplace_product):
|
||||
def test_get_inventory_by_gtin_success(
|
||||
self, db, test_inventory, test_marketplace_product
|
||||
):
|
||||
"""Test getting inventory summary by GTIN successfully."""
|
||||
result = self.service.get_inventory_by_gtin(db, test_inventory.gtin)
|
||||
|
||||
@@ -265,14 +288,20 @@ class TestInventoryService:
|
||||
assert result.locations[0].quantity == test_inventory.quantity
|
||||
assert result.product_title == test_marketplace_product.title
|
||||
|
||||
def test_get_inventory_by_gtin_multiple_locations_success(self, db, test_marketplace_product):
|
||||
def test_get_inventory_by_gtin_multiple_locations_success(
|
||||
self, db, test_marketplace_product
|
||||
):
|
||||
"""Test getting inventory summary with multiple locations successfully."""
|
||||
unique_gtin = test_marketplace_product.gtin
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create multiple inventory entries for the same GTIN with unique locations
|
||||
inventory1 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50)
|
||||
inventory2 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30)
|
||||
inventory1 = Inventory(
|
||||
gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50
|
||||
)
|
||||
inventory2 = Inventory(
|
||||
gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30
|
||||
)
|
||||
|
||||
db.add(inventory1)
|
||||
db.add(inventory2)
|
||||
@@ -301,7 +330,9 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INVENTORY_VALIDATION_FAILED"
|
||||
assert "Invalid GTIN format" in str(exc_info.value)
|
||||
|
||||
def test_get_total_inventory_success(self, db, test_inventory, test_marketplace_product):
|
||||
def test_get_total_inventory_success(
|
||||
self, db, test_inventory, test_marketplace_product
|
||||
):
|
||||
"""Test getting total inventory for a GTIN successfully."""
|
||||
result = self.service.get_total_inventory(db, test_inventory.gtin)
|
||||
|
||||
@@ -364,7 +395,9 @@ class TestInventoryService:
|
||||
|
||||
result = self.service.get_all_inventory(db, skip=2, limit=2)
|
||||
|
||||
assert len(result) <= 2 # Should be at most 2, might be less if other records exist
|
||||
assert (
|
||||
len(result) <= 2
|
||||
) # Should be at most 2, might be less if other records exist
|
||||
|
||||
def test_update_inventory_success(self, db, test_inventory):
|
||||
"""Test updating inventory quantity successfully."""
|
||||
@@ -404,7 +437,9 @@ class TestInventoryService:
|
||||
assert result is True
|
||||
|
||||
# Verify the inventory is actually deleted
|
||||
deleted_inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
deleted_inventory = (
|
||||
db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
)
|
||||
assert deleted_inventory is None
|
||||
|
||||
def test_delete_inventory_not_found_error(self, db):
|
||||
@@ -415,7 +450,9 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INVENTORY_NOT_FOUND"
|
||||
assert "99999" in str(exc_info.value)
|
||||
|
||||
def test_get_low_inventory_items_success(self, db, test_inventory, test_marketplace_product):
|
||||
def test_get_low_inventory_items_success(
|
||||
self, db, test_inventory, test_marketplace_product
|
||||
):
|
||||
"""Test getting low inventory items successfully."""
|
||||
# Set inventory to a low value
|
||||
test_inventory.quantity = 5
|
||||
@@ -424,7 +461,9 @@ class TestInventoryService:
|
||||
result = self.service.get_low_inventory_items(db, threshold=10)
|
||||
|
||||
assert len(result) >= 1
|
||||
low_inventory_item = next((item for item in result if item["gtin"] == test_inventory.gtin), None)
|
||||
low_inventory_item = next(
|
||||
(item for item in result if item["gtin"] == test_inventory.gtin), None
|
||||
)
|
||||
assert low_inventory_item is not None
|
||||
assert low_inventory_item["current_quantity"] == 5
|
||||
assert low_inventory_item["location"] == test_inventory.location
|
||||
@@ -440,9 +479,13 @@ class TestInventoryService:
|
||||
|
||||
def test_get_inventory_summary_by_location_success(self, db, test_inventory):
|
||||
"""Test getting inventory summary by location successfully."""
|
||||
result = self.service.get_inventory_summary_by_location(db, test_inventory.location)
|
||||
result = self.service.get_inventory_summary_by_location(
|
||||
db, test_inventory.location
|
||||
)
|
||||
|
||||
assert result["location"] == test_inventory.location.upper() # Service normalizes to uppercase
|
||||
assert (
|
||||
result["location"] == test_inventory.location.upper()
|
||||
) # Service normalizes to uppercase
|
||||
assert result["total_items"] >= 1
|
||||
assert result["total_quantity"] >= test_inventory.quantity
|
||||
assert result["unique_gtins"] >= 1
|
||||
@@ -450,7 +493,9 @@ class TestInventoryService:
|
||||
def test_get_inventory_summary_by_location_empty_result(self, db):
|
||||
"""Test getting inventory summary for location with no inventory."""
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
result = self.service.get_inventory_summary_by_location(db, f"EMPTY_LOCATION_{unique_id}")
|
||||
result = self.service.get_inventory_summary_by_location(
|
||||
db, f"EMPTY_LOCATION_{unique_id}"
|
||||
)
|
||||
|
||||
assert result["total_items"] == 0
|
||||
assert result["total_quantity"] == 0
|
||||
@@ -459,12 +504,16 @@ class TestInventoryService:
|
||||
def test_validate_quantity_edge_cases(self, db):
|
||||
"""Test quantity validation with edge cases."""
|
||||
# Test zero quantity with allow_zero=True (should succeed)
|
||||
inventory_data = InventoryCreate(gtin="1234567890123", location="WAREHOUSE_A", quantity=0)
|
||||
inventory_data = InventoryCreate(
|
||||
gtin="1234567890123", location="WAREHOUSE_A", quantity=0
|
||||
)
|
||||
result = self.service.set_inventory(db, inventory_data)
|
||||
assert result.quantity == 0
|
||||
|
||||
# Test zero quantity with add_inventory (should fail - doesn't allow zero)
|
||||
inventory_data_add = InventoryAdd(gtin="1234567890123", location="WAREHOUSE_B", quantity=0)
|
||||
inventory_data_add = InventoryAdd(
|
||||
gtin="1234567890123", location="WAREHOUSE_B", quantity=0
|
||||
)
|
||||
with pytest.raises(InvalidQuantityException):
|
||||
self.service.add_inventory(db, inventory_data_add)
|
||||
|
||||
@@ -477,10 +526,10 @@ class TestInventoryService:
|
||||
exception = exc_info.value
|
||||
|
||||
# Verify exception structure matches WizamartException.to_dict()
|
||||
assert hasattr(exception, 'error_code')
|
||||
assert hasattr(exception, 'message')
|
||||
assert hasattr(exception, 'status_code')
|
||||
assert hasattr(exception, 'details')
|
||||
assert hasattr(exception, "error_code")
|
||||
assert hasattr(exception, "message")
|
||||
assert hasattr(exception, "status_code")
|
||||
assert hasattr(exception, "details")
|
||||
|
||||
assert isinstance(exception.error_code, str)
|
||||
assert isinstance(exception.message, str)
|
||||
|
||||
@@ -4,19 +4,18 @@ from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions.marketplace_import_job import (
|
||||
ImportJobNotFoundException,
|
||||
ImportJobNotOwnedException,
|
||||
ImportJobCannotBeCancelledException,
|
||||
ImportJobCannotBeDeletedException,
|
||||
)
|
||||
from app.exceptions.vendor import VendorNotFoundException, UnauthorizedVendorAccessException
|
||||
from app.exceptions.base import ValidationException
|
||||
from app.services.marketplace_import_job_service import MarketplaceImportJobService
|
||||
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
|
||||
from app.exceptions.marketplace_import_job import (
|
||||
ImportJobCannotBeCancelledException, ImportJobCannotBeDeletedException,
|
||||
ImportJobNotFoundException, ImportJobNotOwnedException)
|
||||
from app.exceptions.vendor import (UnauthorizedVendorAccessException,
|
||||
VendorNotFoundException)
|
||||
from app.services.marketplace_import_job_service import \
|
||||
MarketplaceImportJobService
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.vendor import Vendor
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -31,7 +30,9 @@ class TestMarketplaceService:
|
||||
test_vendor.owner_user_id = test_user.id
|
||||
db.commit()
|
||||
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_user)
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code, test_user
|
||||
)
|
||||
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
assert result.owner_user_id == test_user.id
|
||||
@@ -39,8 +40,10 @@ class TestMarketplaceService:
|
||||
def test_validate_vendor_access_admin_can_access_any_vendor(
|
||||
self, db, test_vendor, test_admin
|
||||
):
|
||||
"""Test that admin users can access any vendor """
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_admin)
|
||||
"""Test that admin users can access any vendor"""
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code, test_admin
|
||||
)
|
||||
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
|
||||
@@ -57,7 +60,7 @@ class TestMarketplaceService:
|
||||
def test_validate_vendor_access_permission_denied(
|
||||
self, db, test_vendor, test_user, other_user
|
||||
):
|
||||
"""Test vendor access validation when user doesn't own the vendor """
|
||||
"""Test vendor access validation when user doesn't own the vendor"""
|
||||
# Set the vendor owner to a different user
|
||||
test_vendor.owner_user_id = other_user.id
|
||||
db.commit()
|
||||
@@ -93,7 +96,7 @@ class TestMarketplaceService:
|
||||
assert result.vendor_name == test_vendor.name
|
||||
|
||||
def test_create_import_job_invalid_vendor(self, db, test_user):
|
||||
"""Test import job creation with invalid vendor """
|
||||
"""Test import job creation with invalid vendor"""
|
||||
request = MarketplaceImportJobRequest(
|
||||
url="https://example.com/products.csv",
|
||||
marketplace="Amazon",
|
||||
@@ -108,7 +111,9 @@ class TestMarketplaceService:
|
||||
assert exception.error_code == "VENDOR_NOT_FOUND"
|
||||
assert "INVALID_VENDOR" in exception.message
|
||||
|
||||
def test_create_import_job_unauthorized_access(self, db, test_vendor, test_user, other_user):
|
||||
def test_create_import_job_unauthorized_access(
|
||||
self, db, test_vendor, test_user, other_user
|
||||
):
|
||||
"""Test import job creation with unauthorized vendor access"""
|
||||
# Set the vendor owner to a different user
|
||||
test_vendor.owner_user_id = other_user.id
|
||||
@@ -127,7 +132,9 @@ class TestMarketplaceService:
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "UNAUTHORIZED_VENDOR_ACCESS"
|
||||
|
||||
def test_get_import_job_by_id_success(self, db, test_marketplace_import_job, test_user):
|
||||
def test_get_import_job_by_id_success(
|
||||
self, db, test_marketplace_import_job, test_user
|
||||
):
|
||||
"""Test getting import job by ID for job owner"""
|
||||
result = self.service.get_import_job_by_id(
|
||||
db, test_marketplace_import_job.id, test_user
|
||||
@@ -161,14 +168,18 @@ class TestMarketplaceService:
|
||||
):
|
||||
"""Test access denied when user doesn't own the job"""
|
||||
with pytest.raises(ImportJobNotOwnedException) as exc_info:
|
||||
self.service.get_import_job_by_id(db, test_marketplace_import_job.id, other_user)
|
||||
self.service.get_import_job_by_id(
|
||||
db, test_marketplace_import_job.id, other_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
|
||||
assert exception.status_code == 403
|
||||
assert str(test_marketplace_import_job.id) in exception.message
|
||||
|
||||
def test_get_import_jobs_user_filter(self, db, test_marketplace_import_job, test_user):
|
||||
def test_get_import_jobs_user_filter(
|
||||
self, db, test_marketplace_import_job, test_user
|
||||
):
|
||||
"""Test getting import jobs filtered by user"""
|
||||
jobs = self.service.get_import_jobs(db, test_user)
|
||||
|
||||
@@ -176,7 +187,9 @@ class TestMarketplaceService:
|
||||
assert any(job.id == test_marketplace_import_job.id for job in jobs)
|
||||
assert test_marketplace_import_job.user_id == test_user.id
|
||||
|
||||
def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_import_job, test_admin):
|
||||
def test_get_import_jobs_admin_sees_all(
|
||||
self, db, test_marketplace_import_job, test_admin
|
||||
):
|
||||
"""Test that admin sees all import jobs"""
|
||||
jobs = self.service.get_import_jobs(db, test_admin)
|
||||
|
||||
@@ -192,7 +205,9 @@ class TestMarketplaceService:
|
||||
)
|
||||
|
||||
assert len(jobs) >= 1
|
||||
assert any(job.marketplace == test_marketplace_import_job.marketplace for job in jobs)
|
||||
assert any(
|
||||
job.marketplace == test_marketplace_import_job.marketplace for job in jobs
|
||||
)
|
||||
|
||||
def test_get_import_jobs_with_pagination(self, db, test_user, test_vendor):
|
||||
"""Test getting import jobs with pagination"""
|
||||
@@ -330,10 +345,14 @@ class TestMarketplaceService:
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
|
||||
|
||||
def test_cancel_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
|
||||
def test_cancel_import_job_access_denied(
|
||||
self, db, test_marketplace_import_job, other_user
|
||||
):
|
||||
"""Test cancelling import job without access"""
|
||||
with pytest.raises(ImportJobNotOwnedException) as exc_info:
|
||||
self.service.cancel_import_job(db, test_marketplace_import_job.id, other_user)
|
||||
self.service.cancel_import_job(
|
||||
db, test_marketplace_import_job.id, other_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
|
||||
@@ -347,7 +366,9 @@ class TestMarketplaceService:
|
||||
db.commit()
|
||||
|
||||
with pytest.raises(ImportJobCannotBeCancelledException) as exc_info:
|
||||
self.service.cancel_import_job(db, test_marketplace_import_job.id, test_user)
|
||||
self.service.cancel_import_job(
|
||||
db, test_marketplace_import_job.id, test_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_CANNOT_BE_CANCELLED"
|
||||
@@ -396,10 +417,14 @@ class TestMarketplaceService:
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
|
||||
|
||||
def test_delete_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
|
||||
def test_delete_import_job_access_denied(
|
||||
self, db, test_marketplace_import_job, other_user
|
||||
):
|
||||
"""Test deleting import job without access"""
|
||||
with pytest.raises(ImportJobNotOwnedException) as exc_info:
|
||||
self.service.delete_import_job(db, test_marketplace_import_job.id, other_user)
|
||||
self.service.delete_import_job(
|
||||
db, test_marketplace_import_job.id, other_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
|
||||
@@ -440,11 +465,15 @@ class TestMarketplaceService:
|
||||
db.commit()
|
||||
|
||||
# Test with lowercase vendor code
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.lower(), test_user)
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code.lower(), test_user
|
||||
)
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
|
||||
# Test with uppercase vendor code
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.upper(), test_user)
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code.upper(), test_user
|
||||
)
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
|
||||
def test_create_import_job_database_error(self, db_with_error, test_user):
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# tests/test_product_service.py
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (InvalidMarketplaceProductDataException,
|
||||
MarketplaceProductAlreadyExistsException,
|
||||
MarketplaceProductNotFoundException,
|
||||
MarketplaceProductValidationException,
|
||||
ValidationException)
|
||||
from app.services.marketplace_product_service import MarketplaceProductService
|
||||
from app.exceptions import (
|
||||
MarketplaceProductNotFoundException,
|
||||
MarketplaceProductAlreadyExistsException,
|
||||
InvalidMarketplaceProductDataException,
|
||||
MarketplaceProductValidationException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.schema.marketplace_product import (MarketplaceProductCreate,
|
||||
MarketplaceProductUpdate)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -98,7 +97,10 @@ class TestProductService:
|
||||
assert exc_info.value.error_code == "PRODUCT_ALREADY_EXISTS"
|
||||
assert test_marketplace_product.marketplace_product_id in str(exc_info.value)
|
||||
assert exc_info.value.status_code == 409
|
||||
assert exc_info.value.details.get("marketplace_product_id") == test_marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
exc_info.value.details.get("marketplace_product_id")
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
def test_create_product_invalid_price(self, db):
|
||||
"""Test product creation with invalid price raises InvalidMarketplaceProductDataException"""
|
||||
@@ -117,9 +119,14 @@ class TestProductService:
|
||||
|
||||
def test_get_product_by_id_or_raise_success(self, db, test_marketplace_product):
|
||||
"""Test successful product retrieval by ID"""
|
||||
product = self.service.get_product_by_id_or_raise(db, test_marketplace_product.marketplace_product_id)
|
||||
product = self.service.get_product_by_id_or_raise(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
assert product.marketplace_product_id == test_marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
product.marketplace_product_id
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
assert product.title == test_marketplace_product.title
|
||||
|
||||
def test_get_product_by_id_or_raise_not_found(self, db):
|
||||
@@ -152,21 +159,35 @@ class TestProductService:
|
||||
assert total >= 1
|
||||
assert len(products) >= 1
|
||||
# Verify search worked by checking that title contains search term
|
||||
found_product = next((p for p in products if p.marketplace_product_id == test_marketplace_product.marketplace_product_id), None)
|
||||
found_product = next(
|
||||
(
|
||||
p
|
||||
for p in products
|
||||
if p.marketplace_product_id
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert found_product is not None
|
||||
|
||||
def test_update_product_success(self, db, test_marketplace_product):
|
||||
"""Test successful product update"""
|
||||
update_data = MarketplaceProductUpdate(
|
||||
title="Updated MarketplaceProduct Title",
|
||||
price="39.99"
|
||||
title="Updated MarketplaceProduct Title", price="39.99"
|
||||
)
|
||||
|
||||
updated_product = self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
updated_product = self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert updated_product.title == "Updated MarketplaceProduct Title"
|
||||
assert updated_product.price == "39.99" # Price is stored as string after processing
|
||||
assert updated_product.marketplace_product_id == test_marketplace_product.marketplace_product_id # ID unchanged
|
||||
assert (
|
||||
updated_product.price == "39.99"
|
||||
) # Price is stored as string after processing
|
||||
assert (
|
||||
updated_product.marketplace_product_id
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
) # ID unchanged
|
||||
|
||||
def test_update_product_not_found(self, db):
|
||||
"""Test updating non-existent product raises MarketplaceProductNotFoundException"""
|
||||
@@ -183,7 +204,9 @@ class TestProductService:
|
||||
update_data = MarketplaceProductUpdate(gtin="invalid_gtin")
|
||||
|
||||
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
|
||||
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
|
||||
assert "Invalid GTIN format" in str(exc_info.value)
|
||||
@@ -194,7 +217,9 @@ class TestProductService:
|
||||
update_data = MarketplaceProductUpdate(title="")
|
||||
|
||||
with pytest.raises(MarketplaceProductValidationException) as exc_info:
|
||||
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "PRODUCT_VALIDATION_FAILED"
|
||||
assert "MarketplaceProduct title cannot be empty" in str(exc_info.value)
|
||||
@@ -205,7 +230,9 @@ class TestProductService:
|
||||
update_data = MarketplaceProductUpdate(price="invalid_price")
|
||||
|
||||
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
|
||||
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
|
||||
assert "Invalid price format" in str(exc_info.value)
|
||||
@@ -213,12 +240,16 @@ class TestProductService:
|
||||
|
||||
def test_delete_product_success(self, db, test_marketplace_product):
|
||||
"""Test successful product deletion"""
|
||||
result = self.service.delete_product(db, test_marketplace_product.marketplace_product_id)
|
||||
result = self.service.delete_product(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify product is deleted
|
||||
deleted_product = self.service.get_product_by_id(db, test_marketplace_product.marketplace_product_id)
|
||||
deleted_product = self.service.get_product_by_id(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
assert deleted_product is None
|
||||
|
||||
def test_delete_product_not_found(self, db):
|
||||
@@ -229,10 +260,14 @@ class TestProductService:
|
||||
assert exc_info.value.error_code == "PRODUCT_NOT_FOUND"
|
||||
assert "NONEXISTENT" in str(exc_info.value)
|
||||
|
||||
def test_get_inventory_info_success(self, db, test_marketplace_product_with_inventory):
|
||||
def test_get_inventory_info_success(
|
||||
self, db, test_marketplace_product_with_inventory
|
||||
):
|
||||
"""Test getting inventory info for product with inventory"""
|
||||
# Extract the product from the dictionary
|
||||
marketplace_product = test_marketplace_product_with_inventory['marketplace_product']
|
||||
marketplace_product = test_marketplace_product_with_inventory[
|
||||
"marketplace_product"
|
||||
]
|
||||
|
||||
inventory_info = self.service.get_inventory_info(db, marketplace_product.gtin)
|
||||
|
||||
@@ -243,13 +278,17 @@ class TestProductService:
|
||||
|
||||
def test_get_inventory_info_no_inventory(self, db, test_marketplace_product):
|
||||
"""Test getting inventory info for product without inventory"""
|
||||
inventory_info = self.service.get_inventory_info(db, test_marketplace_product.gtin or "1234567890123")
|
||||
inventory_info = self.service.get_inventory_info(
|
||||
db, test_marketplace_product.gtin or "1234567890123"
|
||||
)
|
||||
|
||||
assert inventory_info is None
|
||||
|
||||
def test_product_exists_true(self, db, test_marketplace_product):
|
||||
"""Test product_exists returns True for existing product"""
|
||||
exists = self.service.product_exists(db, test_marketplace_product.marketplace_product_id)
|
||||
exists = self.service.product_exists(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
assert exists is True
|
||||
|
||||
def test_product_exists_false(self, db):
|
||||
@@ -265,7 +304,9 @@ class TestProductService:
|
||||
csv_lines = list(csv_generator)
|
||||
|
||||
assert len(csv_lines) > 1 # Header + at least one data row
|
||||
assert csv_lines[0].startswith("marketplace_product_id,title,description") # Check header
|
||||
assert csv_lines[0].startswith(
|
||||
"marketplace_product_id,title,description"
|
||||
) # Check header
|
||||
|
||||
# Check that test product appears in CSV
|
||||
csv_content = "".join(csv_lines)
|
||||
@@ -274,8 +315,7 @@ class TestProductService:
|
||||
def test_generate_csv_export_with_filters(self, db, test_marketplace_product):
|
||||
"""Test CSV export with marketplace filter"""
|
||||
csv_generator = self.service.generate_csv_export(
|
||||
db,
|
||||
marketplace=test_marketplace_product.marketplace
|
||||
db, marketplace=test_marketplace_product.marketplace
|
||||
)
|
||||
|
||||
csv_lines = list(csv_generator)
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import pytest
|
||||
|
||||
from app.services.stats_service import StatsService
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -15,7 +15,9 @@ class TestStatsService:
|
||||
"""Setup method following the same pattern as other service tests"""
|
||||
self.service = StatsService()
|
||||
|
||||
def test_get_comprehensive_stats_basic(self, db, test_marketplace_product, test_inventory):
|
||||
def test_get_comprehensive_stats_basic(
|
||||
self, db, test_marketplace_product, test_inventory
|
||||
):
|
||||
"""Test getting comprehensive stats with basic data"""
|
||||
stats = self.service.get_comprehensive_stats(db)
|
||||
|
||||
@@ -31,7 +33,9 @@ class TestStatsService:
|
||||
assert stats["total_inventory_entries"] >= 1
|
||||
assert stats["total_inventory_quantity"] >= 10 # test_inventory has quantity 10
|
||||
|
||||
def test_get_comprehensive_stats_multiple_products(self, db, test_marketplace_product):
|
||||
def test_get_comprehensive_stats_multiple_products(
|
||||
self, db, test_marketplace_product
|
||||
):
|
||||
"""Test comprehensive stats with multiple products across different dimensions"""
|
||||
# Create products with different brands, categories, marketplaces
|
||||
additional_products = [
|
||||
@@ -87,7 +91,7 @@ class TestStatsService:
|
||||
brand=None, # Null brand
|
||||
google_product_category=None, # Null category
|
||||
marketplace=None, # Null marketplace
|
||||
vendor_name=None, # Null vendor
|
||||
vendor_name=None, # Null vendor
|
||||
price="10.00",
|
||||
currency="EUR",
|
||||
),
|
||||
@@ -97,7 +101,7 @@ class TestStatsService:
|
||||
brand="", # Empty brand
|
||||
google_product_category="", # Empty category
|
||||
marketplace="", # Empty marketplace
|
||||
vendor_name="", # Empty vendor
|
||||
vendor_name="", # Empty vendor
|
||||
price="15.00",
|
||||
currency="EUR",
|
||||
),
|
||||
@@ -124,7 +128,11 @@ class TestStatsService:
|
||||
|
||||
# Find our test marketplace in the results
|
||||
test_marketplace_stat = next(
|
||||
(stat for stat in stats if stat["marketplace"] == test_marketplace_product.marketplace),
|
||||
(
|
||||
stat
|
||||
for stat in stats
|
||||
if stat["marketplace"] == test_marketplace_product.marketplace
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert test_marketplace_stat is not None
|
||||
@@ -309,7 +317,9 @@ class TestStatsService:
|
||||
|
||||
count = self.service._get_unique_marketplaces_count(db)
|
||||
|
||||
assert count >= 2 # At least Amazon and eBay, plus test_marketplace_product marketplace
|
||||
assert (
|
||||
count >= 2
|
||||
) # At least Amazon and eBay, plus test_marketplace_product marketplace
|
||||
assert isinstance(count, int)
|
||||
|
||||
def test_get_unique_vendors_count(self, db, test_marketplace_product):
|
||||
@@ -338,7 +348,9 @@ class TestStatsService:
|
||||
|
||||
count = self.service._get_unique_vendors_count(db)
|
||||
|
||||
assert count >= 2 # At least VendorA and VendorB, plus test_marketplace_product vendor
|
||||
assert (
|
||||
count >= 2
|
||||
) # At least VendorA and VendorB, plus test_marketplace_product vendor
|
||||
assert isinstance(count, int)
|
||||
|
||||
def test_get_inventory_statistics(self, db, test_inventory):
|
||||
@@ -438,7 +450,7 @@ class TestStatsService:
|
||||
db.add_all(marketplace_products)
|
||||
db.commit()
|
||||
|
||||
vendors =self.service._get_vendors_by_marketplace(db, "TestMarketplace")
|
||||
vendors = self.service._get_vendors_by_marketplace(db, "TestMarketplace")
|
||||
|
||||
assert len(vendors) == 2
|
||||
assert "TestVendor1" in vendors
|
||||
@@ -482,7 +494,9 @@ class TestStatsService:
|
||||
|
||||
def test_get_products_by_marketplace_not_found(self, db):
|
||||
"""Test getting product count for non-existent marketplace"""
|
||||
count = self.service._get_products_by_marketplace_count(db, "NonExistentMarketplace")
|
||||
count = self.service._get_products_by_marketplace_count(
|
||||
db, "NonExistentMarketplace"
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
# tests/test_vendor_service.py (updated to use custom exceptions)
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (InvalidVendorDataException,
|
||||
MarketplaceProductNotFoundException,
|
||||
MaxVendorsReachedException,
|
||||
ProductAlreadyExistsException,
|
||||
UnauthorizedVendorAccessException,
|
||||
ValidationException, VendorAlreadyExistsException,
|
||||
VendorNotFoundException)
|
||||
from app.services.vendor_service import VendorService
|
||||
from app.exceptions import (
|
||||
VendorNotFoundException,
|
||||
VendorAlreadyExistsException,
|
||||
UnauthorizedVendorAccessException,
|
||||
InvalidVendorDataException,
|
||||
MarketplaceProductNotFoundException,
|
||||
ProductAlreadyExistsException,
|
||||
MaxVendorsReachedException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.schema.vendor import VendorCreate
|
||||
from models.schema.product import ProductCreate
|
||||
from models.schema.vendor import VendorCreate
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -38,15 +35,17 @@ class TestVendorService:
|
||||
assert vendor is not None
|
||||
assert vendor.vendor_code == "NEWVENDOR"
|
||||
assert vendor.owner_user_id == test_user.id
|
||||
assert vendor.is_verified is False # Regular user creates unverified vendor
|
||||
assert vendor.is_verified is False # Regular user creates unverified vendor
|
||||
|
||||
def test_create_vendor_admin_auto_verify(self, db, test_admin, vendor_factory):
|
||||
"""Test admin creates verified vendor automatically"""
|
||||
vendor_data = VendorCreate(vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor")
|
||||
vendor_data = VendorCreate(
|
||||
vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor"
|
||||
)
|
||||
|
||||
vendor = self.service.create_vendor(db, vendor_data, test_admin)
|
||||
|
||||
assert vendor.is_verified is True # Admin creates verified vendor
|
||||
assert vendor.is_verified is True # Admin creates verified vendor
|
||||
|
||||
def test_create_vendor_duplicate_code(self, db, test_user, test_vendor):
|
||||
"""Test vendor creation fails with duplicate vendor code"""
|
||||
@@ -88,7 +87,9 @@ class TestVendorService:
|
||||
|
||||
def test_create_vendor_invalid_code_format(self, db, test_user):
|
||||
"""Test vendor creation fails with invalid vendor code format"""
|
||||
vendor_data = VendorCreate(vendor_code="INVALID@CODE!", vendor_name="Test Vendor")
|
||||
vendor_data = VendorCreate(
|
||||
vendor_code="INVALID@CODE!", vendor_name="Test Vendor"
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidVendorDataException) as exc_info:
|
||||
self.service.create_vendor(db, vendor_data, test_user)
|
||||
@@ -105,7 +106,9 @@ class TestVendorService:
|
||||
def mock_check_vendor_limit(self, db, user):
|
||||
raise MaxVendorsReachedException(max_vendors=5, user_id=user.id)
|
||||
|
||||
monkeypatch.setattr(VendorService, "_check_vendor_limit", mock_check_vendor_limit)
|
||||
monkeypatch.setattr(
|
||||
VendorService, "_check_vendor_limit", mock_check_vendor_limit
|
||||
)
|
||||
|
||||
vendor_data = VendorCreate(vendor_code="NEWVENDOR", vendor_name="New Vendor")
|
||||
|
||||
@@ -118,7 +121,9 @@ class TestVendorService:
|
||||
assert exception.details["max_vendors"] == 5
|
||||
assert exception.details["user_id"] == test_user.id
|
||||
|
||||
def test_get_vendors_regular_user(self, db, test_user, test_vendor, inactive_vendor):
|
||||
def test_get_vendors_regular_user(
|
||||
self, db, test_user, test_vendor, inactive_vendor
|
||||
):
|
||||
"""Test regular user can only see active verified vendors and own vendors"""
|
||||
vendors, total = self.service.get_vendors(db, test_user, skip=0, limit=10)
|
||||
|
||||
@@ -127,7 +132,7 @@ class TestVendorService:
|
||||
assert inactive_vendor.vendor_code not in vendor_codes
|
||||
|
||||
def test_get_vendors_admin_user(
|
||||
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
|
||||
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
|
||||
):
|
||||
"""Test admin user can see all vendors with filters"""
|
||||
vendors, total = self.service.get_vendors(
|
||||
@@ -140,14 +145,16 @@ class TestVendorService:
|
||||
assert verified_vendor.vendor_code in vendor_codes
|
||||
|
||||
def test_get_vendor_by_code_owner_access(self, db, test_user, test_vendor):
|
||||
"""Test vendor owner can access their own vendor """
|
||||
vendor = self.service.get_vendor_by_code(db, test_vendor.vendor_code.lower(), test_user)
|
||||
"""Test vendor owner can access their own vendor"""
|
||||
vendor = self.service.get_vendor_by_code(
|
||||
db, test_vendor.vendor_code.lower(), test_user
|
||||
)
|
||||
|
||||
assert vendor is not None
|
||||
assert vendor.id == test_vendor.id
|
||||
|
||||
def test_get_vendor_by_code_admin_access(self, db, test_admin, test_vendor):
|
||||
"""Test admin can access any vendor """
|
||||
"""Test admin can access any vendor"""
|
||||
vendor = self.service.get_vendor_by_code(
|
||||
db, test_vendor.vendor_code.lower(), test_admin
|
||||
)
|
||||
@@ -178,16 +185,14 @@ class TestVendorService:
|
||||
assert exception.details["user_id"] == test_user.id
|
||||
|
||||
def test_add_product_to_vendor_success(self, db, test_vendor, unique_product):
|
||||
"""Test successfully adding product to vendor """
|
||||
"""Test successfully adding product to vendor"""
|
||||
product_data = ProductCreate(
|
||||
marketplace_product_id=unique_product.marketplace_product_id,
|
||||
price="15.99",
|
||||
is_featured=True,
|
||||
)
|
||||
|
||||
product = self.service.add_product_to_catalog(
|
||||
db, test_vendor, product_data
|
||||
)
|
||||
product = self.service.add_product_to_catalog(db, test_vendor, product_data)
|
||||
|
||||
assert product is not None
|
||||
assert product.vendor_id == test_vendor.id
|
||||
@@ -195,7 +200,9 @@ class TestVendorService:
|
||||
|
||||
def test_add_product_to_vendor_product_not_found(self, db, test_vendor):
|
||||
"""Test adding non-existent product to vendor fails"""
|
||||
product_data = ProductCreate(marketplace_product_id="NONEXISTENT", price="15.99")
|
||||
product_data = ProductCreate(
|
||||
marketplace_product_id="NONEXISTENT", price="15.99"
|
||||
)
|
||||
|
||||
with pytest.raises(MarketplaceProductNotFoundException) as exc_info:
|
||||
self.service.add_product_to_catalog(db, test_vendor, product_data)
|
||||
@@ -209,7 +216,8 @@ class TestVendorService:
|
||||
def test_add_product_to_vendor_already_exists(self, db, test_vendor, test_product):
|
||||
"""Test adding product that's already in vendor fails"""
|
||||
product_data = ProductCreate(
|
||||
marketplace_product_id=test_product.marketplace_product.marketplace_product_id, price="15.99"
|
||||
marketplace_product_id=test_product.marketplace_product.marketplace_product_id,
|
||||
price="15.99",
|
||||
)
|
||||
|
||||
with pytest.raises(ProductAlreadyExistsException) as exc_info:
|
||||
@@ -219,11 +227,12 @@ class TestVendorService:
|
||||
assert exception.status_code == 409
|
||||
assert exception.error_code == "PRODUCT_ALREADY_EXISTS"
|
||||
assert exception.details["vendor_code"] == test_vendor.vendor_code
|
||||
assert exception.details["marketplace_product_id"] == test_product.marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
exception.details["marketplace_product_id"]
|
||||
== test_product.marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
def test_get_products_owner_access(
|
||||
self, db, test_user, test_vendor, test_product
|
||||
):
|
||||
def test_get_products_owner_access(self, db, test_user, test_vendor, test_product):
|
||||
"""Test vendor owner can get vendor products"""
|
||||
products, total = self.service.get_products(db, test_vendor, test_user)
|
||||
|
||||
@@ -291,7 +300,9 @@ class TestVendorService:
|
||||
assert exception.error_code == "VALIDATION_ERROR"
|
||||
assert "Failed to retrieve vendors" in exception.message
|
||||
|
||||
def test_add_product_database_error(self, db, test_vendor, unique_product, monkeypatch):
|
||||
def test_add_product_database_error(
|
||||
self, db, test_vendor, unique_product, monkeypatch
|
||||
):
|
||||
"""Test add product handles database errors gracefully"""
|
||||
|
||||
def mock_commit():
|
||||
|
||||
@@ -18,7 +18,9 @@ class TestCSVProcessor:
|
||||
def test_download_csv_encoding_fallback(self, mock_get):
|
||||
"""Test CSV download with encoding fallback"""
|
||||
# Create content with special characters that would fail UTF-8 if not properly encoded
|
||||
special_content = "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
|
||||
special_content = (
|
||||
"marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
@@ -40,9 +42,7 @@ class TestCSVProcessor:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
# Create bytes that will fail most encodings
|
||||
mock_response.content = (
|
||||
b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
|
||||
)
|
||||
mock_response.content = b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user