revamping documentation

This commit is contained in:
2025-11-17 22:59:42 +01:00
parent bbd64a6f21
commit 807033be16
107 changed files with 11973 additions and 28413 deletions

View File

@@ -0,0 +1,2 @@
# tests/unit/middleware/__init__.py
"""Unit tests - fast, isolated component tests."""

View File

@@ -0,0 +1,657 @@
# tests/unit/middleware/test_auth.py
"""
Comprehensive unit tests for AuthManager.
Tests cover:
- Password hashing and verification
- JWT token creation and validation
- User authentication
- Token expiration handling
- Role-based access control
- Admin/vendor/customer permission checks
- Error handling and edge cases
"""
import pytest
from unittest.mock import Mock, MagicMock, patch
from datetime import datetime, timedelta, timezone
from jose import jwt
from middleware.auth import AuthManager
from app.exceptions import (
InvalidTokenException,
TokenExpiredException,
UserNotActiveException,
InvalidCredentialsException,
AdminRequiredException,
InsufficientPermissionsException,
)
from models.database.user import User
@pytest.mark.unit
@pytest.mark.auth
class TestPasswordHashing:
"""Test suite for password hashing functionality."""
def test_hash_password(self):
"""Test password hashing creates different hash for each call."""
auth_manager = AuthManager()
password = "test_password_123"
hash1 = auth_manager.hash_password(password)
hash2 = auth_manager.hash_password(password)
# Hashes should be different due to salt
assert hash1 != hash2
# Both should be valid bcrypt hashes (start with $2b$)
assert hash1.startswith("$2b$")
assert hash2.startswith("$2b$")
def test_verify_password_correct(self):
"""Test password verification with correct password."""
auth_manager = AuthManager()
password = "test_password_123"
hashed = auth_manager.hash_password(password)
assert auth_manager.verify_password(password, hashed) is True
def test_verify_password_incorrect(self):
"""Test password verification with incorrect password."""
auth_manager = AuthManager()
password = "test_password_123"
wrong_password = "wrong_password_456"
hashed = auth_manager.hash_password(password)
assert auth_manager.verify_password(wrong_password, hashed) is False
def test_verify_password_empty(self):
"""Test password verification with empty password."""
auth_manager = AuthManager()
password = "test_password_123"
hashed = auth_manager.hash_password(password)
assert auth_manager.verify_password("", hashed) is False
def test_hash_password_special_characters(self):
"""Test hashing password with special characters."""
auth_manager = AuthManager()
password = "P@ssw0rd!#$%^&*()_+-=[]{}|;:,.<>?"
hashed = auth_manager.hash_password(password)
assert auth_manager.verify_password(password, hashed) is True
def test_hash_password_unicode(self):
"""Test hashing password with unicode characters."""
auth_manager = AuthManager()
password = "パスワード123こんにちは"
hashed = auth_manager.hash_password(password)
assert auth_manager.verify_password(password, hashed) is True
@pytest.mark.unit
@pytest.mark.auth
class TestUserAuthentication:
"""Test suite for user authentication."""
def test_authenticate_user_success_with_username(self):
"""Test successful authentication with username."""
auth_manager = AuthManager()
mock_db = Mock()
mock_user = Mock(spec=User)
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.hashed_password = auth_manager.hash_password("password123")
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = auth_manager.authenticate_user(mock_db, "testuser", "password123")
assert result is mock_user
def test_authenticate_user_success_with_email(self):
"""Test successful authentication with email."""
auth_manager = AuthManager()
mock_db = Mock()
mock_user = Mock(spec=User)
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.hashed_password = auth_manager.hash_password("password123")
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
assert result is mock_user
def test_authenticate_user_not_found(self):
"""Test authentication with non-existent user."""
auth_manager = AuthManager()
mock_db = Mock()
mock_db.query.return_value.filter.return_value.first.return_value = None
result = auth_manager.authenticate_user(mock_db, "nonexistent", "password123")
assert result is None
def test_authenticate_user_wrong_password(self):
"""Test authentication with wrong password."""
auth_manager = AuthManager()
mock_db = Mock()
mock_user = Mock(spec=User)
mock_user.hashed_password = auth_manager.hash_password("correctpassword")
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = auth_manager.authenticate_user(mock_db, "testuser", "wrongpassword")
assert result is None
@pytest.mark.unit
@pytest.mark.auth
class TestJWTTokenCreation:
"""Test suite for JWT token creation."""
def test_create_access_token_structure(self):
"""Test JWT token creation returns correct structure."""
auth_manager = AuthManager()
mock_user = Mock(spec=User)
mock_user.id = 1
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "customer"
token_data = auth_manager.create_access_token(mock_user)
assert "access_token" in token_data
assert "token_type" in token_data
assert "expires_in" in token_data
assert token_data["token_type"] == "bearer"
assert isinstance(token_data["expires_in"], int)
assert token_data["expires_in"] == auth_manager.token_expire_minutes * 60
def test_create_access_token_payload(self):
"""Test JWT token contains correct payload."""
auth_manager = AuthManager()
mock_user = Mock(spec=User)
mock_user.id = 42
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "vendor"
token_data = auth_manager.create_access_token(mock_user)
token = token_data["access_token"]
# Decode without verification to check payload
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
assert payload["sub"] == "42"
assert payload["username"] == "testuser"
assert payload["email"] == "test@example.com"
assert payload["role"] == "vendor"
assert "exp" in payload
assert "iat" in payload
def test_create_access_token_different_users(self):
"""Test tokens are different for different users."""
auth_manager = AuthManager()
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
token1 = auth_manager.create_access_token(user1)["access_token"]
token2 = auth_manager.create_access_token(user2)["access_token"]
assert token1 != token2
def test_create_access_token_admin_role(self):
"""Test token creation for admin user."""
auth_manager = AuthManager()
admin_user = Mock(spec=User)
admin_user.id = 1
admin_user.username = "admin"
admin_user.email = "admin@example.com"
admin_user.role = "admin"
token_data = auth_manager.create_access_token(admin_user)
payload = jwt.decode(
token_data["access_token"],
auth_manager.secret_key,
algorithms=[auth_manager.algorithm]
)
assert payload["role"] == "admin"
@pytest.mark.unit
@pytest.mark.auth
class TestJWTTokenVerification:
"""Test suite for JWT token verification."""
def test_verify_token_success(self):
"""Test successful token verification."""
auth_manager = AuthManager()
mock_user = Mock(spec=User)
mock_user.id = 1
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "customer"
token_data = auth_manager.create_access_token(mock_user)
token = token_data["access_token"]
result = auth_manager.verify_token(token)
assert result["user_id"] == 1
assert result["username"] == "testuser"
assert result["email"] == "test@example.com"
assert result["role"] == "customer"
def test_verify_token_expired(self):
"""Test token verification with expired token."""
auth_manager = AuthManager()
auth_manager.token_expire_minutes = -1 # Set to negative to force expiration
mock_user = Mock(spec=User)
mock_user.id = 1
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "customer"
token_data = auth_manager.create_access_token(mock_user)
token = token_data["access_token"]
# Reset to normal
auth_manager.token_expire_minutes = 30
with pytest.raises(TokenExpiredException):
auth_manager.verify_token(token)
def test_verify_token_invalid(self):
"""Test token verification with invalid token."""
auth_manager = AuthManager()
with pytest.raises(InvalidTokenException):
auth_manager.verify_token("invalid.token.here")
def test_verify_token_tampered(self):
"""Test token verification with tampered token."""
auth_manager = AuthManager()
mock_user = Mock(spec=User)
mock_user.id = 1
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "customer"
token = auth_manager.create_access_token(mock_user)["access_token"]
# Tamper with token
parts = token.split(".")
tampered_token = ".".join([parts[0], parts[1], "tampered"])
with pytest.raises(InvalidTokenException):
auth_manager.verify_token(tampered_token)
def test_verify_token_missing_user_id(self):
"""Test token verification with missing user ID."""
auth_manager = AuthManager()
# Create token without 'sub' field
payload = {
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
assert "missing user identifier" in str(exc_info.value.message)
def test_verify_token_missing_expiration(self):
"""Test token verification with missing expiration."""
auth_manager = AuthManager()
# Create token without 'exp' field
payload = {
"sub": "1",
"username": "testuser"
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
with pytest.raises(InvalidTokenException) as exc_info:
auth_manager.verify_token(token)
assert "missing expiration" in str(exc_info.value.message)
def test_verify_token_wrong_algorithm(self):
"""Test token verification with different algorithm."""
auth_manager = AuthManager()
payload = {
"sub": "1",
"username": "testuser",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
}
# Create token with different algorithm
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
with pytest.raises(InvalidTokenException):
auth_manager.verify_token(token)
@pytest.mark.unit
@pytest.mark.auth
class TestGetCurrentUser:
"""Test suite for get_current_user functionality."""
def test_get_current_user_success(self):
"""Test successfully getting current user."""
auth_manager = AuthManager()
mock_db = Mock()
# Create mock user
mock_user = Mock(spec=User)
mock_user.id = 1
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "customer"
mock_user.is_active = True
# Setup database mock
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
# Create valid token
token_data = auth_manager.create_access_token(mock_user)
# Create mock credentials
mock_credentials = Mock()
mock_credentials.credentials = token_data["access_token"]
result = auth_manager.get_current_user(mock_db, mock_credentials)
assert result is mock_user
def test_get_current_user_not_found(self):
"""Test get_current_user when user doesn't exist in database."""
auth_manager = AuthManager()
mock_db = Mock()
# Setup database to return None
mock_db.query.return_value.filter.return_value.first.return_value = None
# Create mock user for token
mock_user = Mock(spec=User)
mock_user.id = 999
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "customer"
token_data = auth_manager.create_access_token(mock_user)
mock_credentials = Mock()
mock_credentials.credentials = token_data["access_token"]
with pytest.raises(InvalidCredentialsException):
auth_manager.get_current_user(mock_db, mock_credentials)
def test_get_current_user_inactive(self):
"""Test get_current_user with inactive user."""
auth_manager = AuthManager()
mock_db = Mock()
mock_user = Mock(spec=User)
mock_user.id = 1
mock_user.username = "testuser"
mock_user.email = "test@example.com"
mock_user.role = "customer"
mock_user.is_active = False # Inactive user
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
token_data = auth_manager.create_access_token(mock_user)
mock_credentials = Mock()
mock_credentials.credentials = token_data["access_token"]
with pytest.raises(UserNotActiveException):
auth_manager.get_current_user(mock_db, mock_credentials)
@pytest.mark.unit
@pytest.mark.auth
class TestRoleRequirements:
"""Test suite for role-based access control."""
def test_require_admin_success(self):
"""Test require_admin with admin user."""
auth_manager = AuthManager()
admin_user = Mock(spec=User)
admin_user.role = "admin"
result = auth_manager.require_admin(admin_user)
assert result is admin_user
def test_require_admin_failure(self):
"""Test require_admin with non-admin user."""
auth_manager = AuthManager()
customer_user = Mock(spec=User)
customer_user.role = "customer"
with pytest.raises(AdminRequiredException):
auth_manager.require_admin(customer_user)
def test_require_vendor_with_vendor_role(self):
"""Test require_vendor with vendor user."""
auth_manager = AuthManager()
vendor_user = Mock(spec=User)
vendor_user.role = "vendor"
result = auth_manager.require_vendor(vendor_user)
assert result is vendor_user
def test_require_vendor_with_admin_role(self):
"""Test require_vendor with admin user (admin can access vendor areas)."""
auth_manager = AuthManager()
admin_user = Mock(spec=User)
admin_user.role = "admin"
result = auth_manager.require_vendor(admin_user)
assert result is admin_user
def test_require_vendor_failure(self):
"""Test require_vendor with customer user."""
auth_manager = AuthManager()
customer_user = Mock(spec=User)
customer_user.role = "customer"
with pytest.raises(InsufficientPermissionsException) as exc_info:
auth_manager.require_vendor(customer_user)
assert exc_info.value.details.get("required_permission") == "vendor"
def test_require_customer_with_customer_role(self):
"""Test require_customer with customer user."""
auth_manager = AuthManager()
customer_user = Mock(spec=User)
customer_user.role = "customer"
result = auth_manager.require_customer(customer_user)
assert result is customer_user
def test_require_customer_with_admin_role(self):
"""Test require_customer with admin user (admin can access customer areas)."""
auth_manager = AuthManager()
admin_user = Mock(spec=User)
admin_user.role = "admin"
result = auth_manager.require_customer(admin_user)
assert result is admin_user
def test_require_customer_failure(self):
"""Test require_customer with vendor user."""
auth_manager = AuthManager()
vendor_user = Mock(spec=User)
vendor_user.role = "vendor"
with pytest.raises(InsufficientPermissionsException) as exc_info:
auth_manager.require_customer(vendor_user)
assert exc_info.value.details.get("required_permission") == "customer"
@pytest.mark.unit
@pytest.mark.auth
class TestCreateDefaultAdminUser:
"""Test suite for default admin user creation."""
def test_create_default_admin_user_first_time(self):
"""Test creating default admin user when none exists."""
auth_manager = AuthManager()
mock_db = Mock()
# No existing admin user
mock_db.query.return_value.filter.return_value.first.return_value = None
result = auth_manager.create_default_admin_user(mock_db)
# Verify admin user was created
mock_db.add.assert_called_once()
mock_db.commit.assert_called_once()
mock_db.refresh.assert_called_once()
# Verify the created user
created_user = mock_db.add.call_args[0][0]
assert created_user.username == "admin"
assert created_user.email == "admin@example.com"
assert created_user.role == "admin"
assert created_user.is_active is True
assert auth_manager.verify_password("admin123", created_user.hashed_password)
def test_create_default_admin_user_already_exists(self):
"""Test creating default admin user when one already exists."""
auth_manager = AuthManager()
mock_db = Mock()
# Existing admin user
existing_admin = Mock(spec=User)
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
result = auth_manager.create_default_admin_user(mock_db)
# Should not create new user
mock_db.add.assert_not_called()
mock_db.commit.assert_not_called()
# Should return existing user
assert result is existing_admin
@pytest.mark.unit
@pytest.mark.auth
class TestAuthManagerConfiguration:
"""Test suite for AuthManager configuration."""
def test_default_configuration(self):
"""Test AuthManager uses default configuration."""
with patch.dict('os.environ', {}, clear=True):
auth_manager = AuthManager()
assert auth_manager.algorithm == "HS256"
assert auth_manager.token_expire_minutes == 30
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
def test_custom_configuration(self):
"""Test AuthManager uses environment variables."""
with patch.dict('os.environ', {
'JWT_SECRET_KEY': 'custom-secret-key',
'JWT_EXPIRE_MINUTES': '60'
}):
auth_manager = AuthManager()
assert auth_manager.secret_key == "custom-secret-key"
assert auth_manager.token_expire_minutes == 60
def test_partial_custom_configuration(self):
"""Test AuthManager with partial environment configuration."""
with patch.dict('os.environ', {
'JWT_EXPIRE_MINUTES': '120'
}, clear=False):
auth_manager = AuthManager()
assert auth_manager.token_expire_minutes == 120
# Secret key should use default or existing env var
assert auth_manager.secret_key is not None
@pytest.mark.unit
@pytest.mark.auth
class TestEdgeCases:
"""Test suite for edge cases and error scenarios."""
def test_verify_password_with_none(self):
"""Test password verification with None values."""
auth_manager = AuthManager()
# This should not raise an exception, just return False
with pytest.raises(Exception):
auth_manager.verify_password(None, None)
def test_token_with_future_iat(self):
"""Test token with issued_at time in the future."""
auth_manager = AuthManager()
payload = {
"sub": "1",
"username": "testuser",
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
}
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
# Should still verify successfully (JWT doesn't validate iat by default)
result = auth_manager.verify_token(token)
assert result["user_id"] == 1
def test_authenticate_user_case_sensitivity(self):
"""Test that username/email authentication is case-sensitive."""
auth_manager = AuthManager()
mock_db = Mock()
mock_user = Mock(spec=User)
mock_user.username = "TestUser"
mock_user.email = "test@example.com"
mock_user.hashed_password = auth_manager.hash_password("password123")
# This will depend on database collation, but generally should be case-sensitive
mock_db.query.return_value.filter.return_value.first.return_value = None
result = auth_manager.authenticate_user(mock_db, "testuser", "password123")
# Result depends on how the filter is implemented
# This test documents the expected behavior
assert result is None or result is mock_user

View File

@@ -0,0 +1,573 @@
# tests/unit/middleware/test_context_middleware.py
"""
Comprehensive unit tests for ContextMiddleware and ContextManager.
Tests cover:
- Context detection for API, Admin, Vendor Dashboard, Shop, and Fallback
- Clean path usage for correct context detection
- Host and path-based context determination
- Middleware state injection
- Edge cases and error handling
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from fastapi import Request
from middleware.context_middleware import (
ContextManager,
ContextMiddleware,
RequestContext,
get_request_context,
)
@pytest.mark.unit
class TestRequestContextEnum:
"""Test suite for RequestContext enum."""
def test_request_context_values(self):
"""Test RequestContext enum has correct values."""
assert RequestContext.API.value == "api"
assert RequestContext.ADMIN.value == "admin"
assert RequestContext.VENDOR_DASHBOARD.value == "vendor"
assert RequestContext.SHOP.value == "shop"
assert RequestContext.FALLBACK.value == "fallback"
def test_request_context_types(self):
"""Test RequestContext enum values are strings."""
for context in RequestContext:
assert isinstance(context.value, str)
@pytest.mark.unit
class TestContextManagerDetection:
"""Test suite for ContextManager.detect_context()."""
# ========================================================================
# API Context Tests (Highest Priority)
# ========================================================================
def test_detect_api_context(self):
"""Test API context detection."""
request = Mock(spec=Request)
request.url = Mock(path="/api/v1/vendors")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/api/v1/vendors")
context = ContextManager.detect_context(request)
assert context == RequestContext.API
def test_detect_api_context_nested_path(self):
"""Test API context detection with nested path."""
request = Mock(spec=Request)
request.url = Mock(path="/api/v1/vendors/123/products")
request.headers = {"host": "platform.com"}
request.state = Mock(clean_path="/api/v1/vendors/123/products")
context = ContextManager.detect_context(request)
assert context == RequestContext.API
def test_detect_api_context_with_clean_path(self):
"""Test API context detection uses clean_path when available."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/testvendor/api/products")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/api/products")
context = ContextManager.detect_context(request)
assert context == RequestContext.API
# ========================================================================
# Admin Context Tests
# ========================================================================
def test_detect_admin_context_from_subdomain(self):
"""Test admin context detection from subdomain."""
request = Mock(spec=Request)
request.url = Mock(path="/dashboard")
request.headers = {"host": "admin.platform.com"}
request.state = Mock(clean_path="/dashboard")
context = ContextManager.detect_context(request)
assert context == RequestContext.ADMIN
def test_detect_admin_context_from_path(self):
"""Test admin context detection from path."""
request = Mock(spec=Request)
request.url = Mock(path="/admin/dashboard")
request.headers = {"host": "platform.com"}
request.state = Mock(clean_path="/admin/dashboard")
context = ContextManager.detect_context(request)
assert context == RequestContext.ADMIN
def test_detect_admin_context_with_port(self):
"""Test admin context detection with port number."""
request = Mock(spec=Request)
request.url = Mock(path="/dashboard")
request.headers = {"host": "admin.localhost:8000"}
request.state = Mock(clean_path="/dashboard")
context = ContextManager.detect_context(request)
assert context == RequestContext.ADMIN
def test_detect_admin_context_nested_path(self):
"""Test admin context detection with nested admin path."""
request = Mock(spec=Request)
request.url = Mock(path="/admin/vendors/list")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/admin/vendors/list")
context = ContextManager.detect_context(request)
assert context == RequestContext.ADMIN
# ========================================================================
# Vendor Dashboard Context Tests
# ========================================================================
def test_detect_vendor_dashboard_context(self):
"""Test vendor dashboard context detection."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/testvendor/dashboard")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/dashboard")
context = ContextManager.detect_context(request)
assert context == RequestContext.VENDOR_DASHBOARD
def test_detect_vendor_dashboard_context_direct_path(self):
"""Test vendor dashboard with direct /vendor/ path."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/settings")
request.headers = {"host": "testvendor.platform.com"}
request.state = Mock(clean_path="/vendor/settings")
context = ContextManager.detect_context(request)
assert context == RequestContext.VENDOR_DASHBOARD
def test_not_detect_vendors_plural_as_dashboard(self):
"""Test that /vendors/ path is not detected as vendor dashboard."""
request = Mock(spec=Request)
request.url = Mock(path="/vendors/testvendor/shop")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/shop")
# Should not be vendor dashboard
context = ContextManager.detect_context(request)
assert context != RequestContext.VENDOR_DASHBOARD
# ========================================================================
# Shop Context Tests
# ========================================================================
def test_detect_shop_context_with_vendor_state(self):
"""Test shop context detection when vendor exists in request state."""
request = Mock(spec=Request)
request.url = Mock(path="/products")
request.headers = {"host": "testvendor.platform.com"}
mock_vendor = Mock()
mock_vendor.name = "Test Vendor"
request.state = Mock(clean_path="/products", vendor=mock_vendor)
context = ContextManager.detect_context(request)
assert context == RequestContext.SHOP
def test_detect_shop_context_from_shop_path(self):
"""Test shop context detection from /shop/ path."""
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/shop/products", vendor=None)
context = ContextManager.detect_context(request)
assert context == RequestContext.SHOP
def test_detect_shop_context_custom_domain(self):
"""Test shop context with custom domain and vendor."""
request = Mock(spec=Request)
request.url = Mock(path="/products")
request.headers = {"host": "customdomain.com"}
mock_vendor = Mock(name="Custom Vendor")
request.state = Mock(clean_path="/products", vendor=mock_vendor)
context = ContextManager.detect_context(request)
assert context == RequestContext.SHOP
# ========================================================================
# Fallback Context Tests
# ========================================================================
def test_detect_fallback_context(self):
"""Test fallback context for unknown paths."""
request = Mock(spec=Request)
request.url = Mock(path="/random/path")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/random/path", vendor=None)
context = ContextManager.detect_context(request)
assert context == RequestContext.FALLBACK
def test_detect_fallback_context_root(self):
"""Test fallback context for root path."""
request = Mock(spec=Request)
request.url = Mock(path="/")
request.headers = {"host": "platform.com"}
request.state = Mock(clean_path="/", vendor=None)
context = ContextManager.detect_context(request)
assert context == RequestContext.FALLBACK
def test_detect_fallback_context_no_vendor(self):
"""Test fallback context when no vendor context exists."""
request = Mock(spec=Request)
request.url = Mock(path="/about")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/about", vendor=None)
context = ContextManager.detect_context(request)
assert context == RequestContext.FALLBACK
# ========================================================================
# Clean Path Tests
# ========================================================================
def test_uses_clean_path_when_available(self):
"""Test that clean_path is used over original path."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/testvendor/api/products")
request.headers = {"host": "localhost"}
# clean_path shows the rewritten path
request.state = Mock(clean_path="/api/products")
context = ContextManager.detect_context(request)
# Should detect as API based on clean_path, not original path
assert context == RequestContext.API
def test_falls_back_to_original_path(self):
"""Test falls back to original path when clean_path not set."""
request = Mock(spec=Request)
request.url = Mock(path="/api/vendors")
request.headers = {"host": "localhost"}
request.state = Mock(spec=[]) # No clean_path attribute
context = ContextManager.detect_context(request)
assert context == RequestContext.API
# ========================================================================
# Priority Order Tests
# ========================================================================
def test_api_has_highest_priority(self):
"""Test API context takes precedence over admin."""
request = Mock(spec=Request)
request.url = Mock(path="/api/admin/users")
request.headers = {"host": "admin.platform.com"}
request.state = Mock(clean_path="/api/admin/users")
context = ContextManager.detect_context(request)
# API should win even though it's admin subdomain
assert context == RequestContext.API
def test_admin_has_priority_over_shop(self):
"""Test admin context takes precedence over shop."""
request = Mock(spec=Request)
request.url = Mock(path="/admin/shops")
request.headers = {"host": "localhost"}
mock_vendor = Mock()
request.state = Mock(clean_path="/admin/shops", vendor=mock_vendor)
context = ContextManager.detect_context(request)
# Admin should win even though vendor exists
assert context == RequestContext.ADMIN
def test_vendor_dashboard_priority_over_shop(self):
"""Test vendor dashboard takes precedence over shop."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/settings")
request.headers = {"host": "testvendor.platform.com"}
mock_vendor = Mock()
request.state = Mock(clean_path="/vendor/settings", vendor=mock_vendor)
context = ContextManager.detect_context(request)
assert context == RequestContext.VENDOR_DASHBOARD
@pytest.mark.unit
class TestContextManagerHelpers:
"""Test suite for ContextManager helper methods."""
def test_is_admin_context_from_subdomain(self):
"""Test _is_admin_context with admin subdomain."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
def test_is_admin_context_from_path(self):
"""Test _is_admin_context with admin path."""
request = Mock()
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
def test_is_admin_context_both(self):
"""Test _is_admin_context with both subdomain and path."""
request = Mock()
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
def test_is_not_admin_context(self):
"""Test _is_admin_context returns False for non-admin."""
request = Mock()
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
def test_is_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context with /vendor/ path."""
assert ContextManager._is_vendor_dashboard_context("/vendor/settings") is True
def test_is_vendor_dashboard_context_nested(self):
"""Test _is_vendor_dashboard_context with nested vendor path."""
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
def test_is_not_vendor_dashboard_context_vendors_plural(self):
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
def test_is_not_vendor_dashboard_context(self):
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
assert ContextManager._is_vendor_dashboard_context("/shop/products") is False
@pytest.mark.unit
class TestContextMiddleware:
"""Test suite for ContextMiddleware."""
@pytest.mark.asyncio
async def test_middleware_sets_context(self):
"""Test middleware successfully sets context in request state."""
middleware = ContextMiddleware(app=None)
request = Mock(spec=Request)
request.url = Mock(path="/api/vendors")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/api/vendors", vendor=None)
call_next = AsyncMock(return_value=Mock())
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'context_type')
assert request.state.context_type == RequestContext.API
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_middleware_sets_admin_context(self):
"""Test middleware sets admin context."""
middleware = ContextMiddleware(app=None)
request = Mock(spec=Request)
request.url = Mock(path="/admin/dashboard")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/admin/dashboard")
call_next = AsyncMock(return_value=Mock())
await middleware.dispatch(request, call_next)
assert request.state.context_type == RequestContext.ADMIN
call_next.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_sets_vendor_dashboard_context(self):
"""Test middleware sets vendor dashboard context."""
middleware = ContextMiddleware(app=None)
request = Mock(spec=Request)
request.url = Mock(path="/vendor/settings")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/vendor/settings")
call_next = AsyncMock(return_value=Mock())
await middleware.dispatch(request, call_next)
assert request.state.context_type == RequestContext.VENDOR_DASHBOARD
call_next.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_sets_shop_context(self):
"""Test middleware sets shop context."""
middleware = ContextMiddleware(app=None)
request = Mock(spec=Request)
request.url = Mock(path="/products")
request.headers = {"host": "shop.platform.com"}
mock_vendor = Mock()
request.state = Mock(clean_path="/products", vendor=mock_vendor)
call_next = AsyncMock(return_value=Mock())
await middleware.dispatch(request, call_next)
assert request.state.context_type == RequestContext.SHOP
call_next.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_sets_fallback_context(self):
"""Test middleware sets fallback context."""
middleware = ContextMiddleware(app=None)
request = Mock(spec=Request)
request.url = Mock(path="/random")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/random", vendor=None)
call_next = AsyncMock(return_value=Mock())
await middleware.dispatch(request, call_next)
assert request.state.context_type == RequestContext.FALLBACK
call_next.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_returns_response(self):
"""Test middleware returns response from call_next."""
middleware = ContextMiddleware(app=None)
request = Mock(spec=Request)
request.url = Mock(path="/api/test")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/api/test")
expected_response = Mock()
call_next = AsyncMock(return_value=expected_response)
response = await middleware.dispatch(request, call_next)
assert response is expected_response
@pytest.mark.unit
class TestGetRequestContextHelper:
"""Test suite for get_request_context helper function."""
def test_get_request_context_exists(self):
"""Test getting request context when it exists."""
request = Mock(spec=Request)
request.state.context_type = RequestContext.API
context = get_request_context(request)
assert context == RequestContext.API
def test_get_request_context_default(self):
"""Test getting request context returns FALLBACK as default."""
request = Mock(spec=Request)
request.state = Mock(spec=[]) # No context_type attribute
context = get_request_context(request)
assert context == RequestContext.FALLBACK
def test_get_request_context_for_all_types(self):
"""Test getting all context types."""
for expected_context in RequestContext:
request = Mock(spec=Request)
request.state.context_type = expected_context
context = get_request_context(request)
assert context == expected_context
@pytest.mark.unit
class TestEdgeCases:
"""Test suite for edge cases and error scenarios."""
def test_detect_context_empty_path(self):
"""Test context detection with empty path."""
request = Mock(spec=Request)
request.url = Mock(path="")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="", vendor=None)
context = ContextManager.detect_context(request)
assert context == RequestContext.FALLBACK
def test_detect_context_missing_host(self):
"""Test context detection with missing host header."""
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
request.headers = {}
request.state = Mock(clean_path="/shop/products", vendor=None)
context = ContextManager.detect_context(request)
assert context == RequestContext.SHOP
def test_detect_context_case_sensitivity(self):
"""Test that context detection is case-sensitive for paths."""
request = Mock(spec=Request)
request.url = Mock(path="/API/vendors") # Uppercase
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/API/vendors")
context = ContextManager.detect_context(request)
# Should NOT match /api/ because it's case-sensitive
assert context != RequestContext.API
def test_detect_context_path_with_query_params(self):
"""Test context detection handles path with query parameters."""
request = Mock(spec=Request)
request.url = Mock(path="/api/vendors?page=1")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/api/vendors?page=1")
# path.startswith should still work
context = ContextManager.detect_context(request)
assert context == RequestContext.API
def test_detect_context_admin_substring(self):
"""Test that 'admin' substring doesn't trigger false positive."""
request = Mock(spec=Request)
request.url = Mock(path="/administration/docs")
request.headers = {"host": "localhost"}
request.state = Mock(clean_path="/administration/docs")
context = ContextManager.detect_context(request)
# Should match because path starts with /admin
assert context == RequestContext.ADMIN
def test_detect_context_no_state_attribute(self):
"""Test context detection when request has no state."""
request = Mock(spec=Request)
request.url = Mock(path="/api/vendors")
request.headers = {"host": "localhost"}
# No state attribute at all
delattr(request, 'state')
# Should still work, falling back to url.path
with pytest.raises(AttributeError):
# This will raise because we're trying to access request.state
ContextManager.detect_context(request)

