style: apply black and isort formatting across entire codebase
- Standardize quote style (single to double quotes) - Reorder and group imports alphabetically - Fix line breaks and indentation for consistency - Apply PEP 8 formatting standards Also updated Makefile to exclude both venv and .venv from code quality checks. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -12,21 +12,18 @@ Tests cover:
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from fastapi import HTTPException
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from jose import jwt
|
||||
|
||||
from app.exceptions import (AdminRequiredException,
|
||||
InsufficientPermissionsException,
|
||||
InvalidCredentialsException, InvalidTokenException,
|
||||
TokenExpiredException, UserNotActiveException)
|
||||
from middleware.auth import AuthManager
|
||||
from app.exceptions import (
|
||||
InvalidTokenException,
|
||||
TokenExpiredException,
|
||||
UserNotActiveException,
|
||||
InvalidCredentialsException,
|
||||
AdminRequiredException,
|
||||
InsufficientPermissionsException,
|
||||
)
|
||||
from models.database.user import User
|
||||
|
||||
|
||||
@@ -124,7 +121,9 @@ class TestUserAuthentication:
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
result = auth_manager.authenticate_user(
|
||||
mock_db, "test@example.com", "password123"
|
||||
)
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
@@ -192,7 +191,9 @@ class TestJWTTokenCreation:
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Decode without verification to check payload
|
||||
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
|
||||
payload = jwt.decode(
|
||||
token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
|
||||
)
|
||||
|
||||
assert payload["sub"] == "42"
|
||||
assert payload["username"] == "testuser"
|
||||
@@ -205,8 +206,12 @@ class TestJWTTokenCreation:
|
||||
"""Test tokens are different for different users."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
|
||||
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
|
||||
user1 = Mock(
|
||||
spec=User, id=1, username="user1", email="user1@test.com", role="customer"
|
||||
)
|
||||
user2 = Mock(
|
||||
spec=User, id=2, username="user2", email="user2@test.com", role="vendor"
|
||||
)
|
||||
|
||||
token1 = auth_manager.create_access_token(user1)["access_token"]
|
||||
token2 = auth_manager.create_access_token(user2)["access_token"]
|
||||
@@ -227,7 +232,7 @@ class TestJWTTokenCreation:
|
||||
payload = jwt.decode(
|
||||
token_data["access_token"],
|
||||
auth_manager.secret_key,
|
||||
algorithms=[auth_manager.algorithm]
|
||||
algorithms=[auth_manager.algorithm],
|
||||
)
|
||||
|
||||
assert payload["role"] == "admin"
|
||||
@@ -311,9 +316,11 @@ class TestJWTTokenVerification:
|
||||
# Create token without 'sub' field
|
||||
payload = {
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
@@ -325,11 +332,10 @@ class TestJWTTokenVerification:
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# Create token without 'exp' field
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser"
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
payload = {"sub": "1", "username": "testuser"}
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
@@ -343,7 +349,7 @@ class TestJWTTokenVerification:
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
# Create token with different algorithm
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
|
||||
@@ -357,15 +363,13 @@ class TestJWTTokenVerification:
|
||||
|
||||
# Create a token with expiration in the past
|
||||
past_time = datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"exp": past_time.timestamp()
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()}
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
# Mock jwt.decode to bypass its expiration check and test line 205
|
||||
with patch('middleware.auth.jwt.decode') as mock_decode:
|
||||
with patch("middleware.auth.jwt.decode") as mock_decode:
|
||||
mock_decode.return_value = payload
|
||||
|
||||
with pytest.raises(TokenExpiredException):
|
||||
@@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser:
|
||||
|
||||
# Existing admin user
|
||||
existing_admin = Mock(spec=User)
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = (
|
||||
existing_admin
|
||||
)
|
||||
|
||||
result = auth_manager.create_default_admin_user(mock_db)
|
||||
|
||||
@@ -599,19 +605,21 @@ class TestAuthManagerConfiguration:
|
||||
|
||||
def test_default_configuration(self):
|
||||
"""Test AuthManager uses default configuration."""
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.algorithm == "HS256"
|
||||
assert auth_manager.token_expire_minutes == 30
|
||||
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
|
||||
assert (
|
||||
auth_manager.secret_key == "your-secret-key-change-in-production-please"
|
||||
)
|
||||
|
||||
def test_custom_configuration(self):
|
||||
"""Test AuthManager uses environment variables."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'custom-secret-key',
|
||||
'JWT_EXPIRE_MINUTES': '60'
|
||||
}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"},
|
||||
):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.secret_key == "custom-secret-key"
|
||||
@@ -619,9 +627,7 @@ class TestAuthManagerConfiguration:
|
||||
|
||||
def test_partial_custom_configuration(self):
|
||||
"""Test AuthManager with partial environment configuration."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_EXPIRE_MINUTES': '120'
|
||||
}, clear=False):
|
||||
with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.token_expire_minutes == 120
|
||||
@@ -656,9 +662,11 @@ class TestEdgeCases:
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=2),
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
# Should still verify successfully (JWT doesn't validate iat by default)
|
||||
result = auth_manager.verify_token(token)
|
||||
@@ -698,7 +706,9 @@ class TestEdgeCases:
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Mock jose.jwt.decode to raise an unexpected exception
|
||||
with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")):
|
||||
with patch(
|
||||
"middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error")
|
||||
):
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
|
||||
|
||||
@@ -10,16 +10,13 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.context import (
|
||||
ContextManager,
|
||||
ContextMiddleware,
|
||||
RequestContext,
|
||||
get_request_context,
|
||||
)
|
||||
from middleware.context import (ContextManager, ContextMiddleware,
|
||||
RequestContext, get_request_context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -321,22 +318,38 @@ class TestContextManagerHelpers:
|
||||
def test_is_admin_context_from_subdomain(self):
|
||||
"""Test _is_admin_context with admin subdomain."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(
|
||||
request, "admin.platform.com", "/dashboard"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_admin_context_from_path(self):
|
||||
"""Test _is_admin_context with admin path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(request, "localhost", "/admin/users")
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_admin_context_both(self):
|
||||
"""Test _is_admin_context with both subdomain and path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(
|
||||
request, "admin.platform.com", "/admin/users"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_not_admin_context(self):
|
||||
"""Test _is_admin_context returns False for non-admin."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
|
||||
assert (
|
||||
ContextManager._is_admin_context(request, "vendor.platform.com", "/shop")
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context with /vendor/ path."""
|
||||
@@ -344,11 +357,16 @@ class TestContextManagerHelpers:
|
||||
|
||||
def test_is_vendor_dashboard_context_nested(self):
|
||||
"""Test _is_vendor_dashboard_context with nested vendor path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
|
||||
assert (
|
||||
ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
|
||||
)
|
||||
|
||||
def test_is_not_vendor_dashboard_context_vendors_plural(self):
|
||||
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
|
||||
assert (
|
||||
ContextManager._is_vendor_dashboard_context("/vendors/shop123/products")
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_not_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
|
||||
@@ -373,7 +391,7 @@ class TestContextMiddleware:
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'context_type')
|
||||
assert hasattr(request.state, "context_type")
|
||||
assert request.state.context_type == RequestContext.API
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@@ -565,7 +583,7 @@ class TestEdgeCases:
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.headers = {"host": "localhost"}
|
||||
# No state attribute at all
|
||||
delattr(request, 'state')
|
||||
delattr(request, "state")
|
||||
|
||||
# Should still work, falling back to url.path
|
||||
with pytest.raises(AttributeError):
|
||||
|
||||
@@ -12,16 +12,18 @@ Tests cover:
|
||||
- Edge cases and isolation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from middleware.decorators import rate_limit, rate_limiter
|
||||
from app.exceptions.base import RateLimitException
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions.base import RateLimitException
|
||||
from middleware.decorators import rate_limit, rate_limiter
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limiter():
|
||||
"""Reset rate limiter state before each test to ensure isolation."""
|
||||
@@ -34,6 +36,7 @@ def reset_rate_limiter():
|
||||
# Rate Limit Decorator Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimitDecorator:
|
||||
@@ -42,6 +45,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_allows_within_limit(self):
|
||||
"""Test decorator allows requests within rate limit."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
@@ -53,6 +57,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_blocks_exceeding_limit(self):
|
||||
"""Test decorator blocks requests exceeding rate limit."""
|
||||
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint_blocked():
|
||||
return {"status": "ok"}
|
||||
@@ -71,6 +76,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_preserves_function_metadata(self):
|
||||
"""Test decorator preserves original function metadata."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
"""Test endpoint docstring."""
|
||||
@@ -82,21 +88,19 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_args_and_kwargs(self):
|
||||
"""Test decorator works with function arguments."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint(arg1, arg2, kwarg1=None):
|
||||
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
|
||||
|
||||
result = await test_endpoint("value1", "value2", kwarg1="value3")
|
||||
|
||||
assert result == {
|
||||
"arg1": "value1",
|
||||
"arg2": "value2",
|
||||
"kwarg1": "value3"
|
||||
}
|
||||
assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_default_parameters(self):
|
||||
"""Test decorator uses default parameters."""
|
||||
|
||||
@rate_limit() # Use defaults
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
@@ -108,6 +112,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_exception_includes_retry_after(self):
|
||||
"""Test rate limit exception includes retry_after."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=60)
|
||||
async def test_endpoint_retry():
|
||||
return {"status": "ok"}
|
||||
@@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_zero_max_requests(self):
|
||||
"""Test decorator with max_requests=0 blocks all requests."""
|
||||
|
||||
@rate_limit(max_requests=0, window_seconds=3600)
|
||||
async def test_endpoint_zero():
|
||||
return {"status": "ok"}
|
||||
@@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_very_short_window(self):
|
||||
"""Test decorator with very short time window."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=1)
|
||||
async def test_endpoint_short():
|
||||
return {"status": "ok"}
|
||||
@@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_multiple_functions_separate_limits(self):
|
||||
"""Test that different functions have separate rate limits."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=3600)
|
||||
async def endpoint1():
|
||||
return {"endpoint": "1"}
|
||||
@@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_exception_in_function(self):
|
||||
"""Test decorator handles exceptions from wrapped function."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint_error():
|
||||
raise ValueError("Function error")
|
||||
@@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_isolation_between_tests(self):
|
||||
"""Test that rate limiter state is properly isolated between tests."""
|
||||
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint_isolation():
|
||||
return {"status": "ok"}
|
||||
@@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_dict(self):
|
||||
"""Test decorator correctly returns dictionary."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_dict():
|
||||
return {"key": "value", "number": 42}
|
||||
@@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_list(self):
|
||||
"""Test decorator correctly returns list."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_list():
|
||||
return [1, 2, 3, 4, 5]
|
||||
@@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_none(self):
|
||||
"""Test decorator correctly returns None."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_none():
|
||||
return None
|
||||
@@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_object(self):
|
||||
"""Test decorator correctly returns custom objects."""
|
||||
|
||||
class TestObject:
|
||||
def __init__(self):
|
||||
self.name = "test_object"
|
||||
|
||||
@@ -11,9 +11,10 @@ Tests cover:
|
||||
- Edge cases (missing client info, etc.)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.logging import LoggingMiddleware
|
||||
@@ -40,7 +41,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify request was logged
|
||||
@@ -65,7 +66,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify response was logged
|
||||
@@ -89,7 +90,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert "X-Process-Time" in response.headers
|
||||
@@ -113,11 +114,13 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should log "unknown" for client
|
||||
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
"unknown" in str(call) for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_exceptions(self):
|
||||
@@ -131,8 +134,9 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger, \
|
||||
pytest.raises(Exception):
|
||||
with patch("middleware.logging.logger") as mock_logger, pytest.raises(
|
||||
Exception
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify error was logged
|
||||
@@ -156,7 +160,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = slow_call_next
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
process_time = float(result.headers["X-Process-Time"])
|
||||
@@ -181,7 +185,7 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=200, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should still have process time, even if very small
|
||||
@@ -205,11 +209,13 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=200, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify method was logged
|
||||
assert any(method in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
method in str(call) for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_different_status_codes(self):
|
||||
@@ -227,8 +233,11 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=status_code, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify status code was logged
|
||||
assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
str(status_code) in str(call)
|
||||
for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@@ -11,10 +11,11 @@ Tests cover:
|
||||
- Edge cases and concurrency scenarios
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from middleware.rate_limiter import RateLimiter
|
||||
|
||||
@@ -306,8 +307,8 @@ class TestRateLimiterStatistics:
|
||||
# Add requests at different times
|
||||
now = datetime.now(timezone.utc)
|
||||
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
|
||||
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
|
||||
|
||||
stats = limiter.get_client_stats(client_id)
|
||||
|
||||
@@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases:
|
||||
limiter = RateLimiter()
|
||||
client_id = "long_window_client"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
|
||||
result = limiter.allow_request(
|
||||
client_id, max_requests=10, window_seconds=86400 * 365
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases:
|
||||
client_id = "same_client"
|
||||
|
||||
# Allow with one limit
|
||||
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||
is True
|
||||
)
|
||||
|
||||
# Check with stricter limit
|
||||
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=1, window_seconds=3600)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_rate_limiter_unicode_client_id(self):
|
||||
"""Test rate limiter with unicode client ID."""
|
||||
|
||||
@@ -11,15 +11,14 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.theme_context import (
|
||||
ThemeContextManager,
|
||||
ThemeContextMiddleware,
|
||||
get_current_theme,
|
||||
)
|
||||
from middleware.theme_context import (ThemeContextManager,
|
||||
ThemeContextMiddleware,
|
||||
get_current_theme)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -42,7 +41,14 @@ class TestThemeContextManager:
|
||||
"""Test default theme has all required colors."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
|
||||
required_colors = [
|
||||
"primary",
|
||||
"secondary",
|
||||
"accent",
|
||||
"background",
|
||||
"text",
|
||||
"border",
|
||||
]
|
||||
for color in required_colors:
|
||||
assert color in theme["colors"]
|
||||
assert theme["colors"][color].startswith("#")
|
||||
@@ -79,10 +85,7 @@ class TestThemeContextManager:
|
||||
mock_theme = Mock()
|
||||
|
||||
# Mock to_dict to return actual dictionary
|
||||
custom_theme_dict = {
|
||||
"theme_name": "custom",
|
||||
"colors": {"primary": "#ff0000"}
|
||||
}
|
||||
custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}}
|
||||
mock_theme.to_dict.return_value = custom_theme_dict
|
||||
|
||||
# Correct filter chain: query().filter().first()
|
||||
@@ -141,8 +144,11 @@ class TestThemeContextMiddleware:
|
||||
mock_db = MagicMock()
|
||||
mock_theme = {"theme_name": "test_theme"}
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
|
||||
with patch(
|
||||
"middleware.theme_context.get_db", return_value=iter([mock_db])
|
||||
), patch.object(
|
||||
ThemeContextManager, "get_vendor_theme", return_value=mock_theme
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -161,7 +167,7 @@ class TestThemeContextMiddleware:
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'theme')
|
||||
assert hasattr(request.state, "theme")
|
||||
assert request.state.theme["theme_name"] == "default"
|
||||
call_next.assert_called_once()
|
||||
|
||||
@@ -178,8 +184,11 @@ class TestThemeContextMiddleware:
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
|
||||
with patch(
|
||||
"middleware.theme_context.get_db", return_value=iter([mock_db])
|
||||
), patch.object(
|
||||
ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error")
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -224,7 +233,7 @@ class TestThemeEdgeCases:
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
|
||||
with patch("middleware.theme_context.get_db", return_value=iter([mock_db])):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify database was closed
|
||||
|
||||
@@ -11,17 +11,16 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.vendor_context import (
|
||||
VendorContextManager,
|
||||
VendorContextMiddleware,
|
||||
get_current_vendor,
|
||||
require_vendor_context,
|
||||
)
|
||||
from middleware.vendor_context import (VendorContextManager,
|
||||
VendorContextMiddleware,
|
||||
get_current_vendor,
|
||||
require_vendor_context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -39,7 +38,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "customdomain1.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -55,7 +54,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "customdomain1.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -71,7 +70,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -87,7 +86,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "vendor1.platform.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -140,7 +139,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -153,7 +152,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "www.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -166,7 +165,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "api.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -179,7 +178,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -198,12 +197,11 @@ class TestVendorContextManager:
|
||||
mock_vendor.is_active = True
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor_domain
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -218,12 +216,11 @@ class TestVendorContextManager:
|
||||
mock_vendor.is_active = False
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor_domain
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -232,12 +229,11 @@ class TestVendorContextManager:
|
||||
def test_get_vendor_from_custom_domain_not_found(self):
|
||||
"""Test custom domain not found in database."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "nonexistent.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "nonexistent.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -249,12 +245,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -266,12 +261,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "path",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
context = {"detection_method": "path", "subdomain": "vendor1"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -291,12 +285,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "VENDOR1" # Uppercase
|
||||
}
|
||||
context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -311,10 +304,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -325,10 +315,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendors/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendors/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -339,10 +326,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -353,10 +337,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -425,21 +406,24 @@ class TestVendorContextManager:
|
||||
# Static File Detection Tests
|
||||
# ========================================================================
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/static/css/style.css",
|
||||
"/static/js/app.js",
|
||||
"/media/images/product.png",
|
||||
"/assets/logo.svg",
|
||||
"/.well-known/security.txt",
|
||||
"/favicon.ico",
|
||||
"/image.jpg",
|
||||
"/style.css",
|
||||
"/app.webmanifest",
|
||||
"/static/", # Path starting with /static/ but no extension
|
||||
"/media/uploads", # Path starting with /media/ but no extension
|
||||
"/subfolder/favicon.ico", # favicon.ico in subfolder
|
||||
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/static/css/style.css",
|
||||
"/static/js/app.js",
|
||||
"/media/images/product.png",
|
||||
"/assets/logo.svg",
|
||||
"/.well-known/security.txt",
|
||||
"/favicon.ico",
|
||||
"/image.jpg",
|
||||
"/style.css",
|
||||
"/app.webmanifest",
|
||||
"/static/", # Path starting with /static/ but no extension
|
||||
"/media/uploads", # Path starting with /media/ but no extension
|
||||
"/subfolder/favicon.ico", # favicon.ico in subfolder
|
||||
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
|
||||
],
|
||||
)
|
||||
def test_is_static_file_request(self, path):
|
||||
"""Test static file detection for various paths and extensions."""
|
||||
request = Mock(spec=Request)
|
||||
@@ -447,12 +431,15 @@ class TestVendorContextManager:
|
||||
|
||||
assert VendorContextManager.is_static_file_request(request) is True
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/shop/products",
|
||||
"/admin/dashboard",
|
||||
"/api/vendors",
|
||||
"/about",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/shop/products",
|
||||
"/admin/dashboard",
|
||||
"/api/vendors",
|
||||
"/about",
|
||||
],
|
||||
)
|
||||
def test_is_not_static_file_request(self, path):
|
||||
"""Test non-static file paths."""
|
||||
request = Mock(spec=Request)
|
||||
@@ -478,7 +465,7 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
|
||||
with patch.object(VendorContextManager, "is_admin_request", return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -498,7 +485,7 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
|
||||
with patch.object(VendorContextManager, "is_api_request", return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -517,7 +504,9 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
|
||||
with patch.object(
|
||||
VendorContextManager, "is_static_file_request", return_value=True
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -540,17 +529,19 @@ class TestVendorContextMiddleware:
|
||||
mock_vendor.name = "Test Vendor"
|
||||
mock_vendor.subdomain = "vendor1"
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
|
||||
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=vendor_context
|
||||
), patch.object(
|
||||
VendorContextManager, "get_vendor_from_context", return_value=mock_vendor
|
||||
), patch.object(
|
||||
VendorContextManager, "extract_clean_path", return_value="/shop/products"
|
||||
), patch(
|
||||
"middleware.vendor_context.get_db", return_value=iter([mock_db])
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -571,16 +562,17 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "nonexistent"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=vendor_context
|
||||
), patch.object(
|
||||
VendorContextManager, "get_vendor_from_context", return_value=None
|
||||
), patch(
|
||||
"middleware.vendor_context.get_db", return_value=iter([mock_db])
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -601,7 +593,9 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=None
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -714,7 +708,7 @@ class TestEdgeCases:
|
||||
request.headers = {"host": "shop.vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -735,11 +729,14 @@ class TestEdgeCases:
|
||||
|
||||
context = {"subdomain": "nonexistent", "detection_method": "subdomain"}
|
||||
|
||||
with patch('middleware.vendor_context.logger') as mock_logger:
|
||||
with patch("middleware.vendor_context.logger") as mock_logger:
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is None
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called()
|
||||
warning_message = str(mock_logger.warning.call_args)
|
||||
assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message
|
||||
assert (
|
||||
"No active vendor found for subdomain" in warning_message
|
||||
and "nonexistent" in warning_message
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user