style: apply black and isort formatting across entire codebase

- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards

Also updated Makefile to exclude both venv and .venv from code quality checks.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-28 19:30:17 +01:00
parent 13f0094743
commit 21c13ca39b
236 changed files with 8450 additions and 6545 deletions

View File

@@ -12,21 +12,18 @@ Tests cover:
- Error handling and edge cases
"""
import pytest
from unittest.mock import Mock, MagicMock, patch
from datetime import datetime, timedelta, timezone
from jose import jwt
from fastapi import HTTPException
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi import HTTPException
from jose import jwt
from app.exceptions import (AdminRequiredException,
InsufficientPermissionsException,
InvalidCredentialsException, InvalidTokenException,
TokenExpiredException, UserNotActiveException)
from middleware.auth import AuthManager
from app.exceptions import (
InvalidTokenException,
TokenExpiredException,
UserNotActiveException,
InvalidCredentialsException,
AdminRequiredException,
InsufficientPermissionsException,
)
from models.database.user import User
@@ -124,7 +121,9 @@ class TestUserAuthentication:
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
result = auth_manager.authenticate_user(
mock_db, "test@example.com", "password123"
)
assert result is mock_user
@@ -192,7 +191,9 @@ class TestJWTTokenCreation:
token = token_data["access_token"]
# Decode without verification to check payload
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
payload = jwt.decode(
token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
)
assert payload["sub"] == "42"
assert payload["username"] == "testuser"
@@ -205,8 +206,12 @@ class TestJWTTokenCreation:
"""Test tokens are different for different users."""
auth_manager = AuthManager()
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
user1 = Mock(
spec=User, id=1, username="user1", email="user1@test.com", role="customer"
)
user2 = Mock(
spec=User, id=2, username="user2", email="user2@test.com", role="vendor"
)
token1 = auth_manager.create_access_token(user1)["access_token"]
token2 = auth_manager.create_access_token(user2)["access_token"]
@@ -227,7 +232,7 @@ class TestJWTTokenCreation:
payload = jwt.decode(
token_data["access_token"],
auth_manager.secret_key,
algorithms=[auth_manager.algorithm]
algorithms=[auth_manager.algorithm],
)
assert payload["role"] == "admin"
@@ -311,9 +316,11 @@ class TestJWTTokenVerification:
# Create token without 'sub' field
payload = {
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -325,11 +332,10 @@ class TestJWTTokenVerification:
auth_manager = AuthManager()
# Create token without 'exp' field
payload = {
"sub": "1",
"username": "testuser"
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
payload = {"sub": "1", "username": "testuser"}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -343,7 +349,7 @@ class TestJWTTokenVerification:
payload = {
"sub": "1",
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
# Create token with different algorithm
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
@@ -357,15 +363,13 @@ class TestJWTTokenVerification:
# Create a token with expiration in the past
past_time = datetime.now(timezone.utc) - timedelta(minutes=1)
payload = {
"sub": "1",
"username": "testuser",
"exp": past_time.timestamp()
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
# Mock jwt.decode to bypass its expiration check and test line 205
with patch('middleware.auth.jwt.decode') as mock_decode:
with patch("middleware.auth.jwt.decode") as mock_decode:
mock_decode.return_value = payload
with pytest.raises(TokenExpiredException):
@@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser:
# Existing admin user
existing_admin = Mock(spec=User)
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
mock_db.query.return_value.filter.return_value.first.return_value = (
existing_admin
)
result = auth_manager.create_default_admin_user(mock_db)
@@ -599,19 +605,21 @@ class TestAuthManagerConfiguration:
def test_default_configuration(self):
"""Test AuthManager uses default configuration."""
with patch.dict('os.environ', {}, clear=True):
with patch.dict("os.environ", {}, clear=True):
auth_manager = AuthManager()
assert auth_manager.algorithm == "HS256"
assert auth_manager.token_expire_minutes == 30
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
assert (
auth_manager.secret_key == "your-secret-key-change-in-production-please"
)
def test_custom_configuration(self):
"""Test AuthManager uses environment variables."""
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'custom-secret-key',
'JWT_EXPIRE_MINUTES': '60'
}):
with patch.dict(
"os.environ",
{"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"},
):
auth_manager = AuthManager()
assert auth_manager.secret_key == "custom-secret-key"
@@ -619,9 +627,7 @@ class TestAuthManagerConfiguration:
def test_partial_custom_configuration(self):
"""Test AuthManager with partial environment configuration."""
with patch.dict('os.environ', {
'JWT_EXPIRE_MINUTES': '120'
}, clear=False):
with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False):
auth_manager = AuthManager()
assert auth_manager.token_expire_minutes == 120
@@ -656,9 +662,11 @@ class TestEdgeCases:
"sub": "1",
"username": "testuser",
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
"exp": datetime.now(timezone.utc) + timedelta(hours=2),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
# Should still verify successfully (JWT doesn't validate iat by default)
result = auth_manager.verify_token(token)
@@ -698,7 +706,9 @@ class TestEdgeCases:
token = token_data["access_token"]
# Mock jose.jwt.decode to raise an unexpected exception
with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")):
with patch(
"middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error")
):
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)

View File

@@ -10,16 +10,13 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from unittest.mock import Mock, AsyncMock, patch
from fastapi import Request
from middleware.context import (
ContextManager,
ContextMiddleware,
RequestContext,
get_request_context,
)
from middleware.context import (ContextManager, ContextMiddleware,
RequestContext, get_request_context)
@pytest.mark.unit
@@ -321,22 +318,38 @@ class TestContextManagerHelpers:
def test_is_admin_context_from_subdomain(self):
"""Test _is_admin_context with admin subdomain."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
assert (
ContextManager._is_admin_context(
request, "admin.platform.com", "/dashboard"
)
is True
)
def test_is_admin_context_from_path(self):
"""Test _is_admin_context with admin path."""
request = Mock()
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
assert (
ContextManager._is_admin_context(request, "localhost", "/admin/users")
is True
)
def test_is_admin_context_both(self):
"""Test _is_admin_context with both subdomain and path."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
assert (
ContextManager._is_admin_context(
request, "admin.platform.com", "/admin/users"
)
is True
)
def test_is_not_admin_context(self):
"""Test _is_admin_context returns False for non-admin."""
request = Mock()
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
assert (
ContextManager._is_admin_context(request, "vendor.platform.com", "/shop")
is False
)
def test_is_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context with /vendor/ path."""
@@ -344,11 +357,16 @@ class TestContextManagerHelpers:
def test_is_vendor_dashboard_context_nested(self):
"""Test _is_vendor_dashboard_context with nested vendor path."""
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
assert (
ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
)
def test_is_not_vendor_dashboard_context_vendors_plural(self):
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
assert (
ContextManager._is_vendor_dashboard_context("/vendors/shop123/products")
is False
)
def test_is_not_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
@@ -373,7 +391,7 @@ class TestContextMiddleware:
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'context_type')
assert hasattr(request.state, "context_type")
assert request.state.context_type == RequestContext.API
call_next.assert_called_once_with(request)
@@ -565,7 +583,7 @@ class TestEdgeCases:
request.url = Mock(path="/api/vendors")
request.headers = {"host": "localhost"}
# No state attribute at all
delattr(request, 'state')
delattr(request, "state")
# Should still work, falling back to url.path
with pytest.raises(AttributeError):