View File

@@ -0,0 +1,536 @@
# tests/unit/middleware/test_rate_limiter.py
"""
Comprehensive unit tests for RateLimiter.
Tests cover:
- Request allowance within limits
- Request blocking when exceeding limits
- Sliding window algorithm
- Cleanup of old entries
- Client statistics
- Edge cases and concurrency scenarios
"""
import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta, timezone
from collections import deque
from middleware.rate_limiter import RateLimiter
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterBasic:
"""Test suite for basic rate limiter functionality."""
def test_rate_limiter_initialization(self):
"""Test rate limiter initializes correctly."""
limiter = RateLimiter()
assert isinstance(limiter.clients, dict)
assert limiter.cleanup_interval == 3600
assert isinstance(limiter.last_cleanup, datetime)
def test_allow_first_request(self):
"""Test rate limiter allows first request."""
limiter = RateLimiter()
client_id = "test_client_1"
result = limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
assert len(limiter.clients[client_id]) == 1
def test_allow_multiple_requests_within_limit(self):
"""Test rate limiter allows multiple requests within limit."""
limiter = RateLimiter()
client_id = "test_client_2"
max_requests = 10
# Make 10 requests (at the limit)
for i in range(max_requests):
result = limiter.allow_request(client_id, max_requests, 3600)
assert result is True, f"Request {i+1} should be allowed"
assert len(limiter.clients[client_id]) == max_requests
def test_block_request_exceeding_limit(self):
"""Test rate limiter blocks requests exceeding limit."""
limiter = RateLimiter()
client_id = "test_client_blocked"
max_requests = 3
# Use up the allowed requests
for _ in range(max_requests):
assert limiter.allow_request(client_id, max_requests, 3600) is True
# Next request should be blocked
result = limiter.allow_request(client_id, max_requests, 3600)
assert result is False
# Client should still have only max_requests entries
assert len(limiter.clients[client_id]) == max_requests
def test_different_clients_separate_limits(self):
"""Test different clients have separate rate limits."""
limiter = RateLimiter()
client1 = "client_1"
client2 = "client_2"
max_requests = 5
# Client 1 makes requests
for _ in range(max_requests):
assert limiter.allow_request(client1, max_requests, 3600) is True
# Client 1 is blocked
assert limiter.allow_request(client1, max_requests, 3600) is False
# Client 2 should still be allowed
assert limiter.allow_request(client2, max_requests, 3600) is True
# Verify separate tracking
assert len(limiter.clients[client1]) == max_requests
assert len(limiter.clients[client2]) == 1
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterSlidingWindow:
"""Test suite for sliding window algorithm."""
def test_sliding_window_removes_old_requests(self):
"""Test sliding window removes requests outside time window."""
limiter = RateLimiter()
client_id = "test_client_window"
max_requests = 3
window_seconds = 10
# Manually add old requests
old_time = datetime.now(timezone.utc) - timedelta(seconds=15)
limiter.clients[client_id].append(old_time)
limiter.clients[client_id].append(old_time)
# These old requests should be removed, so new request should be allowed
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is True
assert len(limiter.clients[client_id]) == 1 # Only the new request
def test_sliding_window_keeps_recent_requests(self):
"""Test sliding window keeps requests within time window."""
limiter = RateLimiter()
client_id = "test_client_recent"
max_requests = 3
window_seconds = 60
# Add recent requests
recent_time = datetime.now(timezone.utc) - timedelta(seconds=30)
limiter.clients[client_id].append(recent_time)
limiter.clients[client_id].append(recent_time)
# These requests are within window, so we can only add 1 more
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is True
# Now at limit
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is False
def test_sliding_window_mixed_old_and_recent(self):
"""Test sliding window with mix of old and recent requests."""
limiter = RateLimiter()
client_id = "test_client_mixed"
max_requests = 3
window_seconds = 30
# Add old requests (outside window)
old_time = datetime.now(timezone.utc) - timedelta(seconds=60)
limiter.clients[client_id].append(old_time)
limiter.clients[client_id].append(old_time)
# Add recent request (within window)
recent_time = datetime.now(timezone.utc) - timedelta(seconds=10)
limiter.clients[client_id].append(recent_time)
# Old requests removed, only 1 recent request, so 2 more allowed
assert limiter.allow_request(client_id, max_requests, window_seconds) is True
assert limiter.allow_request(client_id, max_requests, window_seconds) is True
# Now at limit
assert limiter.allow_request(client_id, max_requests, window_seconds) is False
def test_sliding_window_with_zero_window(self):
"""Test rate limiter with very short window."""
limiter = RateLimiter()
client_id = "test_client_zero_window"
max_requests = 5
window_seconds = 1 # 1 second window
# Add old request
old_time = datetime.now(timezone.utc) - timedelta(seconds=2)
limiter.clients[client_id].append(old_time)
# Should allow request because old one is outside 1-second window
result = limiter.allow_request(client_id, max_requests, window_seconds)
assert result is True
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterCleanup:
"""Test suite for cleanup functionality."""
def test_cleanup_removes_old_entries(self):
"""Test cleanup removes entries older than 24 hours."""
limiter = RateLimiter()
# Add clients with old requests
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
limiter.clients["old_client_1"].append(old_time)
limiter.clients["old_client_2"].append(old_time)
# Add client with recent requests
recent_time = datetime.now(timezone.utc) - timedelta(hours=1)
limiter.clients["recent_client"].append(recent_time)
# Run cleanup
limiter._cleanup_old_entries()
# Old clients should be removed
assert "old_client_1" not in limiter.clients
assert "old_client_2" not in limiter.clients
# Recent client should remain
assert "recent_client" in limiter.clients
def test_cleanup_removes_empty_clients(self):
"""Test cleanup removes clients with no requests."""
limiter = RateLimiter()
# Add empty clients
limiter.clients["empty_client_1"] = deque()
limiter.clients["empty_client_2"] = deque()
# Add client with requests
limiter.clients["active_client"].append(datetime.now(timezone.utc))
# Run cleanup
limiter._cleanup_old_entries()
# Empty clients should be removed
assert "empty_client_1" not in limiter.clients
assert "empty_client_2" not in limiter.clients
# Active client should remain
assert "active_client" in limiter.clients
def test_cleanup_partial_removal(self):
"""Test cleanup removes only old requests, keeps recent ones."""
limiter = RateLimiter()
client_id = "mixed_client"
# Add old requests
old_time = datetime.now(timezone.utc) - timedelta(hours=30)
limiter.clients[client_id].append(old_time)
limiter.clients[client_id].append(old_time)
# Add recent requests
recent_time = datetime.now(timezone.utc) - timedelta(hours=1)
limiter.clients[client_id].append(recent_time)
limiter.clients[client_id].append(recent_time)
# Run cleanup
limiter._cleanup_old_entries()
# Client should remain with only recent requests
assert client_id in limiter.clients
assert len(limiter.clients[client_id]) == 2
def test_automatic_cleanup_triggers(self):
"""Test automatic cleanup triggers after interval."""
limiter = RateLimiter()
limiter.cleanup_interval = 0 # Force immediate cleanup
# Set last_cleanup to past
limiter.last_cleanup = datetime.now(timezone.utc) - timedelta(hours=2)
# Add old client
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
limiter.clients["old_client"].append(old_time)
# Make request (should trigger cleanup)
limiter.allow_request("new_client", 10, 3600)
# Old client should be cleaned up
assert "old_client" not in limiter.clients
def test_cleanup_does_not_affect_active_clients(self):
"""Test cleanup doesn't remove clients with recent requests."""
limiter = RateLimiter()
# Add multiple active clients
now = datetime.now(timezone.utc)
for i in range(5):
limiter.clients[f"client_{i}"].append(now - timedelta(hours=i))
# Run cleanup
limiter._cleanup_old_entries()
# All clients should still exist (all within 24 hours)
assert len(limiter.clients) == 5
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterStatistics:
"""Test suite for client statistics functionality."""
def test_get_client_stats_empty(self):
"""Test getting stats for client with no requests."""
limiter = RateLimiter()
client_id = "new_client"
stats = limiter.get_client_stats(client_id)
assert stats["requests_last_hour"] == 0
assert stats["requests_last_day"] == 0
assert stats["total_tracked_requests"] == 0
def test_get_client_stats_with_requests(self):
"""Test getting stats for client with requests."""
limiter = RateLimiter()
client_id = "active_client"
# Add requests at different times
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
stats = limiter.get_client_stats(client_id)
assert stats["requests_last_hour"] == 1
assert stats["requests_last_day"] == 3
assert stats["total_tracked_requests"] == 3
def test_get_client_stats_old_requests(self):
"""Test stats exclude requests older than tracking period."""
limiter = RateLimiter()
client_id = "old_requests_client"
# Add very old requests
now = datetime.now(timezone.utc)
limiter.clients[client_id].append(now - timedelta(days=2))
limiter.clients[client_id].append(now - timedelta(days=3))
stats = limiter.get_client_stats(client_id)
assert stats["requests_last_hour"] == 0
assert stats["requests_last_day"] == 0
assert stats["total_tracked_requests"] == 2 # Still tracked, just not counted
def test_get_client_stats_nonexistent_client(self):
"""Test getting stats for client that doesn't exist."""
limiter = RateLimiter()
stats = limiter.get_client_stats("nonexistent_client")
assert stats["requests_last_hour"] == 0
assert stats["requests_last_day"] == 0
assert stats["total_tracked_requests"] == 0
def test_get_client_stats_boundary_cases(self):
"""Test stats at exact hour/day boundaries."""
limiter = RateLimiter()
client_id = "boundary_client"
now = datetime.now(timezone.utc)
# Exactly 1 hour ago (should be included)
limiter.clients[client_id].append(now - timedelta(hours=1, seconds=1))
# Exactly 24 hours ago (should be excluded)
limiter.clients[client_id].append(now - timedelta(days=1, seconds=1))
stats = limiter.get_client_stats(client_id)
# Boundary behavior depends on > vs >= comparison
assert stats["requests_last_hour"] >= 0
assert stats["requests_last_day"] >= 1
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterEdgeCases:
"""Test suite for edge cases and error scenarios."""
def test_rate_limiter_with_zero_max_requests(self):
"""Test rate limiter with max_requests=0."""
limiter = RateLimiter()
client_id = "zero_limit_client"
result = limiter.allow_request(client_id, max_requests=0, window_seconds=3600)
# Should be blocked immediately
assert result is False
def test_rate_limiter_with_negative_max_requests(self):
"""Test rate limiter with negative max_requests."""
limiter = RateLimiter()
client_id = "negative_limit_client"
result = limiter.allow_request(client_id, max_requests=-1, window_seconds=3600)
# Should be blocked
assert result is False
def test_rate_limiter_with_large_max_requests(self):
"""Test rate limiter with very large max_requests."""
limiter = RateLimiter()
client_id = "large_limit_client"
max_requests = 1000000
result = limiter.allow_request(client_id, max_requests, 3600)
# Should be allowed
assert result is True
def test_rate_limiter_very_short_window(self):
"""Test rate limiter with very short time window."""
limiter = RateLimiter()
client_id = "short_window_client"
result = limiter.allow_request(client_id, max_requests=1, window_seconds=1)
assert result is True
def test_rate_limiter_very_long_window(self):
"""Test rate limiter with very long time window."""
limiter = RateLimiter()
client_id = "long_window_client"
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
assert result is True
def test_rate_limiter_same_client_different_limits(self):
"""Test same client with different rate limits."""
limiter = RateLimiter()
client_id = "same_client"
# Allow with one limit
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
# Check with stricter limit
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
def test_rate_limiter_unicode_client_id(self):
"""Test rate limiter with unicode client ID."""
limiter = RateLimiter()
client_id = "クライアント_123"
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
def test_rate_limiter_special_characters_client_id(self):
"""Test rate limiter with special characters in client ID."""
limiter = RateLimiter()
client_id = "client!@#$%^&*()_+-=[]{}|;:,.<>?"
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
def test_rate_limiter_empty_client_id(self):
"""Test rate limiter with empty client ID."""
limiter = RateLimiter()
client_id = ""
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
assert result is True
assert client_id in limiter.clients
def test_rate_limiter_concurrent_same_client(self):
"""Test rate limiter behavior with rapid requests from same client."""
limiter = RateLimiter()
client_id = "concurrent_client"
max_requests = 3
# Simulate rapid requests
results = []
for _ in range(5):
results.append(limiter.allow_request(client_id, max_requests, 3600))
# First 3 should be True, rest False
assert results[:3] == [True, True, True]
assert results[3:] == [False, False]
def test_cleanup_updates_last_cleanup_time(self):
"""Test that cleanup updates last_cleanup timestamp."""
limiter = RateLimiter()
old_cleanup_time = limiter.last_cleanup
# Force cleanup
limiter.cleanup_interval = 0
limiter.allow_request("test", 10, 3600)
# last_cleanup should be updated
assert limiter.last_cleanup > old_cleanup_time
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimiterMemoryManagement:
"""Test suite for memory management and performance."""
def test_limiter_does_not_grow_indefinitely(self):
"""Test that old entries are cleaned up to prevent memory leaks."""
limiter = RateLimiter()
limiter.cleanup_interval = 0 # Force cleanup on every request
# Simulate many requests over time
for i in range(100):
limiter.allow_request(f"client_{i}", max_requests=10, window_seconds=3600)
# Force cleanup
limiter._cleanup_old_entries()
# Should have cleaned up clients with no recent activity
# Exact number depends on timing, but should be less than 100
assert len(limiter.clients) <= 100
def test_deque_efficiency(self):
"""Test that deque is used for efficient popleft operations."""
limiter = RateLimiter()
client_id = "efficiency_test"
# Add many old requests
old_time = datetime.now(timezone.utc) - timedelta(hours=2)
for _ in range(1000):
limiter.clients[client_id].append(old_time)
# This should efficiently remove all old requests
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
# Should only have the new request
assert len(limiter.clients[client_id]) == 1
def test_multiple_clients_independence(self):
"""Test that multiple clients don't interfere with each other."""
limiter = RateLimiter()
num_clients = 100
# Create many clients with requests
for i in range(num_clients):
limiter.allow_request(f"client_{i}", max_requests=5, window_seconds=3600)
# Each client should have exactly 1 request
assert len(limiter.clients) == num_clients
for i in range(num_clients):
assert len(limiter.clients[f"client_{i}"]) == 1

