code quality run

This commit is contained in:
2025-09-13 21:58:54 +02:00
parent 0dfd885847
commit 3eb18ef91e
63 changed files with 1802 additions and 1289 deletions

View File

@@ -1,14 +1,16 @@
# middleware/auth.py
import logging
import os
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from passlib.context import CryptContext
from jose import jwt
from datetime import datetime, timedelta
from typing import Dict, Any, Optional
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from models.database_models import User
import os
import logging
logger = logging.getLogger(__name__)
@@ -20,7 +22,9 @@ class AuthManager:
"""JWT-based authentication manager with bcrypt password hashing"""
def __init__(self):
self.secret_key = os.getenv("JWT_SECRET_KEY", "your-secret-key-change-in-production-please")
self.secret_key = os.getenv(
"JWT_SECRET_KEY", "your-secret-key-change-in-production-please"
)
self.algorithm = "HS256"
self.token_expire_minutes = int(os.getenv("JWT_EXPIRE_MINUTES", "30"))
@@ -32,11 +36,15 @@ class AuthManager:
"""Verify password against hash"""
return pwd_context.verify(plain_password, hashed_password)
def authenticate_user(self, db: Session, username: str, password: str) -> Optional[User]:
def authenticate_user(
self, db: Session, username: str, password: str
) -> Optional[User]:
"""Authenticate user and return user object if valid"""
user = db.query(User).filter(
(User.username == username) | (User.email == username)
).first()
user = (
db.query(User)
.filter((User.username == username) | (User.email == username))
.first()
)
if not user:
return None
@@ -65,7 +73,7 @@ class AuthManager:
"email": user.email,
"role": user.role,
"exp": expire,
"iat": datetime.utcnow()
"iat": datetime.utcnow(),
}
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
@@ -73,7 +81,7 @@ class AuthManager:
return {
"access_token": token,
"token_type": "bearer",
"expires_in": self.token_expire_minutes * 60 # Return in seconds
"expires_in": self.token_expire_minutes * 60, # Return in seconds
}
def verify_token(self, token: str) -> Dict[str, Any]:
@@ -92,25 +100,31 @@ class AuthManager:
# Extract user data
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="Token missing user identifier")
raise HTTPException(
status_code=401, detail="Token missing user identifier"
)
return {
"user_id": int(user_id),
"username": payload.get("username"),
"email": payload.get("email"),
"role": payload.get("role", "user")
"role": payload.get("role", "user"),
}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.JWTError as e:
logger.error(f"JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Could not validate credentials")
raise HTTPException(
status_code=401, detail="Could not validate credentials"
)
except Exception as e:
logger.error(f"Token verification error: {e}")
raise HTTPException(status_code=401, detail="Authentication failed")
def get_current_user(self, db: Session, credentials: HTTPAuthorizationCredentials) -> User:
def get_current_user(
self, db: Session, credentials: HTTPAuthorizationCredentials
) -> User:
"""Get current authenticated user from database"""
user_data = self.verify_token(credentials.credentials)
@@ -131,7 +145,7 @@ class AuthManager:
if current_user.role != required_role:
raise HTTPException(
status_code=403,
detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'"
detail=f"Required role '{required_role}' not found. Current role: '{current_user.role}'",
)
return func(current_user, *args, **kwargs)
@@ -142,10 +156,7 @@ class AuthManager:
def require_admin(self, current_user: User):
"""Require admin role"""
if current_user.role != "admin":
raise HTTPException(
status_code=403,
detail="Admin privileges required"
)
raise HTTPException(status_code=403, detail="Admin privileges required")
return current_user
def create_default_admin_user(self, db: Session):
@@ -159,11 +170,13 @@ class AuthManager:
username="admin",
hashed_password=hashed_password,
role="admin",
is_active=True
is_active=True,
)
db.add(admin_user)
db.commit()
db.refresh(admin_user)
logger.info("Default admin user created: username='admin', password='admin123'")
logger.info(
"Default admin user created: username='admin', password='admin123'"
)
return admin_user

View File