View File

@@ -12,16 +12,18 @@ Tests cover:
- Edge cases and isolation
"""
import pytest
from unittest.mock import Mock
from middleware.decorators import rate_limit, rate_limiter
from app.exceptions.base import RateLimitException
import pytest
from app.exceptions.base import RateLimitException
from middleware.decorators import rate_limit, rate_limiter
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(autouse=True)
def reset_rate_limiter():
"""Reset rate limiter state before each test to ensure isolation."""
@@ -34,6 +36,7 @@ def reset_rate_limiter():
# Rate Limit Decorator Tests
# =============================================================================
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimitDecorator:
@@ -42,6 +45,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_allows_within_limit(self):
"""Test decorator allows requests within rate limit."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
return {"status": "ok"}
@@ -53,6 +57,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_blocks_exceeding_limit(self):
"""Test decorator blocks requests exceeding rate limit."""
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_blocked():
return {"status": "ok"}
@@ -71,6 +76,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_preserves_function_metadata(self):
"""Test decorator preserves original function metadata."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
"""Test endpoint docstring."""
@@ -82,21 +88,19 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_with_args_and_kwargs(self):
"""Test decorator works with function arguments."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint(arg1, arg2, kwarg1=None):
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
result = await test_endpoint("value1", "value2", kwarg1="value3")
assert result == {
"arg1": "value1",
"arg2": "value2",
"kwarg1": "value3"
}
assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"}
@pytest.mark.asyncio
async def test_decorator_default_parameters(self):
"""Test decorator uses default parameters."""
@rate_limit() # Use defaults
async def test_endpoint():
return {"status": "ok"}
@@ -108,6 +112,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_exception_includes_retry_after(self):
"""Test rate limit exception includes retry_after."""
@rate_limit(max_requests=1, window_seconds=60)
async def test_endpoint_retry():
return {"status": "ok"}
@@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_zero_max_requests(self):
"""Test decorator with max_requests=0 blocks all requests."""
@rate_limit(max_requests=0, window_seconds=3600)
async def test_endpoint_zero():
return {"status": "ok"}
@@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_very_short_window(self):
"""Test decorator with very short time window."""
@rate_limit(max_requests=1, window_seconds=1)
async def test_endpoint_short():
return {"status": "ok"}
@@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_multiple_functions_separate_limits(self):
"""Test that different functions have separate rate limits."""
@rate_limit(max_requests=1, window_seconds=3600)
async def endpoint1():
return {"endpoint": "1"}
@@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_exception_in_function(self):
"""Test decorator handles exceptions from wrapped function."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint_error():
raise ValueError("Function error")
@@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_isolation_between_tests(self):
"""Test that rate limiter state is properly isolated between tests."""
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_isolation():
return {"status": "ok"}
@@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_dict(self):
"""Test decorator correctly returns dictionary."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_dict():
return {"key": "value", "number": 42}
@@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_list(self):
"""Test decorator correctly returns list."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_list():
return [1, 2, 3, 4, 5]
@@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_none(self):
"""Test decorator correctly returns None."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_none():
return None
@@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_object(self):
"""Test decorator correctly returns custom objects."""
class TestObject:
def __init__(self):
self.name = "test_object"