View File

@@ -0,0 +1,589 @@
# tests/unit/middleware/test_theme_logging_path_decorators.py
"""
Comprehensive unit tests for remaining middleware components:
- ThemeContextMiddleware and ThemeContextManager
- LoggingMiddleware
- path_rewrite_middleware
- rate_limit decorator
Tests cover:
- Theme loading and caching
- Request/response logging
- Path rewriting for vendor routing
- Rate limit decorators
- Edge cases and error handling
"""
import pytest
from unittest.mock import Mock, AsyncMock, MagicMock, patch
from fastapi import Request
import time
from middleware.theme_context import (
ThemeContextManager,
ThemeContextMiddleware,
get_current_theme,
)
from middleware.logging_middleware import LoggingMiddleware
from middleware.path_rewrite_middleware import path_rewrite_middleware
from middleware.decorators import rate_limit
from app.exceptions.base import RateLimitException
# =============================================================================
# Theme Context Tests
# =============================================================================
@pytest.mark.unit
class TestThemeContextManager:
"""Test suite for ThemeContextManager."""
def test_get_default_theme_structure(self):
"""Test default theme has correct structure."""
theme = ThemeContextManager.get_default_theme()
assert "theme_name" in theme
assert "colors" in theme
assert "fonts" in theme
assert "branding" in theme
assert "layout" in theme
assert "social_links" in theme
assert "css_variables" in theme
def test_get_default_theme_colors(self):
"""Test default theme has all required colors."""
theme = ThemeContextManager.get_default_theme()
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
for color in required_colors:
assert color in theme["colors"]
assert theme["colors"][color].startswith("#")
def test_get_default_theme_fonts(self):
"""Test default theme has font configuration."""
theme = ThemeContextManager.get_default_theme()
assert "heading" in theme["fonts"]
assert "body" in theme["fonts"]
assert isinstance(theme["fonts"]["heading"], str)
assert isinstance(theme["fonts"]["body"], str)
def test_get_default_theme_branding(self):
"""Test default theme branding structure."""
theme = ThemeContextManager.get_default_theme()
assert "logo" in theme["branding"]
assert "logo_dark" in theme["branding"]
assert "favicon" in theme["branding"]
assert "banner" in theme["branding"]
def test_get_default_theme_css_variables(self):
"""Test default theme has CSS variables."""
theme = ThemeContextManager.get_default_theme()
assert "--color-primary" in theme["css_variables"]
assert "--font-heading" in theme["css_variables"]
assert "--font-body" in theme["css_variables"]
def test_get_vendor_theme_with_custom_theme(self):
"""Test getting vendor-specific theme."""
mock_db = Mock()
mock_theme = Mock()
mock_theme.to_dict.return_value = {
"theme_name": "custom",
"colors": {"primary": "#ff0000"}
}
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_theme
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
assert theme["theme_name"] == "custom"
assert theme["colors"]["primary"] == "#ff0000"
mock_theme.to_dict.assert_called_once()
def test_get_vendor_theme_fallback_to_default(self):
"""Test falling back to default theme when no custom theme exists."""
mock_db = Mock()
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
assert theme["theme_name"] == "default"
assert "colors" in theme
assert "fonts" in theme
def test_get_vendor_theme_inactive_theme(self):
"""Test that inactive themes are not returned."""
mock_db = Mock()
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
# Should return default theme
assert theme["theme_name"] == "default"
@pytest.mark.unit
class TestThemeContextMiddleware:
"""Test suite for ThemeContextMiddleware."""
@pytest.mark.asyncio
async def test_middleware_loads_theme_for_vendor(self):
"""Test middleware loads theme when vendor exists."""
middleware = ThemeContextMiddleware(app=None)
request = Mock(spec=Request)
mock_vendor = Mock()
mock_vendor.id = 1
mock_vendor.name = "Test Vendor"
request.state = Mock(vendor=mock_vendor)
call_next = AsyncMock(return_value=Mock())
mock_db = MagicMock()
mock_theme = {"theme_name": "test_theme"}
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
await middleware.dispatch(request, call_next)
assert request.state.theme == mock_theme
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_middleware_uses_default_theme_no_vendor(self):
"""Test middleware uses default theme when no vendor."""
middleware = ThemeContextMiddleware(app=None)
request = Mock(spec=Request)
request.state = Mock(vendor=None)
call_next = AsyncMock(return_value=Mock())
await middleware.dispatch(request, call_next)
assert hasattr(request.state, 'theme')
assert request.state.theme["theme_name"] == "default"
call_next.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_handles_theme_loading_error(self):
"""Test middleware handles errors gracefully."""
middleware = ThemeContextMiddleware(app=None)
request = Mock(spec=Request)
mock_vendor = Mock(id=1, name="Test Vendor")
request.state = Mock(vendor=mock_vendor)
call_next = AsyncMock(return_value=Mock())
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
await middleware.dispatch(request, call_next)
# Should fallback to default theme
assert request.state.theme["theme_name"] == "default"
call_next.assert_called_once()
def test_get_current_theme_exists(self):
"""Test getting current theme when it exists."""
request = Mock(spec=Request)
test_theme = {"theme_name": "test"}
request.state.theme = test_theme
theme = get_current_theme(request)
assert theme == test_theme
def test_get_current_theme_default(self):
"""Test getting theme returns default when not set."""
request = Mock(spec=Request)
request.state = Mock(spec=[]) # No theme attribute
theme = get_current_theme(request)
assert theme["theme_name"] == "default"
# =============================================================================
# Logging Middleware Tests
# =============================================================================
@pytest.mark.unit
class TestLoggingMiddleware:
"""Test suite for LoggingMiddleware."""
@pytest.mark.asyncio
async def test_middleware_logs_request(self):
"""Test middleware logs incoming request."""
middleware = LoggingMiddleware(app=None)
request = Mock(spec=Request)
request.method = "GET"
request.url = Mock(path="/api/vendors")
request.client = Mock(host="127.0.0.1")
call_next = AsyncMock(return_value=Mock(status_code=200))
with patch('middleware.logging_middleware.logger') as mock_logger:
await middleware.dispatch(request, call_next)
# Verify request was logged
assert mock_logger.info.call_count >= 1
first_call = mock_logger.info.call_args_list[0]
assert "GET" in str(first_call)
assert "/api/vendors" in str(first_call)
@pytest.mark.asyncio
async def test_middleware_logs_response(self):
"""Test middleware logs response with status code and duration."""
middleware = LoggingMiddleware(app=None)
request = Mock(spec=Request)
request.method = "POST"
request.url = Mock(path="/api/products")
request.client = Mock(host="127.0.0.1")
response = Mock()
response.status_code = 201
response.headers = {}
call_next = AsyncMock(return_value=response)
with patch('middleware.logging_middleware.logger') as mock_logger:
result = await middleware.dispatch(request, call_next)
# Verify response was logged
assert mock_logger.info.call_count >= 2 # Request + Response
last_call = mock_logger.info.call_args_list[-1]
assert "201" in str(last_call)
@pytest.mark.asyncio
async def test_middleware_adds_process_time_header(self):
"""Test middleware adds X-Process-Time header."""
middleware = LoggingMiddleware(app=None)
request = Mock(spec=Request)
request.method = "GET"
request.url = Mock(path="/test")
request.client = Mock(host="127.0.0.1")
response = Mock()
response.status_code = 200
response.headers = {}
call_next = AsyncMock(return_value=response)
with patch('middleware.logging_middleware.logger'):
result = await middleware.dispatch(request, call_next)
assert "X-Process-Time" in response.headers
# Should be a numeric string
process_time = float(response.headers["X-Process-Time"])
assert process_time >= 0
@pytest.mark.asyncio
async def test_middleware_handles_no_client(self):
"""Test middleware handles requests with no client info."""
middleware = LoggingMiddleware(app=None)
request = Mock(spec=Request)
request.method = "GET"
request.url = Mock(path="/test")
request.client = None # No client info
call_next = AsyncMock(return_value=Mock(status_code=200))
with patch('middleware.logging_middleware.logger') as mock_logger:
await middleware.dispatch(request, call_next)
# Should log "unknown" for client
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
@pytest.mark.asyncio
async def test_middleware_logs_exceptions(self):
"""Test middleware logs exceptions."""
middleware = LoggingMiddleware(app=None)
request = Mock(spec=Request)
request.method = "GET"
request.url = Mock(path="/error")
request.client = Mock(host="127.0.0.1")
call_next = AsyncMock(side_effect=Exception("Test error"))
with patch('middleware.logging_middleware.logger') as mock_logger, \
pytest.raises(Exception):
await middleware.dispatch(request, call_next)
# Verify error was logged
mock_logger.error.assert_called_once()
assert "Test error" in str(mock_logger.error.call_args)
@pytest.mark.asyncio
async def test_middleware_timing_accuracy(self):
"""Test middleware timing is reasonably accurate."""
middleware = LoggingMiddleware(app=None)
request = Mock(spec=Request)
request.method = "GET"
request.url = Mock(path="/slow")
request.client = Mock(host="127.0.0.1")
async def slow_call_next(req):
await asyncio.sleep(0.1) # 100ms delay
response = Mock(status_code=200, headers={})
return response
call_next = slow_call_next
import asyncio
with patch('middleware.logging_middleware.logger'):
result = await middleware.dispatch(request, call_next)
process_time = float(result.headers["X-Process-Time"])
# Should be at least 0.1 seconds
assert process_time >= 0.1
# =============================================================================
# Path Rewrite Middleware Tests
# =============================================================================
@pytest.mark.unit
class TestPathRewriteMiddleware:
"""Test suite for path_rewrite_middleware."""
@pytest.mark.asyncio
async def test_rewrites_path_when_clean_path_different(self):
"""Test path is rewritten when clean_path differs from original."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/testvendor/shop/products")
request.state = Mock(clean_path="/shop/products")
request.scope = {"path": "/vendor/testvendor/shop/products"}
call_next = AsyncMock(return_value=Mock())
await path_rewrite_middleware(request, call_next)
# Path should be rewritten in scope
assert request.scope["path"] == "/shop/products"
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_does_not_rewrite_when_paths_same(self):
"""Test path is not rewritten when clean_path same as original."""
request = Mock(spec=Request)
original_path = "/shop/products"
request.url = Mock(path=original_path)
request.state = Mock(clean_path=original_path)
request.scope = {"path": original_path}
call_next = AsyncMock(return_value=Mock())
await path_rewrite_middleware(request, call_next)
# Path should remain unchanged
assert request.scope["path"] == original_path
call_next.assert_called_once()
@pytest.mark.asyncio
async def test_does_nothing_when_no_clean_path(self):
"""Test middleware does nothing when no clean_path set."""
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
request.state = Mock(spec=[]) # No clean_path attribute
original_path = "/shop/products"
request.scope = {"path": original_path}
call_next = AsyncMock(return_value=Mock())
await path_rewrite_middleware(request, call_next)
# Path should remain unchanged
assert request.scope["path"] == original_path
call_next.assert_called_once()
@pytest.mark.asyncio
async def test_updates_request_url(self):
"""Test middleware updates request._url."""
request = Mock(spec=Request)
original_url = Mock(path="/vendor/test/shop")
request.url = original_url
request.url.replace = Mock(return_value=Mock(path="/shop"))
request.state = Mock(clean_path="/shop")
request.scope = {"path": "/vendor/test/shop"}
call_next = AsyncMock(return_value=Mock())
await path_rewrite_middleware(request, call_next)
# URL replace should have been called
request.url.replace.assert_called_once_with(path="/shop")
@pytest.mark.asyncio
async def test_preserves_vendor_context(self):
"""Test middleware preserves vendor context in request.state."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/testvendor/products")
mock_vendor = Mock()
request.state = Mock(clean_path="/products", vendor=mock_vendor)
request.scope = {"path": "/vendor/testvendor/products"}
call_next = AsyncMock(return_value=Mock())
await path_rewrite_middleware(request, call_next)
# Vendor should still be accessible
assert request.state.vendor is mock_vendor
# =============================================================================
# Rate Limit Decorator Tests
# =============================================================================
@pytest.mark.unit
@pytest.mark.auth
class TestRateLimitDecorator:
"""Test suite for rate_limit decorator."""
@pytest.mark.asyncio
async def test_decorator_allows_within_limit(self):
"""Test decorator allows requests within rate limit."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
return {"status": "ok"}
result = await test_endpoint()
assert result == {"status": "ok"}
@pytest.mark.asyncio
async def test_decorator_blocks_exceeding_limit(self):
"""Test decorator blocks requests exceeding rate limit."""
@rate_limit(max_requests=2, window_seconds=3600)
async def test_endpoint():
return {"status": "ok"}
# First two should succeed
await test_endpoint()
await test_endpoint()
# Third should raise exception
with pytest.raises(RateLimitException) as exc_info:
await test_endpoint()
assert exc_info.value.status_code == 429
assert "Rate limit exceeded" in exc_info.value.message
@pytest.mark.asyncio
async def test_decorator_preserves_function_metadata(self):
"""Test decorator preserves original function metadata."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint():
"""Test endpoint docstring."""
return {"status": "ok"}
assert test_endpoint.__name__ == "test_endpoint"
assert test_endpoint.__doc__ == "Test endpoint docstring."
@pytest.mark.asyncio
async def test_decorator_with_args_and_kwargs(self):
"""Test decorator works with function arguments."""
@rate_limit(max_requests=10, window_seconds=3600)
async def test_endpoint(arg1, arg2, kwarg1=None):
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
result = await test_endpoint("value1", "value2", kwarg1="value3")
assert result == {
"arg1": "value1",
"arg2": "value2",
"kwarg1": "value3"
}
@pytest.mark.asyncio
async def test_decorator_default_parameters(self):
"""Test decorator uses default parameters."""
@rate_limit() # Use defaults
async def test_endpoint():
return {"status": "ok"}
result = await test_endpoint()
assert result == {"status": "ok"}
@pytest.mark.asyncio
async def test_decorator_exception_includes_retry_after(self):
"""Test rate limit exception includes retry_after."""
@rate_limit(max_requests=1, window_seconds=60)
async def test_endpoint():
return {"status": "ok"}
await test_endpoint() # Use up limit
with pytest.raises(RateLimitException) as exc_info:
await test_endpoint()
assert exc_info.value.details.get("retry_after") == 60
# =============================================================================
# Edge Cases and Integration Tests
# =============================================================================
@pytest.mark.unit
class TestMiddlewareEdgeCases:
"""Test suite for edge cases across middleware."""
@pytest.mark.asyncio
async def test_theme_middleware_closes_db_connection(self):
"""Test theme middleware properly closes database connection."""
middleware = ThemeContextMiddleware(app=None)
request = Mock(spec=Request)
mock_vendor = Mock(id=1, name="Test")
request.state = Mock(vendor=mock_vendor)
call_next = AsyncMock(return_value=Mock())
mock_db = MagicMock()
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
await middleware.dispatch(request, call_next)
# Verify database was closed
mock_db.close.assert_called_once()
@pytest.mark.asyncio
async def test_path_rewrite_with_query_parameters(self):
"""Test path rewrite preserves query parameters."""
request = Mock(spec=Request)
original_url = Mock(path="/vendor/test/shop?page=1")
request.url = original_url
request.url.replace = Mock(return_value=Mock(path="/shop?page=1"))
request.state = Mock(clean_path="/shop?page=1")
request.scope = {"path": "/vendor/test/shop?page=1"}
call_next = AsyncMock(return_value=Mock())
await path_rewrite_middleware(request, call_next)
request.url.replace.assert_called_once_with(path="/shop?page=1")
def test_theme_default_immutability(self):
"""Test that getting default theme doesn't share state."""
theme1 = ThemeContextManager.get_default_theme()
theme2 = ThemeContextManager.get_default_theme()
# Modify theme1
theme1["colors"]["primary"] = "#000000"
# theme2 should not be affected (if properly implemented)
# Note: This test documents expected behavior
assert theme2["colors"]["primary"] == "#6366f1"

