style: apply black and isort formatting across entire codebase
- Standardize quote style (single to double quotes) - Reorder and group imports alphabetically - Fix line breaks and indentation for consistency - Apply PEP 8 formatting standards Also updated Makefile to exclude both venv and .venv from code quality checks. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -12,21 +12,18 @@ Tests cover:
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
from fastapi import HTTPException
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from jose import jwt
|
||||
|
||||
from app.exceptions import (AdminRequiredException,
|
||||
InsufficientPermissionsException,
|
||||
InvalidCredentialsException, InvalidTokenException,
|
||||
TokenExpiredException, UserNotActiveException)
|
||||
from middleware.auth import AuthManager
|
||||
from app.exceptions import (
|
||||
InvalidTokenException,
|
||||
TokenExpiredException,
|
||||
UserNotActiveException,
|
||||
InvalidCredentialsException,
|
||||
AdminRequiredException,
|
||||
InsufficientPermissionsException,
|
||||
)
|
||||
from models.database.user import User
|
||||
|
||||
|
||||
@@ -124,7 +121,9 @@ class TestUserAuthentication:
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
result = auth_manager.authenticate_user(
|
||||
mock_db, "test@example.com", "password123"
|
||||
)
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
@@ -192,7 +191,9 @@ class TestJWTTokenCreation:
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Decode without verification to check payload
|
||||
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
|
||||
payload = jwt.decode(
|
||||
token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
|
||||
)
|
||||
|
||||
assert payload["sub"] == "42"
|
||||
assert payload["username"] == "testuser"
|
||||
@@ -205,8 +206,12 @@ class TestJWTTokenCreation:
|
||||
"""Test tokens are different for different users."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
|
||||
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
|
||||
user1 = Mock(
|
||||
spec=User, id=1, username="user1", email="user1@test.com", role="customer"
|
||||
)
|
||||
user2 = Mock(
|
||||
spec=User, id=2, username="user2", email="user2@test.com", role="vendor"
|
||||
)
|
||||
|
||||
token1 = auth_manager.create_access_token(user1)["access_token"]
|
||||
token2 = auth_manager.create_access_token(user2)["access_token"]
|
||||
@@ -227,7 +232,7 @@ class TestJWTTokenCreation:
|
||||
payload = jwt.decode(
|
||||
token_data["access_token"],
|
||||
auth_manager.secret_key,
|
||||
algorithms=[auth_manager.algorithm]
|
||||
algorithms=[auth_manager.algorithm],
|
||||
)
|
||||
|
||||
assert payload["role"] == "admin"
|
||||
@@ -311,9 +316,11 @@ class TestJWTTokenVerification:
|
||||
# Create token without 'sub' field
|
||||
payload = {
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
@@ -325,11 +332,10 @@ class TestJWTTokenVerification:
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# Create token without 'exp' field
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser"
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
payload = {"sub": "1", "username": "testuser"}
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
@@ -343,7 +349,7 @@ class TestJWTTokenVerification:
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
|
||||
}
|
||||
# Create token with different algorithm
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
|
||||
@@ -357,15 +363,13 @@ class TestJWTTokenVerification:
|
||||
|
||||
# Create a token with expiration in the past
|
||||
past_time = datetime.now(timezone.utc) - timedelta(minutes=1)
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"exp": past_time.timestamp()
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()}
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
# Mock jwt.decode to bypass its expiration check and test line 205
|
||||
with patch('middleware.auth.jwt.decode') as mock_decode:
|
||||
with patch("middleware.auth.jwt.decode") as mock_decode:
|
||||
mock_decode.return_value = payload
|
||||
|
||||
with pytest.raises(TokenExpiredException):
|
||||
@@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser:
|
||||
|
||||
# Existing admin user
|
||||
existing_admin = Mock(spec=User)
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = (
|
||||
existing_admin
|
||||
)
|
||||
|
||||
result = auth_manager.create_default_admin_user(mock_db)
|
||||
|
||||
@@ -599,19 +605,21 @@ class TestAuthManagerConfiguration:
|
||||
|
||||
def test_default_configuration(self):
|
||||
"""Test AuthManager uses default configuration."""
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.algorithm == "HS256"
|
||||
assert auth_manager.token_expire_minutes == 30
|
||||
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
|
||||
assert (
|
||||
auth_manager.secret_key == "your-secret-key-change-in-production-please"
|
||||
)
|
||||
|
||||
def test_custom_configuration(self):
|
||||
"""Test AuthManager uses environment variables."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'custom-secret-key',
|
||||
'JWT_EXPIRE_MINUTES': '60'
|
||||
}):
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"},
|
||||
):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.secret_key == "custom-secret-key"
|
||||
@@ -619,9 +627,7 @@ class TestAuthManagerConfiguration:
|
||||
|
||||
def test_partial_custom_configuration(self):
|
||||
"""Test AuthManager with partial environment configuration."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_EXPIRE_MINUTES': '120'
|
||||
}, clear=False):
|
||||
with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.token_expire_minutes == 120
|
||||
@@ -656,9 +662,11 @@ class TestEdgeCases:
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=2),
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
token = jwt.encode(
|
||||
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
|
||||
)
|
||||
|
||||
# Should still verify successfully (JWT doesn't validate iat by default)
|
||||
result = auth_manager.verify_token(token)
|
||||
@@ -698,7 +706,9 @@ class TestEdgeCases:
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Mock jose.jwt.decode to raise an unexpected exception
|
||||
with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")):
|
||||
with patch(
|
||||
"middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error")
|
||||
):
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
|
||||
|
||||
@@ -10,16 +10,13 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.context import (
|
||||
ContextManager,
|
||||
ContextMiddleware,
|
||||
RequestContext,
|
||||
get_request_context,
|
||||
)
|
||||
from middleware.context import (ContextManager, ContextMiddleware,
|
||||
RequestContext, get_request_context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -321,22 +318,38 @@ class TestContextManagerHelpers:
|
||||
def test_is_admin_context_from_subdomain(self):
|
||||
"""Test _is_admin_context with admin subdomain."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(
|
||||
request, "admin.platform.com", "/dashboard"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_admin_context_from_path(self):
|
||||
"""Test _is_admin_context with admin path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(request, "localhost", "/admin/users")
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_admin_context_both(self):
|
||||
"""Test _is_admin_context with both subdomain and path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
|
||||
assert (
|
||||
ContextManager._is_admin_context(
|
||||
request, "admin.platform.com", "/admin/users"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_is_not_admin_context(self):
|
||||
"""Test _is_admin_context returns False for non-admin."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
|
||||
assert (
|
||||
ContextManager._is_admin_context(request, "vendor.platform.com", "/shop")
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context with /vendor/ path."""
|
||||
@@ -344,11 +357,16 @@ class TestContextManagerHelpers:
|
||||
|
||||
def test_is_vendor_dashboard_context_nested(self):
|
||||
"""Test _is_vendor_dashboard_context with nested vendor path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
|
||||
assert (
|
||||
ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
|
||||
)
|
||||
|
||||
def test_is_not_vendor_dashboard_context_vendors_plural(self):
|
||||
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
|
||||
assert (
|
||||
ContextManager._is_vendor_dashboard_context("/vendors/shop123/products")
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_not_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
|
||||
@@ -373,7 +391,7 @@ class TestContextMiddleware:
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'context_type')
|
||||
assert hasattr(request.state, "context_type")
|
||||
assert request.state.context_type == RequestContext.API
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@@ -565,7 +583,7 @@ class TestEdgeCases:
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.headers = {"host": "localhost"}
|
||||
# No state attribute at all
|
||||
delattr(request, 'state')
|
||||
delattr(request, "state")
|
||||
|
||||
# Should still work, falling back to url.path
|
||||
with pytest.raises(AttributeError):
|
||||
|
||||
@@ -12,16 +12,18 @@ Tests cover:
|
||||
- Edge cases and isolation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from middleware.decorators import rate_limit, rate_limiter
|
||||
from app.exceptions.base import RateLimitException
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions.base import RateLimitException
|
||||
from middleware.decorators import rate_limit, rate_limiter
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_rate_limiter():
|
||||
"""Reset rate limiter state before each test to ensure isolation."""
|
||||
@@ -34,6 +36,7 @@ def reset_rate_limiter():
|
||||
# Rate Limit Decorator Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimitDecorator:
|
||||
@@ -42,6 +45,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_allows_within_limit(self):
|
||||
"""Test decorator allows requests within rate limit."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
@@ -53,6 +57,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_blocks_exceeding_limit(self):
|
||||
"""Test decorator blocks requests exceeding rate limit."""
|
||||
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint_blocked():
|
||||
return {"status": "ok"}
|
||||
@@ -71,6 +76,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_preserves_function_metadata(self):
|
||||
"""Test decorator preserves original function metadata."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
"""Test endpoint docstring."""
|
||||
@@ -82,21 +88,19 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_args_and_kwargs(self):
|
||||
"""Test decorator works with function arguments."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint(arg1, arg2, kwarg1=None):
|
||||
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
|
||||
|
||||
result = await test_endpoint("value1", "value2", kwarg1="value3")
|
||||
|
||||
assert result == {
|
||||
"arg1": "value1",
|
||||
"arg2": "value2",
|
||||
"kwarg1": "value3"
|
||||
}
|
||||
assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_default_parameters(self):
|
||||
"""Test decorator uses default parameters."""
|
||||
|
||||
@rate_limit() # Use defaults
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
@@ -108,6 +112,7 @@ class TestRateLimitDecorator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_exception_includes_retry_after(self):
|
||||
"""Test rate limit exception includes retry_after."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=60)
|
||||
async def test_endpoint_retry():
|
||||
return {"status": "ok"}
|
||||
@@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_zero_max_requests(self):
|
||||
"""Test decorator with max_requests=0 blocks all requests."""
|
||||
|
||||
@rate_limit(max_requests=0, window_seconds=3600)
|
||||
async def test_endpoint_zero():
|
||||
return {"status": "ok"}
|
||||
@@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_very_short_window(self):
|
||||
"""Test decorator with very short time window."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=1)
|
||||
async def test_endpoint_short():
|
||||
return {"status": "ok"}
|
||||
@@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_multiple_functions_separate_limits(self):
|
||||
"""Test that different functions have separate rate limits."""
|
||||
|
||||
@rate_limit(max_requests=1, window_seconds=3600)
|
||||
async def endpoint1():
|
||||
return {"endpoint": "1"}
|
||||
@@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_exception_in_function(self):
|
||||
"""Test decorator handles exceptions from wrapped function."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint_error():
|
||||
raise ValueError("Function error")
|
||||
@@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_isolation_between_tests(self):
|
||||
"""Test that rate limiter state is properly isolated between tests."""
|
||||
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint_isolation():
|
||||
return {"status": "ok"}
|
||||
@@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_dict(self):
|
||||
"""Test decorator correctly returns dictionary."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_dict():
|
||||
return {"key": "value", "number": 42}
|
||||
@@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_list(self):
|
||||
"""Test decorator correctly returns list."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_list():
|
||||
return [1, 2, 3, 4, 5]
|
||||
@@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_none(self):
|
||||
"""Test decorator correctly returns None."""
|
||||
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def return_none():
|
||||
return None
|
||||
@@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues:
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_returns_object(self):
|
||||
"""Test decorator correctly returns custom objects."""
|
||||
|
||||
class TestObject:
|
||||
def __init__(self):
|
||||
self.name = "test_object"
|
||||
|
||||
@@ -11,9 +11,10 @@ Tests cover:
|
||||
- Edge cases (missing client info, etc.)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.logging import LoggingMiddleware
|
||||
@@ -40,7 +41,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify request was logged
|
||||
@@ -65,7 +66,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify response was logged
|
||||
@@ -89,7 +90,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert "X-Process-Time" in response.headers
|
||||
@@ -113,11 +114,13 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should log "unknown" for client
|
||||
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
"unknown" in str(call) for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_exceptions(self):
|
||||
@@ -131,8 +134,9 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger, \
|
||||
pytest.raises(Exception):
|
||||
with patch("middleware.logging.logger") as mock_logger, pytest.raises(
|
||||
Exception
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify error was logged
|
||||
@@ -156,7 +160,7 @@ class TestLoggingMiddleware:
|
||||
|
||||
call_next = slow_call_next
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
process_time = float(result.headers["X-Process-Time"])
|
||||
@@ -181,7 +185,7 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=200, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger'):
|
||||
with patch("middleware.logging.logger"):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should still have process time, even if very small
|
||||
@@ -205,11 +209,13 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=200, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify method was logged
|
||||
assert any(method in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
method in str(call) for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_different_status_codes(self):
|
||||
@@ -227,8 +233,11 @@ class TestLoggingEdgeCases:
|
||||
response = Mock(status_code=status_code, headers={})
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging.logger') as mock_logger:
|
||||
with patch("middleware.logging.logger") as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify status code was logged
|
||||
assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list)
|
||||
assert any(
|
||||
str(status_code) in str(call)
|
||||
for call in mock_logger.info.call_args_list
|
||||
)
|
||||
|
||||
@@ -11,10 +11,11 @@ Tests cover:
|
||||
- Edge cases and concurrency scenarios
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from middleware.rate_limiter import RateLimiter
|
||||
|
||||
@@ -306,8 +307,8 @@ class TestRateLimiterStatistics:
|
||||
# Add requests at different times
|
||||
now = datetime.now(timezone.utc)
|
||||
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
|
||||
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
|
||||
|
||||
stats = limiter.get_client_stats(client_id)
|
||||
|
||||
@@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases:
|
||||
limiter = RateLimiter()
|
||||
client_id = "long_window_client"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
|
||||
result = limiter.allow_request(
|
||||
client_id, max_requests=10, window_seconds=86400 * 365
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases:
|
||||
client_id = "same_client"
|
||||
|
||||
# Allow with one limit
|
||||
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||
is True
|
||||
)
|
||||
|
||||
# Check with stricter limit
|
||||
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
|
||||
assert (
|
||||
limiter.allow_request(client_id, max_requests=1, window_seconds=3600)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_rate_limiter_unicode_client_id(self):
|
||||
"""Test rate limiter with unicode client ID."""
|
||||
|
||||
@@ -11,15 +11,14 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.theme_context import (
|
||||
ThemeContextManager,
|
||||
ThemeContextMiddleware,
|
||||
get_current_theme,
|
||||
)
|
||||
from middleware.theme_context import (ThemeContextManager,
|
||||
ThemeContextMiddleware,
|
||||
get_current_theme)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -42,7 +41,14 @@ class TestThemeContextManager:
|
||||
"""Test default theme has all required colors."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
|
||||
required_colors = [
|
||||
"primary",
|
||||
"secondary",
|
||||
"accent",
|
||||
"background",
|
||||
"text",
|
||||
"border",
|
||||
]
|
||||
for color in required_colors:
|
||||
assert color in theme["colors"]
|
||||
assert theme["colors"][color].startswith("#")
|
||||
@@ -79,10 +85,7 @@ class TestThemeContextManager:
|
||||
mock_theme = Mock()
|
||||
|
||||
# Mock to_dict to return actual dictionary
|
||||
custom_theme_dict = {
|
||||
"theme_name": "custom",
|
||||
"colors": {"primary": "#ff0000"}
|
||||
}
|
||||
custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}}
|
||||
mock_theme.to_dict.return_value = custom_theme_dict
|
||||
|
||||
# Correct filter chain: query().filter().first()
|
||||
@@ -141,8 +144,11 @@ class TestThemeContextMiddleware:
|
||||
mock_db = MagicMock()
|
||||
mock_theme = {"theme_name": "test_theme"}
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
|
||||
with patch(
|
||||
"middleware.theme_context.get_db", return_value=iter([mock_db])
|
||||
), patch.object(
|
||||
ThemeContextManager, "get_vendor_theme", return_value=mock_theme
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -161,7 +167,7 @@ class TestThemeContextMiddleware:
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'theme')
|
||||
assert hasattr(request.state, "theme")
|
||||
assert request.state.theme["theme_name"] == "default"
|
||||
call_next.assert_called_once()
|
||||
|
||||
@@ -178,8 +184,11 @@ class TestThemeContextMiddleware:
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
|
||||
with patch(
|
||||
"middleware.theme_context.get_db", return_value=iter([mock_db])
|
||||
), patch.object(
|
||||
ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error")
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -224,7 +233,7 @@ class TestThemeEdgeCases:
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
|
||||
with patch("middleware.theme_context.get_db", return_value=iter([mock_db])):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify database was closed
|
||||
|
||||
@@ -11,17 +11,16 @@ Tests cover:
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.vendor_context import (
|
||||
VendorContextManager,
|
||||
VendorContextMiddleware,
|
||||
get_current_vendor,
|
||||
require_vendor_context,
|
||||
)
|
||||
from middleware.vendor_context import (VendorContextManager,
|
||||
VendorContextMiddleware,
|
||||
get_current_vendor,
|
||||
require_vendor_context)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -39,7 +38,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "customdomain1.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -55,7 +54,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "customdomain1.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -71,7 +70,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -87,7 +86,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "vendor1.platform.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -140,7 +139,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -153,7 +152,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "www.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -166,7 +165,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "api.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -179,7 +178,7 @@ class TestVendorContextManager:
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -198,12 +197,11 @@ class TestVendorContextManager:
|
||||
mock_vendor.is_active = True
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor_domain
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -218,12 +216,11 @@ class TestVendorContextManager:
|
||||
mock_vendor.is_active = False
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor_domain
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -232,12 +229,11 @@ class TestVendorContextManager:
|
||||
def test_get_vendor_from_custom_domain_not_found(self):
|
||||
"""Test custom domain not found in database."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "nonexistent.com"
|
||||
}
|
||||
context = {"detection_method": "custom_domain", "domain": "nonexistent.com"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -249,12 +245,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -266,12 +261,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "path",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
context = {"detection_method": "path", "subdomain": "vendor1"}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -291,12 +285,11 @@ class TestVendorContextManager:
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
|
||||
mock_vendor
|
||||
)
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "VENDOR1" # Uppercase
|
||||
}
|
||||
context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
@@ -311,10 +304,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -325,10 +315,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendors/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendors/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -339,10 +326,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -353,10 +337,7 @@ class TestVendorContextManager:
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
@@ -425,21 +406,24 @@ class TestVendorContextManager:
|
||||
# Static File Detection Tests
|
||||
# ========================================================================
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/static/css/style.css",
|
||||
"/static/js/app.js",
|
||||
"/media/images/product.png",
|
||||
"/assets/logo.svg",
|
||||
"/.well-known/security.txt",
|
||||
"/favicon.ico",
|
||||
"/image.jpg",
|
||||
"/style.css",
|
||||
"/app.webmanifest",
|
||||
"/static/", # Path starting with /static/ but no extension
|
||||
"/media/uploads", # Path starting with /media/ but no extension
|
||||
"/subfolder/favicon.ico", # favicon.ico in subfolder
|
||||
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/static/css/style.css",
|
||||
"/static/js/app.js",
|
||||
"/media/images/product.png",
|
||||
"/assets/logo.svg",
|
||||
"/.well-known/security.txt",
|
||||
"/favicon.ico",
|
||||
"/image.jpg",
|
||||
"/style.css",
|
||||
"/app.webmanifest",
|
||||
"/static/", # Path starting with /static/ but no extension
|
||||
"/media/uploads", # Path starting with /media/ but no extension
|
||||
"/subfolder/favicon.ico", # favicon.ico in subfolder
|
||||
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
|
||||
],
|
||||
)
|
||||
def test_is_static_file_request(self, path):
|
||||
"""Test static file detection for various paths and extensions."""
|
||||
request = Mock(spec=Request)
|
||||
@@ -447,12 +431,15 @@ class TestVendorContextManager:
|
||||
|
||||
assert VendorContextManager.is_static_file_request(request) is True
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/shop/products",
|
||||
"/admin/dashboard",
|
||||
"/api/vendors",
|
||||
"/about",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/shop/products",
|
||||
"/admin/dashboard",
|
||||
"/api/vendors",
|
||||
"/about",
|
||||
],
|
||||
)
|
||||
def test_is_not_static_file_request(self, path):
|
||||
"""Test non-static file paths."""
|
||||
request = Mock(spec=Request)
|
||||
@@ -478,7 +465,7 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
|
||||
with patch.object(VendorContextManager, "is_admin_request", return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -498,7 +485,7 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
|
||||
with patch.object(VendorContextManager, "is_api_request", return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -517,7 +504,9 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
|
||||
with patch.object(
|
||||
VendorContextManager, "is_static_file_request", return_value=True
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -540,17 +529,19 @@ class TestVendorContextMiddleware:
|
||||
mock_vendor.name = "Test Vendor"
|
||||
mock_vendor.subdomain = "vendor1"
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
|
||||
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=vendor_context
|
||||
), patch.object(
|
||||
VendorContextManager, "get_vendor_from_context", return_value=mock_vendor
|
||||
), patch.object(
|
||||
VendorContextManager, "extract_clean_path", return_value="/shop/products"
|
||||
), patch(
|
||||
"middleware.vendor_context.get_db", return_value=iter([mock_db])
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -571,16 +562,17 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "nonexistent"
|
||||
}
|
||||
vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=vendor_context
|
||||
), patch.object(
|
||||
VendorContextManager, "get_vendor_from_context", return_value=None
|
||||
), patch(
|
||||
"middleware.vendor_context.get_db", return_value=iter([mock_db])
|
||||
):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
@@ -601,7 +593,9 @@ class TestVendorContextMiddleware:
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
|
||||
with patch.object(
|
||||
VendorContextManager, "detect_vendor_context", return_value=None
|
||||
):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
@@ -714,7 +708,7 @@ class TestEdgeCases:
|
||||
request.headers = {"host": "shop.vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
with patch("middleware.vendor_context.settings") as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
@@ -735,11 +729,14 @@ class TestEdgeCases:
|
||||
|
||||
context = {"subdomain": "nonexistent", "detection_method": "subdomain"}
|
||||
|
||||
with patch('middleware.vendor_context.logger') as mock_logger:
|
||||
with patch("middleware.vendor_context.logger") as mock_logger:
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is None
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called()
|
||||
warning_message = str(mock_logger.warning.call_args)
|
||||
assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message
|
||||
assert (
|
||||
"No active vendor found for subdomain" in warning_message
|
||||
and "nonexistent" in warning_message
|
||||
)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# tests/unit/models/test_database_models.py
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.vendor import Vendor, VendorUser, Role
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.user import User
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.product import Product
|
||||
from models.database.customer import Customer, CustomerAddress
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.order import Order, OrderItem
|
||||
from models.database.product import Product
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Role, Vendor, VendorUser
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -277,7 +278,7 @@ class TestMarketplaceProductModel:
|
||||
vendor_id=test_vendor.id,
|
||||
marketplace_product_id="UNIQUE_001",
|
||||
title="Product 1",
|
||||
marketplace="Letzshop"
|
||||
marketplace="Letzshop",
|
||||
)
|
||||
db.add(product1)
|
||||
db.commit()
|
||||
@@ -288,7 +289,7 @@ class TestMarketplaceProductModel:
|
||||
vendor_id=test_vendor.id,
|
||||
marketplace_product_id="UNIQUE_001",
|
||||
title="Product 2",
|
||||
marketplace="Letzshop"
|
||||
marketplace="Letzshop",
|
||||
)
|
||||
db.add(product2)
|
||||
db.commit()
|
||||
@@ -515,7 +516,9 @@ class TestCustomerModel:
|
||||
class TestOrderModel:
|
||||
"""Test Order model"""
|
||||
|
||||
def test_order_creation(self, db, test_vendor, test_customer, test_customer_address):
|
||||
def test_order_creation(
|
||||
self, db, test_vendor, test_customer, test_customer_address
|
||||
):
|
||||
"""Test Order model with customer relationship"""
|
||||
order = Order(
|
||||
vendor_id=test_vendor.id,
|
||||
@@ -563,7 +566,9 @@ class TestOrderModel:
|
||||
assert float(order_item.unit_price) == 49.99
|
||||
assert float(order_item.total_price) == 99.98
|
||||
|
||||
def test_order_number_uniqueness(self, db, test_vendor, test_customer, test_customer_address):
|
||||
def test_order_number_uniqueness(
|
||||
self, db, test_vendor, test_customer, test_customer_address
|
||||
):
|
||||
"""Test order_number unique constraint"""
|
||||
order1 = Order(
|
||||
vendor_id=test_vendor.id,
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
# tests/unit/services/test_admin_service.py
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (
|
||||
UserNotFoundException,
|
||||
UserStatusChangeException,
|
||||
CannotModifySelfException,
|
||||
VendorNotFoundException,
|
||||
VendorVerificationException,
|
||||
AdminOperationException,
|
||||
)
|
||||
from app.exceptions import (AdminOperationException, CannotModifySelfException,
|
||||
UserNotFoundException, UserStatusChangeException,
|
||||
VendorNotFoundException,
|
||||
VendorVerificationException)
|
||||
from app.services.admin_service import AdminService
|
||||
from app.services.stats_service import stats_service
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
@@ -85,7 +81,9 @@ class TestAdminService:
|
||||
assert exception.error_code == "CANNOT_MODIFY_SELF"
|
||||
assert "deactivate account" in exception.message
|
||||
|
||||
def test_toggle_user_status_cannot_modify_admin(self, db, test_admin, another_admin):
|
||||
def test_toggle_user_status_cannot_modify_admin(
|
||||
self, db, test_admin, another_admin
|
||||
):
|
||||
"""Test that admin cannot modify another admin"""
|
||||
with pytest.raises(UserStatusChangeException) as exc_info:
|
||||
self.service.toggle_user_status(db, another_admin.id, test_admin.id)
|
||||
@@ -148,7 +146,7 @@ class TestAdminService:
|
||||
assert "99999" in exception.message
|
||||
|
||||
def test_toggle_vendor_status_deactivate(self, db, test_vendor):
|
||||
"""Test deactivating a vendor """
|
||||
"""Test deactivating a vendor"""
|
||||
original_status = test_vendor.is_active
|
||||
|
||||
vendor, message = self.service.toggle_vendor_status(db, test_vendor.id)
|
||||
@@ -170,21 +168,26 @@ class TestAdminService:
|
||||
assert exception.error_code == "VENDOR_NOT_FOUND"
|
||||
|
||||
# Marketplace Import Jobs Tests
|
||||
def test_get_marketplace_import_jobs_no_filters(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_no_filters(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test getting marketplace import jobs without filters"""
|
||||
result = self.service.get_marketplace_import_jobs(db, skip=0, limit=10)
|
||||
|
||||
assert len(result) >= 1
|
||||
# Find our test job in the results
|
||||
test_job = next(
|
||||
(job for job in result if job.job_id == test_marketplace_import_job.id), None
|
||||
(job for job in result if job.job_id == test_marketplace_import_job.id),
|
||||
None,
|
||||
)
|
||||
assert test_job is not None
|
||||
assert test_job.marketplace == test_marketplace_import_job.marketplace
|
||||
assert test_job.vendor_name == test_marketplace_import_job.name
|
||||
assert test_job.status == test_marketplace_import_job.status
|
||||
|
||||
def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_with_marketplace_filter(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test filtering marketplace import jobs by marketplace"""
|
||||
result = self.service.get_marketplace_import_jobs(
|
||||
db, marketplace=test_marketplace_import_job.marketplace, skip=0, limit=10
|
||||
@@ -192,9 +195,14 @@ class TestAdminService:
|
||||
|
||||
assert len(result) >= 1
|
||||
for job in result:
|
||||
assert test_marketplace_import_job.marketplace.lower() in job.marketplace.lower()
|
||||
assert (
|
||||
test_marketplace_import_job.marketplace.lower()
|
||||
in job.marketplace.lower()
|
||||
)
|
||||
|
||||
def test_get_marketplace_import_jobs_with_vendor_filter(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_with_vendor_filter(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test filtering marketplace import jobs by vendor name"""
|
||||
result = self.service.get_marketplace_import_jobs(
|
||||
db, vendor_name=test_marketplace_import_job.name, skip=0, limit=10
|
||||
@@ -204,7 +212,9 @@ class TestAdminService:
|
||||
for job in result:
|
||||
assert test_marketplace_import_job.name.lower() in job.vendor_name.lower()
|
||||
|
||||
def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_with_status_filter(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test filtering marketplace import jobs by status"""
|
||||
result = self.service.get_marketplace_import_jobs(
|
||||
db, status=test_marketplace_import_job.status, skip=0, limit=10
|
||||
@@ -214,7 +224,9 @@ class TestAdminService:
|
||||
for job in result:
|
||||
assert job.status == test_marketplace_import_job.status
|
||||
|
||||
def test_get_marketplace_import_jobs_pagination(self, db, test_marketplace_import_job):
|
||||
def test_get_marketplace_import_jobs_pagination(
|
||||
self, db, test_marketplace_import_job
|
||||
):
|
||||
"""Test marketplace import jobs pagination"""
|
||||
result_page1 = self.service.get_marketplace_import_jobs(db, skip=0, limit=1)
|
||||
result_page2 = self.service.get_marketplace_import_jobs(db, skip=1, limit=1)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# tests/test_auth_service.py
|
||||
import pytest
|
||||
|
||||
from app.exceptions.auth import (
|
||||
UserAlreadyExistsException,
|
||||
InvalidCredentialsException,
|
||||
UserNotActiveException,
|
||||
)
|
||||
from app.exceptions.auth import (InvalidCredentialsException,
|
||||
UserAlreadyExistsException,
|
||||
UserNotActiveException)
|
||||
from app.exceptions.base import ValidationException
|
||||
from app.services.auth_service import AuthService
|
||||
from models.schema.auth import UserLogin, UserRegister
|
||||
@@ -218,11 +216,14 @@ class TestAuthService:
|
||||
|
||||
def test_create_access_token_failure(self, test_user, monkeypatch):
|
||||
"""Test creating access token handles failures"""
|
||||
|
||||
# Mock the auth_manager to raise an exception
|
||||
def mock_create_token(*args, **kwargs):
|
||||
raise Exception("Token creation failed")
|
||||
|
||||
monkeypatch.setattr(self.service.auth_manager, "create_access_token", mock_create_token)
|
||||
monkeypatch.setattr(
|
||||
self.service.auth_manager, "create_access_token", mock_create_token
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationException) as exc_info:
|
||||
self.service.create_access_token(test_user)
|
||||
@@ -250,11 +251,14 @@ class TestAuthService:
|
||||
|
||||
def test_hash_password_failure(self, monkeypatch):
|
||||
"""Test password hashing handles failures"""
|
||||
|
||||
# Mock the auth_manager to raise an exception
|
||||
def mock_hash_password(*args, **kwargs):
|
||||
raise Exception("Hashing failed")
|
||||
|
||||
monkeypatch.setattr(self.service.auth_manager, "hash_password", mock_hash_password)
|
||||
monkeypatch.setattr(
|
||||
self.service.auth_manager, "hash_password", mock_hash_password
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationException) as exc_info:
|
||||
self.service.hash_password("testpassword")
|
||||
@@ -267,9 +271,7 @@ class TestAuthService:
|
||||
def test_register_user_database_error(self, db_with_error):
|
||||
"""Test user registration handles database errors"""
|
||||
user_data = UserRegister(
|
||||
email="test@example.com",
|
||||
username="testuser",
|
||||
password="password123"
|
||||
email="test@example.com", username="testuser", password="password123"
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationException) as exc_info:
|
||||
|
||||
@@ -3,19 +3,17 @@ import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (InsufficientInventoryException,
|
||||
InvalidInventoryOperationException,
|
||||
InvalidQuantityException,
|
||||
InventoryNotFoundException,
|
||||
InventoryValidationException,
|
||||
NegativeInventoryException, ValidationException)
|
||||
from app.services.inventory_service import InventoryService
|
||||
from app.exceptions import (
|
||||
InventoryNotFoundException,
|
||||
InsufficientInventoryException,
|
||||
InvalidInventoryOperationException,
|
||||
InventoryValidationException,
|
||||
NegativeInventoryException,
|
||||
InvalidQuantityException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.schema.inventory import InventoryAdd, InventoryCreate, InventoryUpdate
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.schema.inventory import (InventoryAdd, InventoryCreate,
|
||||
InventoryUpdate)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -40,10 +38,14 @@ class TestInventoryService:
|
||||
def test_normalize_gtin_valid(self):
|
||||
"""Test GTIN normalization with valid GTINs."""
|
||||
# Test various valid GTIN formats - these should remain unchanged
|
||||
assert self.service._normalize_gtin("1234567890123") == "1234567890123" # EAN-13
|
||||
assert (
|
||||
self.service._normalize_gtin("1234567890123") == "1234567890123"
|
||||
) # EAN-13
|
||||
assert self.service._normalize_gtin("123456789012") == "123456789012" # UPC-A
|
||||
assert self.service._normalize_gtin("12345678") == "12345678" # EAN-8
|
||||
assert self.service._normalize_gtin("12345678901234") == "12345678901234" # GTIN-14
|
||||
assert (
|
||||
self.service._normalize_gtin("12345678901234") == "12345678901234"
|
||||
) # GTIN-14
|
||||
|
||||
# Test with decimal points (should be removed)
|
||||
assert self.service._normalize_gtin("1234567890123.0") == "1234567890123"
|
||||
@@ -52,11 +54,17 @@ class TestInventoryService:
|
||||
assert self.service._normalize_gtin(" 1234567890123 ") == "1234567890123"
|
||||
|
||||
# Test short GTINs being padded
|
||||
assert self.service._normalize_gtin("123") == "0000000000123" # Padded to EAN-13
|
||||
assert self.service._normalize_gtin("12345") == "0000000012345" # Padded to EAN-13
|
||||
assert (
|
||||
self.service._normalize_gtin("123") == "0000000000123"
|
||||
) # Padded to EAN-13
|
||||
assert (
|
||||
self.service._normalize_gtin("12345") == "0000000012345"
|
||||
) # Padded to EAN-13
|
||||
|
||||
# Test long GTINs being truncated
|
||||
assert self.service._normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13
|
||||
assert (
|
||||
self.service._normalize_gtin("123456789012345") == "3456789012345"
|
||||
) # Truncated to 13
|
||||
|
||||
def test_normalize_gtin_edge_cases(self):
|
||||
"""Test GTIN normalization edge cases."""
|
||||
@@ -65,9 +73,15 @@ class TestInventoryService:
|
||||
assert self.service._normalize_gtin(123) == "0000000000123"
|
||||
|
||||
# Test mixed valid/invalid characters
|
||||
assert self.service._normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed
|
||||
assert self.service._normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed
|
||||
assert self.service._normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed
|
||||
assert (
|
||||
self.service._normalize_gtin("123-456-789-012") == "123456789012"
|
||||
) # Dashes removed
|
||||
assert (
|
||||
self.service._normalize_gtin("123 456 789 012") == "123456789012"
|
||||
) # Spaces removed
|
||||
assert (
|
||||
self.service._normalize_gtin("ABC123456789012DEF") == "123456789012"
|
||||
) # Letters removed
|
||||
|
||||
def test_set_inventory_new_entry_success(self, db):
|
||||
"""Test setting inventory for a new GTIN/location combination successfully."""
|
||||
@@ -162,7 +176,9 @@ class TestInventoryService:
|
||||
|
||||
def test_add_inventory_invalid_gtin_validation_error(self, db):
|
||||
"""Test adding inventory with invalid GTIN returns InventoryValidationException."""
|
||||
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50)
|
||||
inventory_data = InventoryAdd(
|
||||
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50
|
||||
)
|
||||
|
||||
with pytest.raises(InventoryValidationException) as exc_info:
|
||||
self.service.add_inventory(db, inventory_data)
|
||||
@@ -180,11 +196,12 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INVALID_QUANTITY"
|
||||
assert "Quantity must be positive" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_remove_inventory_success(self, db, test_inventory):
|
||||
"""Test removing inventory successfully."""
|
||||
original_quantity = test_inventory.quantity
|
||||
remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available
|
||||
remove_quantity = min(
|
||||
10, original_quantity
|
||||
) # Ensure we don't remove more than available
|
||||
|
||||
inventory_data = InventoryAdd(
|
||||
gtin=test_inventory.gtin,
|
||||
@@ -212,7 +229,9 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
|
||||
assert exc_info.value.details["gtin"] == test_inventory.gtin
|
||||
assert exc_info.value.details["location"] == test_inventory.location
|
||||
assert exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
|
||||
assert (
|
||||
exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
|
||||
)
|
||||
assert exc_info.value.details["available_quantity"] == test_inventory.quantity
|
||||
|
||||
def test_remove_inventory_nonexistent_entry_not_found(self, db):
|
||||
@@ -231,7 +250,9 @@ class TestInventoryService:
|
||||
|
||||
def test_remove_inventory_invalid_gtin_validation_error(self, db):
|
||||
"""Test removing inventory with invalid GTIN returns InventoryValidationException."""
|
||||
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10)
|
||||
inventory_data = InventoryAdd(
|
||||
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10
|
||||
)
|
||||
|
||||
with pytest.raises(InventoryValidationException) as exc_info:
|
||||
self.service.remove_inventory(db, inventory_data)
|
||||
@@ -254,7 +275,9 @@ class TestInventoryService:
|
||||
# The service prevents negative inventory through InsufficientInventoryException
|
||||
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
|
||||
|
||||
def test_get_inventory_by_gtin_success(self, db, test_inventory, test_marketplace_product):
|
||||
def test_get_inventory_by_gtin_success(
|
||||
self, db, test_inventory, test_marketplace_product
|
||||
):
|
||||
"""Test getting inventory summary by GTIN successfully."""
|
||||
result = self.service.get_inventory_by_gtin(db, test_inventory.gtin)
|
||||
|
||||
@@ -265,14 +288,20 @@ class TestInventoryService:
|
||||
assert result.locations[0].quantity == test_inventory.quantity
|
||||
assert result.product_title == test_marketplace_product.title
|
||||
|
||||
def test_get_inventory_by_gtin_multiple_locations_success(self, db, test_marketplace_product):
|
||||
def test_get_inventory_by_gtin_multiple_locations_success(
|
||||
self, db, test_marketplace_product
|
||||
):
|
||||
"""Test getting inventory summary with multiple locations successfully."""
|
||||
unique_gtin = test_marketplace_product.gtin
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Create multiple inventory entries for the same GTIN with unique locations
|
||||
inventory1 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50)
|
||||
inventory2 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30)
|
||||
inventory1 = Inventory(
|
||||
gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50
|
||||
)
|
||||
inventory2 = Inventory(
|
||||
gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30
|
||||
)
|
||||
|
||||
db.add(inventory1)
|
||||
db.add(inventory2)
|
||||
@@ -301,7 +330,9 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INVENTORY_VALIDATION_FAILED"
|
||||
assert "Invalid GTIN format" in str(exc_info.value)
|
||||
|
||||
def test_get_total_inventory_success(self, db, test_inventory, test_marketplace_product):
|
||||
def test_get_total_inventory_success(
|
||||
self, db, test_inventory, test_marketplace_product
|
||||
):
|
||||
"""Test getting total inventory for a GTIN successfully."""
|
||||
result = self.service.get_total_inventory(db, test_inventory.gtin)
|
||||
|
||||
@@ -364,7 +395,9 @@ class TestInventoryService:
|
||||
|
||||
result = self.service.get_all_inventory(db, skip=2, limit=2)
|
||||
|
||||
assert len(result) <= 2 # Should be at most 2, might be less if other records exist
|
||||
assert (
|
||||
len(result) <= 2
|
||||
) # Should be at most 2, might be less if other records exist
|
||||
|
||||
def test_update_inventory_success(self, db, test_inventory):
|
||||
"""Test updating inventory quantity successfully."""
|
||||
@@ -404,7 +437,9 @@ class TestInventoryService:
|
||||
assert result is True
|
||||
|
||||
# Verify the inventory is actually deleted
|
||||
deleted_inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
deleted_inventory = (
|
||||
db.query(Inventory).filter(Inventory.id == inventory_id).first()
|
||||
)
|
||||
assert deleted_inventory is None
|
||||
|
||||
def test_delete_inventory_not_found_error(self, db):
|
||||
@@ -415,7 +450,9 @@ class TestInventoryService:
|
||||
assert exc_info.value.error_code == "INVENTORY_NOT_FOUND"
|
||||
assert "99999" in str(exc_info.value)
|
||||
|
||||
def test_get_low_inventory_items_success(self, db, test_inventory, test_marketplace_product):
|
||||
def test_get_low_inventory_items_success(
|
||||
self, db, test_inventory, test_marketplace_product
|
||||
):
|
||||
"""Test getting low inventory items successfully."""
|
||||
# Set inventory to a low value
|
||||
test_inventory.quantity = 5
|
||||
@@ -424,7 +461,9 @@ class TestInventoryService:
|
||||
result = self.service.get_low_inventory_items(db, threshold=10)
|
||||
|
||||
assert len(result) >= 1
|
||||
low_inventory_item = next((item for item in result if item["gtin"] == test_inventory.gtin), None)
|
||||
low_inventory_item = next(
|
||||
(item for item in result if item["gtin"] == test_inventory.gtin), None
|
||||
)
|
||||
assert low_inventory_item is not None
|
||||
assert low_inventory_item["current_quantity"] == 5
|
||||
assert low_inventory_item["location"] == test_inventory.location
|
||||
@@ -440,9 +479,13 @@ class TestInventoryService:
|
||||
|
||||
def test_get_inventory_summary_by_location_success(self, db, test_inventory):
|
||||
"""Test getting inventory summary by location successfully."""
|
||||
result = self.service.get_inventory_summary_by_location(db, test_inventory.location)
|
||||
result = self.service.get_inventory_summary_by_location(
|
||||
db, test_inventory.location
|
||||
)
|
||||
|
||||
assert result["location"] == test_inventory.location.upper() # Service normalizes to uppercase
|
||||
assert (
|
||||
result["location"] == test_inventory.location.upper()
|
||||
) # Service normalizes to uppercase
|
||||
assert result["total_items"] >= 1
|
||||
assert result["total_quantity"] >= test_inventory.quantity
|
||||
assert result["unique_gtins"] >= 1
|
||||
@@ -450,7 +493,9 @@ class TestInventoryService:
|
||||
def test_get_inventory_summary_by_location_empty_result(self, db):
|
||||
"""Test getting inventory summary for location with no inventory."""
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
result = self.service.get_inventory_summary_by_location(db, f"EMPTY_LOCATION_{unique_id}")
|
||||
result = self.service.get_inventory_summary_by_location(
|
||||
db, f"EMPTY_LOCATION_{unique_id}"
|
||||
)
|
||||
|
||||
assert result["total_items"] == 0
|
||||
assert result["total_quantity"] == 0
|
||||
@@ -459,12 +504,16 @@ class TestInventoryService:
|
||||
def test_validate_quantity_edge_cases(self, db):
|
||||
"""Test quantity validation with edge cases."""
|
||||
# Test zero quantity with allow_zero=True (should succeed)
|
||||
inventory_data = InventoryCreate(gtin="1234567890123", location="WAREHOUSE_A", quantity=0)
|
||||
inventory_data = InventoryCreate(
|
||||
gtin="1234567890123", location="WAREHOUSE_A", quantity=0
|
||||
)
|
||||
result = self.service.set_inventory(db, inventory_data)
|
||||
assert result.quantity == 0
|
||||
|
||||
# Test zero quantity with add_inventory (should fail - doesn't allow zero)
|
||||
inventory_data_add = InventoryAdd(gtin="1234567890123", location="WAREHOUSE_B", quantity=0)
|
||||
inventory_data_add = InventoryAdd(
|
||||
gtin="1234567890123", location="WAREHOUSE_B", quantity=0
|
||||
)
|
||||
with pytest.raises(InvalidQuantityException):
|
||||
self.service.add_inventory(db, inventory_data_add)
|
||||
|
||||
@@ -477,10 +526,10 @@ class TestInventoryService:
|
||||
exception = exc_info.value
|
||||
|
||||
# Verify exception structure matches WizamartException.to_dict()
|
||||
assert hasattr(exception, 'error_code')
|
||||
assert hasattr(exception, 'message')
|
||||
assert hasattr(exception, 'status_code')
|
||||
assert hasattr(exception, 'details')
|
||||
assert hasattr(exception, "error_code")
|
||||
assert hasattr(exception, "message")
|
||||
assert hasattr(exception, "status_code")
|
||||
assert hasattr(exception, "details")
|
||||
|
||||
assert isinstance(exception.error_code, str)
|
||||
assert isinstance(exception.message, str)
|
||||
|
||||
@@ -4,19 +4,18 @@ from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.exceptions.marketplace_import_job import (
|
||||
ImportJobNotFoundException,
|
||||
ImportJobNotOwnedException,
|
||||
ImportJobCannotBeCancelledException,
|
||||
ImportJobCannotBeDeletedException,
|
||||
)
|
||||
from app.exceptions.vendor import VendorNotFoundException, UnauthorizedVendorAccessException
|
||||
from app.exceptions.base import ValidationException
|
||||
from app.services.marketplace_import_job_service import MarketplaceImportJobService
|
||||
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
|
||||
from app.exceptions.marketplace_import_job import (
|
||||
ImportJobCannotBeCancelledException, ImportJobCannotBeDeletedException,
|
||||
ImportJobNotFoundException, ImportJobNotOwnedException)
|
||||
from app.exceptions.vendor import (UnauthorizedVendorAccessException,
|
||||
VendorNotFoundException)
|
||||
from app.services.marketplace_import_job_service import \
|
||||
MarketplaceImportJobService
|
||||
from models.database.marketplace_import_job import MarketplaceImportJob
|
||||
from models.database.vendor import Vendor
|
||||
from models.database.user import User
|
||||
from models.database.vendor import Vendor
|
||||
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -31,7 +30,9 @@ class TestMarketplaceService:
|
||||
test_vendor.owner_user_id = test_user.id
|
||||
db.commit()
|
||||
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_user)
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code, test_user
|
||||
)
|
||||
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
assert result.owner_user_id == test_user.id
|
||||
@@ -39,8 +40,10 @@ class TestMarketplaceService:
|
||||
def test_validate_vendor_access_admin_can_access_any_vendor(
|
||||
self, db, test_vendor, test_admin
|
||||
):
|
||||
"""Test that admin users can access any vendor """
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_admin)
|
||||
"""Test that admin users can access any vendor"""
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code, test_admin
|
||||
)
|
||||
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
|
||||
@@ -57,7 +60,7 @@ class TestMarketplaceService:
|
||||
def test_validate_vendor_access_permission_denied(
|
||||
self, db, test_vendor, test_user, other_user
|
||||
):
|
||||
"""Test vendor access validation when user doesn't own the vendor """
|
||||
"""Test vendor access validation when user doesn't own the vendor"""
|
||||
# Set the vendor owner to a different user
|
||||
test_vendor.owner_user_id = other_user.id
|
||||
db.commit()
|
||||
@@ -93,7 +96,7 @@ class TestMarketplaceService:
|
||||
assert result.vendor_name == test_vendor.name
|
||||
|
||||
def test_create_import_job_invalid_vendor(self, db, test_user):
|
||||
"""Test import job creation with invalid vendor """
|
||||
"""Test import job creation with invalid vendor"""
|
||||
request = MarketplaceImportJobRequest(
|
||||
url="https://example.com/products.csv",
|
||||
marketplace="Amazon",
|
||||
@@ -108,7 +111,9 @@ class TestMarketplaceService:
|
||||
assert exception.error_code == "VENDOR_NOT_FOUND"
|
||||
assert "INVALID_VENDOR" in exception.message
|
||||
|
||||
def test_create_import_job_unauthorized_access(self, db, test_vendor, test_user, other_user):
|
||||
def test_create_import_job_unauthorized_access(
|
||||
self, db, test_vendor, test_user, other_user
|
||||
):
|
||||
"""Test import job creation with unauthorized vendor access"""
|
||||
# Set the vendor owner to a different user
|
||||
test_vendor.owner_user_id = other_user.id
|
||||
@@ -127,7 +132,9 @@ class TestMarketplaceService:
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "UNAUTHORIZED_VENDOR_ACCESS"
|
||||
|
||||
def test_get_import_job_by_id_success(self, db, test_marketplace_import_job, test_user):
|
||||
def test_get_import_job_by_id_success(
|
||||
self, db, test_marketplace_import_job, test_user
|
||||
):
|
||||
"""Test getting import job by ID for job owner"""
|
||||
result = self.service.get_import_job_by_id(
|
||||
db, test_marketplace_import_job.id, test_user
|
||||
@@ -161,14 +168,18 @@ class TestMarketplaceService:
|
||||
):
|
||||
"""Test access denied when user doesn't own the job"""
|
||||
with pytest.raises(ImportJobNotOwnedException) as exc_info:
|
||||
self.service.get_import_job_by_id(db, test_marketplace_import_job.id, other_user)
|
||||
self.service.get_import_job_by_id(
|
||||
db, test_marketplace_import_job.id, other_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
|
||||
assert exception.status_code == 403
|
||||
assert str(test_marketplace_import_job.id) in exception.message
|
||||
|
||||
def test_get_import_jobs_user_filter(self, db, test_marketplace_import_job, test_user):
|
||||
def test_get_import_jobs_user_filter(
|
||||
self, db, test_marketplace_import_job, test_user
|
||||
):
|
||||
"""Test getting import jobs filtered by user"""
|
||||
jobs = self.service.get_import_jobs(db, test_user)
|
||||
|
||||
@@ -176,7 +187,9 @@ class TestMarketplaceService:
|
||||
assert any(job.id == test_marketplace_import_job.id for job in jobs)
|
||||
assert test_marketplace_import_job.user_id == test_user.id
|
||||
|
||||
def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_import_job, test_admin):
|
||||
def test_get_import_jobs_admin_sees_all(
|
||||
self, db, test_marketplace_import_job, test_admin
|
||||
):
|
||||
"""Test that admin sees all import jobs"""
|
||||
jobs = self.service.get_import_jobs(db, test_admin)
|
||||
|
||||
@@ -192,7 +205,9 @@ class TestMarketplaceService:
|
||||
)
|
||||
|
||||
assert len(jobs) >= 1
|
||||
assert any(job.marketplace == test_marketplace_import_job.marketplace for job in jobs)
|
||||
assert any(
|
||||
job.marketplace == test_marketplace_import_job.marketplace for job in jobs
|
||||
)
|
||||
|
||||
def test_get_import_jobs_with_pagination(self, db, test_user, test_vendor):
|
||||
"""Test getting import jobs with pagination"""
|
||||
@@ -330,10 +345,14 @@ class TestMarketplaceService:
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
|
||||
|
||||
def test_cancel_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
|
||||
def test_cancel_import_job_access_denied(
|
||||
self, db, test_marketplace_import_job, other_user
|
||||
):
|
||||
"""Test cancelling import job without access"""
|
||||
with pytest.raises(ImportJobNotOwnedException) as exc_info:
|
||||
self.service.cancel_import_job(db, test_marketplace_import_job.id, other_user)
|
||||
self.service.cancel_import_job(
|
||||
db, test_marketplace_import_job.id, other_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
|
||||
@@ -347,7 +366,9 @@ class TestMarketplaceService:
|
||||
db.commit()
|
||||
|
||||
with pytest.raises(ImportJobCannotBeCancelledException) as exc_info:
|
||||
self.service.cancel_import_job(db, test_marketplace_import_job.id, test_user)
|
||||
self.service.cancel_import_job(
|
||||
db, test_marketplace_import_job.id, test_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_CANNOT_BE_CANCELLED"
|
||||
@@ -396,10 +417,14 @@ class TestMarketplaceService:
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
|
||||
|
||||
def test_delete_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
|
||||
def test_delete_import_job_access_denied(
|
||||
self, db, test_marketplace_import_job, other_user
|
||||
):
|
||||
"""Test deleting import job without access"""
|
||||
with pytest.raises(ImportJobNotOwnedException) as exc_info:
|
||||
self.service.delete_import_job(db, test_marketplace_import_job.id, other_user)
|
||||
self.service.delete_import_job(
|
||||
db, test_marketplace_import_job.id, other_user
|
||||
)
|
||||
|
||||
exception = exc_info.value
|
||||
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
|
||||
@@ -440,11 +465,15 @@ class TestMarketplaceService:
|
||||
db.commit()
|
||||
|
||||
# Test with lowercase vendor code
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.lower(), test_user)
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code.lower(), test_user
|
||||
)
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
|
||||
# Test with uppercase vendor code
|
||||
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.upper(), test_user)
|
||||
result = self.service.validate_vendor_access(
|
||||
db, test_vendor.vendor_code.upper(), test_user
|
||||
)
|
||||
assert result.vendor_code == test_vendor.vendor_code
|
||||
|
||||
def test_create_import_job_database_error(self, db_with_error, test_user):
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# tests/test_product_service.py
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (InvalidMarketplaceProductDataException,
|
||||
MarketplaceProductAlreadyExistsException,
|
||||
MarketplaceProductNotFoundException,
|
||||
MarketplaceProductValidationException,
|
||||
ValidationException)
|
||||
from app.services.marketplace_product_service import MarketplaceProductService
|
||||
from app.exceptions import (
|
||||
MarketplaceProductNotFoundException,
|
||||
MarketplaceProductAlreadyExistsException,
|
||||
InvalidMarketplaceProductDataException,
|
||||
MarketplaceProductValidationException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.schema.marketplace_product import (MarketplaceProductCreate,
|
||||
MarketplaceProductUpdate)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -98,7 +97,10 @@ class TestProductService:
|
||||
assert exc_info.value.error_code == "PRODUCT_ALREADY_EXISTS"
|
||||
assert test_marketplace_product.marketplace_product_id in str(exc_info.value)
|
||||
assert exc_info.value.status_code == 409
|
||||
assert exc_info.value.details.get("marketplace_product_id") == test_marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
exc_info.value.details.get("marketplace_product_id")
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
def test_create_product_invalid_price(self, db):
|
||||
"""Test product creation with invalid price raises InvalidMarketplaceProductDataException"""
|
||||
@@ -117,9 +119,14 @@ class TestProductService:
|
||||
|
||||
def test_get_product_by_id_or_raise_success(self, db, test_marketplace_product):
|
||||
"""Test successful product retrieval by ID"""
|
||||
product = self.service.get_product_by_id_or_raise(db, test_marketplace_product.marketplace_product_id)
|
||||
product = self.service.get_product_by_id_or_raise(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
assert product.marketplace_product_id == test_marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
product.marketplace_product_id
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
assert product.title == test_marketplace_product.title
|
||||
|
||||
def test_get_product_by_id_or_raise_not_found(self, db):
|
||||
@@ -152,21 +159,35 @@ class TestProductService:
|
||||
assert total >= 1
|
||||
assert len(products) >= 1
|
||||
# Verify search worked by checking that title contains search term
|
||||
found_product = next((p for p in products if p.marketplace_product_id == test_marketplace_product.marketplace_product_id), None)
|
||||
found_product = next(
|
||||
(
|
||||
p
|
||||
for p in products
|
||||
if p.marketplace_product_id
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert found_product is not None
|
||||
|
||||
def test_update_product_success(self, db, test_marketplace_product):
|
||||
"""Test successful product update"""
|
||||
update_data = MarketplaceProductUpdate(
|
||||
title="Updated MarketplaceProduct Title",
|
||||
price="39.99"
|
||||
title="Updated MarketplaceProduct Title", price="39.99"
|
||||
)
|
||||
|
||||
updated_product = self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
updated_product = self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert updated_product.title == "Updated MarketplaceProduct Title"
|
||||
assert updated_product.price == "39.99" # Price is stored as string after processing
|
||||
assert updated_product.marketplace_product_id == test_marketplace_product.marketplace_product_id # ID unchanged
|
||||
assert (
|
||||
updated_product.price == "39.99"
|
||||
) # Price is stored as string after processing
|
||||
assert (
|
||||
updated_product.marketplace_product_id
|
||||
== test_marketplace_product.marketplace_product_id
|
||||
) # ID unchanged
|
||||
|
||||
def test_update_product_not_found(self, db):
|
||||
"""Test updating non-existent product raises MarketplaceProductNotFoundException"""
|
||||
@@ -183,7 +204,9 @@ class TestProductService:
|
||||
update_data = MarketplaceProductUpdate(gtin="invalid_gtin")
|
||||
|
||||
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
|
||||
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
|
||||
assert "Invalid GTIN format" in str(exc_info.value)
|
||||
@@ -194,7 +217,9 @@ class TestProductService:
|
||||
update_data = MarketplaceProductUpdate(title="")
|
||||
|
||||
with pytest.raises(MarketplaceProductValidationException) as exc_info:
|
||||
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "PRODUCT_VALIDATION_FAILED"
|
||||
assert "MarketplaceProduct title cannot be empty" in str(exc_info.value)
|
||||
@@ -205,7 +230,9 @@ class TestProductService:
|
||||
update_data = MarketplaceProductUpdate(price="invalid_price")
|
||||
|
||||
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
|
||||
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
|
||||
self.service.update_product(
|
||||
db, test_marketplace_product.marketplace_product_id, update_data
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
|
||||
assert "Invalid price format" in str(exc_info.value)
|
||||
@@ -213,12 +240,16 @@ class TestProductService:
|
||||
|
||||
def test_delete_product_success(self, db, test_marketplace_product):
|
||||
"""Test successful product deletion"""
|
||||
result = self.service.delete_product(db, test_marketplace_product.marketplace_product_id)
|
||||
result = self.service.delete_product(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify product is deleted
|
||||
deleted_product = self.service.get_product_by_id(db, test_marketplace_product.marketplace_product_id)
|
||||
deleted_product = self.service.get_product_by_id(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
assert deleted_product is None
|
||||
|
||||
def test_delete_product_not_found(self, db):
|
||||
@@ -229,10 +260,14 @@ class TestProductService:
|
||||
assert exc_info.value.error_code == "PRODUCT_NOT_FOUND"
|
||||
assert "NONEXISTENT" in str(exc_info.value)
|
||||
|
||||
def test_get_inventory_info_success(self, db, test_marketplace_product_with_inventory):
|
||||
def test_get_inventory_info_success(
|
||||
self, db, test_marketplace_product_with_inventory
|
||||
):
|
||||
"""Test getting inventory info for product with inventory"""
|
||||
# Extract the product from the dictionary
|
||||
marketplace_product = test_marketplace_product_with_inventory['marketplace_product']
|
||||
marketplace_product = test_marketplace_product_with_inventory[
|
||||
"marketplace_product"
|
||||
]
|
||||
|
||||
inventory_info = self.service.get_inventory_info(db, marketplace_product.gtin)
|
||||
|
||||
@@ -243,13 +278,17 @@ class TestProductService:
|
||||
|
||||
def test_get_inventory_info_no_inventory(self, db, test_marketplace_product):
|
||||
"""Test getting inventory info for product without inventory"""
|
||||
inventory_info = self.service.get_inventory_info(db, test_marketplace_product.gtin or "1234567890123")
|
||||
inventory_info = self.service.get_inventory_info(
|
||||
db, test_marketplace_product.gtin or "1234567890123"
|
||||
)
|
||||
|
||||
assert inventory_info is None
|
||||
|
||||
def test_product_exists_true(self, db, test_marketplace_product):
|
||||
"""Test product_exists returns True for existing product"""
|
||||
exists = self.service.product_exists(db, test_marketplace_product.marketplace_product_id)
|
||||
exists = self.service.product_exists(
|
||||
db, test_marketplace_product.marketplace_product_id
|
||||
)
|
||||
assert exists is True
|
||||
|
||||
def test_product_exists_false(self, db):
|
||||
@@ -265,7 +304,9 @@ class TestProductService:
|
||||
csv_lines = list(csv_generator)
|
||||
|
||||
assert len(csv_lines) > 1 # Header + at least one data row
|
||||
assert csv_lines[0].startswith("marketplace_product_id,title,description") # Check header
|
||||
assert csv_lines[0].startswith(
|
||||
"marketplace_product_id,title,description"
|
||||
) # Check header
|
||||
|
||||
# Check that test product appears in CSV
|
||||
csv_content = "".join(csv_lines)
|
||||
@@ -274,8 +315,7 @@ class TestProductService:
|
||||
def test_generate_csv_export_with_filters(self, db, test_marketplace_product):
|
||||
"""Test CSV export with marketplace filter"""
|
||||
csv_generator = self.service.generate_csv_export(
|
||||
db,
|
||||
marketplace=test_marketplace_product.marketplace
|
||||
db, marketplace=test_marketplace_product.marketplace
|
||||
)
|
||||
|
||||
csv_lines = list(csv_generator)
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import pytest
|
||||
|
||||
from app.services.stats_service import StatsService
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
from models.database.inventory import Inventory
|
||||
from models.database.marketplace_product import MarketplaceProduct
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -15,7 +15,9 @@ class TestStatsService:
|
||||
"""Setup method following the same pattern as other service tests"""
|
||||
self.service = StatsService()
|
||||
|
||||
def test_get_comprehensive_stats_basic(self, db, test_marketplace_product, test_inventory):
|
||||
def test_get_comprehensive_stats_basic(
|
||||
self, db, test_marketplace_product, test_inventory
|
||||
):
|
||||
"""Test getting comprehensive stats with basic data"""
|
||||
stats = self.service.get_comprehensive_stats(db)
|
||||
|
||||
@@ -31,7 +33,9 @@ class TestStatsService:
|
||||
assert stats["total_inventory_entries"] >= 1
|
||||
assert stats["total_inventory_quantity"] >= 10 # test_inventory has quantity 10
|
||||
|
||||
def test_get_comprehensive_stats_multiple_products(self, db, test_marketplace_product):
|
||||
def test_get_comprehensive_stats_multiple_products(
|
||||
self, db, test_marketplace_product
|
||||
):
|
||||
"""Test comprehensive stats with multiple products across different dimensions"""
|
||||
# Create products with different brands, categories, marketplaces
|
||||
additional_products = [
|
||||
@@ -87,7 +91,7 @@ class TestStatsService:
|
||||
brand=None, # Null brand
|
||||
google_product_category=None, # Null category
|
||||
marketplace=None, # Null marketplace
|
||||
vendor_name=None, # Null vendor
|
||||
vendor_name=None, # Null vendor
|
||||
price="10.00",
|
||||
currency="EUR",
|
||||
),
|
||||
@@ -97,7 +101,7 @@ class TestStatsService:
|
||||
brand="", # Empty brand
|
||||
google_product_category="", # Empty category
|
||||
marketplace="", # Empty marketplace
|
||||
vendor_name="", # Empty vendor
|
||||
vendor_name="", # Empty vendor
|
||||
price="15.00",
|
||||
currency="EUR",
|
||||
),
|
||||
@@ -124,7 +128,11 @@ class TestStatsService:
|
||||
|
||||
# Find our test marketplace in the results
|
||||
test_marketplace_stat = next(
|
||||
(stat for stat in stats if stat["marketplace"] == test_marketplace_product.marketplace),
|
||||
(
|
||||
stat
|
||||
for stat in stats
|
||||
if stat["marketplace"] == test_marketplace_product.marketplace
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert test_marketplace_stat is not None
|
||||
@@ -309,7 +317,9 @@ class TestStatsService:
|
||||
|
||||
count = self.service._get_unique_marketplaces_count(db)
|
||||
|
||||
assert count >= 2 # At least Amazon and eBay, plus test_marketplace_product marketplace
|
||||
assert (
|
||||
count >= 2
|
||||
) # At least Amazon and eBay, plus test_marketplace_product marketplace
|
||||
assert isinstance(count, int)
|
||||
|
||||
def test_get_unique_vendors_count(self, db, test_marketplace_product):
|
||||
@@ -338,7 +348,9 @@ class TestStatsService:
|
||||
|
||||
count = self.service._get_unique_vendors_count(db)
|
||||
|
||||
assert count >= 2 # At least VendorA and VendorB, plus test_marketplace_product vendor
|
||||
assert (
|
||||
count >= 2
|
||||
) # At least VendorA and VendorB, plus test_marketplace_product vendor
|
||||
assert isinstance(count, int)
|
||||
|
||||
def test_get_inventory_statistics(self, db, test_inventory):
|
||||
@@ -438,7 +450,7 @@ class TestStatsService:
|
||||
db.add_all(marketplace_products)
|
||||
db.commit()
|
||||
|
||||
vendors =self.service._get_vendors_by_marketplace(db, "TestMarketplace")
|
||||
vendors = self.service._get_vendors_by_marketplace(db, "TestMarketplace")
|
||||
|
||||
assert len(vendors) == 2
|
||||
assert "TestVendor1" in vendors
|
||||
@@ -482,7 +494,9 @@ class TestStatsService:
|
||||
|
||||
def test_get_products_by_marketplace_not_found(self, db):
|
||||
"""Test getting product count for non-existent marketplace"""
|
||||
count = self.service._get_products_by_marketplace_count(db, "NonExistentMarketplace")
|
||||
count = self.service._get_products_by_marketplace_count(
|
||||
db, "NonExistentMarketplace"
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
# tests/test_vendor_service.py (updated to use custom exceptions)
|
||||
import pytest
|
||||
|
||||
from app.exceptions import (InvalidVendorDataException,
|
||||
MarketplaceProductNotFoundException,
|
||||
MaxVendorsReachedException,
|
||||
ProductAlreadyExistsException,
|
||||
UnauthorizedVendorAccessException,
|
||||
ValidationException, VendorAlreadyExistsException,
|
||||
VendorNotFoundException)
|
||||
from app.services.vendor_service import VendorService
|
||||
from app.exceptions import (
|
||||
VendorNotFoundException,
|
||||
VendorAlreadyExistsException,
|
||||
UnauthorizedVendorAccessException,
|
||||
InvalidVendorDataException,
|
||||
MarketplaceProductNotFoundException,
|
||||
ProductAlreadyExistsException,
|
||||
MaxVendorsReachedException,
|
||||
ValidationException,
|
||||
)
|
||||
from models.schema.vendor import VendorCreate
|
||||
from models.schema.product import ProductCreate
|
||||
from models.schema.vendor import VendorCreate
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -38,15 +35,17 @@ class TestVendorService:
|
||||
assert vendor is not None
|
||||
assert vendor.vendor_code == "NEWVENDOR"
|
||||
assert vendor.owner_user_id == test_user.id
|
||||
assert vendor.is_verified is False # Regular user creates unverified vendor
|
||||
assert vendor.is_verified is False # Regular user creates unverified vendor
|
||||
|
||||
def test_create_vendor_admin_auto_verify(self, db, test_admin, vendor_factory):
|
||||
"""Test admin creates verified vendor automatically"""
|
||||
vendor_data = VendorCreate(vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor")
|
||||
vendor_data = VendorCreate(
|
||||
vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor"
|
||||
)
|
||||
|
||||
vendor = self.service.create_vendor(db, vendor_data, test_admin)
|
||||
|
||||
assert vendor.is_verified is True # Admin creates verified vendor
|
||||
assert vendor.is_verified is True # Admin creates verified vendor
|
||||
|
||||
def test_create_vendor_duplicate_code(self, db, test_user, test_vendor):
|
||||
"""Test vendor creation fails with duplicate vendor code"""
|
||||
@@ -88,7 +87,9 @@ class TestVendorService:
|
||||
|
||||
def test_create_vendor_invalid_code_format(self, db, test_user):
|
||||
"""Test vendor creation fails with invalid vendor code format"""
|
||||
vendor_data = VendorCreate(vendor_code="INVALID@CODE!", vendor_name="Test Vendor")
|
||||
vendor_data = VendorCreate(
|
||||
vendor_code="INVALID@CODE!", vendor_name="Test Vendor"
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidVendorDataException) as exc_info:
|
||||
self.service.create_vendor(db, vendor_data, test_user)
|
||||
@@ -105,7 +106,9 @@ class TestVendorService:
|
||||
def mock_check_vendor_limit(self, db, user):
|
||||
raise MaxVendorsReachedException(max_vendors=5, user_id=user.id)
|
||||
|
||||
monkeypatch.setattr(VendorService, "_check_vendor_limit", mock_check_vendor_limit)
|
||||
monkeypatch.setattr(
|
||||
VendorService, "_check_vendor_limit", mock_check_vendor_limit
|
||||
)
|
||||
|
||||
vendor_data = VendorCreate(vendor_code="NEWVENDOR", vendor_name="New Vendor")
|
||||
|
||||
@@ -118,7 +121,9 @@ class TestVendorService:
|
||||
assert exception.details["max_vendors"] == 5
|
||||
assert exception.details["user_id"] == test_user.id
|
||||
|
||||
def test_get_vendors_regular_user(self, db, test_user, test_vendor, inactive_vendor):
|
||||
def test_get_vendors_regular_user(
|
||||
self, db, test_user, test_vendor, inactive_vendor
|
||||
):
|
||||
"""Test regular user can only see active verified vendors and own vendors"""
|
||||
vendors, total = self.service.get_vendors(db, test_user, skip=0, limit=10)
|
||||
|
||||
@@ -127,7 +132,7 @@ class TestVendorService:
|
||||
assert inactive_vendor.vendor_code not in vendor_codes
|
||||
|
||||
def test_get_vendors_admin_user(
|
||||
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
|
||||
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
|
||||
):
|
||||
"""Test admin user can see all vendors with filters"""
|
||||
vendors, total = self.service.get_vendors(
|
||||
@@ -140,14 +145,16 @@ class TestVendorService:
|
||||
assert verified_vendor.vendor_code in vendor_codes
|
||||
|
||||
def test_get_vendor_by_code_owner_access(self, db, test_user, test_vendor):
|
||||
"""Test vendor owner can access their own vendor """
|
||||
vendor = self.service.get_vendor_by_code(db, test_vendor.vendor_code.lower(), test_user)
|
||||
"""Test vendor owner can access their own vendor"""
|
||||
vendor = self.service.get_vendor_by_code(
|
||||
db, test_vendor.vendor_code.lower(), test_user
|
||||
)
|
||||
|
||||
assert vendor is not None
|
||||
assert vendor.id == test_vendor.id
|
||||
|
||||
def test_get_vendor_by_code_admin_access(self, db, test_admin, test_vendor):
|
||||
"""Test admin can access any vendor """
|
||||
"""Test admin can access any vendor"""
|
||||
vendor = self.service.get_vendor_by_code(
|
||||
db, test_vendor.vendor_code.lower(), test_admin
|
||||
)
|
||||
@@ -178,16 +185,14 @@ class TestVendorService:
|
||||
assert exception.details["user_id"] == test_user.id
|
||||
|
||||
def test_add_product_to_vendor_success(self, db, test_vendor, unique_product):
|
||||
"""Test successfully adding product to vendor """
|
||||
"""Test successfully adding product to vendor"""
|
||||
product_data = ProductCreate(
|
||||
marketplace_product_id=unique_product.marketplace_product_id,
|
||||
price="15.99",
|
||||
is_featured=True,
|
||||
)
|
||||
|
||||
product = self.service.add_product_to_catalog(
|
||||
db, test_vendor, product_data
|
||||
)
|
||||
product = self.service.add_product_to_catalog(db, test_vendor, product_data)
|
||||
|
||||
assert product is not None
|
||||
assert product.vendor_id == test_vendor.id
|
||||
@@ -195,7 +200,9 @@ class TestVendorService:
|
||||
|
||||
def test_add_product_to_vendor_product_not_found(self, db, test_vendor):
|
||||
"""Test adding non-existent product to vendor fails"""
|
||||
product_data = ProductCreate(marketplace_product_id="NONEXISTENT", price="15.99")
|
||||
product_data = ProductCreate(
|
||||
marketplace_product_id="NONEXISTENT", price="15.99"
|
||||
)
|
||||
|
||||
with pytest.raises(MarketplaceProductNotFoundException) as exc_info:
|
||||
self.service.add_product_to_catalog(db, test_vendor, product_data)
|
||||
@@ -209,7 +216,8 @@ class TestVendorService:
|
||||
def test_add_product_to_vendor_already_exists(self, db, test_vendor, test_product):
|
||||
"""Test adding product that's already in vendor fails"""
|
||||
product_data = ProductCreate(
|
||||
marketplace_product_id=test_product.marketplace_product.marketplace_product_id, price="15.99"
|
||||
marketplace_product_id=test_product.marketplace_product.marketplace_product_id,
|
||||
price="15.99",
|
||||
)
|
||||
|
||||
with pytest.raises(ProductAlreadyExistsException) as exc_info:
|
||||
@@ -219,11 +227,12 @@ class TestVendorService:
|
||||
assert exception.status_code == 409
|
||||
assert exception.error_code == "PRODUCT_ALREADY_EXISTS"
|
||||
assert exception.details["vendor_code"] == test_vendor.vendor_code
|
||||
assert exception.details["marketplace_product_id"] == test_product.marketplace_product.marketplace_product_id
|
||||
assert (
|
||||
exception.details["marketplace_product_id"]
|
||||
== test_product.marketplace_product.marketplace_product_id
|
||||
)
|
||||
|
||||
def test_get_products_owner_access(
|
||||
self, db, test_user, test_vendor, test_product
|
||||
):
|
||||
def test_get_products_owner_access(self, db, test_user, test_vendor, test_product):
|
||||
"""Test vendor owner can get vendor products"""
|
||||
products, total = self.service.get_products(db, test_vendor, test_user)
|
||||
|
||||
@@ -291,7 +300,9 @@ class TestVendorService:
|
||||
assert exception.error_code == "VALIDATION_ERROR"
|
||||
assert "Failed to retrieve vendors" in exception.message
|
||||
|
||||
def test_add_product_database_error(self, db, test_vendor, unique_product, monkeypatch):
|
||||
def test_add_product_database_error(
|
||||
self, db, test_vendor, unique_product, monkeypatch
|
||||
):
|
||||
"""Test add product handles database errors gracefully"""
|
||||
|
||||
def mock_commit():
|
||||
|
||||
@@ -18,7 +18,9 @@ class TestCSVProcessor:
|
||||
def test_download_csv_encoding_fallback(self, mock_get):
|
||||
"""Test CSV download with encoding fallback"""
|
||||
# Create content with special characters that would fail UTF-8 if not properly encoded
|
||||
special_content = "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
|
||||
special_content = (
|
||||
"marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
@@ -40,9 +42,7 @@ class TestCSVProcessor:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
# Create bytes that will fail most encodings
|
||||
mock_response.content = (
|
||||
b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
|
||||
)
|
||||
mock_response.content = b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user