# app/modules/dev_tools/services/sql_query_service.py """ SQL Query Service Provides safe, read-only SQL query execution and saved query CRUD operations. Security layers: 1. Regex-based DML/DDL rejection 2. SET TRANSACTION READ ONLY on PostgreSQL 3. Statement timeout (30s) 4. Automatic rollback after every execution """ import re import time from datetime import UTC, datetime from decimal import Decimal from typing import Any from uuid import UUID from sqlalchemy import text from sqlalchemy.orm import Session from app.exceptions.base import ResourceNotFoundException, ValidationException from app.modules.dev_tools.models.saved_query import SavedQuery # Forbidden SQL keywords — matches whole words, case-insensitive _FORBIDDEN_PATTERN = re.compile( r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|GRANT|REVOKE|COPY|VACUUM|REINDEX)\b", re.IGNORECASE, ) class QueryValidationError(ValidationException): """Raised when a query contains forbidden SQL statements.""" def __init__(self, message: str): super().__init__(message=message, field="sql") def _strip_sql_comments(sql: str) -> str: """Remove SQL comments (-- line comments and /* block comments */).""" # Remove block comments result = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # Remove line comments result = re.sub(r"--[^\n]*", " ", result) return result def validate_query(sql: str) -> None: """Validate that the SQL query is SELECT-only (no DML/DDL).""" stripped = sql.strip().rstrip(";") if not stripped: raise QueryValidationError("Query cannot be empty.") # Strip comments before checking for forbidden keywords code_only = _strip_sql_comments(stripped) match = _FORBIDDEN_PATTERN.search(code_only) if match: raise QueryValidationError( f"Forbidden SQL keyword: {match.group().upper()}. Only SELECT queries are allowed." ) def _make_json_safe(value: Any) -> Any: """Convert a database value to a JSON-serializable representation.""" if value is None: return None if isinstance(value, int | float | bool): return value if isinstance(value, Decimal): return float(value) if isinstance(value, datetime): return value.isoformat() if isinstance(value, UUID): return str(value) if isinstance(value, bytes): return f"" if isinstance(value, list | dict): return value return str(value) def execute_query(db: Session, sql: str) -> dict: """ Execute a read-only SQL query and return results. Returns: dict with columns, rows, row_count, truncated, execution_time_ms """ validate_query(sql) max_rows = 1000 connection = db.connection() try: # Set read-only transaction and statement timeout connection.execute(text("SET TRANSACTION READ ONLY")) connection.execute(text("SET statement_timeout = '30s'")) start = time.perf_counter() result = connection.execute(text(sql)) columns = list(result.keys()) if result.returns_rows else [] rows_raw = result.fetchmany(max_rows + 1) elapsed_ms = round((time.perf_counter() - start) * 1000, 2) truncated = len(rows_raw) > max_rows rows_raw = rows_raw[:max_rows] rows = [ [_make_json_safe(cell) for cell in row] for row in rows_raw ] return { "columns": columns, "rows": rows, "row_count": len(rows), "truncated": truncated, "execution_time_ms": elapsed_ms, } except (QueryValidationError, ValidationException): raise except Exception as e: # noqa: EXC003 raise QueryValidationError(str(e)) from e finally: db.rollback() # --------------------------------------------------------------------------- # Saved Query CRUD # --------------------------------------------------------------------------- def list_saved_queries(db: Session) -> list[SavedQuery]: """List all saved queries ordered by name.""" return db.query(SavedQuery).order_by(SavedQuery.name).all() # noqa: SVC-005 def create_saved_query( db: Session, *, name: str, sql_text: str, description: str | None, created_by: int, ) -> SavedQuery: """Create a new saved query.""" query = SavedQuery( name=name, sql_text=sql_text, description=description, created_by=created_by, ) db.add(query) db.flush() db.refresh(query) return query def update_saved_query( db: Session, query_id: int, *, name: str | None = None, sql_text: str | None = None, description: str | None = None, ) -> SavedQuery: """Update an existing saved query. Raises ResourceNotFoundException if not found.""" query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first() if not query: raise ResourceNotFoundException("SavedQuery", str(query_id)) if name is not None: query.name = name if sql_text is not None: query.sql_text = sql_text if description is not None: query.description = description db.flush() db.refresh(query) return query def delete_saved_query(db: Session, query_id: int) -> None: """Delete a saved query. Raises ResourceNotFoundException if not found.""" query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first() if not query: raise ResourceNotFoundException("SavedQuery", str(query_id)) db.delete(query) db.flush() def record_query_run(db: Session, query_id: int) -> None: """Increment run_count and update last_run_at for a saved query.""" query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first() if query: query.run_count = (query.run_count or 0) + 1 query.last_run_at = datetime.now(UTC) db.flush()