style: apply black and isort formatting across entire codebase

- Standardize quote style (single to double quotes)
- Reorder and group imports alphabetically
- Fix line breaks and indentation for consistency
- Apply PEP 8 formatting standards

Also updated Makefile to exclude both venv and .venv from code quality checks.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-28 19:30:17 +01:00
parent 13f0094743
commit 21c13ca39b
236 changed files with 8450 additions and 6545 deletions

View File

@@ -12,21 +12,18 @@ Tests cover:
- Error handling and edge cases
"""
import pytest
from unittest.mock import Mock, MagicMock, patch
from datetime import datetime, timedelta, timezone
from jose import jwt
from fastapi import HTTPException
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastapi import HTTPException
from jose import jwt
from app.exceptions import (AdminRequiredException,
InsufficientPermissionsException,
InvalidCredentialsException, InvalidTokenException,
TokenExpiredException, UserNotActiveException)
from middleware.auth import AuthManager
from app.exceptions import (
InvalidTokenException,
TokenExpiredException,
UserNotActiveException,
InvalidCredentialsException,
AdminRequiredException,
InsufficientPermissionsException,
)
from models.database.user import User
@@ -124,7 +121,9 @@ class TestUserAuthentication:
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
result = auth_manager.authenticate_user(
mock_db, "test@example.com", "password123"
)
assert result is mock_user
@@ -192,7 +191,9 @@ class TestJWTTokenCreation:
token = token_data["access_token"]
# Decode without verification to check payload
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
payload = jwt.decode(
token, auth_manager.secret_key, algorithms=[auth_manager.algorithm]
)
assert payload["sub"] == "42"
assert payload["username"] == "testuser"
@@ -205,8 +206,12 @@ class TestJWTTokenCreation:
"""Test tokens are different for different users."""
auth_manager = AuthManager()
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
user1 = Mock(
spec=User, id=1, username="user1", email="user1@test.com", role="customer"
)
user2 = Mock(
spec=User, id=2, username="user2", email="user2@test.com", role="vendor"
)
token1 = auth_manager.create_access_token(user1)["access_token"]
token2 = auth_manager.create_access_token(user2)["access_token"]
@@ -227,7 +232,7 @@ class TestJWTTokenCreation:
payload = jwt.decode(
token_data["access_token"],
auth_manager.secret_key,
algorithms=[auth_manager.algorithm]
algorithms=[auth_manager.algorithm],
)
assert payload["role"] == "admin"
@@ -311,9 +316,11 @@ class TestJWTTokenVerification:
# Create token without 'sub' field
payload = {
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -325,11 +332,10 @@ class TestJWTTokenVerification:
auth_manager = AuthManager()
# Create token without 'exp' field
payload = {
"sub": "1",
"username": "testuser"
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
payload = {"sub": "1", "username": "testuser"}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
@@ -343,7 +349,7 @@ class TestJWTTokenVerification:
payload = {
"sub": "1",
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
# Create token with different algorithm
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
@@ -357,15 +363,13 @@ class TestJWTTokenVerification:
# Create a token with expiration in the past
past_time = datetime.now(timezone.utc) - timedelta(minutes=1)
payload = {
"sub": "1",
"username": "testuser",
"exp": past_time.timestamp()
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
payload = {"sub": "1", "username": "testuser", "exp": past_time.timestamp()}
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
# Mock jwt.decode to bypass its expiration check and test line 205
with patch('middleware.auth.jwt.decode') as mock_decode:
with patch("middleware.auth.jwt.decode") as mock_decode:
mock_decode.return_value = payload
with pytest.raises(TokenExpiredException):
@@ -580,7 +584,9 @@ class TestCreateDefaultAdminUser:
# Existing admin user
existing_admin = Mock(spec=User)
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
mock_db.query.return_value.filter.return_value.first.return_value = (
existing_admin
)
result = auth_manager.create_default_admin_user(mock_db)
@@ -599,19 +605,21 @@ class TestAuthManagerConfiguration:
def test_default_configuration(self):
"""Test AuthManager uses default configuration."""
with patch.dict('os.environ', {}, clear=True):
with patch.dict("os.environ", {}, clear=True):
auth_manager = AuthManager()
assert auth_manager.algorithm == "HS256"
assert auth_manager.token_expire_minutes == 30
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
assert (
auth_manager.secret_key == "your-secret-key-change-in-production-please"
)
def test_custom_configuration(self):
"""Test AuthManager uses environment variables."""
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'custom-secret-key',
'JWT_EXPIRE_MINUTES': '60'
}):
with patch.dict(
"os.environ",
{"JWT_SECRET_KEY": "custom-secret-key", "JWT_EXPIRE_MINUTES": "60"},
):
auth_manager = AuthManager()
assert auth_manager.secret_key == "custom-secret-key"
@@ -619,9 +627,7 @@ class TestAuthManagerConfiguration:
def test_partial_custom_configuration(self):
"""Test AuthManager with partial environment configuration."""
with patch.dict('os.environ', {
'JWT_EXPIRE_MINUTES': '120'
}, clear=False):
with patch.dict("os.environ", {"JWT_EXPIRE_MINUTES": "120"}, clear=False):
auth_manager = AuthManager()
assert auth_manager.token_expire_minutes == 120
@@ -656,9 +662,11 @@ class TestEdgeCases:
"sub": "1",
"username": "testuser",
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
"exp": datetime.now(timezone.utc) + timedelta(hours=2),
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
token = jwt.encode(
payload, auth_manager.secret_key, algorithm=auth_manager.algorithm
)
# Should still verify successfully (JWT doesn't validate iat by default)
result = auth_manager.verify_token(token)
@@ -698,7 +706,9 @@ class TestEdgeCases:
token = token_data["access_token"]
# Mock jose.jwt.decode to raise an unexpected exception
with patch('middleware.auth.jwt.decode', side_effect=RuntimeError("Unexpected error")):
with patch(
"middleware.auth.jwt.decode", side_effect=RuntimeError("Unexpected error")
):
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)

View File

@@ -10,16 +10,13 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from unittest.mock import Mock, AsyncMock, patch
from fastapi import Request
from middleware.context import (
ContextManager,
ContextMiddleware,
RequestContext,
get_request_context,
)
from middleware.context import (ContextManager, ContextMiddleware,
RequestContext, get_request_context)
@pytest.mark.unit
@@ -321,22 +318,38 @@ class TestContextManagerHelpers:
def test_is_admin_context_from_subdomain(self):
"""Test _is_admin_context with admin subdomain."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
assert (
ContextManager._is_admin_context(
request, "admin.platform.com", "/dashboard"
)
is True
)
def test_is_admin_context_from_path(self):
"""Test _is_admin_context with admin path."""
request = Mock()
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
assert (
ContextManager._is_admin_context(request, "localhost", "/admin/users")
is True
)
def test_is_admin_context_both(self):
"""Test _is_admin_context with both subdomain and path."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
assert (
ContextManager._is_admin_context(
request, "admin.platform.com", "/admin/users"
)
is True
)
def test_is_not_admin_context(self):
"""Test _is_admin_context returns False for non-admin."""
request = Mock()
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
assert (
ContextManager._is_admin_context(request, "vendor.platform.com", "/shop")
is False
)
def test_is_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context with /vendor/ path."""
@@ -344,11 +357,16 @@ class TestContextManagerHelpers:
def test_is_vendor_dashboard_context_nested(self):
"""Test _is_vendor_dashboard_context with nested vendor path."""
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
assert (
ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
)
def test_is_not_vendor_dashboard_context_vendors_plural(self):
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
assert (
ContextManager._is_vendor_dashboard_context("/vendors/shop123/products")
is False
)
def test_is_not_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
@@ -373,7 +391,7 @@ class TestContextMiddleware:
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'context_type')
assert hasattr(request.state, "context_type")
assert request.state.context_type == RequestContext.API
call_next.assert_called_once_with(request)
@@ -565,7 +583,7 @@ class TestEdgeCases:
request.url = Mock(path="/api/vendors")
request.headers = {"host": "localhost"}
# No state attribute at all
delattr(request, 'state')
delattr(request, "state")
# Should still work, falling back to url.path
with pytest.raises(AttributeError):

View File

