# app/utils/encryption.py """ Encryption utilities for sensitive data storage. Uses Fernet symmetric encryption with key derivation from the JWT secret. Provides secure storage for API keys and other sensitive credentials. """ import base64 import logging from cryptography.fernet import Fernet, InvalidToken from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from app.core.config import settings logger = logging.getLogger(__name__) # Salt for key derivation - fixed to ensure consistent encryption/decryption # In production, this should be stored securely and not changed _ENCRYPTION_SALT = b"orion_encryption_salt_v1" class EncryptionError(Exception): """Raised when encryption or decryption fails.""" class EncryptionService: """ Service for encrypting and decrypting sensitive data. Uses Fernet symmetric encryption with a key derived from the application's JWT secret key. This ensures that encrypted data can only be decrypted by the same application instance with the same secret. """ def __init__(self, secret_key: str | None = None): """ Initialize the encryption service. Args: secret_key: The secret key to derive the encryption key from. Defaults to the JWT secret key from settings. """ if secret_key is None: secret_key = settings.jwt_secret_key self._fernet = self._create_fernet(secret_key) def _create_fernet(self, secret_key: str) -> Fernet: """ Create a Fernet instance with a derived key. Uses PBKDF2 to derive a 32-byte key from the secret, then encodes it as base64 for Fernet. """ kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=_ENCRYPTION_SALT, iterations=100000, ) derived_key = kdf.derive(secret_key.encode()) fernet_key = base64.urlsafe_b64encode(derived_key) return Fernet(fernet_key) def encrypt(self, plaintext: str) -> str: """ Encrypt a plaintext string. Args: plaintext: The string to encrypt. Returns: Base64-encoded ciphertext string. Raises: EncryptionError: If encryption fails. """ if not plaintext: raise EncryptionError("Cannot encrypt empty string") try: ciphertext = self._fernet.encrypt(plaintext.encode()) return ciphertext.decode() except Exception as e: logger.error(f"Encryption failed: {e}") raise EncryptionError(f"Failed to encrypt data: {e}") from e def decrypt(self, ciphertext: str) -> str: """ Decrypt a ciphertext string. Args: ciphertext: Base64-encoded ciphertext to decrypt. Returns: Decrypted plaintext string. Raises: EncryptionError: If decryption fails (invalid token or corrupted data). """ if not ciphertext: raise EncryptionError("Cannot decrypt empty string") try: plaintext = self._fernet.decrypt(ciphertext.encode()) return plaintext.decode() except InvalidToken as e: logger.error("Decryption failed: Invalid token") raise EncryptionError( "Failed to decrypt data: Invalid or corrupted ciphertext" ) from e except Exception as e: logger.error(f"Decryption failed: {e}") raise EncryptionError(f"Failed to decrypt data: {e}") from e def is_valid_ciphertext(self, ciphertext: str) -> bool: """ Check if a string is valid ciphertext that can be decrypted. Args: ciphertext: String to validate. Returns: True if the string can be decrypted, False otherwise. """ try: self.decrypt(ciphertext) return True except EncryptionError: return False # Singleton instance using the JWT secret key encryption_service = EncryptionService() def encrypt_value(value: str) -> str: """ Convenience function to encrypt a value using the default service. Args: value: The string to encrypt. Returns: Encrypted string. """ return encryption_service.encrypt(value) def decrypt_value(value: str) -> str: """ Convenience function to decrypt a value using the default service. Args: value: The encrypted string to decrypt. Returns: Decrypted string. """ return encryption_service.decrypt(value) def mask_api_key(api_key: str, visible_chars: int = 4) -> str: """ Mask an API key for display purposes. Shows only the first few characters, replacing the rest with asterisks. Args: api_key: The API key to mask. visible_chars: Number of characters to show at the start. Returns: Masked API key string (e.g., "sk-a***************"). """ if not api_key: return "" if len(api_key) <= visible_chars: return "*" * len(api_key) return api_key[:visible_chars] + "*" * (len(api_key) - visible_chars)