Files
orion/middleware/auth.py

183 lines
6.0 KiB
Python

# middleware/auth.py
"""Summary description ....
This module provides classes and functions for:
- ....
- ....
- ....
"""
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from jose import jwt
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from app.exceptions import (
AdminRequiredException,
InvalidTokenException,
TokenExpiredException,
UserNotActiveException,
InvalidCredentialsException
)
from models.database.user import User
logger = logging.getLogger(__name__)
# Password context for bcrypt hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class AuthManager:
"""JWT-based authentication manager with bcrypt password hashing."""
def __init__(self):
"""Class constructor."""
self.secret_key = os.getenv(
"JWT_SECRET_KEY", "your-secret-key-change-in-production-please"
)
self.algorithm = "HS256"
self.token_expire_minutes = int(os.getenv("JWT_EXPIRE_MINUTES", "30"))
def hash_password(self, password: str) -> str:
"""Hash password using bcrypt."""
return pwd_context.hash(password)
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify password against hash."""
return pwd_context.verify(plain_password, hashed_password)
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
"""Authenticate user and return user object if credentials are valid."""
user = (
db.query(User)
.filter((User.username == username) | (User.email == username))
.first()
)
if not user:
return None
if not self.verify_password(password, user.hashed_password):
return None
return user # Return user if credentials are valid
def create_access_token(self, user: User) -> Dict[str, Any]:
"""Create JWT access token for user."""
expires_delta = timedelta(minutes=self.token_expire_minutes)
expire = datetime.now(timezone.utc) + expires_delta
payload = {
"sub": str(user.id),
"username": user.username,
"email": user.email,
"role": user.role,
"exp": expire,
"iat": datetime.now(timezone.utc),
}
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
return {
"access_token": token,
"token_type": "bearer",
"expires_in": self.token_expire_minutes * 60, # Return in seconds
}
def verify_token(self, token: str) -> Dict[str, Any]:
"""Verify JWT token and return user data."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# Check if token has expired
exp = payload.get("exp")
if exp is None:
raise InvalidTokenException("Token missing expiration")
if datetime.now(timezone.utc) > datetime.fromtimestamp(exp, tz=timezone.utc):
raise TokenExpiredException()
# Extract user data
user_id = payload.get("sub")
if user_id is None:
raise InvalidTokenException("Token missing user identifier")
return {
"user_id": int(user_id),
"username": payload.get("username"),
"email": payload.get("email"),
"role": payload.get("role", "user"),
}
except jwt.ExpiredSignatureError:
raise TokenExpiredException()
except jwt.JWTError as e:
logger.error(f"JWT decode error: {e}")
raise InvalidTokenException("Could not validate credentials")
except Exception as e:
logger.error(f"Token verification error: {e}")
raise InvalidTokenException("Authentication failed")
def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User:
"""Get current authenticated user from database."""
user_data = self.verify_token(credentials.credentials)
user = db.query(User).filter(User.id == user_data["user_id"]).first()
if not user:
raise InvalidCredentialsException("User not found")
if not user.is_active:
raise UserNotActiveException()
return user
def require_role(self, required_role: str):
"""Require specific role."""
def decorator(func):
def wrapper(current_user: User, *args, **kwargs):
if current_user.role != required_role:
raise HTTPException(
status_code=403,
detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'",
)
return func(current_user, *args, **kwargs)
return wrapper
return decorator
def require_admin(self, current_user: User):
"""Require admin role."""
if current_user.role != "admin":
raise AdminRequiredException()
return current_user
def create_default_admin_user(self, db: Session):
"""Create default admin user if it doesn't exist."""
admin_user = db.query(User).filter(User.username == "admin").first()
if not admin_user:
hashed_password = self.hash_password("admin123")
admin_user = User(
email="admin@example.com",
username="admin",
hashed_password=hashed_password,
role="admin",
is_active=True,
)
db.add(admin_user)
db.commit()
db.refresh(admin_user)
logger.info(
"Default admin user created: username='admin', password='admin123'"
)
return admin_user