# tests/unit/middleware/test_auth.py """Unit tests for middleware/auth.py (AuthManager). 58 tests across 9 classes covering all 12 AuthManager methods: - Constructor, password hashing, authenticate_user - create_access_token, verify_token, get_current_user - RBAC (require_role, require_admin, require_store, require_customer) - create_default_admin_user """ import os from datetime import UTC, datetime, timedelta from unittest.mock import Mock, patch import pytest from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials from jose import jwt as jose_jwt from app.modules.tenancy.exceptions import ( AdminRequiredException, InsufficientPermissionsException, InvalidCredentialsException, InvalidTokenException, TokenExpiredException, UserNotActiveException, ) from app.modules.tenancy.models import User from middleware.auth import AuthManager # ─── Phase 1: Constructor ─────────────────────────────────────────── @pytest.mark.unit class TestAuthManagerInit: """Test AuthManager.__init__ reads env vars and sets defaults.""" def test_default_configuration(self): """Verify defaults: fallback secret key, HS256, 30 min expiry.""" with patch.dict(os.environ, {}, clear=True): mgr = AuthManager() assert mgr.secret_key == "your-secret-key-change-in-production-please" assert mgr.algorithm == "HS256" assert mgr.token_expire_minutes == 30 def test_custom_secret_key_from_env(self): """JWT_SECRET_KEY env var overrides the default secret.""" with patch.dict(os.environ, {"JWT_SECRET_KEY": "my-custom-secret"}): mgr = AuthManager() assert mgr.secret_key == "my-custom-secret" def test_custom_expire_minutes_from_env(self): """JWT_EXPIRE_MINUTES env var overrides the default 30.""" with patch.dict(os.environ, {"JWT_EXPIRE_MINUTES": "60"}): mgr = AuthManager() assert mgr.token_expire_minutes == 60 # ─── Phase 2: Password Hashing ────────────────────────────────────── @pytest.mark.unit class TestPasswordHashing: """Test hash_password and verify_password (bcrypt).""" def test_hash_returns_bcrypt_format(self, auth_manager): h = auth_manager.hash_password("testpass") assert h.startswith("$2b$") def test_hash_different_salts(self, auth_manager): """Same plaintext produces different hashes (random salts).""" h1 = auth_manager.hash_password("same") h2 = auth_manager.hash_password("same") assert h1 != h2 def test_verify_correct_password(self, auth_manager): h = auth_manager.hash_password("correct") assert auth_manager.verify_password("correct", h) is True def test_verify_wrong_password(self, auth_manager): h = auth_manager.hash_password("correct") assert auth_manager.verify_password("wrong", h) is False def test_verify_empty_password(self, auth_manager): h = auth_manager.hash_password("notempty") assert auth_manager.verify_password("", h) is False # ─── Phase 3: authenticate_user ───────────────────────────────────── @pytest.mark.unit class TestAuthenticateUser: """Test authenticate_user with real DB fixtures.""" def test_success_by_username(self, db, auth_manager, test_user): result = auth_manager.authenticate_user(db, test_user.username, "testpass123") assert result is not None assert result.id == test_user.id def test_success_by_email(self, db, auth_manager, test_user): result = auth_manager.authenticate_user(db, test_user.email, "testpass123") assert result is not None assert result.id == test_user.id def test_user_not_found(self, db, auth_manager): result = auth_manager.authenticate_user(db, "nonexistent_user", "pass") assert result is None def test_wrong_password(self, db, auth_manager, test_user): result = auth_manager.authenticate_user(db, test_user.username, "wrongpass") assert result is None def test_empty_credentials(self, db, auth_manager): result = auth_manager.authenticate_user(db, "", "") assert result is None # ─── Phase 4: create_access_token ──────────────────────────────────── @pytest.mark.unit class TestCreateAccessToken: """Test create_access_token with real User fixtures.""" def test_basic_user_claims(self, auth_manager, test_user): result = auth_manager.create_access_token(test_user) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["sub"] == str(test_user.id) assert payload["username"] == test_user.username assert payload["email"] == test_user.email assert payload["role"] == test_user.role def test_return_structure(self, auth_manager, test_user): result = auth_manager.create_access_token(test_user) assert "access_token" in result assert result["token_type"] == "bearer" assert result["expires_in"] == auth_manager.token_expire_minutes * 60 def test_super_admin_claims(self, auth_manager, test_super_admin): result = auth_manager.create_access_token(test_super_admin) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["role"] == "super_admin" assert "accessible_platforms" not in payload def test_platform_admin_with_platforms(self, db, auth_manager, test_platform_admin): with patch.object( test_platform_admin, "get_accessible_platform_ids", return_value=[1, 2, 3] ): result = auth_manager.create_access_token(test_platform_admin) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["role"] == "platform_admin" assert payload["accessible_platforms"] == [1, 2, 3] def test_platform_admin_without_platforms(self, db, auth_manager, test_platform_admin): with patch.object( test_platform_admin, "get_accessible_platform_ids", return_value=None ): result = auth_manager.create_access_token(test_platform_admin) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["role"] == "platform_admin" assert "accessible_platforms" not in payload def test_store_context(self, auth_manager, test_user): result = auth_manager.create_access_token( test_user, store_id=5, store_code="mystore", store_role="owner" ) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["store_id"] == 5 assert payload["store_code"] == "mystore" assert payload["store_role"] == "owner" def test_platform_context(self, auth_manager, test_admin): result = auth_manager.create_access_token( test_admin, platform_id=10, platform_code="platX" ) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["platform_id"] == 10 assert payload["platform_code"] == "platX" def test_all_contexts_combined(self, auth_manager, test_admin): result = auth_manager.create_access_token( test_admin, store_id=5, store_code="store1", store_role="manager", platform_id=10, platform_code="plat1", ) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["store_id"] == 5 assert payload["platform_id"] == 10 def test_expiration_matches_config(self, auth_manager, test_user): before = datetime.now(UTC) result = auth_manager.create_access_token(test_user) after = datetime.now(UTC) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) exp = datetime.fromtimestamp(payload["exp"], tz=UTC) # JWT exp is stored as integer seconds, so allow 1s tolerance expected_min = before + timedelta(minutes=auth_manager.token_expire_minutes) - timedelta(seconds=1) expected_max = after + timedelta(minutes=auth_manager.token_expire_minutes) + timedelta(seconds=1) assert expected_min <= exp <= expected_max def test_non_admin_has_no_admin_claims(self, auth_manager, test_user): result = auth_manager.create_access_token(test_user) payload = jose_jwt.decode( result["access_token"], auth_manager.secret_key, algorithms=["HS256"] ) assert payload["role"] not in ("super_admin", "platform_admin") assert "accessible_platforms" not in payload # ─── Phase 5: verify_token ─────────────────────────────────────────── @pytest.mark.unit class TestVerifyToken: """Tests that craft tokens directly with jose_jwt.encode().""" def _encode(self, payload, secret): """Encode a JWT with the given payload and secret.""" return jose_jwt.encode(payload, secret, algorithm="HS256") def _base_payload(self, **overrides): """Produce a minimal valid payload.""" base = { "sub": "42", "username": "alice", "email": "alice@example.com", "role": "user", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC), } base.update(overrides) return base def test_valid_basic_token(self, auth_manager): token = self._encode(self._base_payload(), auth_manager.secret_key) data = auth_manager.verify_token(token) assert data["user_id"] == 42 assert data["username"] == "alice" assert data["email"] == "alice@example.com" assert data["role"] == "user" def test_valid_store_token(self, auth_manager): token = self._encode( self._base_payload(store_id=7, store_code="shop1", store_role="owner"), auth_manager.secret_key, ) data = auth_manager.verify_token(token) assert data["store_id"] == 7 assert data["store_code"] == "shop1" assert data["store_role"] == "owner" def test_valid_admin_token(self, auth_manager): token = self._encode( self._base_payload(role="super_admin"), auth_manager.secret_key ) data = auth_manager.verify_token(token) assert data["role"] == "super_admin" def test_valid_platform_token(self, auth_manager): token = self._encode( self._base_payload(platform_id=3, platform_code="eu"), auth_manager.secret_key, ) data = auth_manager.verify_token(token) assert data["platform_id"] == 3 assert data["platform_code"] == "eu" def test_valid_all_claims(self, auth_manager): token = self._encode( self._base_payload( role="platform_admin", accessible_platforms=[1, 2], platform_id=1, platform_code="us", store_id=5, store_code="s1", store_role="manager", ), auth_manager.secret_key, ) data = auth_manager.verify_token(token) assert data["accessible_platforms"] == [1, 2] assert data["platform_id"] == 1 assert data["store_id"] == 5 def test_default_role_when_missing(self, auth_manager): payload = self._base_payload() del payload["role"] token = self._encode(payload, auth_manager.secret_key) data = auth_manager.verify_token(token) assert data["role"] == "user" def test_expired_token(self, auth_manager): token = self._encode( self._base_payload(exp=datetime.now(UTC) - timedelta(hours=1)), auth_manager.secret_key, ) with pytest.raises(TokenExpiredException): auth_manager.verify_token(token) def test_missing_sub(self, auth_manager): payload = self._base_payload() del payload["sub"] token = self._encode(payload, auth_manager.secret_key) with pytest.raises(InvalidTokenException): auth_manager.verify_token(token) def test_missing_exp(self, auth_manager): payload = self._base_payload() del payload["exp"] token = jose_jwt.encode(payload, auth_manager.secret_key, algorithm="HS256") with pytest.raises(InvalidTokenException): auth_manager.verify_token(token) def test_wrong_secret_key(self, auth_manager): token = self._encode(self._base_payload(), "wrong-secret-key") with pytest.raises(InvalidTokenException): auth_manager.verify_token(token) def test_malformed_token(self, auth_manager): with pytest.raises(InvalidTokenException): auth_manager.verify_token("not.a.valid.jwt") # ─── Phase 6: get_current_user ─────────────────────────────────────── @pytest.mark.unit class TestGetCurrentUser: """Round-trip tests using create_access_token → get_current_user.""" def _make_credentials(self, token: str) -> HTTPAuthorizationCredentials: return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) def test_success(self, db, auth_manager, test_user): token_data = auth_manager.create_access_token(test_user) creds = self._make_credentials(token_data["access_token"]) user = auth_manager.get_current_user(db, creds) assert user.id == test_user.id def test_inactive_user(self, db, auth_manager, test_user): token_data = auth_manager.create_access_token(test_user) test_user.is_active = False db.commit() creds = self._make_credentials(token_data["access_token"]) with pytest.raises(UserNotActiveException): auth_manager.get_current_user(db, creds) def test_nonexistent_user(self, db, auth_manager, test_user): token_data = auth_manager.create_access_token(test_user) db.delete(test_user) db.commit() creds = self._make_credentials(token_data["access_token"]) with pytest.raises(InvalidCredentialsException): auth_manager.get_current_user(db, creds) def test_invalid_token(self, db, auth_manager): creds = self._make_credentials("invalid.token.here") with pytest.raises(InvalidTokenException): auth_manager.get_current_user(db, creds) def test_attaches_admin_attrs(self, db, auth_manager, test_super_admin): token_data = auth_manager.create_access_token(test_super_admin) creds = self._make_credentials(token_data["access_token"]) user = auth_manager.get_current_user(db, creds) assert user.is_super_admin is True def test_attaches_platform_attrs(self, db, auth_manager, test_admin): token_data = auth_manager.create_access_token( test_admin, platform_id=10, platform_code="eu" ) creds = self._make_credentials(token_data["access_token"]) user = auth_manager.get_current_user(db, creds) assert user.token_platform_id == 10 assert user.token_platform_code == "eu" def test_attaches_store_attrs(self, db, auth_manager, test_store_user): token_data = auth_manager.create_access_token( test_store_user, store_id=5, store_code="shop1", store_role="owner" ) creds = self._make_credentials(token_data["access_token"]) user = auth_manager.get_current_user(db, creds) assert user.token_store_id == 5 assert user.token_store_code == "shop1" assert user.token_store_role == "owner" def test_no_optional_attrs_for_basic_user(self, db, auth_manager, test_user): token_data = auth_manager.create_access_token(test_user) creds = self._make_credentials(token_data["access_token"]) user = auth_manager.get_current_user(db, creds) assert user.is_super_admin is False assert not hasattr(user, "token_platform_id") assert not hasattr(user, "token_store_id") # ─── Phase 7: RBAC ────────────────────────────────────────────────── @pytest.mark.unit class TestRequireRole: """Test the require_role decorator factory.""" def test_matching_role(self, auth_manager): user = Mock(spec=User) user.role = "admin" @auth_manager.require_role("admin") def protected(current_user): return "ok" assert protected(user) == "ok" def test_non_matching_role(self, auth_manager): user = Mock(spec=User) user.role = "user" @auth_manager.require_role("admin") def protected(current_user): return "ok" with pytest.raises(HTTPException) as exc_info: protected(user) assert exc_info.value.status_code == 403 def test_args_pass_through(self, auth_manager): user = Mock(spec=User) user.role = "admin" @auth_manager.require_role("admin") def protected(current_user, x, y=10): return x + y assert protected(user, 5, y=20) == 25 def test_error_message(self, auth_manager): user = Mock(spec=User) user.role = "customer" @auth_manager.require_role("store") def protected(current_user): pass with pytest.raises(HTTPException) as exc_info: protected(user) assert "store" in exc_info.value.detail assert "customer" in exc_info.value.detail @pytest.mark.unit class TestRequireAdmin: """Test require_admin method (accepts super_admin and platform_admin).""" def test_super_admin_accepted(self, auth_manager): user = Mock(spec=User) user.is_admin = True result = auth_manager.require_admin(user) assert result is user def test_platform_admin_accepted(self, auth_manager): user = Mock(spec=User) user.is_admin = True result = auth_manager.require_admin(user) assert result is user def test_non_admin_rejected(self, auth_manager): user = Mock(spec=User) user.is_admin = False with pytest.raises(AdminRequiredException): auth_manager.require_admin(user) @pytest.mark.unit class TestRequireStore: """Test require_store method (accepts merchant_owner and store_member).""" def test_merchant_owner_accepted(self, auth_manager): user = Mock(spec=User) user.is_store_user = True assert auth_manager.require_store(user) is user def test_store_member_accepted(self, auth_manager): user = Mock(spec=User) user.is_store_user = True assert auth_manager.require_store(user) is user def test_admin_rejected(self, auth_manager): user = Mock(spec=User) user.is_store_user = False with pytest.raises(InsufficientPermissionsException): auth_manager.require_store(user) def test_customer_rejected(self, auth_manager): user = Mock(spec=User) user.is_store_user = False with pytest.raises(InsufficientPermissionsException): auth_manager.require_store(user) @pytest.mark.unit class TestRequireCustomer: """Test require_customer method (accepts customer and admin roles).""" def test_customer_accepted(self, auth_manager): user = Mock(spec=User) user.role = "customer" assert auth_manager.require_customer(user) is user def test_admin_accepted(self, auth_manager): user = Mock(spec=User) user.role = "admin" assert auth_manager.require_customer(user) is user def test_store_rejected(self, auth_manager): user = Mock(spec=User) user.role = "merchant_owner" with pytest.raises(InsufficientPermissionsException): auth_manager.require_customer(user) # ─── Phase 8: create_default_admin_user ────────────────────────────── @pytest.mark.unit class TestCreateDefaultAdminUser: """Test create_default_admin_user with real DB.""" def test_creates_admin_when_none_exists(self, db, auth_manager): user = auth_manager.create_default_admin_user(db) assert user.username == "admin" assert user.role == "super_admin" assert user.is_super_admin is True assert user.is_active is True def test_skips_when_admin_exists(self, db, auth_manager): first = auth_manager.create_default_admin_user(db) second = auth_manager.create_default_admin_user(db) assert first.id == second.id def test_password_is_verifiable(self, db, auth_manager): user = auth_manager.create_default_admin_user(db) assert auth_manager.verify_password("admin123", user.hashed_password) is True def test_user_persisted_to_db(self, db, auth_manager): auth_manager.create_default_admin_user(db) found = db.query(User).filter(User.username == "admin").first() assert found is not None assert found.email == "admin@example.com"