View File

@@ -1,23 +1,721 @@
from requests.cookies import MockRequest
# tests/unit/middleware/test_vendor_context.py
"""
Comprehensive unit tests for VendorContextMiddleware and VendorContextManager.
from middleware.vendor_context import VendorContextManager
Tests cover:
- Vendor detection from custom domains, subdomains, and path-based routing
- Database lookup and vendor validation
- Path extraction and cleanup
- Admin and API request detection
- Static file request detection
- Edge cases and error handling
"""
import pytest
from unittest.mock import Mock, MagicMock, patch, AsyncMock
from fastapi import Request, HTTPException
from sqlalchemy.orm import Session
from middleware.vendor_context import (
VendorContextManager,
VendorContextMiddleware,
get_current_vendor,
require_vendor_context,
)
def test_custom_domain_detection():
# Mock request with custom domain
request = MockRequest(host="customdomain1.com")
context = VendorContextManager.detect_vendor_context(request)
assert context["detection_method"] == "custom_domain"
assert context["domain"] == "customdomain1.com"
@pytest.mark.unit
@pytest.mark.vendors
class TestVendorContextManager:
"""Test suite for VendorContextManager static methods."""
def test_subdomain_detection():
request = MockRequest(host="vendor1.platform.com")
context = VendorContextManager.detect_vendor_context(request)
assert context["detection_method"] == "subdomain"
assert context["subdomain"] == "vendor1"
# ========================================================================
# Vendor Context Detection Tests
# ========================================================================
def test_path_detection():
request = MockRequest(host="localhost", path="/vendor/vendor1/")
context = VendorContextManager.detect_vendor_context(request)
assert context["detection_method"] == "path"
assert context["subdomain"] == "vendor1"
def test_detect_custom_domain(self):
"""Test custom domain detection."""
request = Mock(spec=Request)
request.headers = {"host": "customdomain1.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "custom_domain"
assert context["domain"] == "customdomain1.com"
assert context["host"] == "customdomain1.com"
def test_detect_custom_domain_with_port(self):
"""Test custom domain detection with port number."""
request = Mock(spec=Request)
request.headers = {"host": "customdomain1.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "custom_domain"
assert context["domain"] == "customdomain1.com"
assert context["host"] == "customdomain1.com"
def test_detect_subdomain(self):
"""Test subdomain detection."""
request = Mock(spec=Request)
request.headers = {"host": "vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "subdomain"
assert context["subdomain"] == "vendor1"
assert context["host"] == "vendor1.platform.com"
def test_detect_subdomain_with_port(self):
"""Test subdomain detection with port number."""
request = Mock(spec=Request)
request.headers = {"host": "vendor1.platform.com:8000"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "subdomain"
assert context["subdomain"] == "vendor1"
def test_detect_path_vendor_singular(self):
"""Test path-based detection with /vendor/ prefix."""
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/vendor/vendor1/shop")
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "path"
assert context["subdomain"] == "vendor1"
assert context["path_prefix"] == "/vendor/vendor1"
assert context["full_prefix"] == "/vendor/"
def test_detect_path_vendors_plural(self):
"""Test path-based detection with /vendors/ prefix."""
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/vendors/vendor1/shop")
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "path"
assert context["subdomain"] == "vendor1"
assert context["path_prefix"] == "/vendors/vendor1"
assert context["full_prefix"] == "/vendors/"
def test_detect_no_vendor_context(self):
"""Test when no vendor context can be detected."""
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/random/path")
context = VendorContextManager.detect_vendor_context(request)
assert context is None
def test_ignore_admin_subdomain(self):
"""Test that admin subdomain is not detected as vendor."""
request = Mock(spec=Request)
request.headers = {"host": "admin.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is None
def test_ignore_www_subdomain(self):
"""Test that www subdomain is not detected as vendor."""
request = Mock(spec=Request)
request.headers = {"host": "www.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is None
def test_ignore_api_subdomain(self):
"""Test that api subdomain is not detected as vendor."""
request = Mock(spec=Request)
request.headers = {"host": "api.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is None
def test_ignore_localhost(self):
"""Test that localhost is not detected as custom domain."""
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is None
# ========================================================================
# Vendor Database Lookup Tests
# ========================================================================
def test_get_vendor_from_custom_domain_context(self):
"""Test getting vendor from custom domain context."""
mock_db = Mock(spec=Session)
mock_vendor_domain = Mock()
mock_vendor = Mock()
mock_vendor.is_active = True
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is mock_vendor
assert vendor.is_active is True
def test_get_vendor_from_custom_domain_inactive_vendor(self):
"""Test getting inactive vendor from custom domain context."""
mock_db = Mock(spec=Session)
mock_vendor_domain = Mock()
mock_vendor = Mock()
mock_vendor.is_active = False
mock_vendor_domain.vendor = mock_vendor
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
context = {
"detection_method": "custom_domain",
"domain": "customdomain1.com"
}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is None
def test_get_vendor_from_custom_domain_not_found(self):
"""Test custom domain not found in database."""
mock_db = Mock(spec=Session)
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
context = {
"detection_method": "custom_domain",
"domain": "nonexistent.com"
}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is None
def test_get_vendor_from_subdomain_context(self):
"""Test getting vendor from subdomain context."""
mock_db = Mock(spec=Session)
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is mock_vendor
def test_get_vendor_from_path_context(self):
"""Test getting vendor from path context."""
mock_db = Mock(spec=Session)
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
context = {
"detection_method": "path",
"subdomain": "vendor1"
}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is mock_vendor
def test_get_vendor_with_no_context(self):
"""Test getting vendor with no context."""
mock_db = Mock(spec=Session)
vendor = VendorContextManager.get_vendor_from_context(mock_db, None)
assert vendor is None
def test_get_vendor_subdomain_case_insensitive(self):
"""Test subdomain lookup is case-insensitive."""
mock_db = Mock(spec=Session)
mock_vendor = Mock()
mock_vendor.is_active = True
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
context = {
"detection_method": "subdomain",
"subdomain": "VENDOR1" # Uppercase
}
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
assert vendor is mock_vendor
# ========================================================================
# Path Extraction Tests
# ========================================================================
def test_extract_clean_path_from_vendor_path(self):
"""Test extracting clean path from /vendor/ prefix."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
assert clean_path == "/shop/products"
def test_extract_clean_path_from_vendors_path(self):
"""Test extracting clean path from /vendors/ prefix."""
request = Mock(spec=Request)
request.url = Mock(path="/vendors/vendor1/shop/products")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendors/vendor1"
}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
assert clean_path == "/shop/products"
def test_extract_clean_path_root(self):
"""Test extracting clean path when result is empty (should return /)."""
request = Mock(spec=Request)
request.url = Mock(path="/vendor/vendor1")
vendor_context = {
"detection_method": "path",
"path_prefix": "/vendor/vendor1"
}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
assert clean_path == "/"
def test_extract_clean_path_no_path_context(self):
"""Test extracting clean path for non-path detection methods."""
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
assert clean_path == "/shop/products"
def test_extract_clean_path_no_context(self):
"""Test extracting clean path with no vendor context."""
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
clean_path = VendorContextManager.extract_clean_path(request, None)
assert clean_path == "/shop/products"
# ========================================================================
# Request Type Detection Tests
# ========================================================================
def test_is_admin_request_admin_subdomain(self):
"""Test admin request detection from subdomain."""
request = Mock(spec=Request)
request.headers = {"host": "admin.platform.com"}
request.url = Mock(path="/dashboard")
assert VendorContextManager.is_admin_request(request) is True
def test_is_admin_request_admin_path(self):
"""Test admin request detection from path."""
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/admin/dashboard")
assert VendorContextManager.is_admin_request(request) is True
def test_is_admin_request_with_port(self):
"""Test admin request detection with port number."""
request = Mock(spec=Request)
request.headers = {"host": "admin.localhost:8000"}
request.url = Mock(path="/dashboard")
assert VendorContextManager.is_admin_request(request) is True
def test_is_not_admin_request(self):
"""Test non-admin request."""
request = Mock(spec=Request)
request.headers = {"host": "vendor1.platform.com"}
request.url = Mock(path="/shop")
assert VendorContextManager.is_admin_request(request) is False
def test_is_api_request(self):
"""Test API request detection."""
request = Mock(spec=Request)
request.url = Mock(path="/api/v1/vendors")
assert VendorContextManager.is_api_request(request) is True
def test_is_not_api_request(self):
"""Test non-API request."""
request = Mock(spec=Request)
request.url = Mock(path="/shop/products")
assert VendorContextManager.is_api_request(request) is False
# ========================================================================
# Static File Detection Tests
# ========================================================================
@pytest.mark.parametrize("path", [
"/static/css/style.css",
"/static/js/app.js",
"/media/images/product.png",
"/assets/logo.svg",
"/.well-known/security.txt",
"/favicon.ico",
"/image.jpg",
"/style.css",
"/app.webmanifest",
])
def test_is_static_file_request(self, path):
"""Test static file detection for various paths and extensions."""
request = Mock(spec=Request)
request.url = Mock(path=path)
assert VendorContextManager.is_static_file_request(request) is True
@pytest.mark.parametrize("path", [
"/shop/products",
"/admin/dashboard",
"/api/vendors",
"/about",
])
def test_is_not_static_file_request(self, path):
"""Test non-static file paths."""
request = Mock(spec=Request)
request.url = Mock(path=path)
assert VendorContextManager.is_static_file_request(request) is False
@pytest.mark.unit
@pytest.mark.vendors
class TestVendorContextMiddleware:
"""Test suite for VendorContextMiddleware."""
@pytest.mark.asyncio
async def test_middleware_skips_admin_request(self):
"""Test middleware skips vendor detection for admin requests."""
middleware = VendorContextMiddleware(app=None)
request = Mock(spec=Request)
request.headers = {"host": "admin.platform.com"}
request.url = Mock(path="/admin/dashboard")
request.state = Mock()
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
assert request.state.vendor_context is None
assert request.state.clean_path == "/admin/dashboard"
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_middleware_skips_api_request(self):
"""Test middleware skips vendor detection for API requests."""
middleware = VendorContextMiddleware(app=None)
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/api/v1/vendors")
request.state = Mock()
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
assert request.state.vendor_context is None
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_middleware_skips_static_file_request(self):
"""Test middleware skips vendor detection for static files."""
middleware = VendorContextMiddleware(app=None)
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/static/css/style.css")
request.state = Mock()
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_middleware_detects_and_sets_vendor(self):
"""Test middleware successfully detects and sets vendor."""
middleware = VendorContextMiddleware(app=None)
request = Mock(spec=Request)
request.headers = {"host": "vendor1.platform.com"}
request.url = Mock(path="/shop/products")
request.state = Mock()
call_next = AsyncMock(return_value=Mock())
mock_vendor = Mock()
mock_vendor.id = 1
mock_vendor.name = "Test Vendor"
mock_vendor.subdomain = "vendor1"
vendor_context = {
"detection_method": "subdomain",
"subdomain": "vendor1"
}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
await middleware.dispatch(request, call_next)
assert request.state.vendor is mock_vendor
assert request.state.vendor_context == vendor_context
assert request.state.clean_path == "/shop/products"
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_middleware_vendor_not_found(self):
"""Test middleware when vendor context detected but vendor not in database."""
middleware = VendorContextMiddleware(app=None)
request = Mock(spec=Request)
request.headers = {"host": "nonexistent.platform.com"}
request.url = Mock(path="/shop")
request.state = Mock()
call_next = AsyncMock(return_value=Mock())
vendor_context = {
"detection_method": "subdomain",
"subdomain": "nonexistent"
}
mock_db = MagicMock()
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
assert request.state.vendor_context == vendor_context
assert request.state.clean_path == "/shop"
call_next.assert_called_once_with(request)
@pytest.mark.asyncio
async def test_middleware_no_vendor_context(self):
"""Test middleware when no vendor context detected."""
middleware = VendorContextMiddleware(app=None)
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/random/path")
request.state = Mock()
call_next = AsyncMock(return_value=Mock())
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
await middleware.dispatch(request, call_next)
assert request.state.vendor is None
assert request.state.vendor_context is None
assert request.state.clean_path == "/random/path"
call_next.assert_called_once_with(request)
@pytest.mark.unit
@pytest.mark.vendors
class TestHelperFunctions:
"""Test suite for helper functions."""
def test_get_current_vendor_exists(self):
"""Test getting current vendor when it exists."""
request = Mock(spec=Request)
mock_vendor = Mock()
request.state.vendor = mock_vendor
vendor = get_current_vendor(request)
assert vendor is mock_vendor
def test_get_current_vendor_not_exists(self):
"""Test getting current vendor when it doesn't exist."""
request = Mock(spec=Request)
request.state = Mock(spec=[]) # vendor attribute doesn't exist
vendor = get_current_vendor(request)
assert vendor is None
def test_require_vendor_context_success(self):
"""Test require_vendor_context dependency with vendor present."""
request = Mock(spec=Request)
mock_vendor = Mock()
request.state.vendor = mock_vendor
dependency = require_vendor_context()
result = dependency(request)
assert result is mock_vendor
def test_require_vendor_context_failure(self):
"""Test require_vendor_context dependency raises HTTPException when no vendor."""
request = Mock(spec=Request)
request.state.vendor = None
dependency = require_vendor_context()
with pytest.raises(HTTPException) as exc_info:
dependency(request)
assert exc_info.value.status_code == 404
assert "Vendor not found" in exc_info.value.detail
@pytest.mark.unit
@pytest.mark.vendors
class TestEdgeCases:
"""Test suite for edge cases and error scenarios."""
def test_detect_vendor_context_empty_host(self):
"""Test vendor detection with empty host header."""
request = Mock(spec=Request)
request.headers = {"host": ""}
request.url = Mock(path="/")
context = VendorContextManager.detect_vendor_context(request)
assert context is None
def test_detect_vendor_context_missing_host(self):
"""Test vendor detection with missing host header."""
request = Mock(spec=Request)
request.headers = {}
request.url = Mock(path="/")
context = VendorContextManager.detect_vendor_context(request)
assert context is None
def test_detect_vendor_path_with_trailing_slash(self):
"""Test path detection with trailing slash."""
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/vendor/vendor1/")
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "path"
assert context["subdomain"] == "vendor1"
def test_detect_vendor_path_without_trailing_slash(self):
"""Test path detection without trailing slash."""
request = Mock(spec=Request)
request.headers = {"host": "localhost"}
request.url = Mock(path="/vendor/vendor1")
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
assert context["detection_method"] == "path"
assert context["subdomain"] == "vendor1"
def test_detect_vendor_complex_subdomain(self):
"""Test detection with multiple subdomain levels."""
request = Mock(spec=Request)
request.headers = {"host": "shop.vendor1.platform.com"}
request.url = Mock(path="/")
with patch('middleware.vendor_context.settings') as mock_settings:
mock_settings.platform_domain = "platform.com"
context = VendorContextManager.detect_vendor_context(request)
assert context is not None
# Should detect 'shop' as subdomain since it's the first part
assert context["detection_method"] == "subdomain"
assert context["subdomain"] == "shop"