Files
orion/app/services/test_runner_service.py
Samir Boulahtit 0e6c9e3eea feat: run tests in background with progress polling
Improve the testing dashboard to run pytest in the background:

- Add background task execution using FastAPI's BackgroundTasks
- Create test_runner_tasks.py following existing background task pattern
- API now returns immediately after starting the test run
- Frontend polls for status every 2 seconds until completion
- Show running indicator with elapsed time counter
- Resume polling if user navigates away and returns while tests running
- Tests continue running even if user closes the page

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-12 23:20:26 +01:00

461 lines
16 KiB
Python

"""
Test Runner Service
Service for running pytest and storing results
"""
import json
import logging
import re
import subprocess
import tempfile
from datetime import UTC, datetime
from pathlib import Path
from sqlalchemy import desc, func
from sqlalchemy.orm import Session
from models.database.test_run import TestCollection, TestResult, TestRun
logger = logging.getLogger(__name__)
class TestRunnerService:
"""Service for managing pytest test runs"""
def __init__(self):
self.project_root = Path(__file__).parent.parent.parent
def create_test_run(
self,
db: Session,
test_path: str = "tests",
triggered_by: str = "manual",
) -> TestRun:
"""Create a test run record without executing tests"""
test_run = TestRun(
timestamp=datetime.now(UTC),
triggered_by=triggered_by,
test_path=test_path,
status="running",
git_commit_hash=self._get_git_commit(),
git_branch=self._get_git_branch(),
)
db.add(test_run)
db.flush()
return test_run
def run_tests(
self,
db: Session,
test_path: str = "tests",
triggered_by: str = "manual",
extra_args: list[str] | None = None,
) -> TestRun:
"""
Run pytest synchronously and store results in database
Args:
db: Database session
test_path: Path to tests (relative to project root)
triggered_by: Who triggered the run
extra_args: Additional pytest arguments
Returns:
TestRun object with results
"""
test_run = self.create_test_run(db, test_path, triggered_by)
self._execute_tests(db, test_run, test_path, extra_args)
return test_run
def _execute_tests(
self,
db: Session,
test_run: TestRun,
test_path: str,
extra_args: list[str] | None,
) -> None:
"""Execute pytest and update the test run record"""
try:
# Build pytest command with JSON output
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json_report_path = f.name
pytest_args = [
"python", "-m", "pytest",
test_path,
"--json-report",
f"--json-report-file={json_report_path}",
"-v",
"--tb=short",
]
if extra_args:
pytest_args.extend(extra_args)
test_run.pytest_args = " ".join(pytest_args)
# Run pytest
start_time = datetime.now(UTC)
result = subprocess.run(
pytest_args,
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=600, # 10 minute timeout
)
end_time = datetime.now(UTC)
test_run.duration_seconds = (end_time - start_time).total_seconds()
# Parse JSON report
try:
with open(json_report_path, 'r') as f:
report = json.load(f)
self._process_json_report(db, test_run, report)
except (FileNotFoundError, json.JSONDecodeError) as e:
# Fallback to parsing stdout if JSON report failed
logger.warning(f"JSON report unavailable ({e}), parsing stdout")
self._parse_pytest_output(test_run, result.stdout, result.stderr)
finally:
# Clean up temp file
try:
Path(json_report_path).unlink()
except Exception:
pass
# Set final status
if test_run.failed > 0 or test_run.errors > 0:
test_run.status = "failed"
else:
test_run.status = "passed"
except subprocess.TimeoutExpired:
test_run.status = "error"
logger.error("Pytest run timed out")
except Exception as e:
test_run.status = "error"
logger.error(f"Error running tests: {e}")
def _process_json_report(self, db: Session, test_run: TestRun, report: dict):
"""Process pytest-json-report output"""
summary = report.get("summary", {})
test_run.total_tests = summary.get("total", 0)
test_run.passed = summary.get("passed", 0)
test_run.failed = summary.get("failed", 0)
test_run.errors = summary.get("error", 0)
test_run.skipped = summary.get("skipped", 0)
test_run.xfailed = summary.get("xfailed", 0)
test_run.xpassed = summary.get("xpassed", 0)
# Process individual test results
tests = report.get("tests", [])
for test in tests:
node_id = test.get("nodeid", "")
outcome = test.get("outcome", "unknown")
# Parse node_id to get file, class, function
test_file, test_class, test_name = self._parse_node_id(node_id)
# Get failure details
error_message = None
traceback = None
if outcome in ("failed", "error"):
call_info = test.get("call", {})
if "longrepr" in call_info:
traceback = call_info["longrepr"]
# Extract error message from traceback
if isinstance(traceback, str):
lines = traceback.strip().split('\n')
if lines:
error_message = lines[-1][:500] # Last line, limited length
test_result = TestResult(
run_id=test_run.id,
node_id=node_id,
test_name=test_name,
test_file=test_file,
test_class=test_class,
outcome=outcome,
duration_seconds=test.get("duration", 0.0),
error_message=error_message,
traceback=traceback,
markers=test.get("keywords", []),
)
db.add(test_result)
def _parse_node_id(self, node_id: str) -> tuple[str, str | None, str]:
"""Parse pytest node_id into file, class, function"""
# Format: tests/unit/test_foo.py::TestClass::test_method
# or: tests/unit/test_foo.py::test_function
parts = node_id.split("::")
test_file = parts[0] if parts else ""
test_class = None
test_name = parts[-1] if parts else ""
if len(parts) == 3:
test_class = parts[1]
elif len(parts) == 2:
# Could be Class::method or file::function
if parts[1].startswith("Test"):
test_class = parts[1]
test_name = parts[1]
# Handle parametrized tests
if "[" in test_name:
test_name = test_name.split("[")[0]
return test_file, test_class, test_name
def _parse_pytest_output(self, test_run: TestRun, stdout: str, stderr: str):
"""Fallback parser for pytest text output"""
# Parse summary line like: "10 passed, 2 failed, 1 skipped"
summary_pattern = r"(\d+)\s+(passed|failed|error|skipped|xfailed|xpassed)"
for match in re.finditer(summary_pattern, stdout):
count = int(match.group(1))
status = match.group(2)
if status == "passed":
test_run.passed = count
elif status == "failed":
test_run.failed = count
elif status == "error":
test_run.errors = count
elif status == "skipped":
test_run.skipped = count
elif status == "xfailed":
test_run.xfailed = count
elif status == "xpassed":
test_run.xpassed = count
test_run.total_tests = (
test_run.passed + test_run.failed + test_run.errors +
test_run.skipped + test_run.xfailed + test_run.xpassed
)
def _get_git_commit(self) -> str | None:
"""Get current git commit hash"""
try:
result = subprocess.run(
["git", "rev-parse", "HEAD"],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=5,
)
return result.stdout.strip()[:40] if result.returncode == 0 else None
except:
return None
def _get_git_branch(self) -> str | None:
"""Get current git branch"""
try:
result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=5,
)
return result.stdout.strip() if result.returncode == 0 else None
except:
return None
def get_run_history(self, db: Session, limit: int = 20) -> list[TestRun]:
"""Get recent test run history"""
return (
db.query(TestRun)
.order_by(desc(TestRun.timestamp))
.limit(limit)
.all()
)
def get_run_by_id(self, db: Session, run_id: int) -> TestRun | None:
"""Get a specific test run with results"""
return db.query(TestRun).filter(TestRun.id == run_id).first()
def get_failed_tests(self, db: Session, run_id: int) -> list[TestResult]:
"""Get failed tests from a run"""
return (
db.query(TestResult)
.filter(
TestResult.run_id == run_id,
TestResult.outcome.in_(["failed", "error"])
)
.all()
)
def get_run_results(
self, db: Session, run_id: int, outcome: str | None = None
) -> list[TestResult]:
"""Get test results for a specific run, optionally filtered by outcome"""
query = db.query(TestResult).filter(TestResult.run_id == run_id)
if outcome:
query = query.filter(TestResult.outcome == outcome)
return query.all()
def get_dashboard_stats(self, db: Session) -> dict:
"""Get statistics for the testing dashboard"""
# Get latest run
latest_run = (
db.query(TestRun)
.filter(TestRun.status != "running")
.order_by(desc(TestRun.timestamp))
.first()
)
# Get test collection info (or calculate from latest run)
collection = db.query(TestCollection).order_by(desc(TestCollection.collected_at)).first()
# Get trend data (last 10 runs)
trend_runs = (
db.query(TestRun)
.filter(TestRun.status != "running")
.order_by(desc(TestRun.timestamp))
.limit(10)
.all()
)
# Calculate stats by category from latest run
by_category = {}
if latest_run:
results = db.query(TestResult).filter(TestResult.run_id == latest_run.id).all()
for result in results:
# Categorize by test path
if "unit" in result.test_file:
category = "Unit Tests"
elif "integration" in result.test_file:
category = "Integration Tests"
elif "performance" in result.test_file:
category = "Performance Tests"
elif "system" in result.test_file:
category = "System Tests"
else:
category = "Other"
if category not in by_category:
by_category[category] = {"total": 0, "passed": 0, "failed": 0}
by_category[category]["total"] += 1
if result.outcome == "passed":
by_category[category]["passed"] += 1
elif result.outcome in ("failed", "error"):
by_category[category]["failed"] += 1
# Get top failing tests (across recent runs)
top_failing = (
db.query(
TestResult.test_name,
TestResult.test_file,
func.count(TestResult.id).label("failure_count")
)
.filter(TestResult.outcome.in_(["failed", "error"]))
.group_by(TestResult.test_name, TestResult.test_file)
.order_by(desc("failure_count"))
.limit(10)
.all()
)
return {
# Current run stats
"total_tests": latest_run.total_tests if latest_run else 0,
"passed": latest_run.passed if latest_run else 0,
"failed": latest_run.failed if latest_run else 0,
"errors": latest_run.errors if latest_run else 0,
"skipped": latest_run.skipped if latest_run else 0,
"pass_rate": round(latest_run.pass_rate, 1) if latest_run else 0,
"duration_seconds": round(latest_run.duration_seconds, 2) if latest_run else 0,
"coverage_percent": latest_run.coverage_percent if latest_run else None,
"last_run": latest_run.timestamp.isoformat() if latest_run else None,
"last_run_status": latest_run.status if latest_run else None,
# Collection stats
"total_test_files": collection.total_files if collection else 0,
# Trend data
"trend": [
{
"timestamp": run.timestamp.isoformat(),
"total": run.total_tests,
"passed": run.passed,
"failed": run.failed,
"pass_rate": round(run.pass_rate, 1),
"duration": round(run.duration_seconds, 1),
}
for run in reversed(trend_runs)
],
# By category
"by_category": by_category,
# Top failing tests
"top_failing": [
{
"test_name": t.test_name,
"test_file": t.test_file,
"failure_count": t.failure_count,
}
for t in top_failing
],
}
def collect_tests(self, db: Session) -> TestCollection:
"""Collect test information without running tests"""
collection = TestCollection(
collected_at=datetime.now(UTC),
)
try:
# Run pytest --collect-only
result = subprocess.run(
["python", "-m", "pytest", "--collect-only", "-q", "tests"],
cwd=str(self.project_root),
capture_output=True,
text=True,
timeout=60,
)
# Parse output
lines = result.stdout.strip().split('\n')
test_files = {}
for line in lines:
if "::" in line:
file_path = line.split("::")[0]
if file_path not in test_files:
test_files[file_path] = 0
test_files[file_path] += 1
# Count by category
for file_path, count in test_files.items():
collection.total_tests += count
collection.total_files += 1
if "unit" in file_path:
collection.unit_tests += count
elif "integration" in file_path:
collection.integration_tests += count
elif "performance" in file_path:
collection.performance_tests += count
elif "system" in file_path:
collection.system_tests += count
collection.test_files = [
{"file": f, "count": c}
for f, c in sorted(test_files.items(), key=lambda x: -x[1])
]
except Exception as e:
logger.error(f"Error collecting tests: {e}")
db.add(collection)
return collection
# Singleton instance
test_runner_service = TestRunnerService()