Files
orion/app/modules/dev_tools/services/sql_query_service.py
Samir Boulahtit d9abb275a5 feat(dev_tools): expand SQL query tool presets and fix column headers
Add 45 new preset queries covering all database tables, reorganize into
platform-aligned sections (Infrastructure, Core, OMS, Loyalty, Hosting,
Internal) with search/filter input. Fix column headers not appearing on
SELECT * queries by capturing result.keys() before fetchmany().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-10 23:28:57 +02:00

197 lines
5.8 KiB
Python

# app/modules/dev_tools/services/sql_query_service.py
"""
SQL Query Service
Provides safe, read-only SQL query execution and saved query CRUD operations.
Security layers:
1. Regex-based DML/DDL rejection
2. SET TRANSACTION READ ONLY on PostgreSQL
3. Statement timeout (30s)
4. Automatic rollback after every execution
"""
import re
import time
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from uuid import UUID
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.exceptions.base import ResourceNotFoundException, ValidationException
from app.modules.dev_tools.models.saved_query import SavedQuery
# Forbidden SQL keywords — matches whole words, case-insensitive
_FORBIDDEN_PATTERN = re.compile(
r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|GRANT|REVOKE|COPY|VACUUM|REINDEX)\b",
re.IGNORECASE,
)
class QueryValidationError(ValidationException):
"""Raised when a query contains forbidden SQL statements."""
def __init__(self, message: str):
super().__init__(message=message, field="sql")
def _strip_sql_comments(sql: str) -> str:
"""Remove SQL comments (-- line comments and /* block comments */)."""
# Remove block comments
result = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL)
# Remove line comments
result = re.sub(r"--[^\n]*", " ", result)
return result
def validate_query(sql: str) -> None:
"""Validate that the SQL query is SELECT-only (no DML/DDL)."""
stripped = sql.strip().rstrip(";")
if not stripped:
raise QueryValidationError("Query cannot be empty.")
# Strip comments before checking for forbidden keywords
code_only = _strip_sql_comments(stripped)
match = _FORBIDDEN_PATTERN.search(code_only)
if match:
raise QueryValidationError(
f"Forbidden SQL keyword: {match.group().upper()}. Only SELECT queries are allowed."
)
def _make_json_safe(value: Any) -> Any:
"""Convert a database value to a JSON-serializable representation."""
if value is None:
return None
if isinstance(value, int | float | bool):
return value
if isinstance(value, Decimal):
return float(value)
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, UUID):
return str(value)
if isinstance(value, bytes):
return f"<bytes({len(value)})>"
if isinstance(value, list | dict):
return value
return str(value)
def execute_query(db: Session, sql: str) -> dict:
"""
Execute a read-only SQL query and return results.
Returns:
dict with columns, rows, row_count, truncated, execution_time_ms
"""
validate_query(sql)
max_rows = 1000
connection = db.connection()
try:
# Set read-only transaction and statement timeout
connection.execute(text("SET TRANSACTION READ ONLY"))
connection.execute(text("SET statement_timeout = '30s'"))
start = time.perf_counter()
result = connection.execute(text(sql))
columns = list(result.keys()) if result.returns_rows else []
rows_raw = result.fetchmany(max_rows + 1)
elapsed_ms = round((time.perf_counter() - start) * 1000, 2)
truncated = len(rows_raw) > max_rows
rows_raw = rows_raw[:max_rows]
rows = [
[_make_json_safe(cell) for cell in row] for row in rows_raw
]
return {
"columns": columns,
"rows": rows,
"row_count": len(rows),
"truncated": truncated,
"execution_time_ms": elapsed_ms,
}
except (QueryValidationError, ValidationException):
raise
except Exception as e: # noqa: EXC003
raise QueryValidationError(str(e)) from e
finally:
db.rollback()
# ---------------------------------------------------------------------------
# Saved Query CRUD
# ---------------------------------------------------------------------------
def list_saved_queries(db: Session) -> list[SavedQuery]:
"""List all saved queries ordered by name."""
return db.query(SavedQuery).order_by(SavedQuery.name).all() # noqa: SVC-005
def create_saved_query(
db: Session,
*,
name: str,
sql_text: str,
description: str | None,
created_by: int,
) -> SavedQuery:
"""Create a new saved query."""
query = SavedQuery(
name=name,
sql_text=sql_text,
description=description,
created_by=created_by,
)
db.add(query)
db.flush()
db.refresh(query)
return query
def update_saved_query(
db: Session,
query_id: int,
*,
name: str | None = None,
sql_text: str | None = None,
description: str | None = None,
) -> SavedQuery:
"""Update an existing saved query. Raises ResourceNotFoundException if not found."""
query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first()
if not query:
raise ResourceNotFoundException("SavedQuery", str(query_id))
if name is not None:
query.name = name
if sql_text is not None:
query.sql_text = sql_text
if description is not None:
query.description = description
db.flush()
db.refresh(query)
return query
def delete_saved_query(db: Session, query_id: int) -> None:
"""Delete a saved query. Raises ResourceNotFoundException if not found."""
query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first()
if not query:
raise ResourceNotFoundException("SavedQuery", str(query_id))
db.delete(query)
db.flush()
def record_query_run(db: Session, query_id: int) -> None:
"""Increment run_count and update last_run_at for a saved query."""
query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first()
if query:
query.run_count = (query.run_count or 0) + 1
query.last_run_at = datetime.now(UTC)
db.flush()