@@ -1,6 +1,8 @@
# middleware/decorators.py
from functools import wraps
from fastapi import HTTPException
from middleware.rate_limiter import RateLimiter
# Initialize rate limiter instance
@@ -17,10 +19,7 @@ def rate_limit(max_requests: int = 100, window_seconds: int = 3600):
client_id = "anonymous" # In production, extract from request
if not rate_limiter.allow_request(client_id, max_requests, window_seconds):
raise HTTPException(
status_code=429,
detail="Rate limit exceeded"
)
raise HTTPException(status_code=429, detail="Rate limit exceeded")
return await func(*args, **kwargs)

View File

@@ -1,31 +1,35 @@
# middleware/error_handler.py
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
import logging
from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
logger = logging.getLogger(__name__)
async def custom_http_exception_handler(request: Request, exc: HTTPException):
"""Custom HTTP exception handler"""
logger.error(f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}")
logger.error(
f"HTTP {exc.status_code}: {exc.detail} - {request.method} {request.url}"
)
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.status_code,
"message": exc.detail,
"type": "http_exception"
"type": "http_exception",
}
}
},
)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle Pydantic validation errors"""
logger.error(f"Validation error: {exc.errors()} - {request.method} {request.url}")
return JSONResponse(
status_code=422,
content={
@@ -33,23 +37,25 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
"code": 422,
"message": "Validation error",
"type": "validation_error",
"details": exc.errors()
"details": exc.errors(),
}
}
},
)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle unexpected exceptions"""
logger.error(f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True)
logger.error(
f"Unexpected error: {str(exc)} - {request.method} {request.url}", exc_info=True
)
return JSONResponse(
status_code=500,
content={
"error": {
"code": 500,
"message": "Internal server error",
"type": "server_error"
"type": "server_error",
}
}
},
)

View File

@@ -1,9 +1,10 @@
# middleware/logging_middleware.py
import logging
import time
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable
logger = logging.getLogger(__name__)
@@ -43,4 +44,4 @@ class LoggingMiddleware(BaseHTTPMiddleware):
f"Error: {str(e)} for {request.method} {request.url.path} "
f"({duration:.3f}s)"
)
raise
raise

View File

@@ -1,8 +1,8 @@
# middleware/rate_limiter.py
from typing import Dict
from datetime import datetime, timedelta
import logging
from collections import defaultdict, deque
from datetime import datetime, timedelta
from typing import Dict
logger = logging.getLogger(__name__)
@@ -16,7 +16,9 @@ class RateLimiter:
self.cleanup_interval = 3600 # Clean up old entries every hour
self.last_cleanup = datetime.utcnow()
def allow_request(self, client_id: str, max_requests: int, window_seconds: int) -> bool:
def allow_request(
self, client_id: str, max_requests: int, window_seconds: int
) -> bool:
"""
Check if client is allowed to make a request
Uses sliding window algorithm
@@ -41,7 +43,9 @@ class RateLimiter:
client_requests.append(now)
return True
logger.warning(f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}")
logger.warning(
f"Rate limit exceeded for client {client_id}: {len(client_requests)}/{max_requests}"
)
return False
def _cleanup_old_entries(self):
@@ -62,7 +66,9 @@ class RateLimiter:
for client_id in clients_to_remove:
del self.clients[client_id]
logger.info(f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients")
logger.info(
f"Rate limiter cleanup completed. Removed {len(clients_to_remove)} inactive clients"
)
def get_client_stats(self, client_id: str) -> Dict[str, int]:
"""Get statistics for a specific client"""
@@ -72,11 +78,13 @@ class RateLimiter:
hour_ago = now - timedelta(hours=1)
day_ago = now - timedelta(days=1)
requests_last_hour = sum(1 for req_time in client_requests if req_time > hour_ago)
requests_last_hour = sum(
1 for req_time in client_requests if req_time > hour_ago
)
requests_last_day = sum(1 for req_time in client_requests if req_time > day_ago)
return {
"requests_last_hour": requests_last_hour,
"requests_last_day": requests_last_day,
"total_tracked_requests": len(client_requests)
}
"total_tracked_requests": len(client_requests),
}