View File

@@ -11,9 +11,10 @@ Tests cover:
- Edge cases (missing client info, etc.)
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request
from middleware.logging import LoggingMiddleware
@@ -40,7 +41,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify request was logged
@@ -65,7 +66,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
result = await middleware.dispatch(request, call_next)
# Verify response was logged
@@ -89,7 +90,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
assert "X-Process-Time" in response.headers
@@ -113,11 +114,13 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Should log "unknown" for client
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
assert any(
"unknown" in str(call) for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_middleware_logs_exceptions(self):
@@ -131,8 +134,9 @@ class TestLoggingMiddleware:
call_next = AsyncMock(side_effect=Exception("Test error"))
with patch('middleware.logging.logger') as mock_logger, \
pytest.raises(Exception):
with patch("middleware.logging.logger") as mock_logger, pytest.raises(
Exception
):
await middleware.dispatch(request, call_next)
# Verify error was logged
@@ -156,7 +160,7 @@ class TestLoggingMiddleware:
call_next = slow_call_next
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
process_time = float(result.headers["X-Process-Time"])
@@ -181,7 +185,7 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
# Should still have process time, even if very small
@@ -205,11 +209,13 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify method was logged
assert any(method in str(call) for call in mock_logger.info.call_args_list)
assert any(
method in str(call) for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_middleware_logs_different_status_codes(self):
@@ -227,8 +233,11 @@ class TestLoggingEdgeCases:
response = Mock(status_code=status_code, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify status code was logged
assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list)
assert any(
str(status_code) in str(call)
for call in mock_logger.info.call_args_list
)

View File

@@ -11,10 +11,11 @@ Tests cover:
- Edge cases and concurrency scenarios
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from collections import deque
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch
import pytest
from middleware.rate_limiter import RateLimiter
@@ -306,8 +307,8 @@ class TestRateLimiterStatistics:
# Add requests at different times
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
stats = limiter.get_client_stats(client_id)
@@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases:
limiter = RateLimiter()
client_id = "long_window_client"
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
result = limiter.allow_request(
client_id, max_requests=10, window_seconds=86400 * 365
)
assert result is True
@@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases:
client_id = "same_client"
# Allow with one limit
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
assert (
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
is True
)
# Check with stricter limit
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
assert (
limiter.allow_request(client_id, max_requests=1, window_seconds=3600)
is False
)
def test_rate_limiter_unicode_client_id(self):
"""Test rate limiter with unicode client ID."""

View File

@@ -11,15 +11,14 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from unittest.mock import Mock, AsyncMock, MagicMock, patch
from fastapi import Request
from middleware.theme_context import (
ThemeContextManager,
ThemeContextMiddleware,
get_current_theme,
)
from middleware.theme_context import (ThemeContextManager,
ThemeContextMiddleware,
get_current_theme)
@pytest.mark.unit
@@ -42,7 +41,14 @@ class TestThemeContextManager:
"""Test default theme has all required colors."""
theme = ThemeContextManager.get_default_theme()
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
required_colors = [
"primary",
"secondary",
"accent",
"background",
"text",
"border",
]
for color in required_colors:
assert color in theme["colors"]
assert theme["colors"][color].startswith("#")
@@ -79,10 +85,7 @@ class TestThemeContextManager:
mock_theme = Mock()
# Mock to_dict to return actual dictionary
custom_theme_dict = {
"theme_name": "custom",
"colors": {"primary": "#ff0000"}
}
custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}}
mock_theme.to_dict.return_value = custom_theme_dict
# Correct filter chain: query().filter().first()
@@ -141,8 +144,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
mock_theme = {"theme_name": "test_theme"}
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
with patch(
"middleware.theme_context.get_db", return_value=iter([mock_db])
), patch.object(
ThemeContextManager, "get_vendor_theme", return_value=mock_theme
):
await middleware.dispatch(request, call_next)
@@ -161,7 +167,7 @@ class TestThemeContextMiddleware:
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'theme')
assert hasattr(request.state, "theme")
assert request.state.theme["theme_name"] == "default"
call_next.assert_called_once()
@@ -178,8 +184,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
with patch(
"middleware.theme_context.get_db", return_value=iter([mock_db])
), patch.object(
ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error")
):
await middleware.dispatch(request, call_next)
@@ -224,7 +233,7 @@ class TestThemeEdgeCases:
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
with patch("middleware.theme_context.get_db", return_value=iter([mock_db])):
await middleware.dispatch(request, call_next)
# Verify database was closed

