Exception handling enhancement
This commit is contained in:
@@ -18,6 +18,13 @@ from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.exceptions import (
|
||||
AdminRequiredException,
|
||||
InvalidTokenException,
|
||||
TokenExpiredException,
|
||||
UserNotActiveException,
|
||||
InvalidCredentialsException
|
||||
)
|
||||
from models.database.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,7 +53,7 @@ class AuthManager:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def authenticate_user(
|
||||
self, db: Session, username: str, password: str
|
||||
self, db: Session, username: str, password: str
|
||||
) -> Optional[User]:
|
||||
"""Authenticate user and return user object if valid."""
|
||||
user = (
|
||||
@@ -101,17 +108,15 @@ class AuthManager:
|
||||
# Check if token has expired
|
||||
exp = payload.get("exp")
|
||||
if exp is None:
|
||||
raise HTTPException(status_code=401, detail="Token missing expiration")
|
||||
raise InvalidTokenException("Token missing expiration")
|
||||
|
||||
if datetime.utcnow() > datetime.fromtimestamp(exp):
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
raise TokenExpiredException()
|
||||
|
||||
# Extract user data
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Token missing user identifier"
|
||||
)
|
||||
raise InvalidTokenException("Token missing user identifier")
|
||||
|
||||
return {
|
||||
"user_id": int(user_id),
|
||||
@@ -121,28 +126,24 @@ class AuthManager:
|
||||
}
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
raise TokenExpiredException()
|
||||
except jwt.JWTError as e:
|
||||
logger.error(f"JWT decode error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Could not validate credentials"
|
||||
)
|
||||
raise InvalidTokenException("Could not validate credentials")
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
raise InvalidTokenException("Authentication failed")
|
||||
|
||||
def get_current_user(
|
||||
self, db: Session, credentials: HTTPAuthorizationCredentials
|
||||
) -> User:
|
||||
def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User:
|
||||
"""Get current authenticated user from database."""
|
||||
user_data = self.verify_token(credentials.credentials)
|
||||
|
||||
user = db.query(User).filter(User.id == user_data["user_id"]).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
raise InvalidCredentialsException("User not found")
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="User account is inactive")
|
||||
raise UserNotActiveException()
|
||||
|
||||
return user
|
||||
|
||||
@@ -165,7 +166,7 @@ class AuthManager:
|
||||
def require_admin(self, current_user: User):
|
||||
"""Require admin role."""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin privileges required")
|
||||
raise AdminRequiredException()
|
||||
return current_user
|
||||
|
||||
def create_default_admin_user(self, db: Session):
|
||||
|
||||
@@ -1,36 +1,35 @@
|
||||
# middleware/decorators.py
|
||||
"""Summary description ....
|
||||
"""
|
||||
FastAPI decorators for cross-cutting concerns.
|
||||
|
||||
This module provides classes and functions for:
|
||||
- ....
|
||||
- ....
|
||||
- ....
|
||||
- Rate limiting decorators for endpoint protection
|
||||
- Request throttling and abuse prevention
|
||||
- Consistent error handling for rate limit violations
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.exceptions.base import RateLimitException # Add this import
|
||||
from middleware.rate_limiter import RateLimiter
|
||||
|
||||
# Initialize rate limiter instance
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
|
||||
def rate_limit(max_requests: int = 100, window_seconds: int = 3600):
|
||||
"""Rate limiting decorator for FastAPI endpoints."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract client IP or user ID for rate limiting
|
||||
client_id = "anonymous" # In production, extract from request
|
||||
|
||||
if not rate_limiter.allow_request(client_id, max_requests, window_seconds):
|
||||
raise HTTPException(status_code=429, detail="Rate limit exceeded")
|
||||
# Use custom exception instead of HTTPException
|
||||
raise RateLimitException(
|
||||
message="Rate limit exceeded",
|
||||
retry_after=window_seconds
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# middleware/error_handler.py
|
||||
"""Summary description ....
|
||||
|
||||
This module provides classes and functions for:
|
||||
- ....
|
||||
- ....
|
||||
- ....
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def custom_http_exception_handler(request: Request, exc: HTTPException):
|
||||
"""Handle HTTP exception."""
|
||||
logger.error(
|
||||
f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": exc.status_code,
|
||||
"message": exc.detail,
|
||||
"type": "http_exception",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle Pydantic validation errors."""
|
||||
logger.error(f"Validation error: {exc.errors()} - {request.method} {request.url}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"code": 422,
|
||||
"message": "Validation error",
|
||||
"type": "validation_error",
|
||||
"details": exc.errors(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle unexpected exceptions."""
|
||||
logger.error(
|
||||
f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": 500,
|
||||
"message": "Internal server error",
|
||||
"type": "server_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user