Cover all 12 methods: constructor, password hashing, authenticate_user, create_access_token, verify_token, get_current_user, RBAC decorators, and create_default_admin_user. Achieves 96.45% coverage on auth.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
556 lines
21 KiB
Python
556 lines
21 KiB
Python
# tests/unit/middleware/test_auth.py
|
|
"""Unit tests for middleware/auth.py (AuthManager).
|
|
|
|
58 tests across 9 classes covering all 12 AuthManager methods:
|
|
- Constructor, password hashing, authenticate_user
|
|
- create_access_token, verify_token, get_current_user
|
|
- RBAC (require_role, require_admin, require_store, require_customer)
|
|
- create_default_admin_user
|
|
"""
|
|
|
|
import os
|
|
from datetime import UTC, datetime, timedelta
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from fastapi.security import HTTPAuthorizationCredentials
|
|
from jose import jwt as jose_jwt
|
|
|
|
from app.modules.tenancy.exceptions import (
|
|
AdminRequiredException,
|
|
InsufficientPermissionsException,
|
|
InvalidCredentialsException,
|
|
InvalidTokenException,
|
|
TokenExpiredException,
|
|
UserNotActiveException,
|
|
)
|
|
from app.modules.tenancy.models import User
|
|
from middleware.auth import AuthManager
|
|
|
|
# ─── Phase 1: Constructor ───────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestAuthManagerInit:
|
|
"""Test AuthManager.__init__ reads env vars and sets defaults."""
|
|
|
|
def test_default_configuration(self):
|
|
"""Verify defaults: fallback secret key, HS256, 30 min expiry."""
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
mgr = AuthManager()
|
|
assert mgr.secret_key == "your-secret-key-change-in-production-please"
|
|
assert mgr.algorithm == "HS256"
|
|
assert mgr.token_expire_minutes == 30
|
|
|
|
def test_custom_secret_key_from_env(self):
|
|
"""JWT_SECRET_KEY env var overrides the default secret."""
|
|
with patch.dict(os.environ, {"JWT_SECRET_KEY": "my-custom-secret"}):
|
|
mgr = AuthManager()
|
|
assert mgr.secret_key == "my-custom-secret"
|
|
|
|
def test_custom_expire_minutes_from_env(self):
|
|
"""JWT_EXPIRE_MINUTES env var overrides the default 30."""
|
|
with patch.dict(os.environ, {"JWT_EXPIRE_MINUTES": "60"}):
|
|
mgr = AuthManager()
|
|
assert mgr.token_expire_minutes == 60
|
|
|
|
|
|
# ─── Phase 2: Password Hashing ──────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestPasswordHashing:
|
|
"""Test hash_password and verify_password (bcrypt)."""
|
|
|
|
def test_hash_returns_bcrypt_format(self, auth_manager):
|
|
h = auth_manager.hash_password("testpass")
|
|
assert h.startswith("$2b$")
|
|
|
|
def test_hash_different_salts(self, auth_manager):
|
|
"""Same plaintext produces different hashes (random salts)."""
|
|
h1 = auth_manager.hash_password("same")
|
|
h2 = auth_manager.hash_password("same")
|
|
assert h1 != h2
|
|
|
|
def test_verify_correct_password(self, auth_manager):
|
|
h = auth_manager.hash_password("correct")
|
|
assert auth_manager.verify_password("correct", h) is True
|
|
|
|
def test_verify_wrong_password(self, auth_manager):
|
|
h = auth_manager.hash_password("correct")
|
|
assert auth_manager.verify_password("wrong", h) is False
|
|
|
|
def test_verify_empty_password(self, auth_manager):
|
|
h = auth_manager.hash_password("notempty")
|
|
assert auth_manager.verify_password("", h) is False
|
|
|
|
|
|
# ─── Phase 3: authenticate_user ─────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestAuthenticateUser:
|
|
"""Test authenticate_user with real DB fixtures."""
|
|
|
|
def test_success_by_username(self, db, auth_manager, test_user):
|
|
result = auth_manager.authenticate_user(db, test_user.username, "testpass123")
|
|
assert result is not None
|
|
assert result.id == test_user.id
|
|
|
|
def test_success_by_email(self, db, auth_manager, test_user):
|
|
result = auth_manager.authenticate_user(db, test_user.email, "testpass123")
|
|
assert result is not None
|
|
assert result.id == test_user.id
|
|
|
|
def test_user_not_found(self, db, auth_manager):
|
|
result = auth_manager.authenticate_user(db, "nonexistent_user", "pass")
|
|
assert result is None
|
|
|
|
def test_wrong_password(self, db, auth_manager, test_user):
|
|
result = auth_manager.authenticate_user(db, test_user.username, "wrongpass")
|
|
assert result is None
|
|
|
|
def test_empty_credentials(self, db, auth_manager):
|
|
result = auth_manager.authenticate_user(db, "", "")
|
|
assert result is None
|
|
|
|
|
|
# ─── Phase 4: create_access_token ────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestCreateAccessToken:
|
|
"""Test create_access_token with real User fixtures."""
|
|
|
|
def test_basic_user_claims(self, auth_manager, test_user):
|
|
result = auth_manager.create_access_token(test_user)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert payload["sub"] == str(test_user.id)
|
|
assert payload["username"] == test_user.username
|
|
assert payload["email"] == test_user.email
|
|
assert payload["role"] == test_user.role
|
|
|
|
def test_return_structure(self, auth_manager, test_user):
|
|
result = auth_manager.create_access_token(test_user)
|
|
assert "access_token" in result
|
|
assert result["token_type"] == "bearer"
|
|
assert result["expires_in"] == auth_manager.token_expire_minutes * 60
|
|
|
|
def test_super_admin_claims(self, auth_manager, test_super_admin):
|
|
result = auth_manager.create_access_token(test_super_admin)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert payload["is_super_admin"] is True
|
|
assert "accessible_platforms" not in payload
|
|
|
|
def test_platform_admin_with_platforms(self, db, auth_manager, test_platform_admin):
|
|
with patch.object(
|
|
test_platform_admin, "get_accessible_platform_ids", return_value=[1, 2, 3]
|
|
):
|
|
result = auth_manager.create_access_token(test_platform_admin)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert payload["is_super_admin"] is False
|
|
assert payload["accessible_platforms"] == [1, 2, 3]
|
|
|
|
def test_platform_admin_without_platforms(self, db, auth_manager, test_platform_admin):
|
|
with patch.object(
|
|
test_platform_admin, "get_accessible_platform_ids", return_value=None
|
|
):
|
|
result = auth_manager.create_access_token(test_platform_admin)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert payload["is_super_admin"] is False
|
|
assert "accessible_platforms" not in payload
|
|
|
|
def test_store_context(self, auth_manager, test_user):
|
|
result = auth_manager.create_access_token(
|
|
test_user, store_id=5, store_code="mystore", store_role="owner"
|
|
)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert payload["store_id"] == 5
|
|
assert payload["store_code"] == "mystore"
|
|
assert payload["store_role"] == "owner"
|
|
|
|
def test_platform_context(self, auth_manager, test_admin):
|
|
result = auth_manager.create_access_token(
|
|
test_admin, platform_id=10, platform_code="platX"
|
|
)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert payload["platform_id"] == 10
|
|
assert payload["platform_code"] == "platX"
|
|
|
|
def test_all_contexts_combined(self, auth_manager, test_admin):
|
|
result = auth_manager.create_access_token(
|
|
test_admin,
|
|
store_id=5,
|
|
store_code="store1",
|
|
store_role="manager",
|
|
platform_id=10,
|
|
platform_code="plat1",
|
|
)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert payload["store_id"] == 5
|
|
assert payload["platform_id"] == 10
|
|
|
|
def test_expiration_matches_config(self, auth_manager, test_user):
|
|
before = datetime.now(UTC)
|
|
result = auth_manager.create_access_token(test_user)
|
|
after = datetime.now(UTC)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
|
# JWT exp is stored as integer seconds, so allow 1s tolerance
|
|
expected_min = before + timedelta(minutes=auth_manager.token_expire_minutes) - timedelta(seconds=1)
|
|
expected_max = after + timedelta(minutes=auth_manager.token_expire_minutes) + timedelta(seconds=1)
|
|
assert expected_min <= exp <= expected_max
|
|
|
|
def test_non_admin_has_no_admin_claims(self, auth_manager, test_user):
|
|
result = auth_manager.create_access_token(test_user)
|
|
payload = jose_jwt.decode(
|
|
result["access_token"], auth_manager.secret_key, algorithms=["HS256"]
|
|
)
|
|
assert "is_super_admin" not in payload
|
|
assert "accessible_platforms" not in payload
|
|
|
|
|
|
# ─── Phase 5: verify_token ───────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestVerifyToken:
|
|
"""Tests that craft tokens directly with jose_jwt.encode()."""
|
|
|
|
def _encode(self, payload, secret):
|
|
"""Encode a JWT with the given payload and secret."""
|
|
return jose_jwt.encode(payload, secret, algorithm="HS256")
|
|
|
|
def _base_payload(self, **overrides):
|
|
"""Produce a minimal valid payload."""
|
|
base = {
|
|
"sub": "42",
|
|
"username": "alice",
|
|
"email": "alice@example.com",
|
|
"role": "user",
|
|
"exp": datetime.now(UTC) + timedelta(hours=1),
|
|
"iat": datetime.now(UTC),
|
|
}
|
|
base.update(overrides)
|
|
return base
|
|
|
|
def test_valid_basic_token(self, auth_manager):
|
|
token = self._encode(self._base_payload(), auth_manager.secret_key)
|
|
data = auth_manager.verify_token(token)
|
|
assert data["user_id"] == 42
|
|
assert data["username"] == "alice"
|
|
assert data["email"] == "alice@example.com"
|
|
assert data["role"] == "user"
|
|
|
|
def test_valid_store_token(self, auth_manager):
|
|
token = self._encode(
|
|
self._base_payload(store_id=7, store_code="shop1", store_role="owner"),
|
|
auth_manager.secret_key,
|
|
)
|
|
data = auth_manager.verify_token(token)
|
|
assert data["store_id"] == 7
|
|
assert data["store_code"] == "shop1"
|
|
assert data["store_role"] == "owner"
|
|
|
|
def test_valid_admin_token(self, auth_manager):
|
|
token = self._encode(
|
|
self._base_payload(is_super_admin=True), auth_manager.secret_key
|
|
)
|
|
data = auth_manager.verify_token(token)
|
|
assert data["is_super_admin"] is True
|
|
|
|
def test_valid_platform_token(self, auth_manager):
|
|
token = self._encode(
|
|
self._base_payload(platform_id=3, platform_code="eu"),
|
|
auth_manager.secret_key,
|
|
)
|
|
data = auth_manager.verify_token(token)
|
|
assert data["platform_id"] == 3
|
|
assert data["platform_code"] == "eu"
|
|
|
|
def test_valid_all_claims(self, auth_manager):
|
|
token = self._encode(
|
|
self._base_payload(
|
|
is_super_admin=False,
|
|
accessible_platforms=[1, 2],
|
|
platform_id=1,
|
|
platform_code="us",
|
|
store_id=5,
|
|
store_code="s1",
|
|
store_role="manager",
|
|
),
|
|
auth_manager.secret_key,
|
|
)
|
|
data = auth_manager.verify_token(token)
|
|
assert data["accessible_platforms"] == [1, 2]
|
|
assert data["platform_id"] == 1
|
|
assert data["store_id"] == 5
|
|
|
|
def test_default_role_when_missing(self, auth_manager):
|
|
payload = self._base_payload()
|
|
del payload["role"]
|
|
token = self._encode(payload, auth_manager.secret_key)
|
|
data = auth_manager.verify_token(token)
|
|
assert data["role"] == "user"
|
|
|
|
def test_expired_token(self, auth_manager):
|
|
token = self._encode(
|
|
self._base_payload(exp=datetime.now(UTC) - timedelta(hours=1)),
|
|
auth_manager.secret_key,
|
|
)
|
|
with pytest.raises(TokenExpiredException):
|
|
auth_manager.verify_token(token)
|
|
|
|
def test_missing_sub(self, auth_manager):
|
|
payload = self._base_payload()
|
|
del payload["sub"]
|
|
token = self._encode(payload, auth_manager.secret_key)
|
|
with pytest.raises(InvalidTokenException):
|
|
auth_manager.verify_token(token)
|
|
|
|
def test_missing_exp(self, auth_manager):
|
|
payload = self._base_payload()
|
|
del payload["exp"]
|
|
token = jose_jwt.encode(payload, auth_manager.secret_key, algorithm="HS256")
|
|
with pytest.raises(InvalidTokenException):
|
|
auth_manager.verify_token(token)
|
|
|
|
def test_wrong_secret_key(self, auth_manager):
|
|
token = self._encode(self._base_payload(), "wrong-secret-key")
|
|
with pytest.raises(InvalidTokenException):
|
|
auth_manager.verify_token(token)
|
|
|
|
def test_malformed_token(self, auth_manager):
|
|
with pytest.raises(InvalidTokenException):
|
|
auth_manager.verify_token("not.a.valid.jwt")
|
|
|
|
|
|
# ─── Phase 6: get_current_user ───────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestGetCurrentUser:
|
|
"""Round-trip tests using create_access_token → get_current_user."""
|
|
|
|
def _make_credentials(self, token: str) -> HTTPAuthorizationCredentials:
|
|
return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token)
|
|
|
|
def test_success(self, db, auth_manager, test_user):
|
|
token_data = auth_manager.create_access_token(test_user)
|
|
creds = self._make_credentials(token_data["access_token"])
|
|
user = auth_manager.get_current_user(db, creds)
|
|
assert user.id == test_user.id
|
|
|
|
def test_inactive_user(self, db, auth_manager, test_user):
|
|
token_data = auth_manager.create_access_token(test_user)
|
|
test_user.is_active = False
|
|
db.commit()
|
|
creds = self._make_credentials(token_data["access_token"])
|
|
with pytest.raises(UserNotActiveException):
|
|
auth_manager.get_current_user(db, creds)
|
|
|
|
def test_nonexistent_user(self, db, auth_manager, test_user):
|
|
token_data = auth_manager.create_access_token(test_user)
|
|
db.delete(test_user)
|
|
db.commit()
|
|
creds = self._make_credentials(token_data["access_token"])
|
|
with pytest.raises(InvalidCredentialsException):
|
|
auth_manager.get_current_user(db, creds)
|
|
|
|
def test_invalid_token(self, db, auth_manager):
|
|
creds = self._make_credentials("invalid.token.here")
|
|
with pytest.raises(InvalidTokenException):
|
|
auth_manager.get_current_user(db, creds)
|
|
|
|
def test_attaches_admin_attrs(self, db, auth_manager, test_super_admin):
|
|
token_data = auth_manager.create_access_token(test_super_admin)
|
|
creds = self._make_credentials(token_data["access_token"])
|
|
user = auth_manager.get_current_user(db, creds)
|
|
assert user.token_is_super_admin is True
|
|
|
|
def test_attaches_platform_attrs(self, db, auth_manager, test_admin):
|
|
token_data = auth_manager.create_access_token(
|
|
test_admin, platform_id=10, platform_code="eu"
|
|
)
|
|
creds = self._make_credentials(token_data["access_token"])
|
|
user = auth_manager.get_current_user(db, creds)
|
|
assert user.token_platform_id == 10
|
|
assert user.token_platform_code == "eu"
|
|
|
|
def test_attaches_store_attrs(self, db, auth_manager, test_store_user):
|
|
token_data = auth_manager.create_access_token(
|
|
test_store_user, store_id=5, store_code="shop1", store_role="owner"
|
|
)
|
|
creds = self._make_credentials(token_data["access_token"])
|
|
user = auth_manager.get_current_user(db, creds)
|
|
assert user.token_store_id == 5
|
|
assert user.token_store_code == "shop1"
|
|
assert user.token_store_role == "owner"
|
|
|
|
def test_no_optional_attrs_for_basic_user(self, db, auth_manager, test_user):
|
|
token_data = auth_manager.create_access_token(test_user)
|
|
creds = self._make_credentials(token_data["access_token"])
|
|
user = auth_manager.get_current_user(db, creds)
|
|
assert not hasattr(user, "token_is_super_admin")
|
|
assert not hasattr(user, "token_platform_id")
|
|
assert not hasattr(user, "token_store_id")
|
|
|
|
|
|
# ─── Phase 7: RBAC ──────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestRequireRole:
|
|
"""Test the require_role decorator factory."""
|
|
|
|
def test_matching_role(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "admin"
|
|
|
|
@auth_manager.require_role("admin")
|
|
def protected(current_user):
|
|
return "ok"
|
|
|
|
assert protected(user) == "ok"
|
|
|
|
def test_non_matching_role(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "user"
|
|
|
|
@auth_manager.require_role("admin")
|
|
def protected(current_user):
|
|
return "ok"
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
protected(user)
|
|
assert exc_info.value.status_code == 403
|
|
|
|
def test_args_pass_through(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "admin"
|
|
|
|
@auth_manager.require_role("admin")
|
|
def protected(current_user, x, y=10):
|
|
return x + y
|
|
|
|
assert protected(user, 5, y=20) == 25
|
|
|
|
def test_error_message(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "customer"
|
|
|
|
@auth_manager.require_role("store")
|
|
def protected(current_user):
|
|
pass
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
protected(user)
|
|
assert "store" in exc_info.value.detail
|
|
assert "customer" in exc_info.value.detail
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestRequireAdmin:
|
|
"""Test require_admin method."""
|
|
|
|
def test_admin_accepted(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "admin"
|
|
result = auth_manager.require_admin(user)
|
|
assert result is user
|
|
|
|
def test_non_admin_rejected(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "user"
|
|
with pytest.raises(AdminRequiredException):
|
|
auth_manager.require_admin(user)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestRequireStore:
|
|
"""Test require_store method (accepts store and admin roles)."""
|
|
|
|
def test_store_accepted(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "store"
|
|
assert auth_manager.require_store(user) is user
|
|
|
|
def test_admin_accepted(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "admin"
|
|
assert auth_manager.require_store(user) is user
|
|
|
|
def test_customer_rejected(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "customer"
|
|
with pytest.raises(InsufficientPermissionsException):
|
|
auth_manager.require_store(user)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestRequireCustomer:
|
|
"""Test require_customer method (accepts customer and admin roles)."""
|
|
|
|
def test_customer_accepted(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "customer"
|
|
assert auth_manager.require_customer(user) is user
|
|
|
|
def test_admin_accepted(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "admin"
|
|
assert auth_manager.require_customer(user) is user
|
|
|
|
def test_store_rejected(self, auth_manager):
|
|
user = Mock(spec=User)
|
|
user.role = "store"
|
|
with pytest.raises(InsufficientPermissionsException):
|
|
auth_manager.require_customer(user)
|
|
|
|
|
|
# ─── Phase 8: create_default_admin_user ──────────────────────────────
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestCreateDefaultAdminUser:
|
|
"""Test create_default_admin_user with real DB."""
|
|
|
|
def test_creates_admin_when_none_exists(self, db, auth_manager):
|
|
user = auth_manager.create_default_admin_user(db)
|
|
assert user.username == "admin"
|
|
assert user.role == "admin"
|
|
assert user.is_super_admin is True
|
|
assert user.is_active is True
|
|
|
|
def test_skips_when_admin_exists(self, db, auth_manager):
|
|
first = auth_manager.create_default_admin_user(db)
|
|
second = auth_manager.create_default_admin_user(db)
|
|
assert first.id == second.id
|
|
|
|
def test_password_is_verifiable(self, db, auth_manager):
|
|
user = auth_manager.create_default_admin_user(db)
|
|
assert auth_manager.verify_password("admin123", user.hashed_password) is True
|
|
|
|
def test_user_persisted_to_db(self, db, auth_manager):
|
|
auth_manager.create_default_admin_user(db)
|
|
found = db.query(User).filter(User.username == "admin").first()
|
|
assert found is not None
|
|
assert found.email == "admin@example.com"
|