View File

@@ -11,17 +11,16 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from fastapi import Request, HTTPException
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session
from middleware.vendor_context import (
VendorContextManager,
VendorContextMiddleware,
get_current_vendor,
require_vendor_context,
)
from middleware.vendor_context import (VendorContextManager,
VendorContextMiddleware,
get_current_vendor,
require_vendor_context)
@pytest.mark.unit
@@ -39,7 +38,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -55,7 +54,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -71,7 +70,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -87,7 +86,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -140,7 +139,7 @@ class TestVendorContextManager:
request.headers = {"host": "admin.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -153,7 +152,7 @@ class TestVendorContextManager:
request.headers = {"host": "www.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -166,7 +165,7 @@ class TestVendorContextManager:
request.headers = {"host": "api.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -179,7 +178,7 @@ class TestVendorContextManager:
request.headers = {"host": "localhost"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -198,12 +197,11 @@ class TestVendorContextManager:
mock_vendor.is_active = True
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor_domain
)
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -218,12 +216,11 @@ class TestVendorContextManager:
mock_vendor.is_active = False
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor_domain
)
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -232,12 +229,11 @@ class TestVendorContextManager:
def test_get_vendor_from_custom_domain_not_found(self):
"""Test custom domain not found in database."""
mock_db = Mock(spec=Session)
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
None
)
context = {
"detection_method": "custom_domain",
"domain": "nonexistent.com"
}
context = {"detection_method": "custom_domain", "domain": "nonexistent.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -249,12 +245,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
context = {"detection_method": "subdomain", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -266,12 +261,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "path",
"subdomain": "vendor1"
}
context = {"detection_method": "path", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -291,12 +285,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "subdomain",
"subdomain": "VENDOR1" # Uppercase
}
context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -311,10 +304,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -325,10 +315,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendors/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendors/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -339,10 +326,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -353,10 +337,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -425,21 +406,24 @@ class TestVendorContextManager:
# Static File Detection Tests
# ========================================================================
@pytest.mark.parametrize("path", [
"/static/css/style.css",
"/static/js/app.js",
"/media/images/product.png",
"/assets/logo.svg",
"/.well-known/security.txt",
"/favicon.ico",
"/image.jpg",
"/style.css",
"/app.webmanifest",
"/static/", # Path starting with /static/ but no extension
"/media/uploads", # Path starting with /media/ but no extension
"/subfolder/favicon.ico", # favicon.ico in subfolder
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
])
@pytest.mark.parametrize(
"path",
[
"/static/css/style.css",
"/static/js/app.js",
"/media/images/product.png",
"/assets/logo.svg",
"/.well-known/security.txt",
"/favicon.ico",
"/image.jpg",
"/style.css",
"/app.webmanifest",
"/static/", # Path starting with /static/ but no extension
"/media/uploads", # Path starting with /media/ but no extension
"/subfolder/favicon.ico", # favicon.ico in subfolder
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
],
)
def test_is_static_file_request(self, path):
"""Test static file detection for various paths and extensions."""
request = Mock(spec=Request)
@@ -447,12 +431,15 @@ class TestVendorContextManager:
assert VendorContextManager.is_static_file_request(request) is True
@pytest.mark.parametrize("path", [
"/shop/products",
"/admin/dashboard",
"/api/vendors",
"/about",
])
@pytest.mark.parametrize(
"path",
[
"/shop/products",
"/admin/dashboard",
"/api/vendors",
"/about",
],
)
def test_is_not_static_file_request(self, path):
"""Test non-static file paths."""
request = Mock(spec=Request)
@@ -478,7 +465,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
with patch.object(VendorContextManager, "is_admin_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -498,7 +485,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
with patch.object(VendorContextManager, "is_api_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -517,7 +504,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
with patch.object(
VendorContextManager, "is_static_file_request", return_value=True
):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -540,17 +529,19 @@ class TestVendorContextMiddleware:
mock_vendor.name = "Test Vendor"
mock_vendor.subdomain = "vendor1"
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=vendor_context
), patch.object(
VendorContextManager, "get_vendor_from_context", return_value=mock_vendor
), patch.object(
VendorContextManager, "extract_clean_path", return_value="/shop/products"
), patch(
"middleware.vendor_context.get_db", return_value=iter([mock_db])
):
await middleware.dispatch(request, call_next)
@@ -571,16 +562,17 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
vendor_context = {
"detection_method": "subdomain",
"subdomain": "nonexistent"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=vendor_context
), patch.object(
VendorContextManager, "get_vendor_from_context", return_value=None
), patch(
"middleware.vendor_context.get_db", return_value=iter([mock_db])
):
await middleware.dispatch(request, call_next)
@@ -601,7 +593,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=None
):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -714,7 +708,7 @@ class TestEdgeCases:
request.headers = {"host": "shop.vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -735,11 +729,14 @@ class TestEdgeCases:
context = {"subdomain": "nonexistent", "detection_method": "subdomain"}
with patch('middleware.vendor_context.logger') as mock_logger:
with patch("middleware.vendor_context.logger") as mock_logger:
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is None
# Verify warning was logged
mock_logger.warning.assert_called()
warning_message = str(mock_logger.warning.call_args)
assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message
assert (
"No active vendor found for subdomain" in warning_message
and "nonexistent" in warning_message
)