revamping documentation
This commit is contained in:
2
tests/unit/middleware/__init__.py
Normal file
2
tests/unit/middleware/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/middleware/__init__.py
|
||||
"""Unit tests - fast, isolated component tests."""
|
||||
657
tests/unit/middleware/test_auth.py
Normal file
657
tests/unit/middleware/test_auth.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# tests/unit/middleware/test_auth.py
|
||||
"""
|
||||
Comprehensive unit tests for AuthManager.
|
||||
|
||||
Tests cover:
|
||||
- Password hashing and verification
|
||||
- JWT token creation and validation
|
||||
- User authentication
|
||||
- Token expiration handling
|
||||
- Role-based access control
|
||||
- Admin/vendor/customer permission checks
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from jose import jwt
|
||||
|
||||
from middleware.auth import AuthManager
|
||||
from app.exceptions import (
|
||||
InvalidTokenException,
|
||||
TokenExpiredException,
|
||||
UserNotActiveException,
|
||||
InvalidCredentialsException,
|
||||
AdminRequiredException,
|
||||
InsufficientPermissionsException,
|
||||
)
|
||||
from models.database.user import User
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestPasswordHashing:
|
||||
"""Test suite for password hashing functionality."""
|
||||
|
||||
def test_hash_password(self):
|
||||
"""Test password hashing creates different hash for each call."""
|
||||
auth_manager = AuthManager()
|
||||
password = "test_password_123"
|
||||
|
||||
hash1 = auth_manager.hash_password(password)
|
||||
hash2 = auth_manager.hash_password(password)
|
||||
|
||||
# Hashes should be different due to salt
|
||||
assert hash1 != hash2
|
||||
# Both should be valid bcrypt hashes (start with $2b$)
|
||||
assert hash1.startswith("$2b$")
|
||||
assert hash2.startswith("$2b$")
|
||||
|
||||
def test_verify_password_correct(self):
|
||||
"""Test password verification with correct password."""
|
||||
auth_manager = AuthManager()
|
||||
password = "test_password_123"
|
||||
hashed = auth_manager.hash_password(password)
|
||||
|
||||
assert auth_manager.verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_password_incorrect(self):
|
||||
"""Test password verification with incorrect password."""
|
||||
auth_manager = AuthManager()
|
||||
password = "test_password_123"
|
||||
wrong_password = "wrong_password_456"
|
||||
hashed = auth_manager.hash_password(password)
|
||||
|
||||
assert auth_manager.verify_password(wrong_password, hashed) is False
|
||||
|
||||
def test_verify_password_empty(self):
|
||||
"""Test password verification with empty password."""
|
||||
auth_manager = AuthManager()
|
||||
password = "test_password_123"
|
||||
hashed = auth_manager.hash_password(password)
|
||||
|
||||
assert auth_manager.verify_password("", hashed) is False
|
||||
|
||||
def test_hash_password_special_characters(self):
|
||||
"""Test hashing password with special characters."""
|
||||
auth_manager = AuthManager()
|
||||
password = "P@ssw0rd!#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
hashed = auth_manager.hash_password(password)
|
||||
|
||||
assert auth_manager.verify_password(password, hashed) is True
|
||||
|
||||
def test_hash_password_unicode(self):
|
||||
"""Test hashing password with unicode characters."""
|
||||
auth_manager = AuthManager()
|
||||
password = "パスワード123こんにちは"
|
||||
hashed = auth_manager.hash_password(password)
|
||||
|
||||
assert auth_manager.verify_password(password, hashed) is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestUserAuthentication:
|
||||
"""Test suite for user authentication."""
|
||||
|
||||
def test_authenticate_user_success_with_username(self):
|
||||
"""Test successful authentication with username."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.hashed_password = auth_manager.hash_password("password123")
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "testuser", "password123")
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
def test_authenticate_user_success_with_email(self):
|
||||
"""Test successful authentication with email."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.hashed_password = auth_manager.hash_password("password123")
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
def test_authenticate_user_not_found(self):
|
||||
"""Test authentication with non-existent user."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "nonexistent", "password123")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_authenticate_user_wrong_password(self):
|
||||
"""Test authentication with wrong password."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.hashed_password = auth_manager.hash_password("correctpassword")
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "testuser", "wrongpassword")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestJWTTokenCreation:
|
||||
"""Test suite for JWT token creation."""
|
||||
|
||||
def test_create_access_token_structure(self):
|
||||
"""Test JWT token creation returns correct structure."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 1
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "customer"
|
||||
|
||||
token_data = auth_manager.create_access_token(mock_user)
|
||||
|
||||
assert "access_token" in token_data
|
||||
assert "token_type" in token_data
|
||||
assert "expires_in" in token_data
|
||||
assert token_data["token_type"] == "bearer"
|
||||
assert isinstance(token_data["expires_in"], int)
|
||||
assert token_data["expires_in"] == auth_manager.token_expire_minutes * 60
|
||||
|
||||
def test_create_access_token_payload(self):
|
||||
"""Test JWT token contains correct payload."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 42
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "vendor"
|
||||
|
||||
token_data = auth_manager.create_access_token(mock_user)
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Decode without verification to check payload
|
||||
payload = jwt.decode(token, auth_manager.secret_key, algorithms=[auth_manager.algorithm])
|
||||
|
||||
assert payload["sub"] == "42"
|
||||
assert payload["username"] == "testuser"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert payload["role"] == "vendor"
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
|
||||
def test_create_access_token_different_users(self):
|
||||
"""Test tokens are different for different users."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
user1 = Mock(spec=User, id=1, username="user1", email="user1@test.com", role="customer")
|
||||
user2 = Mock(spec=User, id=2, username="user2", email="user2@test.com", role="vendor")
|
||||
|
||||
token1 = auth_manager.create_access_token(user1)["access_token"]
|
||||
token2 = auth_manager.create_access_token(user2)["access_token"]
|
||||
|
||||
assert token1 != token2
|
||||
|
||||
def test_create_access_token_admin_role(self):
|
||||
"""Test token creation for admin user."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
admin_user = Mock(spec=User)
|
||||
admin_user.id = 1
|
||||
admin_user.username = "admin"
|
||||
admin_user.email = "admin@example.com"
|
||||
admin_user.role = "admin"
|
||||
|
||||
token_data = auth_manager.create_access_token(admin_user)
|
||||
payload = jwt.decode(
|
||||
token_data["access_token"],
|
||||
auth_manager.secret_key,
|
||||
algorithms=[auth_manager.algorithm]
|
||||
)
|
||||
|
||||
assert payload["role"] == "admin"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestJWTTokenVerification:
|
||||
"""Test suite for JWT token verification."""
|
||||
|
||||
def test_verify_token_success(self):
|
||||
"""Test successful token verification."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 1
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "customer"
|
||||
|
||||
token_data = auth_manager.create_access_token(mock_user)
|
||||
token = token_data["access_token"]
|
||||
|
||||
result = auth_manager.verify_token(token)
|
||||
|
||||
assert result["user_id"] == 1
|
||||
assert result["username"] == "testuser"
|
||||
assert result["email"] == "test@example.com"
|
||||
assert result["role"] == "customer"
|
||||
|
||||
def test_verify_token_expired(self):
|
||||
"""Test token verification with expired token."""
|
||||
auth_manager = AuthManager()
|
||||
auth_manager.token_expire_minutes = -1 # Set to negative to force expiration
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 1
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "customer"
|
||||
|
||||
token_data = auth_manager.create_access_token(mock_user)
|
||||
token = token_data["access_token"]
|
||||
|
||||
# Reset to normal
|
||||
auth_manager.token_expire_minutes = 30
|
||||
|
||||
with pytest.raises(TokenExpiredException):
|
||||
auth_manager.verify_token(token)
|
||||
|
||||
def test_verify_token_invalid(self):
|
||||
"""Test token verification with invalid token."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
with pytest.raises(InvalidTokenException):
|
||||
auth_manager.verify_token("invalid.token.here")
|
||||
|
||||
def test_verify_token_tampered(self):
|
||||
"""Test token verification with tampered token."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 1
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "customer"
|
||||
|
||||
token = auth_manager.create_access_token(mock_user)["access_token"]
|
||||
|
||||
# Tamper with token
|
||||
parts = token.split(".")
|
||||
tampered_token = ".".join([parts[0], parts[1], "tampered"])
|
||||
|
||||
with pytest.raises(InvalidTokenException):
|
||||
auth_manager.verify_token(tampered_token)
|
||||
|
||||
def test_verify_token_missing_user_id(self):
|
||||
"""Test token verification with missing user ID."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# Create token without 'sub' field
|
||||
payload = {
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
|
||||
assert "missing user identifier" in str(exc_info.value.message)
|
||||
|
||||
def test_verify_token_missing_expiration(self):
|
||||
"""Test token verification with missing expiration."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# Create token without 'exp' field
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser"
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
|
||||
with pytest.raises(InvalidTokenException) as exc_info:
|
||||
auth_manager.verify_token(token)
|
||||
|
||||
assert "missing expiration" in str(exc_info.value.message)
|
||||
|
||||
def test_verify_token_wrong_algorithm(self):
|
||||
"""Test token verification with different algorithm."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
}
|
||||
# Create token with different algorithm
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm="HS512")
|
||||
|
||||
with pytest.raises(InvalidTokenException):
|
||||
auth_manager.verify_token(token)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestGetCurrentUser:
|
||||
"""Test suite for get_current_user functionality."""
|
||||
|
||||
def test_get_current_user_success(self):
|
||||
"""Test successfully getting current user."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
# Create mock user
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 1
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "customer"
|
||||
mock_user.is_active = True
|
||||
|
||||
# Setup database mock
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
# Create valid token
|
||||
token_data = auth_manager.create_access_token(mock_user)
|
||||
|
||||
# Create mock credentials
|
||||
mock_credentials = Mock()
|
||||
mock_credentials.credentials = token_data["access_token"]
|
||||
|
||||
result = auth_manager.get_current_user(mock_db, mock_credentials)
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
def test_get_current_user_not_found(self):
|
||||
"""Test get_current_user when user doesn't exist in database."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
# Setup database to return None
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Create mock user for token
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 999
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "customer"
|
||||
|
||||
token_data = auth_manager.create_access_token(mock_user)
|
||||
|
||||
mock_credentials = Mock()
|
||||
mock_credentials.credentials = token_data["access_token"]
|
||||
|
||||
with pytest.raises(InvalidCredentialsException):
|
||||
auth_manager.get_current_user(mock_db, mock_credentials)
|
||||
|
||||
def test_get_current_user_inactive(self):
|
||||
"""Test get_current_user with inactive user."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.id = 1
|
||||
mock_user.username = "testuser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.role = "customer"
|
||||
mock_user.is_active = False # Inactive user
|
||||
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
token_data = auth_manager.create_access_token(mock_user)
|
||||
|
||||
mock_credentials = Mock()
|
||||
mock_credentials.credentials = token_data["access_token"]
|
||||
|
||||
with pytest.raises(UserNotActiveException):
|
||||
auth_manager.get_current_user(mock_db, mock_credentials)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRoleRequirements:
|
||||
"""Test suite for role-based access control."""
|
||||
|
||||
def test_require_admin_success(self):
|
||||
"""Test require_admin with admin user."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
admin_user = Mock(spec=User)
|
||||
admin_user.role = "admin"
|
||||
|
||||
result = auth_manager.require_admin(admin_user)
|
||||
|
||||
assert result is admin_user
|
||||
|
||||
def test_require_admin_failure(self):
|
||||
"""Test require_admin with non-admin user."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
customer_user = Mock(spec=User)
|
||||
customer_user.role = "customer"
|
||||
|
||||
with pytest.raises(AdminRequiredException):
|
||||
auth_manager.require_admin(customer_user)
|
||||
|
||||
def test_require_vendor_with_vendor_role(self):
|
||||
"""Test require_vendor with vendor user."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
vendor_user = Mock(spec=User)
|
||||
vendor_user.role = "vendor"
|
||||
|
||||
result = auth_manager.require_vendor(vendor_user)
|
||||
|
||||
assert result is vendor_user
|
||||
|
||||
def test_require_vendor_with_admin_role(self):
|
||||
"""Test require_vendor with admin user (admin can access vendor areas)."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
admin_user = Mock(spec=User)
|
||||
admin_user.role = "admin"
|
||||
|
||||
result = auth_manager.require_vendor(admin_user)
|
||||
|
||||
assert result is admin_user
|
||||
|
||||
def test_require_vendor_failure(self):
|
||||
"""Test require_vendor with customer user."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
customer_user = Mock(spec=User)
|
||||
customer_user.role = "customer"
|
||||
|
||||
with pytest.raises(InsufficientPermissionsException) as exc_info:
|
||||
auth_manager.require_vendor(customer_user)
|
||||
|
||||
assert exc_info.value.details.get("required_permission") == "vendor"
|
||||
|
||||
def test_require_customer_with_customer_role(self):
|
||||
"""Test require_customer with customer user."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
customer_user = Mock(spec=User)
|
||||
customer_user.role = "customer"
|
||||
|
||||
result = auth_manager.require_customer(customer_user)
|
||||
|
||||
assert result is customer_user
|
||||
|
||||
def test_require_customer_with_admin_role(self):
|
||||
"""Test require_customer with admin user (admin can access customer areas)."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
admin_user = Mock(spec=User)
|
||||
admin_user.role = "admin"
|
||||
|
||||
result = auth_manager.require_customer(admin_user)
|
||||
|
||||
assert result is admin_user
|
||||
|
||||
def test_require_customer_failure(self):
|
||||
"""Test require_customer with vendor user."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
vendor_user = Mock(spec=User)
|
||||
vendor_user.role = "vendor"
|
||||
|
||||
with pytest.raises(InsufficientPermissionsException) as exc_info:
|
||||
auth_manager.require_customer(vendor_user)
|
||||
|
||||
assert exc_info.value.details.get("required_permission") == "customer"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestCreateDefaultAdminUser:
|
||||
"""Test suite for default admin user creation."""
|
||||
|
||||
def test_create_default_admin_user_first_time(self):
|
||||
"""Test creating default admin user when none exists."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
# No existing admin user
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
result = auth_manager.create_default_admin_user(mock_db)
|
||||
|
||||
# Verify admin user was created
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
mock_db.refresh.assert_called_once()
|
||||
|
||||
# Verify the created user
|
||||
created_user = mock_db.add.call_args[0][0]
|
||||
assert created_user.username == "admin"
|
||||
assert created_user.email == "admin@example.com"
|
||||
assert created_user.role == "admin"
|
||||
assert created_user.is_active is True
|
||||
assert auth_manager.verify_password("admin123", created_user.hashed_password)
|
||||
|
||||
def test_create_default_admin_user_already_exists(self):
|
||||
"""Test creating default admin user when one already exists."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
# Existing admin user
|
||||
existing_admin = Mock(spec=User)
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = existing_admin
|
||||
|
||||
result = auth_manager.create_default_admin_user(mock_db)
|
||||
|
||||
# Should not create new user
|
||||
mock_db.add.assert_not_called()
|
||||
mock_db.commit.assert_not_called()
|
||||
|
||||
# Should return existing user
|
||||
assert result is existing_admin
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestAuthManagerConfiguration:
|
||||
"""Test suite for AuthManager configuration."""
|
||||
|
||||
def test_default_configuration(self):
|
||||
"""Test AuthManager uses default configuration."""
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.algorithm == "HS256"
|
||||
assert auth_manager.token_expire_minutes == 30
|
||||
assert auth_manager.secret_key == "your-secret-key-change-in-production-please"
|
||||
|
||||
def test_custom_configuration(self):
|
||||
"""Test AuthManager uses environment variables."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_SECRET_KEY': 'custom-secret-key',
|
||||
'JWT_EXPIRE_MINUTES': '60'
|
||||
}):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.secret_key == "custom-secret-key"
|
||||
assert auth_manager.token_expire_minutes == 60
|
||||
|
||||
def test_partial_custom_configuration(self):
|
||||
"""Test AuthManager with partial environment configuration."""
|
||||
with patch.dict('os.environ', {
|
||||
'JWT_EXPIRE_MINUTES': '120'
|
||||
}, clear=False):
|
||||
auth_manager = AuthManager()
|
||||
|
||||
assert auth_manager.token_expire_minutes == 120
|
||||
# Secret key should use default or existing env var
|
||||
assert auth_manager.secret_key is not None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestEdgeCases:
|
||||
"""Test suite for edge cases and error scenarios."""
|
||||
|
||||
def test_verify_password_with_none(self):
|
||||
"""Test password verification with None values."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
# This should not raise an exception, just return False
|
||||
with pytest.raises(Exception):
|
||||
auth_manager.verify_password(None, None)
|
||||
|
||||
def test_token_with_future_iat(self):
|
||||
"""Test token with issued_at time in the future."""
|
||||
auth_manager = AuthManager()
|
||||
|
||||
payload = {
|
||||
"sub": "1",
|
||||
"username": "testuser",
|
||||
"iat": datetime.now(timezone.utc) + timedelta(hours=1), # Future time
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=2)
|
||||
}
|
||||
token = jwt.encode(payload, auth_manager.secret_key, algorithm=auth_manager.algorithm)
|
||||
|
||||
# Should still verify successfully (JWT doesn't validate iat by default)
|
||||
result = auth_manager.verify_token(token)
|
||||
assert result["user_id"] == 1
|
||||
|
||||
def test_authenticate_user_case_sensitivity(self):
|
||||
"""Test that username/email authentication is case-sensitive."""
|
||||
auth_manager = AuthManager()
|
||||
mock_db = Mock()
|
||||
|
||||
mock_user = Mock(spec=User)
|
||||
mock_user.username = "TestUser"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.hashed_password = auth_manager.hash_password("password123")
|
||||
|
||||
# This will depend on database collation, but generally should be case-sensitive
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
result = auth_manager.authenticate_user(mock_db, "testuser", "password123")
|
||||
|
||||
# Result depends on how the filter is implemented
|
||||
# This test documents the expected behavior
|
||||
assert result is None or result is mock_user
|
||||
573
tests/unit/middleware/test_context_middleware.py
Normal file
573
tests/unit/middleware/test_context_middleware.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# tests/unit/middleware/test_context_middleware.py
|
||||
"""
|
||||
Comprehensive unit tests for ContextMiddleware and ContextManager.
|
||||
|
||||
Tests cover:
|
||||
- Context detection for API, Admin, Vendor Dashboard, Shop, and Fallback
|
||||
- Clean path usage for correct context detection
|
||||
- Host and path-based context determination
|
||||
- Middleware state injection
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from fastapi import Request
|
||||
|
||||
from middleware.context_middleware import (
|
||||
ContextManager,
|
||||
ContextMiddleware,
|
||||
RequestContext,
|
||||
get_request_context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRequestContextEnum:
|
||||
"""Test suite for RequestContext enum."""
|
||||
|
||||
def test_request_context_values(self):
|
||||
"""Test RequestContext enum has correct values."""
|
||||
assert RequestContext.API.value == "api"
|
||||
assert RequestContext.ADMIN.value == "admin"
|
||||
assert RequestContext.VENDOR_DASHBOARD.value == "vendor"
|
||||
assert RequestContext.SHOP.value == "shop"
|
||||
assert RequestContext.FALLBACK.value == "fallback"
|
||||
|
||||
def test_request_context_types(self):
|
||||
"""Test RequestContext enum values are strings."""
|
||||
for context in RequestContext:
|
||||
assert isinstance(context.value, str)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContextManagerDetection:
|
||||
"""Test suite for ContextManager.detect_context()."""
|
||||
|
||||
# ========================================================================
|
||||
# API Context Tests (Highest Priority)
|
||||
# ========================================================================
|
||||
|
||||
def test_detect_api_context(self):
|
||||
"""Test API context detection."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/v1/vendors")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/api/v1/vendors")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.API
|
||||
|
||||
def test_detect_api_context_nested_path(self):
|
||||
"""Test API context detection with nested path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/v1/vendors/123/products")
|
||||
request.headers = {"host": "platform.com"}
|
||||
request.state = Mock(clean_path="/api/v1/vendors/123/products")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.API
|
||||
|
||||
def test_detect_api_context_with_clean_path(self):
|
||||
"""Test API context detection uses clean_path when available."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/testvendor/api/products")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/api/products")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.API
|
||||
|
||||
# ========================================================================
|
||||
# Admin Context Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_detect_admin_context_from_subdomain(self):
|
||||
"""Test admin context detection from subdomain."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/dashboard")
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.state = Mock(clean_path="/dashboard")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.ADMIN
|
||||
|
||||
def test_detect_admin_context_from_path(self):
|
||||
"""Test admin context detection from path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/admin/dashboard")
|
||||
request.headers = {"host": "platform.com"}
|
||||
request.state = Mock(clean_path="/admin/dashboard")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.ADMIN
|
||||
|
||||
def test_detect_admin_context_with_port(self):
|
||||
"""Test admin context detection with port number."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/dashboard")
|
||||
request.headers = {"host": "admin.localhost:8000"}
|
||||
request.state = Mock(clean_path="/dashboard")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.ADMIN
|
||||
|
||||
def test_detect_admin_context_nested_path(self):
|
||||
"""Test admin context detection with nested admin path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/admin/vendors/list")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/admin/vendors/list")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.ADMIN
|
||||
|
||||
# ========================================================================
|
||||
# Vendor Dashboard Context Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_detect_vendor_dashboard_context(self):
|
||||
"""Test vendor dashboard context detection."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/testvendor/dashboard")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/dashboard")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.VENDOR_DASHBOARD
|
||||
|
||||
def test_detect_vendor_dashboard_context_direct_path(self):
|
||||
"""Test vendor dashboard with direct /vendor/ path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/settings")
|
||||
request.headers = {"host": "testvendor.platform.com"}
|
||||
request.state = Mock(clean_path="/vendor/settings")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.VENDOR_DASHBOARD
|
||||
|
||||
def test_not_detect_vendors_plural_as_dashboard(self):
|
||||
"""Test that /vendors/ path is not detected as vendor dashboard."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendors/testvendor/shop")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/shop")
|
||||
|
||||
# Should not be vendor dashboard
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context != RequestContext.VENDOR_DASHBOARD
|
||||
|
||||
# ========================================================================
|
||||
# Shop Context Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_detect_shop_context_with_vendor_state(self):
|
||||
"""Test shop context detection when vendor exists in request state."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/products")
|
||||
request.headers = {"host": "testvendor.platform.com"}
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.name = "Test Vendor"
|
||||
request.state = Mock(clean_path="/products", vendor=mock_vendor)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.SHOP
|
||||
|
||||
def test_detect_shop_context_from_shop_path(self):
|
||||
"""Test shop context detection from /shop/ path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/shop/products", vendor=None)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.SHOP
|
||||
|
||||
def test_detect_shop_context_custom_domain(self):
|
||||
"""Test shop context with custom domain and vendor."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/products")
|
||||
request.headers = {"host": "customdomain.com"}
|
||||
mock_vendor = Mock(name="Custom Vendor")
|
||||
request.state = Mock(clean_path="/products", vendor=mock_vendor)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.SHOP
|
||||
|
||||
# ========================================================================
|
||||
# Fallback Context Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_detect_fallback_context(self):
|
||||
"""Test fallback context for unknown paths."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/random/path")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/random/path", vendor=None)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.FALLBACK
|
||||
|
||||
def test_detect_fallback_context_root(self):
|
||||
"""Test fallback context for root path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/")
|
||||
request.headers = {"host": "platform.com"}
|
||||
request.state = Mock(clean_path="/", vendor=None)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.FALLBACK
|
||||
|
||||
def test_detect_fallback_context_no_vendor(self):
|
||||
"""Test fallback context when no vendor context exists."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/about")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/about", vendor=None)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.FALLBACK
|
||||
|
||||
# ========================================================================
|
||||
# Clean Path Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_uses_clean_path_when_available(self):
|
||||
"""Test that clean_path is used over original path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/testvendor/api/products")
|
||||
request.headers = {"host": "localhost"}
|
||||
# clean_path shows the rewritten path
|
||||
request.state = Mock(clean_path="/api/products")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
# Should detect as API based on clean_path, not original path
|
||||
assert context == RequestContext.API
|
||||
|
||||
def test_falls_back_to_original_path(self):
|
||||
"""Test falls back to original path when clean_path not set."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(spec=[]) # No clean_path attribute
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.API
|
||||
|
||||
# ========================================================================
|
||||
# Priority Order Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_api_has_highest_priority(self):
|
||||
"""Test API context takes precedence over admin."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/admin/users")
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.state = Mock(clean_path="/api/admin/users")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
# API should win even though it's admin subdomain
|
||||
assert context == RequestContext.API
|
||||
|
||||
def test_admin_has_priority_over_shop(self):
|
||||
"""Test admin context takes precedence over shop."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/admin/shops")
|
||||
request.headers = {"host": "localhost"}
|
||||
mock_vendor = Mock()
|
||||
request.state = Mock(clean_path="/admin/shops", vendor=mock_vendor)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
# Admin should win even though vendor exists
|
||||
assert context == RequestContext.ADMIN
|
||||
|
||||
def test_vendor_dashboard_priority_over_shop(self):
|
||||
"""Test vendor dashboard takes precedence over shop."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/settings")
|
||||
request.headers = {"host": "testvendor.platform.com"}
|
||||
mock_vendor = Mock()
|
||||
request.state = Mock(clean_path="/vendor/settings", vendor=mock_vendor)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.VENDOR_DASHBOARD
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContextManagerHelpers:
|
||||
"""Test suite for ContextManager helper methods."""
|
||||
|
||||
def test_is_admin_context_from_subdomain(self):
|
||||
"""Test _is_admin_context with admin subdomain."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/dashboard") is True
|
||||
|
||||
def test_is_admin_context_from_path(self):
|
||||
"""Test _is_admin_context with admin path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "localhost", "/admin/users") is True
|
||||
|
||||
def test_is_admin_context_both(self):
|
||||
"""Test _is_admin_context with both subdomain and path."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "admin.platform.com", "/admin/users") is True
|
||||
|
||||
def test_is_not_admin_context(self):
|
||||
"""Test _is_admin_context returns False for non-admin."""
|
||||
request = Mock()
|
||||
assert ContextManager._is_admin_context(request, "vendor.platform.com", "/shop") is False
|
||||
|
||||
def test_is_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context with /vendor/ path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendor/settings") is True
|
||||
|
||||
def test_is_vendor_dashboard_context_nested(self):
|
||||
"""Test _is_vendor_dashboard_context with nested vendor path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendor/products/list") is True
|
||||
|
||||
def test_is_not_vendor_dashboard_context_vendors_plural(self):
|
||||
"""Test _is_vendor_dashboard_context excludes /vendors/ path."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/vendors/shop123/products") is False
|
||||
|
||||
def test_is_not_vendor_dashboard_context(self):
|
||||
"""Test _is_vendor_dashboard_context returns False for non-vendor paths."""
|
||||
assert ContextManager._is_vendor_dashboard_context("/shop/products") is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestContextMiddleware:
|
||||
"""Test suite for ContextMiddleware."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_sets_context(self):
|
||||
"""Test middleware successfully sets context in request state."""
|
||||
middleware = ContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/api/vendors", vendor=None)
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'context_type')
|
||||
assert request.state.context_type == RequestContext.API
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_sets_admin_context(self):
|
||||
"""Test middleware sets admin context."""
|
||||
middleware = ContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/admin/dashboard")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/admin/dashboard")
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.context_type == RequestContext.ADMIN
|
||||
call_next.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_sets_vendor_dashboard_context(self):
|
||||
"""Test middleware sets vendor dashboard context."""
|
||||
middleware = ContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/settings")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/vendor/settings")
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.context_type == RequestContext.VENDOR_DASHBOARD
|
||||
call_next.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_sets_shop_context(self):
|
||||
"""Test middleware sets shop context."""
|
||||
middleware = ContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/products")
|
||||
request.headers = {"host": "shop.platform.com"}
|
||||
mock_vendor = Mock()
|
||||
request.state = Mock(clean_path="/products", vendor=mock_vendor)
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.context_type == RequestContext.SHOP
|
||||
call_next.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_sets_fallback_context(self):
|
||||
"""Test middleware sets fallback context."""
|
||||
middleware = ContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/random")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/random", vendor=None)
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.context_type == RequestContext.FALLBACK
|
||||
call_next.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_returns_response(self):
|
||||
"""Test middleware returns response from call_next."""
|
||||
middleware = ContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/test")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/api/test")
|
||||
|
||||
expected_response = Mock()
|
||||
call_next = AsyncMock(return_value=expected_response)
|
||||
|
||||
response = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert response is expected_response
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetRequestContextHelper:
|
||||
"""Test suite for get_request_context helper function."""
|
||||
|
||||
def test_get_request_context_exists(self):
|
||||
"""Test getting request context when it exists."""
|
||||
request = Mock(spec=Request)
|
||||
request.state.context_type = RequestContext.API
|
||||
|
||||
context = get_request_context(request)
|
||||
|
||||
assert context == RequestContext.API
|
||||
|
||||
def test_get_request_context_default(self):
|
||||
"""Test getting request context returns FALLBACK as default."""
|
||||
request = Mock(spec=Request)
|
||||
request.state = Mock(spec=[]) # No context_type attribute
|
||||
|
||||
context = get_request_context(request)
|
||||
|
||||
assert context == RequestContext.FALLBACK
|
||||
|
||||
def test_get_request_context_for_all_types(self):
|
||||
"""Test getting all context types."""
|
||||
for expected_context in RequestContext:
|
||||
request = Mock(spec=Request)
|
||||
request.state.context_type = expected_context
|
||||
|
||||
context = get_request_context(request)
|
||||
|
||||
assert context == expected_context
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestEdgeCases:
|
||||
"""Test suite for edge cases and error scenarios."""
|
||||
|
||||
def test_detect_context_empty_path(self):
|
||||
"""Test context detection with empty path."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="", vendor=None)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.FALLBACK
|
||||
|
||||
def test_detect_context_missing_host(self):
|
||||
"""Test context detection with missing host header."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
request.headers = {}
|
||||
request.state = Mock(clean_path="/shop/products", vendor=None)
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.SHOP
|
||||
|
||||
def test_detect_context_case_sensitivity(self):
|
||||
"""Test that context detection is case-sensitive for paths."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/API/vendors") # Uppercase
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/API/vendors")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
# Should NOT match /api/ because it's case-sensitive
|
||||
assert context != RequestContext.API
|
||||
|
||||
def test_detect_context_path_with_query_params(self):
|
||||
"""Test context detection handles path with query parameters."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/vendors?page=1")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/api/vendors?page=1")
|
||||
|
||||
# path.startswith should still work
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
assert context == RequestContext.API
|
||||
|
||||
def test_detect_context_admin_substring(self):
|
||||
"""Test that 'admin' substring doesn't trigger false positive."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/administration/docs")
|
||||
request.headers = {"host": "localhost"}
|
||||
request.state = Mock(clean_path="/administration/docs")
|
||||
|
||||
context = ContextManager.detect_context(request)
|
||||
|
||||
# Should match because path starts with /admin
|
||||
assert context == RequestContext.ADMIN
|
||||
|
||||
def test_detect_context_no_state_attribute(self):
|
||||
"""Test context detection when request has no state."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.headers = {"host": "localhost"}
|
||||
# No state attribute at all
|
||||
delattr(request, 'state')
|
||||
|
||||
# Should still work, falling back to url.path
|
||||
with pytest.raises(AttributeError):
|
||||
# This will raise because we're trying to access request.state
|
||||
ContextManager.detect_context(request)
|
||||
536
tests/unit/middleware/test_rate_limiter.py
Normal file
536
tests/unit/middleware/test_rate_limiter.py
Normal file
@@ -0,0 +1,536 @@
|
||||
# tests/unit/middleware/test_rate_limiter.py
|
||||
"""
|
||||
Comprehensive unit tests for RateLimiter.
|
||||
|
||||
Tests cover:
|
||||
- Request allowance within limits
|
||||
- Request blocking when exceeding limits
|
||||
- Sliding window algorithm
|
||||
- Cleanup of old entries
|
||||
- Client statistics
|
||||
- Edge cases and concurrency scenarios
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections import deque
|
||||
|
||||
from middleware.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimiterBasic:
|
||||
"""Test suite for basic rate limiter functionality."""
|
||||
|
||||
def test_rate_limiter_initialization(self):
|
||||
"""Test rate limiter initializes correctly."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
assert isinstance(limiter.clients, dict)
|
||||
assert limiter.cleanup_interval == 3600
|
||||
assert isinstance(limiter.last_cleanup, datetime)
|
||||
|
||||
def test_allow_first_request(self):
|
||||
"""Test rate limiter allows first request."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_1"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||
|
||||
assert result is True
|
||||
assert client_id in limiter.clients
|
||||
assert len(limiter.clients[client_id]) == 1
|
||||
|
||||
def test_allow_multiple_requests_within_limit(self):
|
||||
"""Test rate limiter allows multiple requests within limit."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_2"
|
||||
max_requests = 10
|
||||
|
||||
# Make 10 requests (at the limit)
|
||||
for i in range(max_requests):
|
||||
result = limiter.allow_request(client_id, max_requests, 3600)
|
||||
assert result is True, f"Request {i+1} should be allowed"
|
||||
|
||||
assert len(limiter.clients[client_id]) == max_requests
|
||||
|
||||
def test_block_request_exceeding_limit(self):
|
||||
"""Test rate limiter blocks requests exceeding limit."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_blocked"
|
||||
max_requests = 3
|
||||
|
||||
# Use up the allowed requests
|
||||
for _ in range(max_requests):
|
||||
assert limiter.allow_request(client_id, max_requests, 3600) is True
|
||||
|
||||
# Next request should be blocked
|
||||
result = limiter.allow_request(client_id, max_requests, 3600)
|
||||
assert result is False
|
||||
|
||||
# Client should still have only max_requests entries
|
||||
assert len(limiter.clients[client_id]) == max_requests
|
||||
|
||||
def test_different_clients_separate_limits(self):
|
||||
"""Test different clients have separate rate limits."""
|
||||
limiter = RateLimiter()
|
||||
client1 = "client_1"
|
||||
client2 = "client_2"
|
||||
max_requests = 5
|
||||
|
||||
# Client 1 makes requests
|
||||
for _ in range(max_requests):
|
||||
assert limiter.allow_request(client1, max_requests, 3600) is True
|
||||
|
||||
# Client 1 is blocked
|
||||
assert limiter.allow_request(client1, max_requests, 3600) is False
|
||||
|
||||
# Client 2 should still be allowed
|
||||
assert limiter.allow_request(client2, max_requests, 3600) is True
|
||||
|
||||
# Verify separate tracking
|
||||
assert len(limiter.clients[client1]) == max_requests
|
||||
assert len(limiter.clients[client2]) == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimiterSlidingWindow:
|
||||
"""Test suite for sliding window algorithm."""
|
||||
|
||||
def test_sliding_window_removes_old_requests(self):
|
||||
"""Test sliding window removes requests outside time window."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_window"
|
||||
max_requests = 3
|
||||
window_seconds = 10
|
||||
|
||||
# Manually add old requests
|
||||
old_time = datetime.now(timezone.utc) - timedelta(seconds=15)
|
||||
limiter.clients[client_id].append(old_time)
|
||||
limiter.clients[client_id].append(old_time)
|
||||
|
||||
# These old requests should be removed, so new request should be allowed
|
||||
result = limiter.allow_request(client_id, max_requests, window_seconds)
|
||||
|
||||
assert result is True
|
||||
assert len(limiter.clients[client_id]) == 1 # Only the new request
|
||||
|
||||
def test_sliding_window_keeps_recent_requests(self):
|
||||
"""Test sliding window keeps requests within time window."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_recent"
|
||||
max_requests = 3
|
||||
window_seconds = 60
|
||||
|
||||
# Add recent requests
|
||||
recent_time = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
limiter.clients[client_id].append(recent_time)
|
||||
limiter.clients[client_id].append(recent_time)
|
||||
|
||||
# These requests are within window, so we can only add 1 more
|
||||
result = limiter.allow_request(client_id, max_requests, window_seconds)
|
||||
assert result is True
|
||||
|
||||
# Now at limit
|
||||
result = limiter.allow_request(client_id, max_requests, window_seconds)
|
||||
assert result is False
|
||||
|
||||
def test_sliding_window_mixed_old_and_recent(self):
|
||||
"""Test sliding window with mix of old and recent requests."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_mixed"
|
||||
max_requests = 3
|
||||
window_seconds = 30
|
||||
|
||||
# Add old requests (outside window)
|
||||
old_time = datetime.now(timezone.utc) - timedelta(seconds=60)
|
||||
limiter.clients[client_id].append(old_time)
|
||||
limiter.clients[client_id].append(old_time)
|
||||
|
||||
# Add recent request (within window)
|
||||
recent_time = datetime.now(timezone.utc) - timedelta(seconds=10)
|
||||
limiter.clients[client_id].append(recent_time)
|
||||
|
||||
# Old requests removed, only 1 recent request, so 2 more allowed
|
||||
assert limiter.allow_request(client_id, max_requests, window_seconds) is True
|
||||
assert limiter.allow_request(client_id, max_requests, window_seconds) is True
|
||||
|
||||
# Now at limit
|
||||
assert limiter.allow_request(client_id, max_requests, window_seconds) is False
|
||||
|
||||
def test_sliding_window_with_zero_window(self):
|
||||
"""Test rate limiter with very short window."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "test_client_zero_window"
|
||||
max_requests = 5
|
||||
window_seconds = 1 # 1 second window
|
||||
|
||||
# Add old request
|
||||
old_time = datetime.now(timezone.utc) - timedelta(seconds=2)
|
||||
limiter.clients[client_id].append(old_time)
|
||||
|
||||
# Should allow request because old one is outside 1-second window
|
||||
result = limiter.allow_request(client_id, max_requests, window_seconds)
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimiterCleanup:
|
||||
"""Test suite for cleanup functionality."""
|
||||
|
||||
def test_cleanup_removes_old_entries(self):
|
||||
"""Test cleanup removes entries older than 24 hours."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
# Add clients with old requests
|
||||
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
|
||||
limiter.clients["old_client_1"].append(old_time)
|
||||
limiter.clients["old_client_2"].append(old_time)
|
||||
|
||||
# Add client with recent requests
|
||||
recent_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
limiter.clients["recent_client"].append(recent_time)
|
||||
|
||||
# Run cleanup
|
||||
limiter._cleanup_old_entries()
|
||||
|
||||
# Old clients should be removed
|
||||
assert "old_client_1" not in limiter.clients
|
||||
assert "old_client_2" not in limiter.clients
|
||||
|
||||
# Recent client should remain
|
||||
assert "recent_client" in limiter.clients
|
||||
|
||||
def test_cleanup_removes_empty_clients(self):
|
||||
"""Test cleanup removes clients with no requests."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
# Add empty clients
|
||||
limiter.clients["empty_client_1"] = deque()
|
||||
limiter.clients["empty_client_2"] = deque()
|
||||
|
||||
# Add client with requests
|
||||
limiter.clients["active_client"].append(datetime.now(timezone.utc))
|
||||
|
||||
# Run cleanup
|
||||
limiter._cleanup_old_entries()
|
||||
|
||||
# Empty clients should be removed
|
||||
assert "empty_client_1" not in limiter.clients
|
||||
assert "empty_client_2" not in limiter.clients
|
||||
|
||||
# Active client should remain
|
||||
assert "active_client" in limiter.clients
|
||||
|
||||
def test_cleanup_partial_removal(self):
|
||||
"""Test cleanup removes only old requests, keeps recent ones."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "mixed_client"
|
||||
|
||||
# Add old requests
|
||||
old_time = datetime.now(timezone.utc) - timedelta(hours=30)
|
||||
limiter.clients[client_id].append(old_time)
|
||||
limiter.clients[client_id].append(old_time)
|
||||
|
||||
# Add recent requests
|
||||
recent_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
limiter.clients[client_id].append(recent_time)
|
||||
limiter.clients[client_id].append(recent_time)
|
||||
|
||||
# Run cleanup
|
||||
limiter._cleanup_old_entries()
|
||||
|
||||
# Client should remain with only recent requests
|
||||
assert client_id in limiter.clients
|
||||
assert len(limiter.clients[client_id]) == 2
|
||||
|
||||
def test_automatic_cleanup_triggers(self):
|
||||
"""Test automatic cleanup triggers after interval."""
|
||||
limiter = RateLimiter()
|
||||
limiter.cleanup_interval = 0 # Force immediate cleanup
|
||||
|
||||
# Set last_cleanup to past
|
||||
limiter.last_cleanup = datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
|
||||
# Add old client
|
||||
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
|
||||
limiter.clients["old_client"].append(old_time)
|
||||
|
||||
# Make request (should trigger cleanup)
|
||||
limiter.allow_request("new_client", 10, 3600)
|
||||
|
||||
# Old client should be cleaned up
|
||||
assert "old_client" not in limiter.clients
|
||||
|
||||
def test_cleanup_does_not_affect_active_clients(self):
|
||||
"""Test cleanup doesn't remove clients with recent requests."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
# Add multiple active clients
|
||||
now = datetime.now(timezone.utc)
|
||||
for i in range(5):
|
||||
limiter.clients[f"client_{i}"].append(now - timedelta(hours=i))
|
||||
|
||||
# Run cleanup
|
||||
limiter._cleanup_old_entries()
|
||||
|
||||
# All clients should still exist (all within 24 hours)
|
||||
assert len(limiter.clients) == 5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimiterStatistics:
|
||||
"""Test suite for client statistics functionality."""
|
||||
|
||||
def test_get_client_stats_empty(self):
|
||||
"""Test getting stats for client with no requests."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "new_client"
|
||||
|
||||
stats = limiter.get_client_stats(client_id)
|
||||
|
||||
assert stats["requests_last_hour"] == 0
|
||||
assert stats["requests_last_day"] == 0
|
||||
assert stats["total_tracked_requests"] == 0
|
||||
|
||||
def test_get_client_stats_with_requests(self):
|
||||
"""Test getting stats for client with requests."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "active_client"
|
||||
|
||||
# Add requests at different times
|
||||
now = datetime.now(timezone.utc)
|
||||
limiter.clients[client_id].append(now - timedelta(minutes=30)) # Within hour
|
||||
limiter.clients[client_id].append(now - timedelta(hours=2)) # Within day
|
||||
limiter.clients[client_id].append(now - timedelta(hours=12)) # Within day
|
||||
|
||||
stats = limiter.get_client_stats(client_id)
|
||||
|
||||
assert stats["requests_last_hour"] == 1
|
||||
assert stats["requests_last_day"] == 3
|
||||
assert stats["total_tracked_requests"] == 3
|
||||
|
||||
def test_get_client_stats_old_requests(self):
|
||||
"""Test stats exclude requests older than tracking period."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "old_requests_client"
|
||||
|
||||
# Add very old requests
|
||||
now = datetime.now(timezone.utc)
|
||||
limiter.clients[client_id].append(now - timedelta(days=2))
|
||||
limiter.clients[client_id].append(now - timedelta(days=3))
|
||||
|
||||
stats = limiter.get_client_stats(client_id)
|
||||
|
||||
assert stats["requests_last_hour"] == 0
|
||||
assert stats["requests_last_day"] == 0
|
||||
assert stats["total_tracked_requests"] == 2 # Still tracked, just not counted
|
||||
|
||||
def test_get_client_stats_nonexistent_client(self):
|
||||
"""Test getting stats for client that doesn't exist."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
stats = limiter.get_client_stats("nonexistent_client")
|
||||
|
||||
assert stats["requests_last_hour"] == 0
|
||||
assert stats["requests_last_day"] == 0
|
||||
assert stats["total_tracked_requests"] == 0
|
||||
|
||||
def test_get_client_stats_boundary_cases(self):
|
||||
"""Test stats at exact hour/day boundaries."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "boundary_client"
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Exactly 1 hour ago (should be included)
|
||||
limiter.clients[client_id].append(now - timedelta(hours=1, seconds=1))
|
||||
|
||||
# Exactly 24 hours ago (should be excluded)
|
||||
limiter.clients[client_id].append(now - timedelta(days=1, seconds=1))
|
||||
|
||||
stats = limiter.get_client_stats(client_id)
|
||||
|
||||
# Boundary behavior depends on > vs >= comparison
|
||||
assert stats["requests_last_hour"] >= 0
|
||||
assert stats["requests_last_day"] >= 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimiterEdgeCases:
|
||||
"""Test suite for edge cases and error scenarios."""
|
||||
|
||||
def test_rate_limiter_with_zero_max_requests(self):
|
||||
"""Test rate limiter with max_requests=0."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "zero_limit_client"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=0, window_seconds=3600)
|
||||
|
||||
# Should be blocked immediately
|
||||
assert result is False
|
||||
|
||||
def test_rate_limiter_with_negative_max_requests(self):
|
||||
"""Test rate limiter with negative max_requests."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "negative_limit_client"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=-1, window_seconds=3600)
|
||||
|
||||
# Should be blocked
|
||||
assert result is False
|
||||
|
||||
def test_rate_limiter_with_large_max_requests(self):
|
||||
"""Test rate limiter with very large max_requests."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "large_limit_client"
|
||||
max_requests = 1000000
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests, 3600)
|
||||
|
||||
# Should be allowed
|
||||
assert result is True
|
||||
|
||||
def test_rate_limiter_very_short_window(self):
|
||||
"""Test rate limiter with very short time window."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "short_window_client"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=1, window_seconds=1)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_rate_limiter_very_long_window(self):
|
||||
"""Test rate limiter with very long time window."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "long_window_client"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=10, window_seconds=86400*365)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_rate_limiter_same_client_different_limits(self):
|
||||
"""Test same client with different rate limits."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "same_client"
|
||||
|
||||
# Allow with one limit
|
||||
assert limiter.allow_request(client_id, max_requests=10, window_seconds=3600) is True
|
||||
|
||||
# Check with stricter limit
|
||||
assert limiter.allow_request(client_id, max_requests=1, window_seconds=3600) is False
|
||||
|
||||
def test_rate_limiter_unicode_client_id(self):
|
||||
"""Test rate limiter with unicode client ID."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "クライアント_123"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
|
||||
|
||||
assert result is True
|
||||
assert client_id in limiter.clients
|
||||
|
||||
def test_rate_limiter_special_characters_client_id(self):
|
||||
"""Test rate limiter with special characters in client ID."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "client!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
|
||||
|
||||
assert result is True
|
||||
assert client_id in limiter.clients
|
||||
|
||||
def test_rate_limiter_empty_client_id(self):
|
||||
"""Test rate limiter with empty client ID."""
|
||||
limiter = RateLimiter()
|
||||
client_id = ""
|
||||
|
||||
result = limiter.allow_request(client_id, max_requests=5, window_seconds=3600)
|
||||
|
||||
assert result is True
|
||||
assert client_id in limiter.clients
|
||||
|
||||
def test_rate_limiter_concurrent_same_client(self):
|
||||
"""Test rate limiter behavior with rapid requests from same client."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "concurrent_client"
|
||||
max_requests = 3
|
||||
|
||||
# Simulate rapid requests
|
||||
results = []
|
||||
for _ in range(5):
|
||||
results.append(limiter.allow_request(client_id, max_requests, 3600))
|
||||
|
||||
# First 3 should be True, rest False
|
||||
assert results[:3] == [True, True, True]
|
||||
assert results[3:] == [False, False]
|
||||
|
||||
def test_cleanup_updates_last_cleanup_time(self):
|
||||
"""Test that cleanup updates last_cleanup timestamp."""
|
||||
limiter = RateLimiter()
|
||||
old_cleanup_time = limiter.last_cleanup
|
||||
|
||||
# Force cleanup
|
||||
limiter.cleanup_interval = 0
|
||||
limiter.allow_request("test", 10, 3600)
|
||||
|
||||
# last_cleanup should be updated
|
||||
assert limiter.last_cleanup > old_cleanup_time
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimiterMemoryManagement:
|
||||
"""Test suite for memory management and performance."""
|
||||
|
||||
def test_limiter_does_not_grow_indefinitely(self):
|
||||
"""Test that old entries are cleaned up to prevent memory leaks."""
|
||||
limiter = RateLimiter()
|
||||
limiter.cleanup_interval = 0 # Force cleanup on every request
|
||||
|
||||
# Simulate many requests over time
|
||||
for i in range(100):
|
||||
limiter.allow_request(f"client_{i}", max_requests=10, window_seconds=3600)
|
||||
|
||||
# Force cleanup
|
||||
limiter._cleanup_old_entries()
|
||||
|
||||
# Should have cleaned up clients with no recent activity
|
||||
# Exact number depends on timing, but should be less than 100
|
||||
assert len(limiter.clients) <= 100
|
||||
|
||||
def test_deque_efficiency(self):
|
||||
"""Test that deque is used for efficient popleft operations."""
|
||||
limiter = RateLimiter()
|
||||
client_id = "efficiency_test"
|
||||
|
||||
# Add many old requests
|
||||
old_time = datetime.now(timezone.utc) - timedelta(hours=2)
|
||||
for _ in range(1000):
|
||||
limiter.clients[client_id].append(old_time)
|
||||
|
||||
# This should efficiently remove all old requests
|
||||
limiter.allow_request(client_id, max_requests=10, window_seconds=3600)
|
||||
|
||||
# Should only have the new request
|
||||
assert len(limiter.clients[client_id]) == 1
|
||||
|
||||
def test_multiple_clients_independence(self):
|
||||
"""Test that multiple clients don't interfere with each other."""
|
||||
limiter = RateLimiter()
|
||||
num_clients = 100
|
||||
|
||||
# Create many clients with requests
|
||||
for i in range(num_clients):
|
||||
limiter.allow_request(f"client_{i}", max_requests=5, window_seconds=3600)
|
||||
|
||||
# Each client should have exactly 1 request
|
||||
assert len(limiter.clients) == num_clients
|
||||
for i in range(num_clients):
|
||||
assert len(limiter.clients[f"client_{i}"]) == 1
|
||||
589
tests/unit/middleware/test_theme_logging_path_decorators.py
Normal file
589
tests/unit/middleware/test_theme_logging_path_decorators.py
Normal file
@@ -0,0 +1,589 @@
|
||||
# tests/unit/middleware/test_theme_logging_path_decorators.py
|
||||
"""
|
||||
Comprehensive unit tests for remaining middleware components:
|
||||
- ThemeContextMiddleware and ThemeContextManager
|
||||
- LoggingMiddleware
|
||||
- path_rewrite_middleware
|
||||
- rate_limit decorator
|
||||
|
||||
Tests cover:
|
||||
- Theme loading and caching
|
||||
- Request/response logging
|
||||
- Path rewriting for vendor routing
|
||||
- Rate limit decorators
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
||||
from fastapi import Request
|
||||
import time
|
||||
|
||||
from middleware.theme_context import (
|
||||
ThemeContextManager,
|
||||
ThemeContextMiddleware,
|
||||
get_current_theme,
|
||||
)
|
||||
from middleware.logging_middleware import LoggingMiddleware
|
||||
from middleware.path_rewrite_middleware import path_rewrite_middleware
|
||||
from middleware.decorators import rate_limit
|
||||
from app.exceptions.base import RateLimitException
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Theme Context Tests
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestThemeContextManager:
|
||||
"""Test suite for ThemeContextManager."""
|
||||
|
||||
def test_get_default_theme_structure(self):
|
||||
"""Test default theme has correct structure."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
assert "theme_name" in theme
|
||||
assert "colors" in theme
|
||||
assert "fonts" in theme
|
||||
assert "branding" in theme
|
||||
assert "layout" in theme
|
||||
assert "social_links" in theme
|
||||
assert "css_variables" in theme
|
||||
|
||||
def test_get_default_theme_colors(self):
|
||||
"""Test default theme has all required colors."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
required_colors = ["primary", "secondary", "accent", "background", "text", "border"]
|
||||
for color in required_colors:
|
||||
assert color in theme["colors"]
|
||||
assert theme["colors"][color].startswith("#")
|
||||
|
||||
def test_get_default_theme_fonts(self):
|
||||
"""Test default theme has font configuration."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
assert "heading" in theme["fonts"]
|
||||
assert "body" in theme["fonts"]
|
||||
assert isinstance(theme["fonts"]["heading"], str)
|
||||
assert isinstance(theme["fonts"]["body"], str)
|
||||
|
||||
def test_get_default_theme_branding(self):
|
||||
"""Test default theme branding structure."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
assert "logo" in theme["branding"]
|
||||
assert "logo_dark" in theme["branding"]
|
||||
assert "favicon" in theme["branding"]
|
||||
assert "banner" in theme["branding"]
|
||||
|
||||
def test_get_default_theme_css_variables(self):
|
||||
"""Test default theme has CSS variables."""
|
||||
theme = ThemeContextManager.get_default_theme()
|
||||
|
||||
assert "--color-primary" in theme["css_variables"]
|
||||
assert "--font-heading" in theme["css_variables"]
|
||||
assert "--font-body" in theme["css_variables"]
|
||||
|
||||
def test_get_vendor_theme_with_custom_theme(self):
|
||||
"""Test getting vendor-specific theme."""
|
||||
mock_db = Mock()
|
||||
mock_theme = Mock()
|
||||
mock_theme.to_dict.return_value = {
|
||||
"theme_name": "custom",
|
||||
"colors": {"primary": "#ff0000"}
|
||||
}
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_theme
|
||||
|
||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||
|
||||
assert theme["theme_name"] == "custom"
|
||||
assert theme["colors"]["primary"] == "#ff0000"
|
||||
mock_theme.to_dict.assert_called_once()
|
||||
|
||||
def test_get_vendor_theme_fallback_to_default(self):
|
||||
"""Test falling back to default theme when no custom theme exists."""
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||
|
||||
assert theme["theme_name"] == "default"
|
||||
assert "colors" in theme
|
||||
assert "fonts" in theme
|
||||
|
||||
def test_get_vendor_theme_inactive_theme(self):
|
||||
"""Test that inactive themes are not returned."""
|
||||
mock_db = Mock()
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
theme = ThemeContextManager.get_vendor_theme(mock_db, vendor_id=1)
|
||||
|
||||
# Should return default theme
|
||||
assert theme["theme_name"] == "default"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestThemeContextMiddleware:
|
||||
"""Test suite for ThemeContextMiddleware."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_loads_theme_for_vendor(self):
|
||||
"""Test middleware loads theme when vendor exists."""
|
||||
middleware = ThemeContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.id = 1
|
||||
mock_vendor.name = "Test Vendor"
|
||||
request.state = Mock(vendor=mock_vendor)
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_theme = {"theme_name": "test_theme"}
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', return_value=mock_theme):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.theme == mock_theme
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_uses_default_theme_no_vendor(self):
|
||||
"""Test middleware uses default theme when no vendor."""
|
||||
middleware = ThemeContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.state = Mock(vendor=None)
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert hasattr(request.state, 'theme')
|
||||
assert request.state.theme["theme_name"] == "default"
|
||||
call_next.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_handles_theme_loading_error(self):
|
||||
"""Test middleware handles errors gracefully."""
|
||||
middleware = ThemeContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
mock_vendor = Mock(id=1, name="Test Vendor")
|
||||
request.state = Mock(vendor=mock_vendor)
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])), \
|
||||
patch.object(ThemeContextManager, 'get_vendor_theme', side_effect=Exception("DB Error")):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should fallback to default theme
|
||||
assert request.state.theme["theme_name"] == "default"
|
||||
call_next.assert_called_once()
|
||||
|
||||
def test_get_current_theme_exists(self):
|
||||
"""Test getting current theme when it exists."""
|
||||
request = Mock(spec=Request)
|
||||
test_theme = {"theme_name": "test"}
|
||||
request.state.theme = test_theme
|
||||
|
||||
theme = get_current_theme(request)
|
||||
|
||||
assert theme == test_theme
|
||||
|
||||
def test_get_current_theme_default(self):
|
||||
"""Test getting theme returns default when not set."""
|
||||
request = Mock(spec=Request)
|
||||
request.state = Mock(spec=[]) # No theme attribute
|
||||
|
||||
theme = get_current_theme(request)
|
||||
|
||||
assert theme["theme_name"] == "default"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Logging Middleware Tests
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLoggingMiddleware:
|
||||
"""Test suite for LoggingMiddleware."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_request(self):
|
||||
"""Test middleware logs incoming request."""
|
||||
middleware = LoggingMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
request.url = Mock(path="/api/vendors")
|
||||
request.client = Mock(host="127.0.0.1")
|
||||
|
||||
call_next = AsyncMock(return_value=Mock(status_code=200))
|
||||
|
||||
with patch('middleware.logging_middleware.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify request was logged
|
||||
assert mock_logger.info.call_count >= 1
|
||||
first_call = mock_logger.info.call_args_list[0]
|
||||
assert "GET" in str(first_call)
|
||||
assert "/api/vendors" in str(first_call)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_response(self):
|
||||
"""Test middleware logs response with status code and duration."""
|
||||
middleware = LoggingMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "POST"
|
||||
request.url = Mock(path="/api/products")
|
||||
request.client = Mock(host="127.0.0.1")
|
||||
|
||||
response = Mock()
|
||||
response.status_code = 201
|
||||
response.headers = {}
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging_middleware.logger') as mock_logger:
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify response was logged
|
||||
assert mock_logger.info.call_count >= 2 # Request + Response
|
||||
last_call = mock_logger.info.call_args_list[-1]
|
||||
assert "201" in str(last_call)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_adds_process_time_header(self):
|
||||
"""Test middleware adds X-Process-Time header."""
|
||||
middleware = LoggingMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
request.url = Mock(path="/test")
|
||||
request.client = Mock(host="127.0.0.1")
|
||||
|
||||
response = Mock()
|
||||
response.status_code = 200
|
||||
response.headers = {}
|
||||
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with patch('middleware.logging_middleware.logger'):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
assert "X-Process-Time" in response.headers
|
||||
# Should be a numeric string
|
||||
process_time = float(response.headers["X-Process-Time"])
|
||||
assert process_time >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_handles_no_client(self):
|
||||
"""Test middleware handles requests with no client info."""
|
||||
middleware = LoggingMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
request.url = Mock(path="/test")
|
||||
request.client = None # No client info
|
||||
|
||||
call_next = AsyncMock(return_value=Mock(status_code=200))
|
||||
|
||||
with patch('middleware.logging_middleware.logger') as mock_logger:
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Should log "unknown" for client
|
||||
assert any("unknown" in str(call) for call in mock_logger.info.call_args_list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_logs_exceptions(self):
|
||||
"""Test middleware logs exceptions."""
|
||||
middleware = LoggingMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
request.url = Mock(path="/error")
|
||||
request.client = Mock(host="127.0.0.1")
|
||||
|
||||
call_next = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
with patch('middleware.logging_middleware.logger') as mock_logger, \
|
||||
pytest.raises(Exception):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_once()
|
||||
assert "Test error" in str(mock_logger.error.call_args)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_timing_accuracy(self):
|
||||
"""Test middleware timing is reasonably accurate."""
|
||||
middleware = LoggingMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.method = "GET"
|
||||
request.url = Mock(path="/slow")
|
||||
request.client = Mock(host="127.0.0.1")
|
||||
|
||||
async def slow_call_next(req):
|
||||
await asyncio.sleep(0.1) # 100ms delay
|
||||
response = Mock(status_code=200, headers={})
|
||||
return response
|
||||
|
||||
call_next = slow_call_next
|
||||
|
||||
import asyncio
|
||||
with patch('middleware.logging_middleware.logger'):
|
||||
result = await middleware.dispatch(request, call_next)
|
||||
|
||||
process_time = float(result.headers["X-Process-Time"])
|
||||
# Should be at least 0.1 seconds
|
||||
assert process_time >= 0.1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Path Rewrite Middleware Tests
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPathRewriteMiddleware:
|
||||
"""Test suite for path_rewrite_middleware."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rewrites_path_when_clean_path_different(self):
|
||||
"""Test path is rewritten when clean_path differs from original."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/testvendor/shop/products")
|
||||
request.state = Mock(clean_path="/shop/products")
|
||||
request.scope = {"path": "/vendor/testvendor/shop/products"}
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await path_rewrite_middleware(request, call_next)
|
||||
|
||||
# Path should be rewritten in scope
|
||||
assert request.scope["path"] == "/shop/products"
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_rewrite_when_paths_same(self):
|
||||
"""Test path is not rewritten when clean_path same as original."""
|
||||
request = Mock(spec=Request)
|
||||
original_path = "/shop/products"
|
||||
request.url = Mock(path=original_path)
|
||||
request.state = Mock(clean_path=original_path)
|
||||
request.scope = {"path": original_path}
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await path_rewrite_middleware(request, call_next)
|
||||
|
||||
# Path should remain unchanged
|
||||
assert request.scope["path"] == original_path
|
||||
call_next.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_nothing_when_no_clean_path(self):
|
||||
"""Test middleware does nothing when no clean_path set."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
request.state = Mock(spec=[]) # No clean_path attribute
|
||||
original_path = "/shop/products"
|
||||
request.scope = {"path": original_path}
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await path_rewrite_middleware(request, call_next)
|
||||
|
||||
# Path should remain unchanged
|
||||
assert request.scope["path"] == original_path
|
||||
call_next.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_request_url(self):
|
||||
"""Test middleware updates request._url."""
|
||||
request = Mock(spec=Request)
|
||||
original_url = Mock(path="/vendor/test/shop")
|
||||
request.url = original_url
|
||||
request.url.replace = Mock(return_value=Mock(path="/shop"))
|
||||
request.state = Mock(clean_path="/shop")
|
||||
request.scope = {"path": "/vendor/test/shop"}
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await path_rewrite_middleware(request, call_next)
|
||||
|
||||
# URL replace should have been called
|
||||
request.url.replace.assert_called_once_with(path="/shop")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_vendor_context(self):
|
||||
"""Test middleware preserves vendor context in request.state."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/testvendor/products")
|
||||
mock_vendor = Mock()
|
||||
request.state = Mock(clean_path="/products", vendor=mock_vendor)
|
||||
request.scope = {"path": "/vendor/testvendor/products"}
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await path_rewrite_middleware(request, call_next)
|
||||
|
||||
# Vendor should still be accessible
|
||||
assert request.state.vendor is mock_vendor
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rate Limit Decorator Tests
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.auth
|
||||
class TestRateLimitDecorator:
|
||||
"""Test suite for rate_limit decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_allows_within_limit(self):
|
||||
"""Test decorator allows requests within rate limit."""
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
result = await test_endpoint()
|
||||
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_blocks_exceeding_limit(self):
|
||||
"""Test decorator blocks requests exceeding rate limit."""
|
||||
@rate_limit(max_requests=2, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
# First two should succeed
|
||||
await test_endpoint()
|
||||
await test_endpoint()
|
||||
|
||||
# Third should raise exception
|
||||
with pytest.raises(RateLimitException) as exc_info:
|
||||
await test_endpoint()
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert "Rate limit exceeded" in exc_info.value.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_preserves_function_metadata(self):
|
||||
"""Test decorator preserves original function metadata."""
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint():
|
||||
"""Test endpoint docstring."""
|
||||
return {"status": "ok"}
|
||||
|
||||
assert test_endpoint.__name__ == "test_endpoint"
|
||||
assert test_endpoint.__doc__ == "Test endpoint docstring."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_args_and_kwargs(self):
|
||||
"""Test decorator works with function arguments."""
|
||||
@rate_limit(max_requests=10, window_seconds=3600)
|
||||
async def test_endpoint(arg1, arg2, kwarg1=None):
|
||||
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1}
|
||||
|
||||
result = await test_endpoint("value1", "value2", kwarg1="value3")
|
||||
|
||||
assert result == {
|
||||
"arg1": "value1",
|
||||
"arg2": "value2",
|
||||
"kwarg1": "value3"
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_default_parameters(self):
|
||||
"""Test decorator uses default parameters."""
|
||||
@rate_limit() # Use defaults
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
result = await test_endpoint()
|
||||
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_exception_includes_retry_after(self):
|
||||
"""Test rate limit exception includes retry_after."""
|
||||
@rate_limit(max_requests=1, window_seconds=60)
|
||||
async def test_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
await test_endpoint() # Use up limit
|
||||
|
||||
with pytest.raises(RateLimitException) as exc_info:
|
||||
await test_endpoint()
|
||||
|
||||
assert exc_info.value.details.get("retry_after") == 60
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Edge Cases and Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestMiddlewareEdgeCases:
|
||||
"""Test suite for edge cases across middleware."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_theme_middleware_closes_db_connection(self):
|
||||
"""Test theme middleware properly closes database connection."""
|
||||
middleware = ThemeContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
mock_vendor = Mock(id=1, name="Test")
|
||||
request.state = Mock(vendor=mock_vendor)
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch('middleware.theme_context.get_db', return_value=iter([mock_db])):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
# Verify database was closed
|
||||
mock_db.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_rewrite_with_query_parameters(self):
|
||||
"""Test path rewrite preserves query parameters."""
|
||||
request = Mock(spec=Request)
|
||||
original_url = Mock(path="/vendor/test/shop?page=1")
|
||||
request.url = original_url
|
||||
request.url.replace = Mock(return_value=Mock(path="/shop?page=1"))
|
||||
request.state = Mock(clean_path="/shop?page=1")
|
||||
request.scope = {"path": "/vendor/test/shop?page=1"}
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
await path_rewrite_middleware(request, call_next)
|
||||
|
||||
request.url.replace.assert_called_once_with(path="/shop?page=1")
|
||||
|
||||
def test_theme_default_immutability(self):
|
||||
"""Test that getting default theme doesn't share state."""
|
||||
theme1 = ThemeContextManager.get_default_theme()
|
||||
theme2 = ThemeContextManager.get_default_theme()
|
||||
|
||||
# Modify theme1
|
||||
theme1["colors"]["primary"] = "#000000"
|
||||
|
||||
# theme2 should not be affected (if properly implemented)
|
||||
# Note: This test documents expected behavior
|
||||
assert theme2["colors"]["primary"] == "#6366f1"
|
||||
@@ -1,23 +1,721 @@
|
||||
from requests.cookies import MockRequest
|
||||
# tests/unit/middleware/test_vendor_context.py
|
||||
"""
|
||||
Comprehensive unit tests for VendorContextMiddleware and VendorContextManager.
|
||||
|
||||
from middleware.vendor_context import VendorContextManager
|
||||
Tests cover:
|
||||
- Vendor detection from custom domains, subdomains, and path-based routing
|
||||
- Database lookup and vendor validation
|
||||
- Path extraction and cleanup
|
||||
- Admin and API request detection
|
||||
- Static file request detection
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch, AsyncMock
|
||||
from fastapi import Request, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.vendor_context import (
|
||||
VendorContextManager,
|
||||
VendorContextMiddleware,
|
||||
get_current_vendor,
|
||||
require_vendor_context,
|
||||
)
|
||||
|
||||
|
||||
def test_custom_domain_detection():
|
||||
# Mock request with custom domain
|
||||
request = MockRequest(host="customdomain1.com")
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
assert context["detection_method"] == "custom_domain"
|
||||
assert context["domain"] == "customdomain1.com"
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.vendors
|
||||
class TestVendorContextManager:
|
||||
"""Test suite for VendorContextManager static methods."""
|
||||
|
||||
def test_subdomain_detection():
|
||||
request = MockRequest(host="vendor1.platform.com")
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
assert context["detection_method"] == "subdomain"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
# ========================================================================
|
||||
# Vendor Context Detection Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_path_detection():
|
||||
request = MockRequest(host="localhost", path="/vendor/vendor1/")
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
assert context["detection_method"] == "path"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
def test_detect_custom_domain(self):
|
||||
"""Test custom domain detection."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "customdomain1.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "custom_domain"
|
||||
assert context["domain"] == "customdomain1.com"
|
||||
assert context["host"] == "customdomain1.com"
|
||||
|
||||
def test_detect_custom_domain_with_port(self):
|
||||
"""Test custom domain detection with port number."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "customdomain1.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "custom_domain"
|
||||
assert context["domain"] == "customdomain1.com"
|
||||
assert context["host"] == "customdomain1.com"
|
||||
|
||||
def test_detect_subdomain(self):
|
||||
"""Test subdomain detection."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "subdomain"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
assert context["host"] == "vendor1.platform.com"
|
||||
|
||||
def test_detect_subdomain_with_port(self):
|
||||
"""Test subdomain detection with port number."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "vendor1.platform.com:8000"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "subdomain"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
|
||||
def test_detect_path_vendor_singular(self):
|
||||
"""Test path-based detection with /vendor/ prefix."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/vendor/vendor1/shop")
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "path"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
assert context["path_prefix"] == "/vendor/vendor1"
|
||||
assert context["full_prefix"] == "/vendor/"
|
||||
|
||||
def test_detect_path_vendors_plural(self):
|
||||
"""Test path-based detection with /vendors/ prefix."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/vendors/vendor1/shop")
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "path"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
assert context["path_prefix"] == "/vendors/vendor1"
|
||||
assert context["full_prefix"] == "/vendors/"
|
||||
|
||||
def test_detect_no_vendor_context(self):
|
||||
"""Test when no vendor context can be detected."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/random/path")
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is None
|
||||
|
||||
def test_ignore_admin_subdomain(self):
|
||||
"""Test that admin subdomain is not detected as vendor."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is None
|
||||
|
||||
def test_ignore_www_subdomain(self):
|
||||
"""Test that www subdomain is not detected as vendor."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "www.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is None
|
||||
|
||||
def test_ignore_api_subdomain(self):
|
||||
"""Test that api subdomain is not detected as vendor."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "api.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is None
|
||||
|
||||
def test_ignore_localhost(self):
|
||||
"""Test that localhost is not detected as custom domain."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is None
|
||||
|
||||
# ========================================================================
|
||||
# Vendor Database Lookup Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_get_vendor_from_custom_domain_context(self):
|
||||
"""Test getting vendor from custom domain context."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_vendor_domain = Mock()
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is mock_vendor
|
||||
assert vendor.is_active is True
|
||||
|
||||
def test_get_vendor_from_custom_domain_inactive_vendor(self):
|
||||
"""Test getting inactive vendor from custom domain context."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_vendor_domain = Mock()
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = False
|
||||
mock_vendor_domain.vendor = mock_vendor
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor_domain
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "customdomain1.com"
|
||||
}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is None
|
||||
|
||||
def test_get_vendor_from_custom_domain_not_found(self):
|
||||
"""Test custom domain not found in database."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
context = {
|
||||
"detection_method": "custom_domain",
|
||||
"domain": "nonexistent.com"
|
||||
}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is None
|
||||
|
||||
def test_get_vendor_from_subdomain_context(self):
|
||||
"""Test getting vendor from subdomain context."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is mock_vendor
|
||||
|
||||
def test_get_vendor_from_path_context(self):
|
||||
"""Test getting vendor from path context."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
|
||||
context = {
|
||||
"detection_method": "path",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is mock_vendor
|
||||
|
||||
def test_get_vendor_with_no_context(self):
|
||||
"""Test getting vendor with no context."""
|
||||
mock_db = Mock(spec=Session)
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, None)
|
||||
|
||||
assert vendor is None
|
||||
|
||||
def test_get_vendor_subdomain_case_insensitive(self):
|
||||
"""Test subdomain lookup is case-insensitive."""
|
||||
mock_db = Mock(spec=Session)
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.is_active = True
|
||||
|
||||
mock_db.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_vendor
|
||||
|
||||
context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "VENDOR1" # Uppercase
|
||||
}
|
||||
|
||||
vendor = VendorContextManager.get_vendor_from_context(mock_db, context)
|
||||
|
||||
assert vendor is mock_vendor
|
||||
|
||||
# ========================================================================
|
||||
# Path Extraction Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_extract_clean_path_from_vendor_path(self):
|
||||
"""Test extracting clean path from /vendor/ prefix."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
assert clean_path == "/shop/products"
|
||||
|
||||
def test_extract_clean_path_from_vendors_path(self):
|
||||
"""Test extracting clean path from /vendors/ prefix."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendors/vendor1/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendors/vendor1"
|
||||
}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
assert clean_path == "/shop/products"
|
||||
|
||||
def test_extract_clean_path_root(self):
|
||||
"""Test extracting clean path when result is empty (should return /)."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/vendor/vendor1")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "path",
|
||||
"path_prefix": "/vendor/vendor1"
|
||||
}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
assert clean_path == "/"
|
||||
|
||||
def test_extract_clean_path_no_path_context(self):
|
||||
"""Test extracting clean path for non-path detection methods."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, vendor_context)
|
||||
|
||||
assert clean_path == "/shop/products"
|
||||
|
||||
def test_extract_clean_path_no_context(self):
|
||||
"""Test extracting clean path with no vendor context."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
|
||||
clean_path = VendorContextManager.extract_clean_path(request, None)
|
||||
|
||||
assert clean_path == "/shop/products"
|
||||
|
||||
# ========================================================================
|
||||
# Request Type Detection Tests
|
||||
# ========================================================================
|
||||
|
||||
def test_is_admin_request_admin_subdomain(self):
|
||||
"""Test admin request detection from subdomain."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.url = Mock(path="/dashboard")
|
||||
|
||||
assert VendorContextManager.is_admin_request(request) is True
|
||||
|
||||
def test_is_admin_request_admin_path(self):
|
||||
"""Test admin request detection from path."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/admin/dashboard")
|
||||
|
||||
assert VendorContextManager.is_admin_request(request) is True
|
||||
|
||||
def test_is_admin_request_with_port(self):
|
||||
"""Test admin request detection with port number."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "admin.localhost:8000"}
|
||||
request.url = Mock(path="/dashboard")
|
||||
|
||||
assert VendorContextManager.is_admin_request(request) is True
|
||||
|
||||
def test_is_not_admin_request(self):
|
||||
"""Test non-admin request."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "vendor1.platform.com"}
|
||||
request.url = Mock(path="/shop")
|
||||
|
||||
assert VendorContextManager.is_admin_request(request) is False
|
||||
|
||||
def test_is_api_request(self):
|
||||
"""Test API request detection."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/api/v1/vendors")
|
||||
|
||||
assert VendorContextManager.is_api_request(request) is True
|
||||
|
||||
def test_is_not_api_request(self):
|
||||
"""Test non-API request."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path="/shop/products")
|
||||
|
||||
assert VendorContextManager.is_api_request(request) is False
|
||||
|
||||
# ========================================================================
|
||||
# Static File Detection Tests
|
||||
# ========================================================================
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/static/css/style.css",
|
||||
"/static/js/app.js",
|
||||
"/media/images/product.png",
|
||||
"/assets/logo.svg",
|
||||
"/.well-known/security.txt",
|
||||
"/favicon.ico",
|
||||
"/image.jpg",
|
||||
"/style.css",
|
||||
"/app.webmanifest",
|
||||
])
|
||||
def test_is_static_file_request(self, path):
|
||||
"""Test static file detection for various paths and extensions."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path=path)
|
||||
|
||||
assert VendorContextManager.is_static_file_request(request) is True
|
||||
|
||||
@pytest.mark.parametrize("path", [
|
||||
"/shop/products",
|
||||
"/admin/dashboard",
|
||||
"/api/vendors",
|
||||
"/about",
|
||||
])
|
||||
def test_is_not_static_file_request(self, path):
|
||||
"""Test non-static file paths."""
|
||||
request = Mock(spec=Request)
|
||||
request.url = Mock(path=path)
|
||||
|
||||
assert VendorContextManager.is_static_file_request(request) is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.vendors
|
||||
class TestVendorContextMiddleware:
|
||||
"""Test suite for VendorContextMiddleware."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_skips_admin_request(self):
|
||||
"""Test middleware skips vendor detection for admin requests."""
|
||||
middleware = VendorContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "admin.platform.com"}
|
||||
request.url = Mock(path="/admin/dashboard")
|
||||
request.state = Mock()
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_admin_request', return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
assert request.state.vendor_context is None
|
||||
assert request.state.clean_path == "/admin/dashboard"
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_skips_api_request(self):
|
||||
"""Test middleware skips vendor detection for API requests."""
|
||||
middleware = VendorContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/api/v1/vendors")
|
||||
request.state = Mock()
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_api_request', return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
assert request.state.vendor_context is None
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_skips_static_file_request(self):
|
||||
"""Test middleware skips vendor detection for static files."""
|
||||
middleware = VendorContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/static/css/style.css")
|
||||
request.state = Mock()
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'is_static_file_request', return_value=True):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_detects_and_sets_vendor(self):
|
||||
"""Test middleware successfully detects and sets vendor."""
|
||||
middleware = VendorContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "vendor1.platform.com"}
|
||||
request.url = Mock(path="/shop/products")
|
||||
request.state = Mock()
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
mock_vendor = Mock()
|
||||
mock_vendor.id = 1
|
||||
mock_vendor.name = "Test Vendor"
|
||||
mock_vendor.subdomain = "vendor1"
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "vendor1"
|
||||
}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=mock_vendor), \
|
||||
patch.object(VendorContextManager, 'extract_clean_path', return_value="/shop/products"), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is mock_vendor
|
||||
assert request.state.vendor_context == vendor_context
|
||||
assert request.state.clean_path == "/shop/products"
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_vendor_not_found(self):
|
||||
"""Test middleware when vendor context detected but vendor not in database."""
|
||||
middleware = VendorContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "nonexistent.platform.com"}
|
||||
request.url = Mock(path="/shop")
|
||||
request.state = Mock()
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
vendor_context = {
|
||||
"detection_method": "subdomain",
|
||||
"subdomain": "nonexistent"
|
||||
}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=vendor_context), \
|
||||
patch.object(VendorContextManager, 'get_vendor_from_context', return_value=None), \
|
||||
patch('middleware.vendor_context.get_db', return_value=iter([mock_db])):
|
||||
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
assert request.state.vendor_context == vendor_context
|
||||
assert request.state.clean_path == "/shop"
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_no_vendor_context(self):
|
||||
"""Test middleware when no vendor context detected."""
|
||||
middleware = VendorContextMiddleware(app=None)
|
||||
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/random/path")
|
||||
request.state = Mock()
|
||||
|
||||
call_next = AsyncMock(return_value=Mock())
|
||||
|
||||
with patch.object(VendorContextManager, 'detect_vendor_context', return_value=None):
|
||||
await middleware.dispatch(request, call_next)
|
||||
|
||||
assert request.state.vendor is None
|
||||
assert request.state.vendor_context is None
|
||||
assert request.state.clean_path == "/random/path"
|
||||
call_next.assert_called_once_with(request)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.vendors
|
||||
class TestHelperFunctions:
|
||||
"""Test suite for helper functions."""
|
||||
|
||||
def test_get_current_vendor_exists(self):
|
||||
"""Test getting current vendor when it exists."""
|
||||
request = Mock(spec=Request)
|
||||
mock_vendor = Mock()
|
||||
request.state.vendor = mock_vendor
|
||||
|
||||
vendor = get_current_vendor(request)
|
||||
|
||||
assert vendor is mock_vendor
|
||||
|
||||
def test_get_current_vendor_not_exists(self):
|
||||
"""Test getting current vendor when it doesn't exist."""
|
||||
request = Mock(spec=Request)
|
||||
request.state = Mock(spec=[]) # vendor attribute doesn't exist
|
||||
|
||||
vendor = get_current_vendor(request)
|
||||
|
||||
assert vendor is None
|
||||
|
||||
def test_require_vendor_context_success(self):
|
||||
"""Test require_vendor_context dependency with vendor present."""
|
||||
request = Mock(spec=Request)
|
||||
mock_vendor = Mock()
|
||||
request.state.vendor = mock_vendor
|
||||
|
||||
dependency = require_vendor_context()
|
||||
result = dependency(request)
|
||||
|
||||
assert result is mock_vendor
|
||||
|
||||
def test_require_vendor_context_failure(self):
|
||||
"""Test require_vendor_context dependency raises HTTPException when no vendor."""
|
||||
request = Mock(spec=Request)
|
||||
request.state.vendor = None
|
||||
|
||||
dependency = require_vendor_context()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
dependency(request)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "Vendor not found" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.vendors
|
||||
class TestEdgeCases:
|
||||
"""Test suite for edge cases and error scenarios."""
|
||||
|
||||
def test_detect_vendor_context_empty_host(self):
|
||||
"""Test vendor detection with empty host header."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": ""}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is None
|
||||
|
||||
def test_detect_vendor_context_missing_host(self):
|
||||
"""Test vendor detection with missing host header."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is None
|
||||
|
||||
def test_detect_vendor_path_with_trailing_slash(self):
|
||||
"""Test path detection with trailing slash."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/vendor/vendor1/")
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "path"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
|
||||
def test_detect_vendor_path_without_trailing_slash(self):
|
||||
"""Test path detection without trailing slash."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "localhost"}
|
||||
request.url = Mock(path="/vendor/vendor1")
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
assert context["detection_method"] == "path"
|
||||
assert context["subdomain"] == "vendor1"
|
||||
|
||||
def test_detect_vendor_complex_subdomain(self):
|
||||
"""Test detection with multiple subdomain levels."""
|
||||
request = Mock(spec=Request)
|
||||
request.headers = {"host": "shop.vendor1.platform.com"}
|
||||
request.url = Mock(path="/")
|
||||
|
||||
with patch('middleware.vendor_context.settings') as mock_settings:
|
||||
mock_settings.platform_domain = "platform.com"
|
||||
|
||||
context = VendorContextManager.detect_vendor_context(request)
|
||||
|
||||
assert context is not None
|
||||
# Should detect 'shop' as subdomain since it's the first part
|
||||
assert context["detection_method"] == "subdomain"
|
||||
assert context["subdomain"] == "shop"
|
||||
|
||||
Reference in New Issue
Block a user