Add 45 new preset queries covering all database tables, reorganize into platform-aligned sections (Infrastructure, Core, OMS, Loyalty, Hosting, Internal) with search/filter input. Fix column headers not appearing on SELECT * queries by capturing result.keys() before fetchmany(). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
197 lines
5.8 KiB
Python
197 lines
5.8 KiB
Python
# app/modules/dev_tools/services/sql_query_service.py
|
|
"""
|
|
SQL Query Service
|
|
|
|
Provides safe, read-only SQL query execution and saved query CRUD operations.
|
|
Security layers:
|
|
1. Regex-based DML/DDL rejection
|
|
2. SET TRANSACTION READ ONLY on PostgreSQL
|
|
3. Statement timeout (30s)
|
|
4. Automatic rollback after every execution
|
|
"""
|
|
|
|
import re
|
|
import time
|
|
from datetime import UTC, datetime
|
|
from decimal import Decimal
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.exceptions.base import ResourceNotFoundException, ValidationException
|
|
from app.modules.dev_tools.models.saved_query import SavedQuery
|
|
|
|
# Forbidden SQL keywords — matches whole words, case-insensitive
|
|
_FORBIDDEN_PATTERN = re.compile(
|
|
r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|GRANT|REVOKE|COPY|VACUUM|REINDEX)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
class QueryValidationError(ValidationException):
|
|
"""Raised when a query contains forbidden SQL statements."""
|
|
|
|
def __init__(self, message: str):
|
|
super().__init__(message=message, field="sql")
|
|
|
|
|
|
def _strip_sql_comments(sql: str) -> str:
|
|
"""Remove SQL comments (-- line comments and /* block comments */)."""
|
|
# Remove block comments
|
|
result = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL)
|
|
# Remove line comments
|
|
result = re.sub(r"--[^\n]*", " ", result)
|
|
return result
|
|
|
|
|
|
def validate_query(sql: str) -> None:
|
|
"""Validate that the SQL query is SELECT-only (no DML/DDL)."""
|
|
stripped = sql.strip().rstrip(";")
|
|
if not stripped:
|
|
raise QueryValidationError("Query cannot be empty.")
|
|
|
|
# Strip comments before checking for forbidden keywords
|
|
code_only = _strip_sql_comments(stripped)
|
|
match = _FORBIDDEN_PATTERN.search(code_only)
|
|
if match:
|
|
raise QueryValidationError(
|
|
f"Forbidden SQL keyword: {match.group().upper()}. Only SELECT queries are allowed."
|
|
)
|
|
|
|
|
|
def _make_json_safe(value: Any) -> Any:
|
|
"""Convert a database value to a JSON-serializable representation."""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, int | float | bool):
|
|
return value
|
|
if isinstance(value, Decimal):
|
|
return float(value)
|
|
if isinstance(value, datetime):
|
|
return value.isoformat()
|
|
if isinstance(value, UUID):
|
|
return str(value)
|
|
if isinstance(value, bytes):
|
|
return f"<bytes({len(value)})>"
|
|
if isinstance(value, list | dict):
|
|
return value
|
|
return str(value)
|
|
|
|
|
|
def execute_query(db: Session, sql: str) -> dict:
|
|
"""
|
|
Execute a read-only SQL query and return results.
|
|
|
|
Returns:
|
|
dict with columns, rows, row_count, truncated, execution_time_ms
|
|
"""
|
|
validate_query(sql)
|
|
|
|
max_rows = 1000
|
|
connection = db.connection()
|
|
|
|
try:
|
|
# Set read-only transaction and statement timeout
|
|
connection.execute(text("SET TRANSACTION READ ONLY"))
|
|
connection.execute(text("SET statement_timeout = '30s'"))
|
|
|
|
start = time.perf_counter()
|
|
result = connection.execute(text(sql))
|
|
columns = list(result.keys()) if result.returns_rows else []
|
|
rows_raw = result.fetchmany(max_rows + 1)
|
|
elapsed_ms = round((time.perf_counter() - start) * 1000, 2)
|
|
truncated = len(rows_raw) > max_rows
|
|
rows_raw = rows_raw[:max_rows]
|
|
|
|
rows = [
|
|
[_make_json_safe(cell) for cell in row] for row in rows_raw
|
|
]
|
|
|
|
return {
|
|
"columns": columns,
|
|
"rows": rows,
|
|
"row_count": len(rows),
|
|
"truncated": truncated,
|
|
"execution_time_ms": elapsed_ms,
|
|
}
|
|
except (QueryValidationError, ValidationException):
|
|
raise
|
|
except Exception as e: # noqa: EXC003
|
|
raise QueryValidationError(str(e)) from e
|
|
finally:
|
|
db.rollback()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Saved Query CRUD
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def list_saved_queries(db: Session) -> list[SavedQuery]:
|
|
"""List all saved queries ordered by name."""
|
|
return db.query(SavedQuery).order_by(SavedQuery.name).all() # noqa: SVC-005
|
|
|
|
|
|
def create_saved_query(
|
|
db: Session,
|
|
*,
|
|
name: str,
|
|
sql_text: str,
|
|
description: str | None,
|
|
created_by: int,
|
|
) -> SavedQuery:
|
|
"""Create a new saved query."""
|
|
query = SavedQuery(
|
|
name=name,
|
|
sql_text=sql_text,
|
|
description=description,
|
|
created_by=created_by,
|
|
)
|
|
db.add(query)
|
|
db.flush()
|
|
db.refresh(query)
|
|
return query
|
|
|
|
|
|
def update_saved_query(
|
|
db: Session,
|
|
query_id: int,
|
|
*,
|
|
name: str | None = None,
|
|
sql_text: str | None = None,
|
|
description: str | None = None,
|
|
) -> SavedQuery:
|
|
"""Update an existing saved query. Raises ResourceNotFoundException if not found."""
|
|
query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first()
|
|
if not query:
|
|
raise ResourceNotFoundException("SavedQuery", str(query_id))
|
|
if name is not None:
|
|
query.name = name
|
|
if sql_text is not None:
|
|
query.sql_text = sql_text
|
|
if description is not None:
|
|
query.description = description
|
|
db.flush()
|
|
db.refresh(query)
|
|
return query
|
|
|
|
|
|
def delete_saved_query(db: Session, query_id: int) -> None:
|
|
"""Delete a saved query. Raises ResourceNotFoundException if not found."""
|
|
query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first()
|
|
if not query:
|
|
raise ResourceNotFoundException("SavedQuery", str(query_id))
|
|
db.delete(query)
|
|
db.flush()
|
|
|
|
|
|
def record_query_run(db: Session, query_id: int) -> None:
|
|
"""Increment run_count and update last_run_at for a saved query."""
|
|
query = db.query(SavedQuery).filter(SavedQuery.id == query_id).first()
|
|
if query:
|
|
query.run_count = (query.run_count or 0) + 1
|
|
query.last_run_at = datetime.now(UTC)
|
|
db.flush()
|