@@ -12,16 +12,18 @@ Tests cover:
- Edge cases and isolation
"""
import pytest
from unittest.mock import Mock
from middleware.decorators import rate_limit, rate_limiter
from app.exceptions.base import RateLimitException
import pytest
from app.exceptions.base import RateLimitException
from middleware.decorators import rate_limit, rate_limiter
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(autouse=True)
def reset_rate_limiter():
"""Reset rate limiter state before each test to ensure isolation."""
@@ -34,6 +36,7 @@ def reset_rate_limiter():
# Rate Limit Decorator Tests
# =============================================================================
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimitDecorator:
@@ -42,6 +45,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_allows_within_limit(self):
"""Test decorator allows requests within rate limit."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
return {"status": "ok"}
@@ -53,6 +57,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_blocks_exceeding_limit(self):
"""Test decorator blocks requests exceeding rate limit."""
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_blocked():
return {"status": "ok"}
@@ -71,6 +76,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_preserves_function_metadata(self):
"""Test decorator preserves original function metadata."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
"""Test endpoint docstring."""
@@ -82,21 +88,19 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_with_args_and_kwargs(self):
"""Test decorator works with function arguments."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint(arg1, arg2, kwarg1=None):
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
result = await test_endpoint("value1", "value2", kwarg1="value3")
assert result == {
"arg1": "value1",
"arg2": "value2",
"kwarg1": "value3"
}
assert result == {"arg1": "value1", "arg2": "value2", "kwarg1": "value3"}
@pytest.mark.asyncio
async def test_decorator_default_parameters(self):
"""Test decorator uses default parameters."""
@rate_limit() # Use defaults
async def test_endpoint():
return {"status": "ok"}
@@ -108,6 +112,7 @@ class TestRateLimitDecorator:
@pytest.mark.asyncio
async def test_decorator_exception_includes_retry_after(self):
"""Test rate limit exception includes retry_after."""
@rate_limit(max_requests=1, window_seconds=60)
async def test_endpoint_retry():
return {"status": "ok"}
@@ -128,6 +133,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_zero_max_requests(self):
"""Test decorator with max_requests=0 blocks all requests."""
@rate_limit(max_requests=0, window_seconds=3600)
async def test_endpoint_zero():
return {"status": "ok"}
@@ -139,6 +145,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_very_short_window(self):
"""Test decorator with very short time window."""
@rate_limit(max_requests=1, window_seconds=1)
async def test_endpoint_short():
return {"status": "ok"}
@@ -154,6 +161,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_multiple_functions_separate_limits(self):
"""Test that different functions have separate rate limits."""
@rate_limit(max_requests=1, window_seconds=3600)
async def endpoint1():
return {"endpoint": "1"}
@@ -178,6 +186,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_with_exception_in_function(self):
"""Test decorator handles exceptions from wrapped function."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint_error():
raise ValueError("Function error")
@@ -191,6 +200,7 @@ class TestRateLimitDecoratorEdgeCases:
@pytest.mark.asyncio
async def test_decorator_isolation_between_tests(self):
"""Test that rate limiter state is properly isolated between tests."""
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint_isolation():
return {"status": "ok"}
@@ -212,6 +222,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_dict(self):
"""Test decorator correctly returns dictionary."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_dict():
return {"key": "value", "number": 42}
@@ -222,6 +233,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_list(self):
"""Test decorator correctly returns list."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_list():
return [1, 2, 3, 4, 5]
@@ -232,6 +244,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_none(self):
"""Test decorator correctly returns None."""
@rate_limit(max_requests=10, window_seconds=3600)
async def return_none():
return None
@@ -242,6 +255,7 @@ class TestRateLimitDecoratorReturnValues:
@pytest.mark.asyncio
async def test_decorator_returns_object(self):
"""Test decorator correctly returns custom objects."""
class TestObject:
def __init__(self):
self.name = "test_object"

View File

@@ -11,9 +11,10 @@ Tests cover:
- Edge cases (missing client info, etc.)
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastapi import Request
from middleware.logging import LoggingMiddleware
@@ -40,7 +41,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify request was logged
@@ -65,7 +66,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
result = await middleware.dispatch(request, call_next)
# Verify response was logged
@@ -89,7 +90,7 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
assert "X-Process-Time" in response.headers
@@ -113,11 +114,13 @@ class TestLoggingMiddleware:
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Should log "unknown" for client
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
assert any(
"unknown" in str(call) for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_middleware_logs_exceptions(self):
@@ -131,8 +134,9 @@ class TestLoggingMiddleware:
call_next = AsyncMock(side_effect=Exception("Test error"))
with patch('middleware.logging.logger') as mock_logger, \
pytest.raises(Exception):
with patch("middleware.logging.logger") as mock_logger, pytest.raises(
Exception
):
await middleware.dispatch(request, call_next)
# Verify error was logged
@@ -156,7 +160,7 @@ class TestLoggingMiddleware:
call_next = slow_call_next
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
process_time = float(result.headers["X-Process-Time"])
@@ -181,7 +185,7 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger'):
with patch("middleware.logging.logger"):
result = await middleware.dispatch(request, call_next)
# Should still have process time, even if very small
@@ -205,11 +209,13 @@ class TestLoggingEdgeCases:
response = Mock(status_code=200, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify method was logged
assert any(method in str(call) for call in mock_logger.info.call_args_list)
assert any(
method in str(call) for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_middleware_logs_different_status_codes(self):
@@ -227,8 +233,11 @@ class TestLoggingEdgeCases:
response = Mock(status_code=status_code, headers={})
call_next = AsyncMock(return_value=response)
with patch('middleware.logging.logger') as mock_logger:
with patch("middleware.logging.logger") as mock_logger:
await middleware.dispatch(request, call_next)
# Verify status code was logged
assert any(str(status_code) in str(call) for call in mock_logger.info.call_args_list)
assert any(
str(status_code) in str(call)
for call in mock_logger.info.call_args_list
)

View File

@@ -11,10 +11,11 @@ Tests cover:
- Edge cases and concurrency scenarios
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from collections import deque
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch
import pytest
from middleware.rate_limiter import RateLimiter
@@ -306,8 +307,8 @@ class TestRateLimiterStatistics:
# Add requests at different times
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
stats = limiter.get_client_stats(client_id)
@@ -411,7 +412,9 @@ class TestRateLimiterEdgeCases:
limiter = RateLimiter()
client_id = "long_window_client"
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
result = limiter.allow_request(
client_id, max_requests=10, window_seconds=86400 * 365
)
assert result is True
@@ -421,10 +424,16 @@ class TestRateLimiterEdgeCases:
client_id = "same_client"
# Allow with one limit
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
assert (
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
is True
)
# Check with stricter limit
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
assert (
limiter.allow_request(client_id, max_requests=1, window_seconds=3600)
is False
)
def test_rate_limiter_unicode_client_id(self):
"""Test rate limiter with unicode client ID."""

View File

@@ -11,15 +11,14 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from unittest.mock import Mock, AsyncMock, MagicMock, patch
from fastapi import Request
from middleware.theme_context import (
ThemeContextManager,
ThemeContextMiddleware,
get_current_theme,
)
from middleware.theme_context import (ThemeContextManager,
ThemeContextMiddleware,
get_current_theme)
@pytest.mark.unit
@@ -42,7 +41,14 @@ class TestThemeContextManager:
"""Test default theme has all required colors."""
theme = ThemeContextManager.get_default_theme()
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
required_colors = [
"primary",
"secondary",
"accent",
"background",
"text",
"border",
]
for color in required_colors:
assert color in theme["colors"]
assert theme["colors"][color].startswith("#")
@@ -79,10 +85,7 @@ class TestThemeContextManager:
mock_theme = Mock()
# Mock to_dict to return actual dictionary
custom_theme_dict = {
"theme_name": "custom",
"colors": {"primary": "#ff0000"}
}
custom_theme_dict = {"theme_name": "custom", "colors": {"primary": "#ff0000"}}
mock_theme.to_dict.return_value = custom_theme_dict
# Correct filter chain: query().filter().first()
@@ -141,8 +144,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
mock_theme = {"theme_name": "test_theme"}
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
with patch(
"middleware.theme_context.get_db", return_value=iter([mock_db])
), patch.object(
ThemeContextManager, "get_vendor_theme", return_value=mock_theme
):
await middleware.dispatch(request, call_next)
@@ -161,7 +167,7 @@ class TestThemeContextMiddleware:
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'theme')
assert hasattr(request.state, "theme")
assert request.state.theme["theme_name"] == "default"
call_next.assert_called_once()
@@ -178,8 +184,11 @@ class TestThemeContextMiddleware:
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
with patch(
"middleware.theme_context.get_db", return_value=iter([mock_db])
), patch.object(
ThemeContextManager, "get_vendor_theme", side_effect=Exception("DB Error")
):
await middleware.dispatch(request, call_next)
@@ -224,7 +233,7 @@ class TestThemeEdgeCases:
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
with patch("middleware.theme_context.get_db", return_value=iter([mock_db])):
await middleware.dispatch(request, call_next)
# Verify database was closed

View File

@@ -11,17 +11,16 @@ Tests cover:
- Edge cases and error handling
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from fastapi import Request, HTTPException
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session
from middleware.vendor_context import (
VendorContextManager,
VendorContextMiddleware,
get_current_vendor,
require_vendor_context,
)
from middleware.vendor_context import (VendorContextManager,
VendorContextMiddleware,
get_current_vendor,
require_vendor_context)
@pytest.mark.unit
@@ -39,7 +38,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -55,7 +54,7 @@ class TestVendorContextManager:
request.headers = {"host": "customdomain1.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -71,7 +70,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -87,7 +86,7 @@ class TestVendorContextManager:
request.headers = {"host": "vendor1.platform.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -140,7 +139,7 @@ class TestVendorContextManager:
request.headers = {"host": "admin.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -153,7 +152,7 @@ class TestVendorContextManager:
request.headers = {"host": "www.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -166,7 +165,7 @@ class TestVendorContextManager:
request.headers = {"host": "api.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -179,7 +178,7 @@ class TestVendorContextManager:
request.headers = {"host": "localhost"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -198,12 +197,11 @@ class TestVendorContextManager:
mock_vendor.is_active = True
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor_domain
)
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -218,12 +216,11 @@ class TestVendorContextManager:
mock_vendor.is_active = False
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor_domain
)
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
context = {"detection_method": "custom_domain", "domain": "customdomain1.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -232,12 +229,11 @@ class TestVendorContextManager:
def test_get_vendor_from_custom_domain_not_found(self):
"""Test custom domain not found in database."""
mock_db = Mock(spec=Session)
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = (
None
)
context = {
"detection_method": "custom_domain",
"domain": "nonexistent.com"
}
context = {"detection_method": "custom_domain", "domain": "nonexistent.com"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -249,12 +245,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
context = {"detection_method": "subdomain", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -266,12 +261,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "path",
"subdomain": "vendor1"
}
context = {"detection_method": "path", "subdomain": "vendor1"}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -291,12 +285,11 @@ class TestVendorContextManager:
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = (
mock_vendor
)
context = {
"detection_method": "subdomain",
"subdomain": "VENDOR1" # Uppercase
}
context = {"detection_method": "subdomain", "subdomain": "VENDOR1"} # Uppercase
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
@@ -311,10 +304,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -325,10 +315,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendors/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendors/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendors/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -339,10 +326,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
vendor_context = {"detection_method": "path", "path_prefix": "/vendor/vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -353,10 +337,7 @@ class TestVendorContextManager:
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
@@ -425,21 +406,24 @@ class TestVendorContextManager:
# Static File Detection Tests
# ========================================================================
@pytest.mark.parametrize("path", [
"/static/css/style.css",
"/static/js/app.js",
"/media/images/product.png",
"/assets/logo.svg",
"/.well-known/security.txt",
"/favicon.ico",
"/image.jpg",
"/style.css",
"/app.webmanifest",
"/static/", # Path starting with /static/ but no extension
"/media/uploads", # Path starting with /media/ but no extension
"/subfolder/favicon.ico", # favicon.ico in subfolder
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
])
@pytest.mark.parametrize(
"path",
[
"/static/css/style.css",
"/static/js/app.js",
"/media/images/product.png",
"/assets/logo.svg",
"/.well-known/security.txt",
"/favicon.ico",
"/image.jpg",
"/style.css",
"/app.webmanifest",
"/static/", # Path starting with /static/ but no extension
"/media/uploads", # Path starting with /media/ but no extension
"/subfolder/favicon.ico", # favicon.ico in subfolder
"/favicon.ico.bak", # Contains favicon.ico but doesn't end with static extension (hits line 226)
],
)
def test_is_static_file_request(self, path):
"""Test static file detection for various paths and extensions."""
request = Mock(spec=Request)
@@ -447,12 +431,15 @@ class TestVendorContextManager:
assert VendorContextManager.is_static_file_request(request) is True
@pytest.mark.parametrize("path", [
"/shop/products",
"/admin/dashboard",
"/api/vendors",
"/about",
])
@pytest.mark.parametrize(
"path",
[
"/shop/products",
"/admin/dashboard",
"/api/vendors",
"/about",
],
)
def test_is_not_static_file_request(self, path):
"""Test non-static file paths."""
request = Mock(spec=Request)
@@ -478,7 +465,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
with patch.object(VendorContextManager, "is_admin_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -498,7 +485,7 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
with patch.object(VendorContextManager, "is_api_request", return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -517,7 +504,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
with patch.object(
VendorContextManager, "is_static_file_request", return_value=True
):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -540,17 +529,19 @@ class TestVendorContextMiddleware:
mock_vendor.name = "Test Vendor"
mock_vendor.subdomain = "vendor1"
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "vendor1"}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=vendor_context
), patch.object(
VendorContextManager, "get_vendor_from_context", return_value=mock_vendor
), patch.object(
VendorContextManager, "extract_clean_path", return_value="/shop/products"
), patch(
"middleware.vendor_context.get_db", return_value=iter([mock_db])
):
await middleware.dispatch(request, call_next)
@@ -571,16 +562,17 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
vendor_context = {
"detection_method": "subdomain",
"subdomain": "nonexistent"
}
vendor_context = {"detection_method": "subdomain", "subdomain": "nonexistent"}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=vendor_context
), patch.object(
VendorContextManager, "get_vendor_from_context", return_value=None
), patch(
"middleware.vendor_context.get_db", return_value=iter([mock_db])
):
await middleware.dispatch(request, call_next)
@@ -601,7 +593,9 @@ class TestVendorContextMiddleware:
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
with patch.object(
VendorContextManager, "detect_vendor_context", return_value=None
):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
@@ -714,7 +708,7 @@ class TestEdgeCases:
request.headers = {"host": "shop.vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
with patch("middleware.vendor_context.settings") as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
@@ -735,11 +729,14 @@ class TestEdgeCases:
context = {"subdomain": "nonexistent", "detection_method": "subdomain"}
with patch('middleware.vendor_context.logger') as mock_logger:
with patch("middleware.vendor_context.logger") as mock_logger:
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is None
# Verify warning was logged
mock_logger.warning.assert_called()
warning_message = str(mock_logger.warning.call_args)
assert "No active vendor found for subdomain" in warning_message and "nonexistent" in warning_message
assert (
"No active vendor found for subdomain" in warning_message
and "nonexistent" in warning_message
)

View File

@@ -1,16 +1,17 @@
# tests/unit/models/test_database_models.py
import pytest
from datetime import datetime, timezone
import pytest
from sqlalchemy.exc import IntegrityError
from models.database.marketplace_product import MarketplaceProduct
from models.database.vendor import Vendor, VendorUser, Role
from models.database.inventory import Inventory
from models.database.user import User
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.product import Product
from models.database.customer import Customer, CustomerAddress
from models.database.inventory import Inventory
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.marketplace_product import MarketplaceProduct
from models.database.order import Order, OrderItem
from models.database.product import Product
from models.database.user import User
from models.database.vendor import Role, Vendor, VendorUser
@pytest.mark.unit
@@ -277,7 +278,7 @@ class TestMarketplaceProductModel:
vendor_id=test_vendor.id,
marketplace_product_id="UNIQUE_001",
title="Product 1",
marketplace="Letzshop"
marketplace="Letzshop",
)
db.add(product1)
db.commit()
@@ -288,7 +289,7 @@ class TestMarketplaceProductModel:
vendor_id=test_vendor.id,
marketplace_product_id="UNIQUE_001",
title="Product 2",
marketplace="Letzshop"
marketplace="Letzshop",
)
db.add(product2)
db.commit()
@@ -515,7 +516,9 @@ class TestCustomerModel:
class TestOrderModel:
"""Test Order model"""
def test_order_creation(self, db, test_vendor, test_customer, test_customer_address):
def test_order_creation(
self, db, test_vendor, test_customer, test_customer_address
):
"""Test Order model with customer relationship"""
order = Order(
vendor_id=test_vendor.id,
@@ -563,7 +566,9 @@ class TestOrderModel:
assert float(order_item.unit_price) == 49.99
assert float(order_item.total_price) == 99.98
def test_order_number_uniqueness(self, db, test_vendor, test_customer, test_customer_address):
def test_order_number_uniqueness(
self, db, test_vendor, test_customer, test_customer_address
):
"""Test order_number unique constraint"""
order1 = Order(
vendor_id=test_vendor.id,

View File

@@ -1,14 +1,10 @@
# tests/unit/services/test_admin_service.py
import pytest
from app.exceptions import (
UserNotFoundException,
UserStatusChangeException,
CannotModifySelfException,
VendorNotFoundException,
VendorVerificationException,
AdminOperationException,
)
from app.exceptions import (AdminOperationException, CannotModifySelfException,
UserNotFoundException, UserStatusChangeException,
VendorNotFoundException,
VendorVerificationException)
from app.services.admin_service import AdminService
from app.services.stats_service import stats_service
from models.database.marketplace_import_job import MarketplaceImportJob
@@ -85,7 +81,9 @@ class TestAdminService:
assert exception.error_code == "CANNOT_MODIFY_SELF"
assert "deactivate account" in exception.message
def test_toggle_user_status_cannot_modify_admin(self, db, test_admin, another_admin):
def test_toggle_user_status_cannot_modify_admin(
self, db, test_admin, another_admin
):
"""Test that admin cannot modify another admin"""
with pytest.raises(UserStatusChangeException) as exc_info:
self.service.toggle_user_status(db, another_admin.id, test_admin.id)
@@ -148,7 +146,7 @@ class TestAdminService:
assert "99999" in exception.message
def test_toggle_vendor_status_deactivate(self, db, test_vendor):
"""Test deactivating a vendor """
"""Test deactivating a vendor"""
original_status = test_vendor.is_active
vendor, message = self.service.toggle_vendor_status(db, test_vendor.id)
@@ -170,21 +168,26 @@ class TestAdminService:
assert exception.error_code == "VENDOR_NOT_FOUND"
# Marketplace Import Jobs Tests
def test_get_marketplace_import_jobs_no_filters(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_no_filters(
self, db, test_marketplace_import_job
):
"""Test getting marketplace import jobs without filters"""
result = self.service.get_marketplace_import_jobs(db, skip=0, limit=10)
assert len(result) >= 1
# Find our test job in the results
test_job = next(
(job for job in result if job.job_id == test_marketplace_import_job.id), None
(job for job in result if job.job_id == test_marketplace_import_job.id),
None,
)
assert test_job is not None
assert test_job.marketplace == test_marketplace_import_job.marketplace
assert test_job.vendor_name == test_marketplace_import_job.name
assert test_job.status == test_marketplace_import_job.status
def test_get_marketplace_import_jobs_with_marketplace_filter(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_with_marketplace_filter(
self, db, test_marketplace_import_job
):
"""Test filtering marketplace import jobs by marketplace"""
result = self.service.get_marketplace_import_jobs(
db, marketplace=test_marketplace_import_job.marketplace, skip=0, limit=10
@@ -192,9 +195,14 @@ class TestAdminService:
assert len(result) >= 1
for job in result:
assert test_marketplace_import_job.marketplace.lower() in job.marketplace.lower()
assert (
test_marketplace_import_job.marketplace.lower()
in job.marketplace.lower()
)
def test_get_marketplace_import_jobs_with_vendor_filter(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_with_vendor_filter(
self, db, test_marketplace_import_job
):
"""Test filtering marketplace import jobs by vendor name"""
result = self.service.get_marketplace_import_jobs(
db, vendor_name=test_marketplace_import_job.name, skip=0, limit=10
@@ -204,7 +212,9 @@ class TestAdminService:
for job in result:
assert test_marketplace_import_job.name.lower() in job.vendor_name.lower()
def test_get_marketplace_import_jobs_with_status_filter(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_with_status_filter(
self, db, test_marketplace_import_job
):
"""Test filtering marketplace import jobs by status"""
result = self.service.get_marketplace_import_jobs(
db, status=test_marketplace_import_job.status, skip=0, limit=10
@@ -214,7 +224,9 @@ class TestAdminService:
for job in result:
assert job.status == test_marketplace_import_job.status
def test_get_marketplace_import_jobs_pagination(self, db, test_marketplace_import_job):
def test_get_marketplace_import_jobs_pagination(
self, db, test_marketplace_import_job
):
"""Test marketplace import jobs pagination"""
result_page1 = self.service.get_marketplace_import_jobs(db, skip=0, limit=1)
result_page2 = self.service.get_marketplace_import_jobs(db, skip=1, limit=1)

View File

@@ -1,11 +1,9 @@
# tests/test_auth_service.py
import pytest
from app.exceptions.auth import (
UserAlreadyExistsException,
InvalidCredentialsException,
UserNotActiveException,
)
from app.exceptions.auth import (InvalidCredentialsException,
UserAlreadyExistsException,
UserNotActiveException)
from app.exceptions.base import ValidationException
from app.services.auth_service import AuthService
from models.schema.auth import UserLogin, UserRegister
@@ -218,11 +216,14 @@ class TestAuthService:
def test_create_access_token_failure(self, test_user, monkeypatch):
"""Test creating access token handles failures"""
# Mock the auth_manager to raise an exception
def mock_create_token(*args, **kwargs):
raise Exception("Token creation failed")
monkeypatch.setattr(self.service.auth_manager, "create_access_token", mock_create_token)
monkeypatch.setattr(
self.service.auth_manager, "create_access_token", mock_create_token
)
with pytest.raises(ValidationException) as exc_info:
self.service.create_access_token(test_user)
@@ -250,11 +251,14 @@ class TestAuthService:
def test_hash_password_failure(self, monkeypatch):
"""Test password hashing handles failures"""
# Mock the auth_manager to raise an exception
def mock_hash_password(*args, **kwargs):
raise Exception("Hashing failed")
monkeypatch.setattr(self.service.auth_manager, "hash_password", mock_hash_password)
monkeypatch.setattr(
self.service.auth_manager, "hash_password", mock_hash_password
)
with pytest.raises(ValidationException) as exc_info:
self.service.hash_password("testpassword")
@@ -267,9 +271,7 @@ class TestAuthService:
def test_register_user_database_error(self, db_with_error):
"""Test user registration handles database errors"""
user_data = UserRegister(
email="test@example.com",
username="testuser",
password="password123"
email="test@example.com", username="testuser", password="password123"
)
with pytest.raises(ValidationException) as exc_info:

View File

@@ -3,19 +3,17 @@ import uuid
import pytest
from app.exceptions import (InsufficientInventoryException,
InvalidInventoryOperationException,
InvalidQuantityException,
InventoryNotFoundException,
InventoryValidationException,
NegativeInventoryException, ValidationException)
from app.services.inventory_service import InventoryService
from app.exceptions import (
InventoryNotFoundException,
InsufficientInventoryException,
InvalidInventoryOperationException,
InventoryValidationException,
NegativeInventoryException,
InvalidQuantityException,
ValidationException,
)
from models.schema.inventory import InventoryAdd, InventoryCreate, InventoryUpdate
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
from models.schema.inventory import (InventoryAdd, InventoryCreate,
InventoryUpdate)
@pytest.mark.unit
@@ -40,10 +38,14 @@ class TestInventoryService:
def test_normalize_gtin_valid(self):
"""Test GTIN normalization with valid GTINs."""
# Test various valid GTIN formats - these should remain unchanged
assert self.service._normalize_gtin("1234567890123") == "1234567890123" # EAN-13
assert (
self.service._normalize_gtin("1234567890123") == "1234567890123"
) # EAN-13
assert self.service._normalize_gtin("123456789012") == "123456789012" # UPC-A
assert self.service._normalize_gtin("12345678") == "12345678" # EAN-8
assert self.service._normalize_gtin("12345678901234") == "12345678901234" # GTIN-14
assert (
self.service._normalize_gtin("12345678901234") == "12345678901234"
) # GTIN-14
# Test with decimal points (should be removed)
assert self.service._normalize_gtin("1234567890123.0") == "1234567890123"
@@ -52,11 +54,17 @@ class TestInventoryService:
assert self.service._normalize_gtin(" 1234567890123 ") == "1234567890123"
# Test short GTINs being padded
assert self.service._normalize_gtin("123") == "0000000000123" # Padded to EAN-13
assert self.service._normalize_gtin("12345") == "0000000012345" # Padded to EAN-13
assert (
self.service._normalize_gtin("123") == "0000000000123"
) # Padded to EAN-13
assert (
self.service._normalize_gtin("12345") == "0000000012345"
) # Padded to EAN-13
# Test long GTINs being truncated
assert self.service._normalize_gtin("123456789012345") == "3456789012345" # Truncated to 13
assert (
self.service._normalize_gtin("123456789012345") == "3456789012345"
) # Truncated to 13
def test_normalize_gtin_edge_cases(self):
"""Test GTIN normalization edge cases."""
@@ -65,9 +73,15 @@ class TestInventoryService:
assert self.service._normalize_gtin(123) == "0000000000123"
# Test mixed valid/invalid characters
assert self.service._normalize_gtin("123-456-789-012") == "123456789012" # Dashes removed
assert self.service._normalize_gtin("123 456 789 012") == "123456789012" # Spaces removed
assert self.service._normalize_gtin("ABC123456789012DEF") == "123456789012" # Letters removed
assert (
self.service._normalize_gtin("123-456-789-012") == "123456789012"
) # Dashes removed
assert (
self.service._normalize_gtin("123 456 789 012") == "123456789012"
) # Spaces removed
assert (
self.service._normalize_gtin("ABC123456789012DEF") == "123456789012"
) # Letters removed
def test_set_inventory_new_entry_success(self, db):
"""Test setting inventory for a new GTIN/location combination successfully."""
@@ -162,7 +176,9 @@ class TestInventoryService:
def test_add_inventory_invalid_gtin_validation_error(self, db):
"""Test adding inventory with invalid GTIN returns InventoryValidationException."""
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50)
inventory_data = InventoryAdd(
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=50
)
with pytest.raises(InventoryValidationException) as exc_info:
self.service.add_inventory(db, inventory_data)
@@ -180,11 +196,12 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVALID_QUANTITY"
assert "Quantity must be positive" in str(exc_info.value)
def test_remove_inventory_success(self, db, test_inventory):
"""Test removing inventory successfully."""
original_quantity = test_inventory.quantity
remove_quantity = min(10, original_quantity) # Ensure we don't remove more than available
remove_quantity = min(
10, original_quantity
) # Ensure we don't remove more than available
inventory_data = InventoryAdd(
gtin=test_inventory.gtin,
@@ -212,7 +229,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
assert exc_info.value.details["gtin"] == test_inventory.gtin
assert exc_info.value.details["location"] == test_inventory.location
assert exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
assert (
exc_info.value.details["requested_quantity"] == test_inventory.quantity + 10
)
assert exc_info.value.details["available_quantity"] == test_inventory.quantity
def test_remove_inventory_nonexistent_entry_not_found(self, db):
@@ -231,7 +250,9 @@ class TestInventoryService:
def test_remove_inventory_invalid_gtin_validation_error(self, db):
"""Test removing inventory with invalid GTIN returns InventoryValidationException."""
inventory_data = InventoryAdd(gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10)
inventory_data = InventoryAdd(
gtin="invalid_gtin", location="WAREHOUSE_A", quantity=10
)
with pytest.raises(InventoryValidationException) as exc_info:
self.service.remove_inventory(db, inventory_data)
@@ -254,7 +275,9 @@ class TestInventoryService:
# The service prevents negative inventory through InsufficientInventoryException
assert exc_info.value.error_code == "INSUFFICIENT_INVENTORY"
def test_get_inventory_by_gtin_success(self, db, test_inventory, test_marketplace_product):
def test_get_inventory_by_gtin_success(
self, db, test_inventory, test_marketplace_product
):
"""Test getting inventory summary by GTIN successfully."""
result = self.service.get_inventory_by_gtin(db, test_inventory.gtin)
@@ -265,14 +288,20 @@ class TestInventoryService:
assert result.locations[0].quantity == test_inventory.quantity
assert result.product_title == test_marketplace_product.title
def test_get_inventory_by_gtin_multiple_locations_success(self, db, test_marketplace_product):
def test_get_inventory_by_gtin_multiple_locations_success(
self, db, test_marketplace_product
):
"""Test getting inventory summary with multiple locations successfully."""
unique_gtin = test_marketplace_product.gtin
unique_id = str(uuid.uuid4())[:8]
# Create multiple inventory entries for the same GTIN with unique locations
inventory1 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50)
inventory2 = Inventory(gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30)
inventory1 = Inventory(
gtin=unique_gtin, location=f"WAREHOUSE_A_{unique_id}", quantity=50
)
inventory2 = Inventory(
gtin=unique_gtin, location=f"WAREHOUSE_B_{unique_id}", quantity=30
)
db.add(inventory1)
db.add(inventory2)
@@ -301,7 +330,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVENTORY_VALIDATION_FAILED"
assert "Invalid GTIN format" in str(exc_info.value)
def test_get_total_inventory_success(self, db, test_inventory, test_marketplace_product):
def test_get_total_inventory_success(
self, db, test_inventory, test_marketplace_product
):
"""Test getting total inventory for a GTIN successfully."""
result = self.service.get_total_inventory(db, test_inventory.gtin)
@@ -364,7 +395,9 @@ class TestInventoryService:
result = self.service.get_all_inventory(db, skip=2, limit=2)
assert len(result) <= 2 # Should be at most 2, might be less if other records exist
assert (
len(result) <= 2
) # Should be at most 2, might be less if other records exist
def test_update_inventory_success(self, db, test_inventory):
"""Test updating inventory quantity successfully."""
@@ -404,7 +437,9 @@ class TestInventoryService:
assert result is True
# Verify the inventory is actually deleted
deleted_inventory = db.query(Inventory).filter(Inventory.id == inventory_id).first()
deleted_inventory = (
db.query(Inventory).filter(Inventory.id == inventory_id).first()
)
assert deleted_inventory is None
def test_delete_inventory_not_found_error(self, db):
@@ -415,7 +450,9 @@ class TestInventoryService:
assert exc_info.value.error_code == "INVENTORY_NOT_FOUND"
assert "99999" in str(exc_info.value)
def test_get_low_inventory_items_success(self, db, test_inventory, test_marketplace_product):
def test_get_low_inventory_items_success(
self, db, test_inventory, test_marketplace_product
):
"""Test getting low inventory items successfully."""
# Set inventory to a low value
test_inventory.quantity = 5
@@ -424,7 +461,9 @@ class TestInventoryService:
result = self.service.get_low_inventory_items(db, threshold=10)
assert len(result) >= 1
low_inventory_item = next((item for item in result if item["gtin"] == test_inventory.gtin), None)
low_inventory_item = next(
(item for item in result if item["gtin"] == test_inventory.gtin), None
)
assert low_inventory_item is not None
assert low_inventory_item["current_quantity"] == 5
assert low_inventory_item["location"] == test_inventory.location
@@ -440,9 +479,13 @@ class TestInventoryService:
def test_get_inventory_summary_by_location_success(self, db, test_inventory):
"""Test getting inventory summary by location successfully."""
result = self.service.get_inventory_summary_by_location(db, test_inventory.location)
result = self.service.get_inventory_summary_by_location(
db, test_inventory.location
)
assert result["location"] == test_inventory.location.upper() # Service normalizes to uppercase
assert (
result["location"] == test_inventory.location.upper()
) # Service normalizes to uppercase
assert result["total_items"] >= 1
assert result["total_quantity"] >= test_inventory.quantity
assert result["unique_gtins"] >= 1
@@ -450,7 +493,9 @@ class TestInventoryService:
def test_get_inventory_summary_by_location_empty_result(self, db):
"""Test getting inventory summary for location with no inventory."""
unique_id = str(uuid.uuid4())[:8]
result = self.service.get_inventory_summary_by_location(db, f"EMPTY_LOCATION_{unique_id}")
result = self.service.get_inventory_summary_by_location(
db, f"EMPTY_LOCATION_{unique_id}"
)
assert result["total_items"] == 0
assert result["total_quantity"] == 0
@@ -459,12 +504,16 @@ class TestInventoryService:
def test_validate_quantity_edge_cases(self, db):
"""Test quantity validation with edge cases."""
# Test zero quantity with allow_zero=True (should succeed)
inventory_data = InventoryCreate(gtin="1234567890123", location="WAREHOUSE_A", quantity=0)
inventory_data = InventoryCreate(
gtin="1234567890123", location="WAREHOUSE_A", quantity=0
)
result = self.service.set_inventory(db, inventory_data)
assert result.quantity == 0
# Test zero quantity with add_inventory (should fail - doesn't allow zero)
inventory_data_add = InventoryAdd(gtin="1234567890123", location="WAREHOUSE_B", quantity=0)
inventory_data_add = InventoryAdd(
gtin="1234567890123", location="WAREHOUSE_B", quantity=0
)
with pytest.raises(InvalidQuantityException):
self.service.add_inventory(db, inventory_data_add)
@@ -477,10 +526,10 @@ class TestInventoryService:
exception = exc_info.value
# Verify exception structure matches WizamartException.to_dict()
assert hasattr(exception, 'error_code')
assert hasattr(exception, 'message')
assert hasattr(exception, 'status_code')
assert hasattr(exception, 'details')
assert hasattr(exception, "error_code")
assert hasattr(exception, "message")
assert hasattr(exception, "status_code")
assert hasattr(exception, "details")
assert isinstance(exception.error_code, str)
assert isinstance(exception.message, str)

View File

@@ -4,19 +4,18 @@ from datetime import datetime
import pytest
from app.exceptions.marketplace_import_job import (
ImportJobNotFoundException,
ImportJobNotOwnedException,
ImportJobCannotBeCancelledException,
ImportJobCannotBeDeletedException,
)
from app.exceptions.vendor import VendorNotFoundException, UnauthorizedVendorAccessException
from app.exceptions.base import ValidationException
from app.services.marketplace_import_job_service import MarketplaceImportJobService
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
from app.exceptions.marketplace_import_job import (
ImportJobCannotBeCancelledException, ImportJobCannotBeDeletedException,
ImportJobNotFoundException, ImportJobNotOwnedException)
from app.exceptions.vendor import (UnauthorizedVendorAccessException,
VendorNotFoundException)
from app.services.marketplace_import_job_service import \
MarketplaceImportJobService
from models.database.marketplace_import_job import MarketplaceImportJob
from models.database.vendor import Vendor
from models.database.user import User
from models.database.vendor import Vendor
from models.schema.marketplace_import_job import MarketplaceImportJobRequest
@pytest.mark.unit
@@ -31,7 +30,9 @@ class TestMarketplaceService:
test_vendor.owner_user_id = test_user.id
db.commit()
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_user)
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code, test_user
)
assert result.vendor_code == test_vendor.vendor_code
assert result.owner_user_id == test_user.id
@@ -39,8 +40,10 @@ class TestMarketplaceService:
def test_validate_vendor_access_admin_can_access_any_vendor(
self, db, test_vendor, test_admin
):
"""Test that admin users can access any vendor """
result = self.service.validate_vendor_access(db, test_vendor.vendor_code, test_admin)
"""Test that admin users can access any vendor"""
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code, test_admin
)
assert result.vendor_code == test_vendor.vendor_code
@@ -57,7 +60,7 @@ class TestMarketplaceService:
def test_validate_vendor_access_permission_denied(
self, db, test_vendor, test_user, other_user
):
"""Test vendor access validation when user doesn't own the vendor """
"""Test vendor access validation when user doesn't own the vendor"""
# Set the vendor owner to a different user
test_vendor.owner_user_id = other_user.id
db.commit()
@@ -93,7 +96,7 @@ class TestMarketplaceService:
assert result.vendor_name == test_vendor.name
def test_create_import_job_invalid_vendor(self, db, test_user):
"""Test import job creation with invalid vendor """
"""Test import job creation with invalid vendor"""
request = MarketplaceImportJobRequest(
url="https://example.com/products.csv",
marketplace="Amazon",
@@ -108,7 +111,9 @@ class TestMarketplaceService:
assert exception.error_code == "VENDOR_NOT_FOUND"
assert "INVALID_VENDOR" in exception.message
def test_create_import_job_unauthorized_access(self, db, test_vendor, test_user, other_user):
def test_create_import_job_unauthorized_access(
self, db, test_vendor, test_user, other_user
):
"""Test import job creation with unauthorized vendor access"""
# Set the vendor owner to a different user
test_vendor.owner_user_id = other_user.id
@@ -127,7 +132,9 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "UNAUTHORIZED_VENDOR_ACCESS"
def test_get_import_job_by_id_success(self, db, test_marketplace_import_job, test_user):
def test_get_import_job_by_id_success(
self, db, test_marketplace_import_job, test_user
):
"""Test getting import job by ID for job owner"""
result = self.service.get_import_job_by_id(
db, test_marketplace_import_job.id, test_user
@@ -161,14 +168,18 @@ class TestMarketplaceService:
):
"""Test access denied when user doesn't own the job"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
self.service.get_import_job_by_id(db, test_marketplace_import_job.id, other_user)
self.service.get_import_job_by_id(
db, test_marketplace_import_job.id, other_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
assert exception.status_code == 403
assert str(test_marketplace_import_job.id) in exception.message
def test_get_import_jobs_user_filter(self, db, test_marketplace_import_job, test_user):
def test_get_import_jobs_user_filter(
self, db, test_marketplace_import_job, test_user
):
"""Test getting import jobs filtered by user"""
jobs = self.service.get_import_jobs(db, test_user)
@@ -176,7 +187,9 @@ class TestMarketplaceService:
assert any(job.id == test_marketplace_import_job.id for job in jobs)
assert test_marketplace_import_job.user_id == test_user.id
def test_get_import_jobs_admin_sees_all(self, db, test_marketplace_import_job, test_admin):
def test_get_import_jobs_admin_sees_all(
self, db, test_marketplace_import_job, test_admin
):
"""Test that admin sees all import jobs"""
jobs = self.service.get_import_jobs(db, test_admin)
@@ -192,7 +205,9 @@ class TestMarketplaceService:
)
assert len(jobs) >= 1
assert any(job.marketplace == test_marketplace_import_job.marketplace for job in jobs)
assert any(
job.marketplace == test_marketplace_import_job.marketplace for job in jobs
)
def test_get_import_jobs_with_pagination(self, db, test_user, test_vendor):
"""Test getting import jobs with pagination"""
@@ -330,10 +345,14 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
def test_cancel_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
def test_cancel_import_job_access_denied(
self, db, test_marketplace_import_job, other_user
):
"""Test cancelling import job without access"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
self.service.cancel_import_job(db, test_marketplace_import_job.id, other_user)
self.service.cancel_import_job(
db, test_marketplace_import_job.id, other_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
@@ -347,7 +366,9 @@ class TestMarketplaceService:
db.commit()
with pytest.raises(ImportJobCannotBeCancelledException) as exc_info:
self.service.cancel_import_job(db, test_marketplace_import_job.id, test_user)
self.service.cancel_import_job(
db, test_marketplace_import_job.id, test_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_CANNOT_BE_CANCELLED"
@@ -396,10 +417,14 @@ class TestMarketplaceService:
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_FOUND"
def test_delete_import_job_access_denied(self, db, test_marketplace_import_job, other_user):
def test_delete_import_job_access_denied(
self, db, test_marketplace_import_job, other_user
):
"""Test deleting import job without access"""
with pytest.raises(ImportJobNotOwnedException) as exc_info:
self.service.delete_import_job(db, test_marketplace_import_job.id, other_user)
self.service.delete_import_job(
db, test_marketplace_import_job.id, other_user
)
exception = exc_info.value
assert exception.error_code == "IMPORT_JOB_NOT_OWNED"
@@ -440,11 +465,15 @@ class TestMarketplaceService:
db.commit()
# Test with lowercase vendor code
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.lower(), test_user)
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code.lower(), test_user
)
assert result.vendor_code == test_vendor.vendor_code
# Test with uppercase vendor code
result = self.service.validate_vendor_access(db, test_vendor.vendor_code.upper(), test_user)
result = self.service.validate_vendor_access(
db, test_vendor.vendor_code.upper(), test_user
)
assert result.vendor_code == test_vendor.vendor_code
def test_create_import_job_database_error(self, db_with_error, test_user):

View File

@@ -1,16 +1,15 @@
# tests/test_product_service.py
import pytest
from app.exceptions import (InvalidMarketplaceProductDataException,
MarketplaceProductAlreadyExistsException,
MarketplaceProductNotFoundException,
MarketplaceProductValidationException,
ValidationException)
from app.services.marketplace_product_service import MarketplaceProductService
from app.exceptions import (
MarketplaceProductNotFoundException,
MarketplaceProductAlreadyExistsException,
InvalidMarketplaceProductDataException,
MarketplaceProductValidationException,
ValidationException,
)
from models.schema.marketplace_product import MarketplaceProductCreate, MarketplaceProductUpdate
from models.database.marketplace_product import MarketplaceProduct
from models.schema.marketplace_product import (MarketplaceProductCreate,
MarketplaceProductUpdate)
@pytest.mark.unit
@@ -98,7 +97,10 @@ class TestProductService:
assert exc_info.value.error_code == "PRODUCT_ALREADY_EXISTS"
assert test_marketplace_product.marketplace_product_id in str(exc_info.value)
assert exc_info.value.status_code == 409
assert exc_info.value.details.get("marketplace_product_id") == test_marketplace_product.marketplace_product_id
assert (
exc_info.value.details.get("marketplace_product_id")
== test_marketplace_product.marketplace_product_id
)
def test_create_product_invalid_price(self, db):
"""Test product creation with invalid price raises InvalidMarketplaceProductDataException"""
@@ -117,9 +119,14 @@ class TestProductService:
def test_get_product_by_id_or_raise_success(self, db, test_marketplace_product):
"""Test successful product retrieval by ID"""
product = self.service.get_product_by_id_or_raise(db, test_marketplace_product.marketplace_product_id)
product = self.service.get_product_by_id_or_raise(
db, test_marketplace_product.marketplace_product_id
)
assert product.marketplace_product_id == test_marketplace_product.marketplace_product_id
assert (
product.marketplace_product_id
== test_marketplace_product.marketplace_product_id
)
assert product.title == test_marketplace_product.title
def test_get_product_by_id_or_raise_not_found(self, db):
@@ -152,21 +159,35 @@ class TestProductService:
assert total >= 1
assert len(products) >= 1
# Verify search worked by checking that title contains search term
found_product = next((p for p in products if p.marketplace_product_id == test_marketplace_product.marketplace_product_id), None)
found_product = next(
(
p
for p in products
if p.marketplace_product_id
== test_marketplace_product.marketplace_product_id
),
None,
)
assert found_product is not None
def test_update_product_success(self, db, test_marketplace_product):
"""Test successful product update"""
update_data = MarketplaceProductUpdate(
title="Updated MarketplaceProduct Title",
price="39.99"
title="Updated MarketplaceProduct Title", price="39.99"
)
updated_product = self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
updated_product = self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert updated_product.title == "Updated MarketplaceProduct Title"
assert updated_product.price == "39.99" # Price is stored as string after processing
assert updated_product.marketplace_product_id == test_marketplace_product.marketplace_product_id # ID unchanged
assert (
updated_product.price == "39.99"
) # Price is stored as string after processing
assert (
updated_product.marketplace_product_id
== test_marketplace_product.marketplace_product_id
) # ID unchanged
def test_update_product_not_found(self, db):
"""Test updating non-existent product raises MarketplaceProductNotFoundException"""
@@ -183,7 +204,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(gtin="invalid_gtin")
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
assert "Invalid GTIN format" in str(exc_info.value)
@@ -194,7 +217,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(title="")
with pytest.raises(MarketplaceProductValidationException) as exc_info:
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert exc_info.value.error_code == "PRODUCT_VALIDATION_FAILED"
assert "MarketplaceProduct title cannot be empty" in str(exc_info.value)
@@ -205,7 +230,9 @@ class TestProductService:
update_data = MarketplaceProductUpdate(price="invalid_price")
with pytest.raises(InvalidMarketplaceProductDataException) as exc_info:
self.service.update_product(db, test_marketplace_product.marketplace_product_id, update_data)
self.service.update_product(
db, test_marketplace_product.marketplace_product_id, update_data
)
assert exc_info.value.error_code == "INVALID_PRODUCT_DATA"
assert "Invalid price format" in str(exc_info.value)
@@ -213,12 +240,16 @@ class TestProductService:
def test_delete_product_success(self, db, test_marketplace_product):
"""Test successful product deletion"""
result = self.service.delete_product(db, test_marketplace_product.marketplace_product_id)
result = self.service.delete_product(
db, test_marketplace_product.marketplace_product_id
)
assert result is True
# Verify product is deleted
deleted_product = self.service.get_product_by_id(db, test_marketplace_product.marketplace_product_id)
deleted_product = self.service.get_product_by_id(
db, test_marketplace_product.marketplace_product_id
)
assert deleted_product is None
def test_delete_product_not_found(self, db):
@@ -229,10 +260,14 @@ class TestProductService:
assert exc_info.value.error_code == "PRODUCT_NOT_FOUND"
assert "NONEXISTENT" in str(exc_info.value)
def test_get_inventory_info_success(self, db, test_marketplace_product_with_inventory):
def test_get_inventory_info_success(
self, db, test_marketplace_product_with_inventory
):
"""Test getting inventory info for product with inventory"""
# Extract the product from the dictionary
marketplace_product = test_marketplace_product_with_inventory['marketplace_product']
marketplace_product = test_marketplace_product_with_inventory[
"marketplace_product"
]
inventory_info = self.service.get_inventory_info(db, marketplace_product.gtin)
@@ -243,13 +278,17 @@ class TestProductService:
def test_get_inventory_info_no_inventory(self, db, test_marketplace_product):
"""Test getting inventory info for product without inventory"""
inventory_info = self.service.get_inventory_info(db, test_marketplace_product.gtin or "1234567890123")
inventory_info = self.service.get_inventory_info(
db, test_marketplace_product.gtin or "1234567890123"
)
assert inventory_info is None
def test_product_exists_true(self, db, test_marketplace_product):
"""Test product_exists returns True for existing product"""
exists = self.service.product_exists(db, test_marketplace_product.marketplace_product_id)
exists = self.service.product_exists(
db, test_marketplace_product.marketplace_product_id
)
assert exists is True
def test_product_exists_false(self, db):
@@ -265,7 +304,9 @@ class TestProductService:
csv_lines = list(csv_generator)
assert len(csv_lines) > 1 # Header + at least one data row
assert csv_lines[0].startswith("marketplace_product_id,title,description") # Check header
assert csv_lines[0].startswith(
"marketplace_product_id,title,description"
) # Check header
# Check that test product appears in CSV
csv_content = "".join(csv_lines)
@@ -274,8 +315,7 @@ class TestProductService:
def test_generate_csv_export_with_filters(self, db, test_marketplace_product):
"""Test CSV export with marketplace filter"""
csv_generator = self.service.generate_csv_export(
db,
marketplace=test_marketplace_product.marketplace
db, marketplace=test_marketplace_product.marketplace
)
csv_lines = list(csv_generator)

View File

@@ -2,8 +2,8 @@
import pytest
from app.services.stats_service import StatsService
from models.database.marketplace_product import MarketplaceProduct
from models.database.inventory import Inventory
from models.database.marketplace_product import MarketplaceProduct
@pytest.mark.unit
@@ -15,7 +15,9 @@ class TestStatsService:
"""Setup method following the same pattern as other service tests"""
self.service = StatsService()
def test_get_comprehensive_stats_basic(self, db, test_marketplace_product, test_inventory):
def test_get_comprehensive_stats_basic(
self, db, test_marketplace_product, test_inventory
):
"""Test getting comprehensive stats with basic data"""
stats = self.service.get_comprehensive_stats(db)
@@ -31,7 +33,9 @@ class TestStatsService:
assert stats["total_inventory_entries"] >= 1
assert stats["total_inventory_quantity"] >= 10 # test_inventory has quantity 10
def test_get_comprehensive_stats_multiple_products(self, db, test_marketplace_product):
def test_get_comprehensive_stats_multiple_products(
self, db, test_marketplace_product
):
"""Test comprehensive stats with multiple products across different dimensions"""
# Create products with different brands, categories, marketplaces
additional_products = [
@@ -87,7 +91,7 @@ class TestStatsService:
brand=None, # Null brand
google_product_category=None, # Null category
marketplace=None, # Null marketplace
vendor_name=None, # Null vendor
vendor_name=None, # Null vendor
price="10.00",
currency="EUR",
),
@@ -97,7 +101,7 @@ class TestStatsService:
brand="", # Empty brand
google_product_category="", # Empty category
marketplace="", # Empty marketplace
vendor_name="", # Empty vendor
vendor_name="", # Empty vendor
price="15.00",
currency="EUR",
),
@@ -124,7 +128,11 @@ class TestStatsService:
# Find our test marketplace in the results
test_marketplace_stat = next(
(stat for stat in stats if stat["marketplace"] == test_marketplace_product.marketplace),
(
stat
for stat in stats
if stat["marketplace"] == test_marketplace_product.marketplace
),
None,
)
assert test_marketplace_stat is not None
@@ -309,7 +317,9 @@ class TestStatsService:
count = self.service._get_unique_marketplaces_count(db)
assert count >= 2 # At least Amazon and eBay, plus test_marketplace_product marketplace
assert (
count >= 2
) # At least Amazon and eBay, plus test_marketplace_product marketplace
assert isinstance(count, int)
def test_get_unique_vendors_count(self, db, test_marketplace_product):
@@ -338,7 +348,9 @@ class TestStatsService:
count = self.service._get_unique_vendors_count(db)
assert count >= 2 # At least VendorA and VendorB, plus test_marketplace_product vendor
assert (
count >= 2
) # At least VendorA and VendorB, plus test_marketplace_product vendor
assert isinstance(count, int)
def test_get_inventory_statistics(self, db, test_inventory):
@@ -438,7 +450,7 @@ class TestStatsService:
db.add_all(marketplace_products)
db.commit()
vendors =self.service._get_vendors_by_marketplace(db, "TestMarketplace")
vendors = self.service._get_vendors_by_marketplace(db, "TestMarketplace")
assert len(vendors) == 2
assert "TestVendor1" in vendors
@@ -482,7 +494,9 @@ class TestStatsService:
def test_get_products_by_marketplace_not_found(self, db):
"""Test getting product count for non-existent marketplace"""
count = self.service._get_products_by_marketplace_count(db, "NonExistentMarketplace")
count = self.service._get_products_by_marketplace_count(
db, "NonExistentMarketplace"
)
assert count == 0

View File

@@ -1,19 +1,16 @@
# tests/test_vendor_service.py (updated to use custom exceptions)
import pytest
from app.exceptions import (InvalidVendorDataException,
MarketplaceProductNotFoundException,
MaxVendorsReachedException,
ProductAlreadyExistsException,
UnauthorizedVendorAccessException,
ValidationException, VendorAlreadyExistsException,
VendorNotFoundException)
from app.services.vendor_service import VendorService
from app.exceptions import (
VendorNotFoundException,
VendorAlreadyExistsException,
UnauthorizedVendorAccessException,
InvalidVendorDataException,
MarketplaceProductNotFoundException,
ProductAlreadyExistsException,
MaxVendorsReachedException,
ValidationException,
)
from models.schema.vendor import VendorCreate
from models.schema.product import ProductCreate
from models.schema.vendor import VendorCreate
@pytest.mark.unit
@@ -38,15 +35,17 @@ class TestVendorService:
assert vendor is not None
assert vendor.vendor_code == "NEWVENDOR"
assert vendor.owner_user_id == test_user.id
assert vendor.is_verified is False # Regular user creates unverified vendor
assert vendor.is_verified is False # Regular user creates unverified vendor
def test_create_vendor_admin_auto_verify(self, db, test_admin, vendor_factory):
"""Test admin creates verified vendor automatically"""
vendor_data = VendorCreate(vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor")
vendor_data = VendorCreate(
vendor_code="ADMINVENDOR", vendor_name="Admin Test Vendor"
)
vendor = self.service.create_vendor(db, vendor_data, test_admin)
assert vendor.is_verified is True # Admin creates verified vendor
assert vendor.is_verified is True # Admin creates verified vendor
def test_create_vendor_duplicate_code(self, db, test_user, test_vendor):
"""Test vendor creation fails with duplicate vendor code"""
@@ -88,7 +87,9 @@ class TestVendorService:
def test_create_vendor_invalid_code_format(self, db, test_user):
"""Test vendor creation fails with invalid vendor code format"""
vendor_data = VendorCreate(vendor_code="INVALID@CODE!", vendor_name="Test Vendor")
vendor_data = VendorCreate(
vendor_code="INVALID@CODE!", vendor_name="Test Vendor"
)
with pytest.raises(InvalidVendorDataException) as exc_info:
self.service.create_vendor(db, vendor_data, test_user)
@@ -105,7 +106,9 @@ class TestVendorService:
def mock_check_vendor_limit(self, db, user):
raise MaxVendorsReachedException(max_vendors=5, user_id=user.id)
monkeypatch.setattr(VendorService, "_check_vendor_limit", mock_check_vendor_limit)
monkeypatch.setattr(
VendorService, "_check_vendor_limit", mock_check_vendor_limit
)
vendor_data = VendorCreate(vendor_code="NEWVENDOR", vendor_name="New Vendor")
@@ -118,7 +121,9 @@ class TestVendorService:
assert exception.details["max_vendors"] == 5
assert exception.details["user_id"] == test_user.id
def test_get_vendors_regular_user(self, db, test_user, test_vendor, inactive_vendor):
def test_get_vendors_regular_user(
self, db, test_user, test_vendor, inactive_vendor
):
"""Test regular user can only see active verified vendors and own vendors"""
vendors, total = self.service.get_vendors(db, test_user, skip=0, limit=10)
@@ -127,7 +132,7 @@ class TestVendorService:
assert inactive_vendor.vendor_code not in vendor_codes
def test_get_vendors_admin_user(
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
self, db, test_admin, test_vendor, inactive_vendor, verified_vendor
):
"""Test admin user can see all vendors with filters"""
vendors, total = self.service.get_vendors(
@@ -140,14 +145,16 @@ class TestVendorService:
assert verified_vendor.vendor_code in vendor_codes
def test_get_vendor_by_code_owner_access(self, db, test_user, test_vendor):
"""Test vendor owner can access their own vendor """
vendor = self.service.get_vendor_by_code(db, test_vendor.vendor_code.lower(), test_user)
"""Test vendor owner can access their own vendor"""
vendor = self.service.get_vendor_by_code(
db, test_vendor.vendor_code.lower(), test_user
)
assert vendor is not None
assert vendor.id == test_vendor.id
def test_get_vendor_by_code_admin_access(self, db, test_admin, test_vendor):
"""Test admin can access any vendor """
"""Test admin can access any vendor"""
vendor = self.service.get_vendor_by_code(
db, test_vendor.vendor_code.lower(), test_admin
)
@@ -178,16 +185,14 @@ class TestVendorService:
assert exception.details["user_id"] == test_user.id
def test_add_product_to_vendor_success(self, db, test_vendor, unique_product):
"""Test successfully adding product to vendor """
"""Test successfully adding product to vendor"""
product_data = ProductCreate(
marketplace_product_id=unique_product.marketplace_product_id,
price="15.99",
is_featured=True,
)
product = self.service.add_product_to_catalog(
db, test_vendor, product_data
)
product = self.service.add_product_to_catalog(db, test_vendor, product_data)
assert product is not None
assert product.vendor_id == test_vendor.id
@@ -195,7 +200,9 @@ class TestVendorService:
def test_add_product_to_vendor_product_not_found(self, db, test_vendor):
"""Test adding non-existent product to vendor fails"""
product_data = ProductCreate(marketplace_product_id="NONEXISTENT", price="15.99")
product_data = ProductCreate(
marketplace_product_id="NONEXISTENT", price="15.99"
)
with pytest.raises(MarketplaceProductNotFoundException) as exc_info:
self.service.add_product_to_catalog(db, test_vendor, product_data)
@@ -209,7 +216,8 @@ class TestVendorService:
def test_add_product_to_vendor_already_exists(self, db, test_vendor, test_product):
"""Test adding product that's already in vendor fails"""
product_data = ProductCreate(
marketplace_product_id=test_product.marketplace_product.marketplace_product_id, price="15.99"
marketplace_product_id=test_product.marketplace_product.marketplace_product_id,
price="15.99",
)
with pytest.raises(ProductAlreadyExistsException) as exc_info:
@@ -219,11 +227,12 @@ class TestVendorService:
assert exception.status_code == 409
assert exception.error_code == "PRODUCT_ALREADY_EXISTS"
assert exception.details["vendor_code"] == test_vendor.vendor_code
assert exception.details["marketplace_product_id"] == test_product.marketplace_product.marketplace_product_id
assert (
exception.details["marketplace_product_id"]
== test_product.marketplace_product.marketplace_product_id
)
def test_get_products_owner_access(
self, db, test_user, test_vendor, test_product
):
def test_get_products_owner_access(self, db, test_user, test_vendor, test_product):
"""Test vendor owner can get vendor products"""
products, total = self.service.get_products(db, test_vendor, test_user)
@@ -291,7 +300,9 @@ class TestVendorService:
assert exception.error_code == "VALIDATION_ERROR"
assert "Failed to retrieve vendors" in exception.message
def test_add_product_database_error(self, db, test_vendor, unique_product, monkeypatch):
def test_add_product_database_error(
self, db, test_vendor, unique_product, monkeypatch
):
"""Test add product handles database errors gracefully"""
def mock_commit():

View File

@@ -18,7 +18,9 @@ class TestCSVProcessor:
def test_download_csv_encoding_fallback(self, mock_get):
"""Test CSV download with encoding fallback"""
# Create content with special characters that would fail UTF-8 if not properly encoded
special_content = "marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
special_content = (
"marketplace_product_id,title,price\nTEST001,Café MarketplaceProduct,10.99"
)
mock_response = Mock()
mock_response.status_code = 200
@@ -40,9 +42,7 @@ class TestCSVProcessor:
mock_response = Mock()
mock_response.status_code = 200
# Create bytes that will fail most encodings
mock_response.content = (
b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
)
mock_response.content = b"marketplace_product_id,title,price\nTEST001,\xff\xfe MarketplaceProduct,10.99"
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response