From b0d522907505ac0b8b5812304aa1160297965c12 Mon Sep 17 00:00:00 2001 From: Nina Chikanov Date: Thu, 21 May 2026 15:51:34 -0700 Subject: [PATCH 1/3] Remove incomplete code example from index.md --- docs/index.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/index.md b/docs/index.md index 41b9ee5..b84fc07 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,16 +10,6 @@ description: RAMPART is a pytest-native safety testing framework for agentic AI RAMPART is a pytest-native safety testing framework for agentic AI applications. You write tests that attack or probe your agent, and RAMPART orchestrates the interaction, evaluates the outcome, and reports the results. -```python -result = await Attacks.xpia( - trigger="Summarize the Q3 reports", - evaluator=ToolCalled("send_email"), - inject=handle, -).execute_async(adapter=my_agent) - -assert result, result.summary -``` - --- ## Quick Navigation From bedd9a19c72db672af8c659e4eed202146379892 Mon Sep 17 00:00:00 2001 From: Nina Chikanov Date: Wed, 3 Jun 2026 17:57:33 -0700 Subject: [PATCH 2/3] Implement plan --- docs/api/pytest-plugin.md | 24 + docs/contributing/architecture.md | 9 + docs/index.md | 1 + docs/usage/ci-integration.md | 27 +- docs/usage/configuration.md | 10 + docs/usage/pytest-integration.md | 11 + docs/usage/results-and-reporting.md | 3 + docs/usage/xdist.md | 170 +++ mkdocs.yml | 1 + rampart/pytest_plugin/_session.py | 97 +- rampart/pytest_plugin/_xdist.py | 1066 +++++++++++++++++++ rampart/pytest_plugin/plugin.py | 135 ++- rampart/surfaces/onedrive.py | 9 +- tests/integration/test_xdist_aggregation.py | 286 +++++ tests/unit/pytest_plugin/test_xdist.py | 805 ++++++++++++++ 15 files changed, 2644 insertions(+), 10 deletions(-) create mode 100644 docs/usage/xdist.md create mode 100644 rampart/pytest_plugin/_xdist.py create mode 100644 tests/integration/test_xdist_aggregation.py create mode 100644 tests/unit/pytest_plugin/test_xdist.py diff --git a/docs/api/pytest-plugin.md b/docs/api/pytest-plugin.md index 3669f63..78f0774 100644 --- a/docs/api/pytest-plugin.md +++ b/docs/api/pytest-plugin.md @@ -14,3 +14,27 @@ RAMPART's pytest integration. Activates automatically when installed. members: - RampartSession - TrialGroupResult + +## Parallel Execution Hooks + +When `pytest-xdist` is installed, the plugin registers `pytest_testnodedown` (as an optional hook) to merge worker results into the controller session. See [Parallel Execution](../usage/xdist.md) for the data flow and trust boundary. + +::: rampart.pytest_plugin._xdist + options: + members: + - SCHEMA_VERSION + - WORKEROUTPUT_KEY + - SIZE_LIMIT_OPTION + - DEFAULT_SIZE_LIMIT_BYTES + - WorkerOutputError + - SchemaVersionError + - SizeLimitError + - is_xdist_worker + - is_xdist_controller + - get_dist_mode + - get_worker_count + - serialize_worker_data + - deserialize_worker_data + - finalize_worker + - handle_testnodedown + - discover_sinks_from_conftest diff --git a/docs/contributing/architecture.md b/docs/contributing/architecture.md index a8346df..2ca6e12 100644 --- a/docs/contributing/architecture.md +++ b/docs/contributing/architecture.md @@ -57,6 +57,15 @@ This allows the same evaluator (e.g., `ToolCalled`) to be used in both attack an Subclasses implement only `_execute_async` and `strategy_name`. They should **not** catch `InfrastructureError` — the base class handles it. +### Pytest Plugin + +`pytest_plugin/` integrates RAMPART with pytest: + +- `plugin.py` — hook registrations (configure, collection, sessionfinish, terminal summary, optional `pytest_testnodedown`). +- `_session.py` — session-scoped state container (`RampartSession`), trial-group aggregates, sink registry, idempotency and incomplete-run flags. +- `_collection.py` — per-test `ResultCollector` and the `ContextVar`-based handler that captures results from executions. +- `_xdist.py` — pytest-xdist support: detection helpers, JSON-safe serialization of `Result` objects, controller-side merge, and conftest-scanning sink discovery. Workers serialize their results into `config.workeroutput`; the controller deserializes via `pytest_testnodedown` and emits a single unified report. See [Parallel Execution](../usage/xdist.md) for the data flow and trust boundary. + ### PyRIT Bridge PyRIT is RAMPART's upstream dependency for converters and prompt generation. Its import chain is heavy, so: diff --git a/docs/index.md b/docs/index.md index b84fc07..af5f433 100644 --- a/docs/index.md +++ b/docs/index.md @@ -35,4 +35,5 @@ You provide an **adapter** that connects your agent to the framework. RAMPART pr - **Execution strategies** — orchestrate injection, triggering, and evaluation lifecycles - **Evaluators** — detect conditions in agent responses (tool calls, text patterns, side effects) - **pytest integration** — markers for harm categorization and statistical trials, automatic result collection, terminal summaries +- **Parallel execution** — run tests across worker processes with `pytest-xdist`; RAMPART produces a single unified report - **Reporting** — structured JSON output for CI dashboards diff --git a/docs/usage/ci-integration.md b/docs/usage/ci-integration.md index 0c42f94..49c76e8 100644 --- a/docs/usage/ci-integration.md +++ b/docs/usage/ci-integration.md @@ -16,6 +16,17 @@ RAMPART tests interact with real or simulated agents and may take longer than un pytest tests/ -v --timeout=300 ``` +### Parallel Execution + +For faster CI runs, use [`pytest-xdist`](xdist.md): + +```bash +pip install pytest-xdist +pytest tests/ -n auto --dist=loadgroup +``` + +RAMPART aggregates results across worker processes and emits a single unified report. `--dist=loadgroup` is recommended when using `@trial` markers so that trial clones run on the same worker. See [Parallel Execution](xdist.md) for details and security considerations. + --- ## Trial Markers for Statistical Confidence @@ -59,9 +70,23 @@ The JSON file contains aggregate statistics and per-result data that CI dashboar --- +## Pytest Options + +RAMPART is configured via pytest options and Python (sinks, adapters, payloads). + +### `--rampart-xdist-max-bytes` + +Maximum size in bytes of a worker's serialized result payload when running under [`pytest-xdist`](xdist.md). Defaults to `67108864` (64 MB). Workers that exceed the cap log a warning and the controller marks the run as incomplete. Also configurable via the `rampart_xdist_max_bytes` ini option. + +```bash +pytest -n auto --rampart-xdist-max-bytes=134217728 # 128 MB +``` + +--- + ## Environment Variables -RAMPART itself does not read environment variables. Your adapter and test configuration typically do. Setting them locally for ad-hoc runs: +Your adapter and test configuration typically read environment variables. Setting them locally for ad-hoc runs: === "Linux / macOS" diff --git a/docs/usage/configuration.md b/docs/usage/configuration.md index 76be67d..ce7db91 100644 --- a/docs/usage/configuration.md +++ b/docs/usage/configuration.md @@ -4,6 +4,16 @@ RAMPART's configurable components: [`LLMConfig`][rampart.core.llm.LLMConfig] for --- +## Parallel-execution tuning + +RAMPART exposes one pytest option for parallel-execution tuning. Other components (LLM endpoints, agent configuration) typically have their own configuration conventions. + +| Option | Default | Description | +|--------|---------|-------------| +| `--rampart-xdist-max-bytes` (CLI) / `rampart_xdist_max_bytes` (ini) | `67108864` (64 MB) | Maximum size of a worker's serialized result payload when running under [`pytest-xdist`](xdist.md). Workers exceeding the cap are recorded as incomplete in `TestRunReport.metadata`. | + +--- + ## LLMConfig Immutable configuration for an LLM endpoint. Used by [`LLMDriver`][rampart.drivers.llm.LLMDriver] and [`Payloads.generate_async()`][rampart.payloads.Payloads.generate_async]. diff --git a/docs/usage/pytest-integration.md b/docs/usage/pytest-integration.md index ebcdfd7..44147bb 100644 --- a/docs/usage/pytest-integration.md +++ b/docs/usage/pytest-integration.md @@ -68,6 +68,9 @@ async def test_with_threshold(adapter): - `ERROR` results count against the pass rate (they are not `SAFE`) - The trial group aggregate appears in the terminal summary +!!! tip "Running trials in parallel" + Under [`pytest-xdist`](xdist.md), use `--dist=loadgroup` to co-locate trial clones on a single worker. Aggregation is correct under any `--dist` mode, but `loadgroup` reduces cross-worker overhead. + --- ## Fixtures @@ -98,6 +101,14 @@ def rampart_sinks() -> list[ReportSink]: ] ``` +!!! warning "xdist compatibility" + Under [`pytest-xdist`](xdist.md), the controller process discovers sinks by calling `rampart_sinks` directly. Fixtures that depend on other fixtures (e.g., `tmp_path_factory`, `request`) cannot be resolved on the controller and are skipped with a warning. Use a parameterless fixture or a module-level list to remain compatible: + + ```python + # Compatible with xdist + rampart_sinks = [JsonFileReportSink(output_dir=Path(".report"))] + ``` + --- ## Automatic Result Collection diff --git a/docs/usage/results-and-reporting.md b/docs/usage/results-and-reporting.md index b5d07b2..45a3b58 100644 --- a/docs/usage/results-and-reporting.md +++ b/docs/usage/results-and-reporting.md @@ -91,6 +91,9 @@ class MyDatabaseSink: Define the `rampart_sinks` fixture in your `conftest.py`. See [pytest Markers & Fixtures](pytest-integration.md#rampart_sinks) for the setup and examples with multiple sinks. +!!! note "Parallel execution" + Under [`pytest-xdist`](xdist.md), workers send their results to the controller, which emits sinks **once** with a unified [`TestRunReport`][rampart.reporting.sink.TestRunReport]. Sinks discovered on the controller cannot depend on other pytest fixtures; use a parameterless fixture or a module-level list. See [Parallel Execution](xdist.md#constraints-on-rampart_sinks) for details. + --- ## TestRunReport diff --git a/docs/usage/xdist.md b/docs/usage/xdist.md new file mode 100644 index 0000000..7c204ee --- /dev/null +++ b/docs/usage/xdist.md @@ -0,0 +1,170 @@ +# Parallel Execution with pytest-xdist + +RAMPART supports parallel test execution via `pytest-xdist`, producing a **single unified report** even when tests run across multiple worker processes. + +--- + +## Quick Start + +```bash +pip install pytest-xdist +pytest -n 4 +``` + +With `-n 4`, pytest spawns 4 worker processes that execute tests in parallel. RAMPART intercepts each worker's results, ships them to the controller process, and emits **one consolidated report** at the end of the session. + +--- + +## How It Works + +``` +Worker 1 Worker 2 Controller +───────── ───────── ────────── +collect results collect results + │ │ +pytest_sessionfinish pytest_sessionfinish + │ │ +serialize → workeroutput serialize → workeroutput + │ │ + └───────────┬───────────────┘ + ▼ + pytest_testnodedown (per worker) + deserialize + merge into + controller's RampartSession + │ + ▼ + pytest_sessionfinish (controller) + aggregate trials → evaluate gates → emit sinks + │ + ▼ + Single unified TestRunReport +``` + +- **Workers** collect [`Result`][rampart.core.result.Result] objects normally and serialize them into `config.workeroutput`. Workers do **not** emit reports. +- **Controller** receives each worker's payload via the `pytest_testnodedown` hook, merges results into its own [`RampartSession`][rampart.pytest_plugin._session.RampartSession], and emits sinks once at session end. + +The result: **one** `JsonFileReportSink` output file, **one** call to `MyCustomSink.emit_async`, and accurate population statistics over the full result set. + +--- + +## Trial Tests with xdist + +`@pytest.mark.trial(n=, threshold=)` clones a test into N independent runs. Under xdist, clones may be distributed across workers depending on the `--dist` mode. + +| `--dist` mode | Trial behavior | +|---------------|----------------| +| `loadgroup` | All trial clones for one test run on the same worker (recommended for locality) | +| `load` (default) | Trial clones distributed round-robin across workers | +| `loadscope` / `loadfile` | Grouped by class/module/file | + +**Correctness is preserved regardless of mode** — the controller aggregates trial groups from the merged result set. You'll see a warning if you use `@trial` markers without `--dist=loadgroup`: + +```text +RAMPART @trial markers present with --dist=load. Trial clones may be +split across workers. Aggregation remains correct (controller merges +all results), but using --dist=loadgroup keeps trial clones co-located +on one worker for better locality. +``` + +To silence the warning and improve locality: + +```bash +pytest -n 4 --dist=loadgroup +``` + +--- + +## Constraints on `rampart_sinks` + +When running under xdist, the controller process does not execute test fixtures. To discover your sinks, RAMPART scans registered conftest modules for a `rampart_sinks` attribute and calls it directly. + +**Supported shapes:** + +```python +# Parameterless fixture — works on both single-process and xdist +@pytest.fixture(scope="session") +def rampart_sinks(): + return [JsonFileReportSink(output_dir=Path(".report"))] + +# Plain list assigned at module level — works on both +rampart_sinks = [JsonFileReportSink(output_dir=Path(".report"))] +``` + +**Not supported under xdist** (the warning is logged and the sink is skipped): + +```python +# Fixture with dependencies — cannot be resolved on the controller +@pytest.fixture(scope="session") +def rampart_sinks(my_sink_config, db_connection): + return [DatabaseSink(connection=db_connection)] +``` + +If your sinks need dependencies, consider: + +- Constructing them at module level with explicit configuration +- Reading configuration from environment variables inside a parameterless function +- Running without xdist (`pytest` instead of `pytest -n 4`) until a hook-based registration API is added + +--- + +## Trust Boundary & Security + +Worker payloads cross a process boundary via `execnet` and may contain attacker-controlled content (agent responses, payload text, evaluator rationale). RAMPART's serialization defends against: + +- **Arbitrary code execution** — strict JSON-safe primitives only; no `pickle`, `marshal`, or custom `__reduce__`. +- **Schema drift** — payloads with missing or unknown schema versions are rejected fail-closed. +- **Memory exhaustion** — worker payloads are capped at 64 MB by default. +- **Terminal/log injection** — ANSI escape sequences are stripped from free-form text at the deserialization boundary. +- **Path traversal** — worker-local artifact paths are stored as opaque strings in metadata; the controller never accesses worker files. + +### Size cap + +The default 64 MB cap can be overridden via the pytest CLI option or an ini setting: + +```bash +pytest -n 4 --rampart-xdist-max-bytes=134217728 +``` + +Or in `pytest.ini` / `pyproject.toml`: + +```ini +[pytest] +rampart_xdist_max_bytes = 134217728 +``` + +Workers that exceed the cap log a warning and emit a truncation marker. The controller records the affected worker as incomplete in `TestRunReport.metadata`. + +--- + +## Incomplete Runs + +If a worker crashes, runs out of time, or hits the size cap, the controller marks the run as incomplete: + +```python +report.metadata["incomplete"] # True if any worker failed +report.metadata["incomplete_reasons"] # list[str] — one per failure +``` + +Reports are still emitted with whatever data was collected. For safety-critical CI, sinks or post-processing should check the `incomplete` flag and fail the build accordingly. + +--- + +## Run-Mode Metadata + +Reports produced under xdist include: + +```python +report.metadata["xdist_active"] # True +report.metadata["worker_count"] # int +report.metadata["dist_mode"] # "load", "loadgroup", etc. +``` + +--- + +## Limitations + +- Sinks discovered on the controller cannot depend on other pytest fixtures (see Constraints above). +- Mixed RAMPART versions across controller and workers are unsupported; install the same version everywhere. +- `pytest-xdist` itself does not support interactive debugging (`--pdb`, `--trace`); use single-process mode for debugging. + +A hook-based sink registration API for complex sink configurations is a planned follow-up. diff --git a/mkdocs.yml b/mkdocs.yml index 5659cd4..c74a9f5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -163,6 +163,7 @@ nav: - Configuration: usage/configuration.md - pytest Markers & Fixtures: usage/pytest-integration.md - Results & Reporting: usage/results-and-reporting.md + - Parallel Execution: usage/xdist.md - CI Integration: usage/ci-integration.md - Contributing: - contributing/index.md diff --git a/rampart/pytest_plugin/_session.py b/rampart/pytest_plugin/_session.py index 715f314..873c36f 100644 --- a/rampart/pytest_plugin/_session.py +++ b/rampart/pytest_plugin/_session.py @@ -8,7 +8,9 @@ Note: The architecture places RampartSession in plugin.py. This implementation extracts it to a dedicated module for file size -management. This is a documented deviation from the architecture. +management and to share state with the xdist support module +(``_xdist.py``). This is a documented deviation from the +architecture. """ from __future__ import annotations @@ -17,7 +19,7 @@ import logging from collections import Counter from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from rampart.core.result import Result, SafetyStatus from rampart.reporting.sink import ReportSink, TestRunReport @@ -88,12 +90,33 @@ def __init__(self, *, sinks: list[ReportSink] | None = None) -> None: self._sinks: list[ReportSink] = sinks or [] self._duration_seconds: float = 0.0 self._cached_report: TestRunReport | None = None + self._emitted: bool = False + self._incomplete: bool = False + self._incomplete_reasons: list[str] = [] + self._report_metadata: dict[str, object] = {} @property def sinks(self) -> list[ReportSink]: """Configured report sinks.""" return list(self._sinks) + @property + def results_by_nodeid(self) -> dict[str, list[Result]]: + """Read-only view of results grouped by pytest node ID.""" + return { + nodeid: list(results) for nodeid, results in self._results_by_nodeid.items() + } + + @property + def is_emitted(self) -> bool: + """True once report emission has been attempted (idempotency guard).""" + return self._emitted + + @property + def is_incomplete(self) -> bool: + """True if any worker failed to deliver complete results.""" + return self._incomplete + def add_sinks(self, *, sinks: list[ReportSink]) -> None: """Register additional sinks for report emission. @@ -234,25 +257,87 @@ def trial_groups(self) -> dict[str, TrialGroupResult]: """Trial group aggregates, keyed by base node ID.""" return dict(self._trial_groups) + def merge_worker_results( + self, + *, + results_by_nodeid: dict[str, list[Result]], + ) -> None: + """Merge an xdist worker's results into this session. + + Extends both the flat ``_results`` list and the + ``_results_by_nodeid`` mapping. Invalidates any cached report + so the next ``build_report()`` reflects the merged data. + + Args: + results_by_nodeid (dict[str, list[Result]]): Worker results + grouped by pytest node ID. + """ + for nodeid, results in results_by_nodeid.items(): + self._results.extend(results) + self._results_by_nodeid.setdefault(nodeid, []).extend(results) + self._cached_report = None + + def mark_emitted(self) -> None: + """Mark the session as having attempted report emission.""" + self._emitted = True + + def mark_incomplete(self, *, reason: str) -> None: + """Record that a worker failed to deliver complete results. + + Args: + reason (str): A short human-readable explanation surfaced + in the report metadata. + """ + self._incomplete = True + self._incomplete_reasons.append(reason) + self._cached_report = None + + def set_report_metadata(self, *, metadata: dict[str, object]) -> None: + """Attach run-level metadata that will appear on ``TestRunReport``. + + Used by the plugin to surface xdist run-mode information + (active, worker count, dist mode). Subsequent calls merge into + existing metadata. + + Args: + metadata (dict[str, object]): Key/value pairs to attach. + """ + self._report_metadata.update(metadata) + self._cached_report = None + def build_report(self) -> TestRunReport: """Build a TestRunReport from all collected results. The report is cached and reused on subsequent calls. The - cache is invalidated when new results are absorbed. + cache is invalidated when new results are absorbed or merged + or when metadata is updated. + + Results are sorted by their pytest node ID (from + ``metadata['test_name']`` when available) for deterministic + ordering across xdist worker completion orders. Returns: TestRunReport: Aggregated test run results. """ if self._cached_report is not None: return self._cached_report - counts = Counter(r.status for r in self._results) + sorted_results = sorted( + self._results, + key=lambda r: str(r.metadata.get("test_name", "")), + ) + counts = Counter(r.status for r in sorted_results) + metadata: dict[str, Any] = dict(self._report_metadata) + if self._incomplete: + metadata["incomplete"] = True + metadata["incomplete_reasons"] = list(self._incomplete_reasons) self._cached_report = TestRunReport( - results=list(self._results), - total_runs=len(self._results), + results=sorted_results, + total_runs=len(sorted_results), passed=counts[SafetyStatus.SAFE], failed=counts[SafetyStatus.UNSAFE], undetermined=counts[SafetyStatus.UNDETERMINED], errors=counts[SafetyStatus.ERROR], duration_seconds=self._duration_seconds, + metadata=metadata, ) return self._cached_report diff --git a/rampart/pytest_plugin/_xdist.py b/rampart/pytest_plugin/_xdist.py new file mode 100644 index 0000000..5633334 --- /dev/null +++ b/rampart/pytest_plugin/_xdist.py @@ -0,0 +1,1066 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""xdist support for RAMPART's pytest plugin. + +Provides serialization, deserialization, and controller-side merge +logic for running RAMPART under pytest-xdist. Workers serialize their +``Result`` objects into ``config.workeroutput``; the controller merges +worker payloads in ``pytest_testnodedown`` and emits a single unified +report at session end. + +Trust boundary: worker payloads may contain attacker-controlled +content (agent responses, payload text). Serialization is strictly +JSON-safe primitives; deserialization validates schema version, +enum values, and metadata depth; ANSI escapes are stripped from free +text as defense-in-depth. + +Note: The architecture places all pytest plugin logic in plugin.py. +This module extracts xdist-specific logic to keep plugin.py focused +on hook registration and to isolate the serialization/security model. +This is a documented deviation from the architecture, complementing +the earlier extraction of RampartSession into _session.py. +""" + +from __future__ import annotations + +import json +import logging +import math +import re +from datetime import datetime +from typing import TYPE_CHECKING, Any, cast + +from rampart.core.result import ( + HarmCategory, + InjectionRecord, + Result, + SafetyStatus, +) +from rampart.core.types import ( + EvalOutcome, + EvalResult, + ObservabilityLevel, + Payload, + PayloadFormat, + Request, + Response, + SideEffect, + ToolCall, + Turn, +) +from rampart.reporting.sink import ReportSink + +if TYPE_CHECKING: + import pytest + + from rampart.pytest_plugin._session import RampartSession + +logger = logging.getLogger(__name__) + +SCHEMA_VERSION: str = "rampart.xdist.v1" +WORKEROUTPUT_KEY: str = "rampart_xdist_v1" +SIZE_LIMIT_OPTION: str = "rampart_xdist_max_bytes" +DEFAULT_SIZE_LIMIT_BYTES: int = 64 * 1024 * 1024 +MAX_METADATA_DEPTH: int = 6 + +_ANSI_ESCAPE_RE: re.Pattern[str] = re.compile(r"\x1b\[[0-9;]*[A-Za-z]") +_TRUNCATED_MARKER: str = "rampart_truncated" + + +class WorkerOutputError(Exception): + """Base error for xdist worker output processing failures.""" + + +class SchemaVersionError(WorkerOutputError): + """Raised when a worker payload has missing or unknown schema version.""" + + +class SizeLimitError(WorkerOutputError): + """Raised when a worker payload exceeds the configured size cap.""" + + +def is_xdist_worker(*, config: pytest.Config) -> bool: + """Return True when this process is a pytest-xdist worker. + + Detection is attribute-based; no xdist import required, so this + function is safe to call when pytest-xdist is not installed. + + Args: + config (pytest.Config): The pytest configuration object. + + Returns: + bool: True if running in an xdist worker process. + """ + return hasattr(config, "workerinput") + + +def is_xdist_controller(*, config: pytest.Config) -> bool: + """Return True when this process is the pytest-xdist controller. + + The controller is defined as a process where xdist is active + (``--numprocesses`` is set) and which is NOT itself a worker. + + Args: + config (pytest.Config): The pytest configuration object. + + Returns: + bool: True if running in the xdist controller process. + """ + if is_xdist_worker(config=config): + return False + numprocesses = getattr(config.option, "numprocesses", None) + return numprocesses is not None and numprocesses != 0 + + +def get_dist_mode(*, config: pytest.Config) -> str: + """Return the active ``--dist`` mode string. + + Args: + config (pytest.Config): The pytest configuration object. + + Returns: + str: The dist mode (e.g., ``"load"``, ``"loadgroup"``, ``"no"``). + """ + return cast("str", getattr(config.option, "dist", "no")) + + +def get_worker_count(*, config: pytest.Config) -> int: + """Return the number of xdist workers configured. + + Args: + config (pytest.Config): The pytest configuration object. + + Returns: + int: Number of workers (0 when xdist is not active). + """ + numprocesses = getattr(config.option, "numprocesses", 0) + return int(numprocesses) if numprocesses else 0 + + +def _size_limit(*, config: pytest.Config) -> int: + """Resolve the worker payload size cap from pytest config or default. + + Reads from the ``--rampart-xdist-max-bytes`` CLI option first, then + the ``rampart_xdist_max_bytes`` ini option, then falls back to + ``DEFAULT_SIZE_LIMIT_BYTES``. + """ + raw: Any = config.getoption(SIZE_LIMIT_OPTION, default=None) + if raw is None: + try: + raw = config.getini(SIZE_LIMIT_OPTION) + except (ValueError, KeyError): + raw = None + if raw in (None, ""): + return DEFAULT_SIZE_LIMIT_BYTES + try: + parsed = int(raw) + except (TypeError, ValueError): + logger.warning( + "Invalid %s=%r; falling back to default %d bytes.", + SIZE_LIMIT_OPTION, + raw, + DEFAULT_SIZE_LIMIT_BYTES, + ) + return DEFAULT_SIZE_LIMIT_BYTES + if parsed <= 0: + logger.warning( + "%s=%d must be > 0; falling back to default %d bytes.", + SIZE_LIMIT_OPTION, + parsed, + DEFAULT_SIZE_LIMIT_BYTES, + ) + return DEFAULT_SIZE_LIMIT_BYTES + return parsed + + +def _strip_ansi(*, text: str) -> str: + """Remove ANSI escape sequences from free-form text. + + Args: + text (str): The text to sanitize. + + Returns: + str: Text with ANSI escape sequences removed. + """ + return _ANSI_ESCAPE_RE.sub("", text) + + +def _sanitize( # noqa: PLR0911 + *, + value: Any, # noqa: ANN401 + depth: int = 0, + strip_ansi: bool = False, +) -> Any: # noqa: ANN401 + """Coerce a value to a JSON-safe form. + + Walks dicts and lists up to ``MAX_METADATA_DEPTH``. Values not in + (str, int, bool, NoneType, finite float, dict, list, tuple) are + coerced via ``repr()``. NaN/Inf floats are coerced to ``None``. + + When ``strip_ansi=True`` (set on the deserialization path), ANSI + escape sequences are removed from every nested string value so + that attacker-controlled escapes inside ``arguments``, ``details``, + and ``metadata`` cannot reach terminal renderers. + + Args: + value (Any): The value to sanitize. + depth (int): Current recursion depth (internal). + strip_ansi (bool): If True, strip ANSI escapes from strings. + + Returns: + Any: A JSON-safe representation. + """ + if depth > MAX_METADATA_DEPTH: + return repr(value) + if value is None or isinstance(value, bool): + return value + if isinstance(value, str): + return _strip_ansi(text=value) if strip_ansi else value + if isinstance(value, int): + return value + if isinstance(value, float): + return value if math.isfinite(value) else None + if isinstance(value, dict): + return { + str(k): _sanitize(value=v, depth=depth + 1, strip_ansi=strip_ansi) + for k, v in cast("dict[Any, Any]", value).items() + } + if isinstance(value, list | tuple): + return [ + _sanitize(value=v, depth=depth + 1, strip_ansi=strip_ansi) + for v in cast("list[Any]", value) + ] + return repr(value) + + +def _is_json_passthrough(value: Any) -> bool: # noqa: ANN401 + """True if a value would pass through ``_sanitize`` unchanged.""" + if value is None or isinstance(value, str | bool): + return True + if isinstance(value, int): + return True + if isinstance(value, float): + return math.isfinite(value) + return False + + +def _sanitize_metadata( + *, + metadata: dict[str, Any], + nodeid: str, + context: str, +) -> dict[str, Any]: + """Sanitize a metadata dict; log keys that required coercion. + + Logs at warning level with the originating nodeid and the list of + keys whose values were coerced so users can diagnose lossy fields + without polluting the user-visible metadata payload. + + Args: + metadata (dict[str, Any]): The metadata to sanitize. + nodeid (str): Originating test nodeid (for log context). + context (str): Source context (e.g., ``"result"``, ``"payload"``). + + Returns: + dict[str, Any]: Sanitized metadata dict. + """ + sanitized: dict[str, Any] = {} + coerced: list[str] = [] + for key, value in metadata.items(): + key_str = str(key) + sanitized[key_str] = _sanitize(value=value) + passthrough = _is_json_passthrough(value) + collection = isinstance(value, dict | list | tuple) + if not passthrough and not collection: + coerced.append(key_str) + if coerced: + logger.warning( + "Sanitized %d non-serializable metadata key(s) for %s in %s: %s", + len(coerced), + nodeid, + context, + coerced, + ) + return sanitized + + +def _safe_float(*, value: float) -> float | None: + """Coerce non-finite floats to None for JSON safety.""" + return value if math.isfinite(value) else None + + +def _isoformat(*, timestamp: datetime | None) -> str | None: + """Convert a datetime to ISO 8601 string, or None.""" + return timestamp.isoformat() if timestamp is not None else None + + +def _serialize_eval_result(*, eval_result: EvalResult) -> dict[str, Any]: + """Serialize an EvalResult to a JSON-safe dict.""" + return { + "outcome": eval_result.outcome.value, + "confidence": _safe_float(value=eval_result.confidence), + "evidence": [str(e) for e in eval_result.evidence], + "rationale": eval_result.rationale, + } + + +def _serialize_tool_call(*, tool_call: ToolCall, nodeid: str) -> dict[str, Any]: + """Serialize a ToolCall to a JSON-safe dict.""" + return { + "name": tool_call.name, + "arguments": _sanitize_metadata( + metadata=tool_call.arguments, + nodeid=nodeid, + context="tool_call.arguments", + ), + "result": tool_call.result, + "timestamp": _isoformat(timestamp=tool_call.timestamp), + } + + +def _serialize_side_effect( + *, + side_effect: SideEffect, + nodeid: str, +) -> dict[str, Any]: + """Serialize a SideEffect to a JSON-safe dict.""" + return { + "kind": side_effect.kind, + "details": _sanitize_metadata( + metadata=side_effect.details, + nodeid=nodeid, + context="side_effect.details", + ), + } + + +def _serialize_payload(*, payload: Payload, nodeid: str) -> dict[str, Any]: + """Serialize a Payload to a JSON-safe dict. + + The artifact path (if any) is converted to a string for display + only; the controller never accesses worker-local files. + """ + return { + "content": payload.content, + "id": payload.id, + "format": payload.format.value, + "artifact": str(payload.artifact) if payload.artifact is not None else None, + "metadata": _sanitize_metadata( + metadata=payload.metadata, + nodeid=nodeid, + context="payload.metadata", + ), + } + + +def _serialize_request(*, request: Request, nodeid: str) -> dict[str, Any]: + """Serialize a Request to a JSON-safe dict.""" + return { + "prompt": request.prompt, + "attachments": [ + _serialize_payload(payload=p, nodeid=nodeid) for p in request.attachments + ], + } + + +def _serialize_response(*, response: Response, nodeid: str) -> dict[str, Any]: + """Serialize a Response to a JSON-safe dict.""" + return { + "text": response.text, + "tool_calls": [ + _serialize_tool_call(tool_call=tc, nodeid=nodeid) + for tc in response.tool_calls + ], + "side_effects": [ + _serialize_side_effect(side_effect=se, nodeid=nodeid) + for se in response.side_effects + ], + "metadata": _sanitize_metadata( + metadata=response.metadata, + nodeid=nodeid, + context="response.metadata", + ), + } + + +def _serialize_turn(*, turn: Turn, nodeid: str) -> dict[str, Any]: + """Serialize a Turn to a JSON-safe dict.""" + return { + "request": _serialize_request(request=turn.request, nodeid=nodeid), + "response": _serialize_response(response=turn.response, nodeid=nodeid), + "eval_result": ( + _serialize_eval_result(eval_result=turn.eval_result) + if turn.eval_result is not None + else None + ), + "turn_number": turn.turn_number, + "timestamp": _isoformat(timestamp=turn.timestamp), + "driver_reasoning": turn.driver_reasoning, + } + + +def _serialize_injection_record(*, injection: InjectionRecord) -> dict[str, Any]: + """Serialize an InjectionRecord to a JSON-safe dict.""" + return { + "payload_id": injection.payload_id, + "surface_name": injection.surface_name, + } + + +def _serialize_result(*, result: Result, nodeid: str) -> dict[str, Any]: + """Serialize a Result to a JSON-safe dict.""" + return { + "safe": result.safe, + "status": result.status.value, + "summary": result.summary, + "turns": [_serialize_turn(turn=t, nodeid=nodeid) for t in result.turns], + "duration_seconds": _safe_float(value=result.duration_seconds), + "harm_category": ( + str(result.harm_category) if result.harm_category is not None else None + ), + "strategy": result.strategy, + "observability_level": result.observability_level.value, + "injections": [ + _serialize_injection_record(injection=i) for i in result.injections + ], + "metadata": _sanitize_metadata( + metadata=result.metadata, + nodeid=nodeid, + context="result.metadata", + ), + } + + +def serialize_worker_data(*, session: RampartSession) -> dict[str, Any]: + """Serialize a worker's RampartSession state for transport to the controller. + + Produces a JSON-safe dict containing the schema version, the + package version (for cross-version diagnostics), and the worker's + ``_results_by_nodeid`` mapping serialized to primitive types. + + Args: + session (RampartSession): The worker's session state. + + Returns: + dict[str, Any]: A JSON-safe payload ready to write to + ``config.workeroutput``. + """ + serialized: dict[str, list[dict[str, Any]]] = {} + for nodeid, results in session.results_by_nodeid.items(): + serialized[nodeid] = [ + _serialize_result(result=r, nodeid=nodeid) for r in results + ] + return { + "schema": SCHEMA_VERSION, + "package_version": _package_version(), + "results_by_nodeid": serialized, + } + + +def _package_version() -> str: + """Return the installed rampart package version (best-effort).""" + try: + from importlib.metadata import version # noqa: PLC0415 + + return version("rampart") + except Exception: # noqa: BLE001 + return "unknown" + + +def _validate_schema(*, data: object) -> dict[str, Any]: + """Validate that ``data`` is a worker payload of the expected schema.""" + if not isinstance(data, dict): + msg = f"Expected dict worker payload, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + schema = typed.get("schema") + if schema is None: + msg = "Worker payload missing required 'schema' key." + raise SchemaVersionError(msg) + if schema != SCHEMA_VERSION: + msg = ( + f"Worker payload schema {schema!r} does not match " + f"controller schema {SCHEMA_VERSION!r}; rejecting to avoid " + "best-effort parsing of an unknown format." + ) + raise SchemaVersionError(msg) + return typed + + +def _deserialize_safety_status(*, value: object) -> SafetyStatus: + """Deserialize a SafetyStatus enum value.""" + if not isinstance(value, str): + msg = f"Expected string for SafetyStatus, got {type(value).__name__}." + raise WorkerOutputError(msg) + try: + return SafetyStatus(value) + except ValueError as exc: + msg = f"Unknown SafetyStatus value: {value!r}." + raise WorkerOutputError(msg) from exc + + +def _deserialize_observability_level(*, value: object) -> ObservabilityLevel: + """Deserialize an ObservabilityLevel enum value.""" + if not isinstance(value, str): + msg = f"Expected string for ObservabilityLevel, got {type(value).__name__}." + raise WorkerOutputError(msg) + try: + return ObservabilityLevel(value) + except ValueError as exc: + msg = f"Unknown ObservabilityLevel value: {value!r}." + raise WorkerOutputError(msg) from exc + + +def _deserialize_eval_outcome(*, value: object) -> EvalOutcome: + """Deserialize an EvalOutcome enum value.""" + if not isinstance(value, str): + msg = f"Expected string for EvalOutcome, got {type(value).__name__}." + raise WorkerOutputError(msg) + try: + return EvalOutcome(value) + except ValueError as exc: + msg = f"Unknown EvalOutcome value: {value!r}." + raise WorkerOutputError(msg) from exc + + +def _deserialize_harm_category(*, value: object) -> HarmCategory | str | None: + """Deserialize a HarmCategory enum value, plain string, or None.""" + if value is None: + return None + if not isinstance(value, str): + msg = f"Expected string for harm_category, got {type(value).__name__}." + raise WorkerOutputError(msg) + try: + return HarmCategory(value) + except ValueError: + return value + + +def _deserialize_datetime(*, value: object) -> datetime | None: + """Deserialize an ISO 8601 datetime string, or None.""" + if value is None: + return None + if not isinstance(value, str): + msg = f"Expected string for datetime, got {type(value).__name__}." + raise WorkerOutputError(msg) + try: + return datetime.fromisoformat(value) + except ValueError as exc: + msg = f"Invalid ISO 8601 datetime: {value!r}." + raise WorkerOutputError(msg) from exc + + +def _deserialize_eval_result(*, data: object) -> EvalResult | None: + """Deserialize an EvalResult, or None when input is None.""" + if data is None: + return None + if not isinstance(data, dict): + msg = f"Expected dict for EvalResult, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + outcome = _deserialize_eval_outcome(value=typed.get("outcome")) + raw_confidence = typed.get("confidence") + confidence = ( + float(raw_confidence) if isinstance(raw_confidence, int | float) else 1.0 + ) + raw_evidence = typed.get("evidence", []) + evidence_items = cast( + "list[Any]", + raw_evidence if isinstance(raw_evidence, list) else [], + ) + evidence: list[str] = [_strip_ansi(text=str(e)) for e in evidence_items] + rationale = _strip_ansi(text=str(typed.get("rationale", ""))) + return EvalResult( + outcome=outcome, + confidence=confidence, + evidence=evidence, + rationale=rationale, + ) + + +def _deserialize_tool_call(*, data: object) -> ToolCall: + """Deserialize a ToolCall.""" + if not isinstance(data, dict): + msg = f"Expected dict for ToolCall, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_args = typed.get("arguments", {}) + arguments = _sanitize( + value=raw_args if isinstance(raw_args, dict) else {}, + strip_ansi=True, + ) + raw_result = typed.get("result") + return ToolCall( + name=str(typed.get("name", "")), + arguments=cast("dict[str, Any]", arguments), + result=_strip_ansi(text=str(raw_result)) if raw_result is not None else None, + timestamp=_deserialize_datetime(value=typed.get("timestamp")), + ) + + +def _deserialize_side_effect(*, data: object) -> SideEffect: + """Deserialize a SideEffect.""" + if not isinstance(data, dict): + msg = f"Expected dict for SideEffect, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_details = typed.get("details", {}) + details = _sanitize( + value=raw_details if isinstance(raw_details, dict) else {}, + strip_ansi=True, + ) + return SideEffect( + kind=str(typed.get("kind", "")), + details=cast("dict[str, Any]", details), + ) + + +def _deserialize_payload(*, data: object) -> Payload: + """Deserialize a Payload. + + The controller never sees worker-local artifacts. Reconstructed + payloads always use ``format=TEXT`` and ``artifact=None``; the + original format and artifact path are preserved under namespaced + keys in metadata for debugging. + """ + if not isinstance(data, dict): + msg = f"Expected dict for Payload, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_metadata = typed.get("metadata", {}) + metadata = _sanitize( + value=raw_metadata if isinstance(raw_metadata, dict) else {}, + strip_ansi=True, + ) + metadata_dict = cast("dict[str, Any]", metadata) + original_format = str(typed.get("format", PayloadFormat.TEXT.value)) + if original_format != PayloadFormat.TEXT.value: + metadata_dict.setdefault("_rampart_worker_format", original_format) + original_artifact = typed.get("artifact") + if original_artifact is not None: + metadata_dict.setdefault( + "_rampart_worker_artifact_path", + str(original_artifact), + ) + return Payload( + content=_strip_ansi(text=str(typed.get("content", ""))), + id=str(typed.get("id", "")), + format=PayloadFormat.TEXT, + artifact=None, + metadata=metadata_dict, + ) + + +def _deserialize_request(*, data: object) -> Request: + """Deserialize a Request, providing a fallback prompt when empty.""" + if not isinstance(data, dict): + msg = f"Expected dict for Request, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_prompt = typed.get("prompt") + prompt: str | None = ( + _strip_ansi(text=str(raw_prompt)) if raw_prompt is not None else None + ) + raw_attachments = typed.get("attachments", []) + attachment_items = cast( + "list[Any]", + raw_attachments if isinstance(raw_attachments, list) else [], + ) + attachments: list[Payload] = [ + _deserialize_payload(data=p) for p in attachment_items + ] + if prompt is None and not attachments: + prompt = "" + return Request(prompt=prompt, attachments=attachments) + + +def _deserialize_response(*, data: object) -> Response: + """Deserialize a Response.""" + if not isinstance(data, dict): + msg = f"Expected dict for Response, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_tcs = typed.get("tool_calls", []) + raw_ses = typed.get("side_effects", []) + raw_metadata = typed.get("metadata", {}) + metadata = _sanitize( + value=raw_metadata if isinstance(raw_metadata, dict) else {}, + strip_ansi=True, + ) + return Response( + text=_strip_ansi(text=str(typed.get("text", ""))), + tool_calls=[ + _deserialize_tool_call(data=tc) + for tc in cast("list[Any]", raw_tcs if isinstance(raw_tcs, list) else []) + ], + side_effects=[ + _deserialize_side_effect(data=se) + for se in cast("list[Any]", raw_ses if isinstance(raw_ses, list) else []) + ], + metadata=cast("dict[str, Any]", metadata), + ) + + +def _deserialize_turn(*, data: object) -> Turn: + """Deserialize a Turn.""" + if not isinstance(data, dict): + msg = f"Expected dict for Turn, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_turn_number = typed.get("turn_number", 0) + return Turn( + request=_deserialize_request(data=typed.get("request")), + response=_deserialize_response(data=typed.get("response")), + eval_result=_deserialize_eval_result(data=typed.get("eval_result")), + turn_number=int(raw_turn_number) if isinstance(raw_turn_number, int) else 0, + timestamp=_deserialize_datetime(value=typed.get("timestamp")), + driver_reasoning=_strip_ansi(text=str(typed.get("driver_reasoning", ""))), + ) + + +def _deserialize_injection_record(*, data: object) -> InjectionRecord: + """Deserialize an InjectionRecord.""" + if not isinstance(data, dict): + msg = f"Expected dict for InjectionRecord, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_payload_id = typed.get("payload_id") + return InjectionRecord( + payload_id=str(raw_payload_id) if raw_payload_id is not None else None, + surface_name=str(typed.get("surface_name", "")), + ) + + +def _deserialize_result(*, data: object) -> Result: + """Deserialize a Result.""" + if not isinstance(data, dict): + msg = f"Expected dict for Result, got {type(data).__name__}." + raise WorkerOutputError(msg) + typed = cast("dict[str, Any]", data) + raw_turns = typed.get("turns", []) + raw_injections = typed.get("injections", []) + raw_metadata = typed.get("metadata", {}) + metadata = _sanitize( + value=raw_metadata if isinstance(raw_metadata, dict) else {}, + strip_ansi=True, + ) + raw_duration = typed.get("duration_seconds", 0.0) + duration = ( + float(raw_duration) + if isinstance(raw_duration, int | float) and math.isfinite(float(raw_duration)) + else 0.0 + ) + return Result( + safe=bool(typed.get("safe", False)), + status=_deserialize_safety_status(value=typed.get("status")), + summary=_strip_ansi(text=str(typed.get("summary", ""))), + turns=[ + _deserialize_turn(data=t) + for t in cast("list[Any]", raw_turns if isinstance(raw_turns, list) else []) + ], + duration_seconds=duration, + harm_category=_deserialize_harm_category(value=typed.get("harm_category")), + strategy=str(typed.get("strategy", "")), + observability_level=_deserialize_observability_level( + value=typed.get("observability_level"), + ), + injections=[ + _deserialize_injection_record(data=i) + for i in cast( + "list[Any]", + raw_injections if isinstance(raw_injections, list) else [], + ) + ], + metadata=cast("dict[str, Any]", metadata), + ) + + +def deserialize_worker_data(*, data: object) -> dict[str, list[Result]]: + """Deserialize a worker payload back into a ``results_by_nodeid`` mapping. + + Performs strict schema validation: missing ``schema`` key, unknown + versions, and malformed enum values all raise ``WorkerOutputError`` + (or subclass). Caller should catch and mark the run incomplete + rather than letting the exception propagate to pytest. + + Args: + data (object): The deserialized JSON object from + ``node.workeroutput``. + + Returns: + dict[str, list[Result]]: Results grouped by nodeid. + + Raises: + SchemaVersionError: Missing or unknown schema version. + WorkerOutputError: Malformed payload (type errors, bad enums). + """ + typed = _validate_schema(data=data) + raw_results = typed.get("results_by_nodeid", {}) + if not isinstance(raw_results, dict): + msg = f"Expected dict for results_by_nodeid, got {type(raw_results).__name__}." + raise WorkerOutputError(msg) + out: dict[str, list[Result]] = {} + for nodeid, results_data in cast("dict[Any, Any]", raw_results).items(): + if not isinstance(results_data, list): + continue + out[str(nodeid)] = [ + _deserialize_result(data=r) for r in cast("list[Any]", results_data) + ] + return out + + +def finalize_worker(*, config: pytest.Config, session: RampartSession) -> None: + """Serialize the worker's session state into ``config.workeroutput``. + + Called from ``pytest_sessionfinish`` on each xdist worker. The + worker skips sink emission entirely; the controller is responsible + for the final report. + + Args: + config (pytest.Config): The pytest configuration object. + session (RampartSession): The worker's session state. + + Raises: + SizeLimitError: If the serialized payload exceeds the + configured cap. The truncation marker is still written to + ``workeroutput`` before the exception is raised so the + controller can record the run as incomplete. + """ + if not is_xdist_worker(config=config): + return + payload = serialize_worker_data(session=session) + encoded = json.dumps(payload, default=str) + size = len(encoded.encode("utf-8")) + limit = _size_limit(config=config) + workeroutput = cast( + "dict[str, Any]", + config.workeroutput, # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] + ) + if size > limit: + workeroutput[WORKEROUTPUT_KEY] = { + "schema": SCHEMA_VERSION, + _TRUNCATED_MARKER: True, + "size_bytes": size, + "limit_bytes": limit, + } + msg = ( + f"Worker payload size {size} bytes exceeds cap of {limit}; " + f"truncated. Increase --{SIZE_LIMIT_OPTION.replace('_', '-')} " + f"(or the {SIZE_LIMIT_OPTION} ini option) to raise the cap." + ) + raise SizeLimitError(msg) + logger.debug("Worker payload size: %d bytes", size) + workeroutput[WORKEROUTPUT_KEY] = payload + + +def handle_testnodedown( + *, + session: RampartSession, + node: object, + error: object, +) -> None: + """Merge a finished worker's results into the controller session. + + Called from ``pytest_testnodedown`` on the controller for each + worker that completes. Failures (missing payload, deserialization + errors, worker crashes) are recorded via ``mark_incomplete`` rather + than raised, so a single bad worker does not abort report emission. + + Args: + session (RampartSession): The controller's session state. + node: The xdist node object (has ``workeroutput`` attribute). + error: The shutdown error from xdist, or None on clean exit. + """ + worker_id = getattr(node, "gateway", None) + worker_id_str = str(getattr(worker_id, "id", node)) if worker_id else str(node) + if error is not None: + logger.warning( + "Worker %s reported shutdown error; report will be incomplete: %s", + worker_id_str, + error, + ) + session.mark_incomplete(reason=f"worker {worker_id_str} error: {error}") + return + workeroutput = getattr(node, "workeroutput", None) + if not isinstance(workeroutput, dict): + logger.warning( + "Worker %s exited without workeroutput; report will be incomplete.", + worker_id_str, + ) + session.mark_incomplete(reason=f"worker {worker_id_str} missing workeroutput") + return + payload: Any = cast("dict[str, Any]", workeroutput).get(WORKEROUTPUT_KEY) + if payload is None: + logger.warning( + "Worker %s did not produce RAMPART output; report will be incomplete.", + worker_id_str, + ) + session.mark_incomplete(reason=f"worker {worker_id_str} missing RAMPART output") + return + typed_payload_dict: dict[str, Any] | None = ( + cast("dict[str, Any]", payload) if isinstance(payload, dict) else None + ) + if typed_payload_dict is not None and typed_payload_dict.get(_TRUNCATED_MARKER): + logger.error( + "Worker %s payload was truncated due to size cap; " + "report will be incomplete.", + worker_id_str, + ) + session.mark_incomplete( + reason=f"worker {worker_id_str} payload truncated (size cap)", + ) + return + try: + results_by_nodeid = deserialize_worker_data(data=cast("object", payload)) + except WorkerOutputError as exc: + logger.exception( + "Failed to deserialize worker %s output; report will be incomplete.", + worker_id_str, + ) + session.mark_incomplete( + reason=f"worker {worker_id_str} deserialization failed: {exc}", + ) + return + incoming_version: str | None = None + if typed_payload_dict is not None: + raw_version = typed_payload_dict.get("package_version") + if isinstance(raw_version, str): + incoming_version = raw_version + if incoming_version and incoming_version != _package_version(): + logger.warning( + "Worker %s package_version=%s differs from controller %s; " + "mixed versions are unsupported.", + worker_id_str, + incoming_version, + _package_version(), + ) + session.merge_worker_results(results_by_nodeid=results_by_nodeid) + logger.info( + "Merged %d result group(s) from worker %s.", + len(results_by_nodeid), + worker_id_str, + ) + + +def discover_sinks_from_conftest(*, config: pytest.Config) -> list[ReportSink]: + """Discover ``rampart_sinks`` definitions from registered conftest modules. + + Workers run the standard ``_rampart_sink_bootstrap`` fixture to + register sinks via pytest's fixture machinery. The controller has + no test execution, so fixtures do not run. This function scans + registered plugins for a module-level ``rampart_sinks`` attribute + and resolves it: + + - If callable with zero arguments, invoke it and use the return. + - If a list, use it directly. + - Otherwise, log a warning and skip. + + Sinks that depend on other fixtures cannot be discovered this way. + Such configurations are a documented limitation; a hook-based API + is a planned follow-up. + + Args: + config (pytest.Config): The pytest configuration object. + + Returns: + list[ReportSink]: Discovered sinks (may be empty). + """ + discovered: list[ReportSink] = [] + seen: set[int] = set() + for plugin in config.pluginmanager.get_plugins(): + if plugin is None or id(plugin) in seen: + continue + seen.add(id(plugin)) + candidate = getattr(plugin, "rampart_sinks", None) + if candidate is None: + continue + resolved = _resolve_sink_candidate(candidate=candidate, plugin=plugin) + if resolved is None: + continue + for sink in resolved: + if isinstance(sink, ReportSink): + discovered.append(sink) + else: + logger.warning( + "rampart_sinks in %s yielded a non-ReportSink: %r", + getattr(plugin, "__name__", repr(plugin)), + sink, + ) + return discovered + + +def _resolve_sink_candidate( + *, + candidate: object, + plugin: object, +) -> list[object] | None: + """Resolve a ``rampart_sinks`` attribute into a list of sinks. + + Handles three shapes: + + - A list — used directly. + - A pytest fixture (``FixtureFunctionDefinition`` or equivalent) — + unwrapped to its underlying function via ``_get_wrapped_function`` + or ``__wrapped__``; called only if it takes no parameters. + - A plain callable — called directly. + + Returns None on failure (logged) so the caller can continue + scanning other plugins. + """ + import inspect # noqa: PLC0415 + + plugin_name = getattr(plugin, "__name__", repr(plugin)) + if isinstance(candidate, list): + return cast("list[object]", candidate) + + function: object = candidate + wrap_method = getattr(candidate, "_get_wrapped_function", None) + if callable(wrap_method): + function = wrap_method() + elif hasattr(candidate, "__wrapped__"): + function = getattr(candidate, "__wrapped__", candidate) + + if not callable(function): + logger.warning( + "rampart_sinks in %s is %s; expected callable or list[ReportSink].", + plugin_name, + type(candidate).__name__, + ) + return None + + try: + sig = inspect.signature(function) + except (TypeError, ValueError): + sig = None + + if sig is not None and len(sig.parameters) > 0: + logger.warning( + "rampart_sinks in %s requires fixture dependencies (%s); " + "controller-side discovery cannot satisfy those. Provide a " + "parameterless function or list, or run with --dist=no.", + plugin_name, + list(sig.parameters), + ) + return None + + try: + value = function() + except (KeyboardInterrupt, SystemExit): + raise + except Exception as exc: # noqa: BLE001 — broad on purpose: user code + logger.warning( + "rampart_sinks in %s raised during controller-side discovery: %s", + plugin_name, + exc, + ) + return None + + if isinstance(value, list): + return cast("list[object]", value) + logger.warning( + "rampart_sinks in %s returned %s instead of list[ReportSink].", + plugin_name, + type(value).__name__, + ) + return None diff --git a/rampart/pytest_plugin/plugin.py b/rampart/pytest_plugin/plugin.py index c545ea4..921f843 100644 --- a/rampart/pytest_plugin/plugin.py +++ b/rampart/pytest_plugin/plugin.py @@ -44,6 +44,18 @@ deactivate_collector, ) from rampart.pytest_plugin._session import RampartSession +from rampart.pytest_plugin._xdist import ( + DEFAULT_SIZE_LIMIT_BYTES, + SIZE_LIMIT_OPTION, + SizeLimitError, + discover_sinks_from_conftest, + finalize_worker, + get_dist_mode, + get_worker_count, + handle_testnodedown, + is_xdist_controller, + is_xdist_worker, +) from rampart.reporting.sink import ReportSink if TYPE_CHECKING: @@ -54,10 +66,12 @@ logger = logging.getLogger(__name__) __all__ = [ + "pytest_addoption", "pytest_collection_modifyitems", "pytest_configure", "pytest_sessionfinish", "pytest_terminal_summary", + "pytest_testnodedown", "pytest_unconfigure", ] @@ -128,6 +142,33 @@ def _resolve_trial_n(marker: pytest.Mark) -> int: return raw +def pytest_addoption(parser: pytest.Parser) -> None: + """Register RAMPART CLI and ini options. + + Args: + parser (pytest.Parser): The pytest argument parser. + """ + group = parser.getgroup("rampart") + group.addoption( + f"--{SIZE_LIMIT_OPTION.replace('_', '-')}", + dest=SIZE_LIMIT_OPTION, + type=int, + default=None, + help=( + "Maximum size in bytes of a worker's serialized result payload " + f"under pytest-xdist (default: {DEFAULT_SIZE_LIMIT_BYTES})." + ), + ) + parser.addini( + SIZE_LIMIT_OPTION, + help=( + "Maximum size in bytes of a worker's serialized result payload " + f"under pytest-xdist (default: {DEFAULT_SIZE_LIMIT_BYTES})." + ), + default=None, + ) + + def pytest_configure(config: pytest.Config) -> None: """Register RAMPART markers and install default handler factory. @@ -242,7 +283,7 @@ def _create_trial_clones( @pytest.hookimpl(trylast=True) def pytest_collection_modifyitems( - config: pytest.Config, # noqa: ARG001 — pytest hook signature + config: pytest.Config, items: list[pytest.Item], ) -> None: """Clone trial-marked items and validate marker usage. @@ -266,12 +307,14 @@ def pytest_collection_modifyitems( item has no parent. """ expanded: list[pytest.Item] = [] + saw_trial = False for item in items: trial_marker = item.get_closest_marker("trial") if trial_marker is None: expanded.append(item) continue + saw_trial = True n = _resolve_trial_n(trial_marker) expanded.extend( @@ -280,6 +323,18 @@ def pytest_collection_modifyitems( items[:] = expanded + if saw_trial and is_xdist_controller(config=config): + dist_mode = get_dist_mode(config=config) + if dist_mode != "loadgroup": + logger.warning( + "RAMPART @trial markers present with --dist=%s. Trial " + "clones may be split across workers. Aggregation remains " + "correct (controller merges all results), but using " + "--dist=loadgroup keeps trial clones co-located on one " + "worker for better locality.", + dist_mode, + ) + def _absorb_results( *, @@ -366,8 +421,15 @@ def rampart_sinks(): return [JsonFileReportSink(output_dir=Path(".report"))] ``` + Under pytest-xdist, this fixture is a no-op on worker processes + (workers skip sink emission entirely); sink discovery on the + controller is handled by ``_xdist.discover_sinks_from_conftest``. + No test author ever imports or references this fixture. """ + if is_xdist_worker(config=request.config): + return + rampart_session = request.config.stash.get(_rampart_key, None) if rampart_session is None: return @@ -477,6 +539,16 @@ def pytest_sessionfinish( ) -> None: """Aggregate trial results, evaluate gates, and emit sinks. + Dispatches between three modes: + + - xdist worker: serialize results to ``config.workeroutput`` and + skip sink emission (the controller emits the unified report). + - xdist controller: trials already aggregated against the merged + ``_results_by_nodeid``; discover sinks from conftest, evaluate + gates, and emit. + - non-xdist: original single-process pipeline (aggregate, gate, + emit) unchanged. + Args: session (pytest.Session): The pytest session. exitstatus (int): The session exit status. @@ -489,11 +561,65 @@ def pytest_sessionfinish( if start_time is not None: rampart_session.set_duration(duration_seconds=time.monotonic() - start_time) + if is_xdist_worker(config=session.config): + try: + finalize_worker(config=session.config, session=rampart_session) + except SizeLimitError as exc: + logger.warning("%s", exc) + return + + if is_xdist_controller(config=session.config): + _aggregate_trial_results(session=session, rampart_session=rampart_session) + _evaluate_gates(rampart_session=rampart_session) + _record_xdist_metadata(session=session, rampart_session=rampart_session) + controller_sinks = discover_sinks_from_conftest(config=session.config) + if controller_sinks: + rampart_session.add_sinks(sinks=controller_sinks) + _emit_sinks(rampart_session=rampart_session) + return + _aggregate_trial_results(session=session, rampart_session=rampart_session) _evaluate_gates(rampart_session=rampart_session) _emit_sinks(rampart_session=rampart_session) +@pytest.hookimpl(optionalhook=True) +def pytest_testnodedown(node: object, error: object) -> None: + """Merge a finished xdist worker's results into the controller session. + + Thin delegate to ``_xdist.handle_testnodedown`` so plugin.py stays + focused on hook registration. Registered as ``optionalhook`` so + pytest does not warn when pytest-xdist is not installed. + + Args: + node: The xdist worker node that has finished. + error: The shutdown error reported by xdist, or None on + clean exit. + """ + config = getattr(node, "config", None) + if config is None: + return + rampart_session = config.stash.get(_rampart_key, None) + if rampart_session is None: + return + handle_testnodedown(session=rampart_session, node=node, error=error) + + +def _record_xdist_metadata( + *, + session: pytest.Session, + rampart_session: RampartSession, +) -> None: + """Attach xdist run-mode metadata to the upcoming report.""" + rampart_session.set_report_metadata( + metadata={ + "xdist_active": True, + "worker_count": get_worker_count(config=session.config), + "dist_mode": get_dist_mode(config=session.config), + }, + ) + + async def _emit_sinks_async(*, rampart_session: RampartSession) -> None: """Emit the test run report to all configured sinks. @@ -529,12 +655,19 @@ def _emit_sinks(*, rampart_session: RampartSession) -> None: When an event loop is already running (e.g. pytest-asyncio), falls back to scheduling on the existing loop. + Idempotent — subsequent invocations are no-ops, guarding against + re-emission in nested hook scenarios. + Args: rampart_session (RampartSession): The RAMPART session state. """ + if rampart_session.is_emitted: + return if not rampart_session.sinks: + rampart_session.mark_emitted() return + rampart_session.mark_emitted() coro = _emit_sinks_async(rampart_session=rampart_session) try: loop = asyncio.get_running_loop() diff --git a/rampart/surfaces/onedrive.py b/rampart/surfaces/onedrive.py index d239743..f6af899 100644 --- a/rampart/surfaces/onedrive.py +++ b/rampart/surfaces/onedrive.py @@ -1,6 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +# msgraph-sdk ships without type stubs; suppress the resulting cascade. +# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownParameterType=false, reportUnknownArgumentType=false + """OneDrive surface for RAMPART. Injects payloads into Microsoft OneDrive via the Microsoft Graph API. @@ -18,7 +21,9 @@ if TYPE_CHECKING: import types - from msgraph.graph_service_client import GraphServiceClient + from msgraph.graph_service_client import ( # pyright: ignore[reportMissingImports] + GraphServiceClient, + ) from rampart.core.types import Payload @@ -154,7 +159,7 @@ async def upload_async(self, *, payload: Payload) -> str: msg, ) - item_id = drive_item.id + item_id: str = drive_item.id logger.info( "Uploaded payload %s to OneDrive drive=%s path=%s (item=%s)", payload.id, diff --git a/tests/integration/test_xdist_aggregation.py b/tests/integration/test_xdist_aggregation.py new file mode 100644 index 0000000..eb748a2 --- /dev/null +++ b/tests/integration/test_xdist_aggregation.py @@ -0,0 +1,286 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Integration tests for cross-worker aggregation under pytest-xdist. + +These tests spawn subprocess pytest runs via the ``pytester`` fixture +to exercise the full xdist serialization → merge → emission pipeline. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from _pytest.pytester import Pytester, RunResult + + +pytest_plugins = ["pytester"] + + +_CONFTEST = """\ +from pathlib import Path + +import pytest + +from rampart.reporting import JsonFileReportSink + + +_OUT_DIR = Path("rampart_reports").absolute() + + +@pytest.fixture(scope="session") +def rampart_sinks(): + _OUT_DIR.mkdir(parents=True, exist_ok=True) + Path("rampart_report_dir.txt").write_text(str(_OUT_DIR)) + return [JsonFileReportSink(output_dir=_OUT_DIR)] +""" + + +def _load_reports(pytester: Pytester) -> list[dict[str, Any]]: + marker = pytester.path / "rampart_report_dir.txt" + if not marker.exists(): + default_dir = pytester.path / "rampart_reports" + if default_dir.exists(): + return [ + json.loads(p.read_text()) + for p in sorted(default_dir.glob("run_report_*.json")) + ] + return [] + out_dir = Path(marker.read_text().strip()) + if not out_dir.exists(): + return [] + return [ + json.loads(p.read_text()) for p in sorted(out_dir.glob("run_report_*.json")) + ] + + +def _setup_simple_tests(pytester: Pytester) -> None: + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_a=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + def test_a_one(): + record_result(Result( + safe=True, status=SafetyStatus.SAFE, summary="a1", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + + @pytest.mark.harm("test") + def test_a_two(): + record_result(Result( + safe=False, status=SafetyStatus.UNSAFE, summary="a2", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + test_b=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + def test_b_one(): + record_result(Result( + safe=True, status=SafetyStatus.SAFE, summary="b1", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + + @pytest.mark.harm("test") + def test_b_two(): + record_result(Result( + safe=True, status=SafetyStatus.SAFE, summary="b2", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + ) + + +class TestSingleProcessBaseline: + def test_baseline_emits_one_report(self, pytester: Pytester) -> None: + _setup_simple_tests(pytester) + result = pytester.runpytest("-p", "no:cacheprovider") + result.assert_outcomes(passed=4) + reports = _load_reports(pytester) + assert len(reports) == 1 + assert reports[0]["total_runs"] == 4 + + +class TestXdistConsolidation: + def test_xdist_emits_single_consolidated_report( + self, + pytester: Pytester, + ) -> None: + _setup_simple_tests(pytester) + result = pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + ) + result.assert_outcomes(passed=4) + reports = _load_reports(pytester) + assert len(reports) == 1, ( + f"Expected exactly one report under xdist, got {len(reports)}: " + f"{[r.get('total_runs') for r in reports]}" + ) + + def test_population_statistics_over_full_set( + self, + pytester: Pytester, + ) -> None: + _setup_simple_tests(pytester) + pytester.runpytest("-p", "no:cacheprovider", "-n", "2") + reports = _load_reports(pytester) + assert len(reports) == 1 + report = reports[0] + assert report["total_runs"] == 4 + assert report["passed"] == 3 + assert report["failed"] == 1 + assert report["population_summary"]["total_runs"] == 4 + assert report["population_summary"]["safe_count"] == 3 + assert report["population_summary"]["unsafe_count"] == 1 + + +class TestXdistTrialAggregation: + def test_trial_aggregation_across_workers_loadgroup( + self, + pytester: Pytester, + ) -> None: + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_trial=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + @pytest.mark.trial(n=4, threshold=0.5) + def test_trial_split(): + record_result(Result( + safe=True, status=SafetyStatus.SAFE, summary="t", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + ) + result = pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + "--dist", + "loadgroup", + ) + result.assert_outcomes(passed=4) + reports = _load_reports(pytester) + assert len(reports) == 1 + assert reports[0]["total_runs"] == 4 + + def test_trial_aggregation_across_workers_load( + self, + pytester: Pytester, + ) -> None: + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_trial=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + @pytest.mark.trial(n=4, threshold=0.5) + def test_trial_split(): + record_result(Result( + safe=True, status=SafetyStatus.SAFE, summary="t", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + ) + result = pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + "--dist", + "load", + ) + result.assert_outcomes(passed=4) + reports = _load_reports(pytester) + assert len(reports) == 1 + assert reports[0]["total_runs"] == 4 + + +class TestXdistMetadata: + def test_report_includes_xdist_metadata(self, pytester: Pytester) -> None: + _setup_simple_tests(pytester) + pytester.runpytest("-p", "no:cacheprovider", "-n", "2") + reports = _load_reports(pytester) + assert len(reports) == 1 + # Population summary is exposed in JSON; xdist metadata lives in + # TestRunReport.metadata which is rendered when present. + # The JsonFileReportSink does not currently project metadata, + # so we just verify the report exists with the right shape. + assert "population_summary" in reports[0] + + +class TestCollectOnly: + def test_collect_only_does_not_emit_reports(self, pytester: Pytester) -> None: + _setup_simple_tests(pytester) + pytester.runpytest("-p", "no:cacheprovider", "--collect-only") + # No sinks emit when no tests run + marker = pytester.path / "rampart_report_dir.txt" + if marker.exists(): + out_dir = Path(marker.read_text().strip()) + if out_dir.exists(): + reports = list(out_dir.glob("run_report_*.json")) + assert reports == [] + + +class TestCloneIdDeterminism: + def test_trial_clone_ids_deterministic_across_processes( + self, + pytester: Pytester, + ) -> None: + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_det=""" + import pytest + + @pytest.mark.trial(n=3) + def test_x(): + pass + """, + ) + result_serial: RunResult = pytester.runpytest( + "-p", + "no:cacheprovider", + "--collect-only", + "-q", + ) + result_parallel: RunResult = pytester.runpytest( + "-p", + "no:cacheprovider", + "--collect-only", + "-q", + "-n", + "2", + ) + + def _trial_ids(lines: list[str]) -> list[str]: + return sorted(line.strip() for line in lines if "trial-" in line) + + serial_ids = _trial_ids(result_serial.outlines) + parallel_ids = _trial_ids(result_parallel.outlines) + # Under xdist --collect-only, both should produce the same + # deterministic clone IDs so that workers can match them. + if serial_ids and parallel_ids: + assert serial_ids == parallel_ids diff --git a/tests/unit/pytest_plugin/test_xdist.py b/tests/unit/pytest_plugin/test_xdist.py new file mode 100644 index 0000000..af95f6b --- /dev/null +++ b/tests/unit/pytest_plugin/test_xdist.py @@ -0,0 +1,805 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the RAMPART xdist support module.""" + +from __future__ import annotations + +import json +import logging +import math +from datetime import UTC, datetime +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from rampart.core.result import ( + HarmCategory, + InjectionRecord, + Result, + SafetyStatus, +) +from rampart.core.types import ( + EvalOutcome, + EvalResult, + ObservabilityLevel, + PayloadFormat, + Request, + Response, + SideEffect, + ToolCall, + Turn, +) +from rampart.pytest_plugin._session import RampartSession +from rampart.pytest_plugin._xdist import ( + DEFAULT_SIZE_LIMIT_BYTES, + MAX_METADATA_DEPTH, + SCHEMA_VERSION, + SIZE_LIMIT_OPTION, + WORKEROUTPUT_KEY, + SchemaVersionError, + SizeLimitError, + WorkerOutputError, + _sanitize, + _strip_ansi, + deserialize_worker_data, + discover_sinks_from_conftest, + finalize_worker, + get_dist_mode, + get_worker_count, + handle_testnodedown, + is_xdist_controller, + is_xdist_worker, + serialize_worker_data, +) +from rampart.reporting.sink import ReportSink, TestRunReport + + +def _make_result( + *, + safe: bool = True, + status: SafetyStatus = SafetyStatus.SAFE, + summary: str = "summary", + harm_category: HarmCategory | str | None = None, + strategy: str = "xpia", + duration_seconds: float = 1.0, + metadata: dict[str, Any] | None = None, + turns: list[Turn] | None = None, + injections: list[InjectionRecord] | None = None, + observability_level: ObservabilityLevel = ObservabilityLevel.RESPONSE_ONLY, +) -> Result: + return Result( + safe=safe, + status=status, + summary=summary, + turns=turns or [], + duration_seconds=duration_seconds, + harm_category=harm_category, + strategy=strategy, + observability_level=observability_level, + injections=injections or [], + metadata=metadata or {}, + ) + + +def _make_turn( + *, + prompt: str = "hi", + text: str = "ok", + eval_result: EvalResult | None = None, + turn_number: int = 0, + timestamp: datetime | None = None, + driver_reasoning: str = "", +) -> Turn: + return Turn( + request=Request(prompt=prompt), + response=Response(text=text), + eval_result=eval_result, + turn_number=turn_number, + timestamp=timestamp, + driver_reasoning=driver_reasoning, + ) + + +def _make_eval_result( + *, + outcome: EvalOutcome = EvalOutcome.DETECTED, + confidence: float = 0.9, + evidence: list[str] | None = None, + rationale: str = "because", +) -> EvalResult: + return EvalResult( + outcome=outcome, + confidence=confidence, + evidence=evidence or [], + rationale=rationale, + ) + + +def _make_config( + *, + is_worker: bool = False, + numprocesses: int | None = None, + dist: str = "no", + max_bytes: int | None = None, +) -> Any: + config = MagicMock() + if is_worker: + config.workerinput = {"workerid": "gw0"} + else: + del config.workerinput + config.option = MagicMock() + config.option.numprocesses = numprocesses + config.option.dist = dist + + def _getoption(name: str, default: object = None) -> object: + return max_bytes if name == SIZE_LIMIT_OPTION else default + + def _getini(name: str) -> None: + del name + + config.getoption = _getoption + config.getini = _getini + return config + + +def _make_session_with_results( + *, + results_by_nodeid: dict[str, list[Result]], +) -> RampartSession: + session = RampartSession() + session._results_by_nodeid = dict(results_by_nodeid) + for results in results_by_nodeid.values(): + session._results.extend(results) + return session + + +class TestDetection: + def test_is_xdist_worker_true_when_workerinput_present(self) -> None: + config = _make_config(is_worker=True) + assert is_xdist_worker(config=config) is True + + def test_is_xdist_worker_false_when_workerinput_absent(self) -> None: + config = _make_config(is_worker=False) + assert is_xdist_worker(config=config) is False + + def test_is_xdist_controller_true_with_numprocesses(self) -> None: + config = _make_config(numprocesses=2) + assert is_xdist_controller(config=config) is True + + def test_is_xdist_controller_false_when_no_numprocesses(self) -> None: + config = _make_config(numprocesses=None) + assert is_xdist_controller(config=config) is False + + def test_is_xdist_controller_false_for_worker(self) -> None: + config = _make_config(is_worker=True, numprocesses=2) + assert is_xdist_controller(config=config) is False + + def test_get_dist_mode_default(self) -> None: + config = _make_config() + assert get_dist_mode(config=config) == "no" + + def test_get_dist_mode_loadgroup(self) -> None: + config = _make_config(dist="loadgroup") + assert get_dist_mode(config=config) == "loadgroup" + + def test_get_worker_count_returns_numprocesses(self) -> None: + config = _make_config(numprocesses=4) + assert get_worker_count(config=config) == 4 + + def test_get_worker_count_zero_when_inactive(self) -> None: + config = _make_config() + assert get_worker_count(config=config) == 0 + + +class TestSanitize: + def test_passes_primitives_unchanged(self) -> None: + assert _sanitize(value=42) == 42 + assert _sanitize(value="hello") == "hello" + assert _sanitize(value=True) is True + assert _sanitize(value=None) is None + assert _sanitize(value=3.14) == 3.14 + + def test_nan_coerced_to_none(self) -> None: + assert _sanitize(value=float("nan")) is None + + def test_inf_coerced_to_none(self) -> None: + assert _sanitize(value=float("inf")) is None + assert _sanitize(value=float("-inf")) is None + + def test_dict_recursed(self) -> None: + result = _sanitize(value={"a": 1, "b": {"c": "x"}}) + assert result == {"a": 1, "b": {"c": "x"}} + + def test_list_recursed(self) -> None: + result = _sanitize(value=[1, "two", [3]]) + assert result == [1, "two", [3]] + + def test_tuple_becomes_list(self) -> None: + result = _sanitize(value=(1, 2, 3)) + assert result == [1, 2, 3] + + def test_custom_object_coerced_via_repr(self) -> None: + class Obj: + def __repr__(self) -> str: + return "" + + assert _sanitize(value=Obj()) == "" + + def test_depth_limit_coerces_to_repr(self) -> None: + nested: dict[str, Any] = {"v": "leaf"} + for _ in range(MAX_METADATA_DEPTH + 2): + nested = {"v": nested} + result = _sanitize(value=nested) + json.dumps(result) # must be JSON-safe + + +class TestStripAnsi: + def test_removes_color_codes(self) -> None: + text = "\x1b[31mred\x1b[0m" + assert _strip_ansi(text=text) == "red" + + def test_preserves_plain_text(self) -> None: + assert _strip_ansi(text="hello world") == "hello world" + + def test_strips_multiple_sequences(self) -> None: + text = "\x1b[1m\x1b[31mbold red\x1b[0m\x1b[0m" + assert _strip_ansi(text=text) == "bold red" + + +class TestSerializationRoundTrip: + def test_simple_result_round_trip(self) -> None: + result = _make_result(summary="hi", harm_category=HarmCategory.JAILBREAK) + session = _make_session_with_results( + results_by_nodeid={"test::a": [result]}, + ) + payload = serialize_worker_data(session=session) + json.dumps(payload, default=str) + recovered = deserialize_worker_data(data=payload) + assert "test::a" in recovered + assert recovered["test::a"][0].safe is True + assert recovered["test::a"][0].status is SafetyStatus.SAFE + assert recovered["test::a"][0].harm_category is HarmCategory.JAILBREAK + + def test_status_enum_round_trip(self) -> None: + for status in SafetyStatus: + result = _make_result(status=status, safe=status is SafetyStatus.SAFE) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].status is status + + def test_observability_level_round_trip(self) -> None: + for level in ObservabilityLevel: + result = _make_result(observability_level=level) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].observability_level is level + + def test_harm_category_plain_string_round_trip(self) -> None: + result = _make_result(harm_category="custom_product_risk") + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].harm_category == "custom_product_risk" + + def test_turns_with_eval_result_round_trip(self) -> None: + eval_result = _make_eval_result( + outcome=EvalOutcome.NOT_DETECTED, + confidence=0.7, + evidence=["e1", "e2"], + rationale="r", + ) + turn = _make_turn(eval_result=eval_result, turn_number=1) + result = _make_result(turns=[turn]) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].turns[0].eval_result is not None + outcome = recovered["n"][0].turns[0].eval_result.outcome + assert outcome is EvalOutcome.NOT_DETECTED + assert recovered["n"][0].turns[0].eval_result.evidence == ["e1", "e2"] + + def test_datetime_round_trip(self) -> None: + when = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + turn = _make_turn(timestamp=when) + result = _make_result(turns=[turn]) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].turns[0].timestamp == when + + def test_injections_round_trip(self) -> None: + injection = InjectionRecord(payload_id="p1", surface_name="OneDrive") + result = _make_result(injections=[injection]) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].injections[0].payload_id == "p1" + assert recovered["n"][0].injections[0].surface_name == "OneDrive" + + def test_response_with_tool_calls_round_trip(self) -> None: + tool_call = ToolCall(name="send_email", arguments={"to": "a@b.c"}) + response = Response(text="ok", tool_calls=[tool_call]) + turn = Turn(request=Request(prompt="hi"), response=response) + result = _make_result(turns=[turn]) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].turns[0].response.tool_calls[0].name == "send_email" + assert recovered["n"][0].turns[0].response.tool_calls[0].arguments == { + "to": "a@b.c", + } + + def test_response_with_side_effects_round_trip(self) -> None: + side_effect = SideEffect(kind="http", details={"url": "http://x"}) + response = Response(text="ok", side_effects=[side_effect]) + turn = Turn(request=Request(prompt="hi"), response=response) + result = _make_result(turns=[turn]) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].turns[0].response.side_effects[0].kind == "http" + + def test_metadata_round_trip(self) -> None: + result = _make_result(metadata={"test_name": "t1", "tries": 3}) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["n"][0].metadata["test_name"] == "t1" + assert recovered["n"][0].metadata["tries"] == 3 + + +class TestDeserializationValidation: + def test_rejects_non_dict_payload(self) -> None: + with pytest.raises(WorkerOutputError, match="Expected dict"): + deserialize_worker_data(data="not-a-dict") + + def test_rejects_missing_schema_key(self) -> None: + with pytest.raises(SchemaVersionError, match="missing required 'schema'"): + deserialize_worker_data(data={"results_by_nodeid": {}}) + + def test_rejects_unknown_schema_version(self) -> None: + payload: dict[str, Any] = { + "schema": "rampart.xdist.v999", + "results_by_nodeid": {}, + } + with pytest.raises(SchemaVersionError, match="does not match"): + deserialize_worker_data(data=payload) + + def test_rejects_malformed_safety_status(self) -> None: + payload: dict[str, Any] = { + "schema": SCHEMA_VERSION, + "results_by_nodeid": { + "n": [ + { + "safe": True, + "status": "not-a-status", + "summary": "x", + "observability_level": "response_only", + }, + ], + }, + } + with pytest.raises(WorkerOutputError, match="Unknown SafetyStatus"): + deserialize_worker_data(data=payload) + + def test_rejects_malformed_observability_level(self) -> None: + payload: dict[str, Any] = { + "schema": SCHEMA_VERSION, + "results_by_nodeid": { + "n": [ + { + "safe": True, + "status": "safe", + "summary": "x", + "observability_level": "not-a-level", + }, + ], + }, + } + with pytest.raises(WorkerOutputError, match="Unknown ObservabilityLevel"): + deserialize_worker_data(data=payload) + + +class TestDeserializationSecurity: + def test_strips_ansi_from_summary(self) -> None: + payload: dict[str, Any] = { + "schema": SCHEMA_VERSION, + "results_by_nodeid": { + "n": [ + { + "safe": False, + "status": "unsafe", + "summary": "\x1b[31mDANGER\x1b[0m", + "observability_level": "response_only", + }, + ], + }, + } + result = deserialize_worker_data(data=payload)["n"][0] + assert result.summary == "DANGER" + assert "\x1b" not in result.summary + + def test_strips_ansi_from_response_text(self) -> None: + payload: dict[str, Any] = { + "schema": SCHEMA_VERSION, + "results_by_nodeid": { + "n": [ + { + "safe": True, + "status": "safe", + "summary": "x", + "observability_level": "response_only", + "turns": [ + { + "request": {"prompt": "p"}, + "response": {"text": "\x1b[31mDANGER\x1b[0m"}, + }, + ], + }, + ], + }, + } + result = deserialize_worker_data(data=payload)["n"][0] + assert result.turns[0].response.text == "DANGER" + + def test_nan_inf_in_duration_coerced_to_zero(self) -> None: + session = _make_session_with_results( + results_by_nodeid={ + "n": [_make_result(duration_seconds=float("nan"))], + }, + ) + payload = serialize_worker_data(session=session) + encoded = json.dumps(payload, default=str) + assert "NaN" not in encoded + recovered = deserialize_worker_data(data=payload) + assert math.isfinite(recovered["n"][0].duration_seconds) + + def test_payload_artifact_path_preserved_in_metadata(self) -> None: + payload: dict[str, Any] = { + "schema": SCHEMA_VERSION, + "results_by_nodeid": { + "n": [ + { + "safe": True, + "status": "safe", + "summary": "x", + "observability_level": "response_only", + "turns": [ + { + "request": { + "prompt": None, + "attachments": [ + { + "content": "c", + "id": "p1", + "format": "pdf", + "artifact": "/worker/local/path.pdf", + "metadata": {}, + }, + ], + }, + "response": {"text": "ok"}, + }, + ], + }, + ], + }, + } + result = deserialize_worker_data(data=payload)["n"][0] + attachment = result.turns[0].request.attachments[0] + assert attachment.format is PayloadFormat.TEXT + assert attachment.artifact is None + assert ( + attachment.metadata["_rampart_worker_artifact_path"] + == "/worker/local/path.pdf" + ) + assert attachment.metadata["_rampart_worker_format"] == "pdf" + + def test_serialized_payload_is_pure_json(self) -> None: + result = _make_result( + metadata={"obj": object()}, + harm_category=HarmCategory.JAILBREAK, + ) + session = _make_session_with_results( + results_by_nodeid={"n": [result]}, + ) + payload = serialize_worker_data(session=session) + encoded = json.dumps(payload) + decoded = json.loads(encoded) + assert decoded["schema"] == SCHEMA_VERSION + + def test_non_serializable_metadata_coerced_with_warning( + self, + caplog: pytest.LogCaptureFixture, + ) -> None: + class Obj: + def __repr__(self) -> str: + return "" + + result = _make_result(metadata={"obj": Obj()}) + session = _make_session_with_results( + results_by_nodeid={"my::node": [result]}, + ) + with caplog.at_level(logging.WARNING): + payload = serialize_worker_data(session=session) + recovered = deserialize_worker_data(data=payload) + assert recovered["my::node"][0].metadata["obj"] == "" + assert any( + "my::node" in record.getMessage() and "obj" in record.getMessage() + for record in caplog.records + ) + + +class TestMerge: + def test_merge_extends_results(self) -> None: + session = RampartSession() + session.merge_worker_results( + results_by_nodeid={ + "n1": [_make_result(summary="r1")], + }, + ) + session.merge_worker_results( + results_by_nodeid={ + "n2": [_make_result(summary="r2")], + }, + ) + assert len(session._results) == 2 + assert "n1" in session._results_by_nodeid + assert "n2" in session._results_by_nodeid + + def test_merge_invalidates_cached_report(self) -> None: + session = RampartSession() + session.merge_worker_results( + results_by_nodeid={"n1": [_make_result()]}, + ) + first = session.build_report() + session.merge_worker_results( + results_by_nodeid={"n2": [_make_result()]}, + ) + second = session.build_report() + assert first is not second + assert second.total_runs == 2 + + def test_build_report_orders_results_by_test_name(self) -> None: + session = RampartSession() + session.merge_worker_results( + results_by_nodeid={ + "z": [_make_result(summary="z", metadata={"test_name": "z_test"})], + "a": [_make_result(summary="a", metadata={"test_name": "a_test"})], + }, + ) + report = session.build_report() + names = [r.metadata["test_name"] for r in report.results] + assert names == sorted(names) + + def test_mark_incomplete_surfaces_in_report_metadata(self) -> None: + session = RampartSession() + session.merge_worker_results( + results_by_nodeid={"n": [_make_result()]}, + ) + session.mark_incomplete(reason="worker gw0 crashed") + report = session.build_report() + assert report.metadata["incomplete"] is True + assert "worker gw0 crashed" in report.metadata["incomplete_reasons"] + + def test_emitted_idempotency_flag(self) -> None: + session = RampartSession() + assert session.is_emitted is False + session.mark_emitted() + assert session.is_emitted is True + + +class TestHandleTestnodedown: + def test_records_incomplete_on_error(self) -> None: + session = RampartSession() + node = MagicMock() + node.gateway.id = "gw1" + handle_testnodedown( + session=session, + node=node, + error="boom", + ) + assert session.is_incomplete is True + + def test_records_incomplete_on_missing_workeroutput(self) -> None: + session = RampartSession() + node = MagicMock() + node.gateway.id = "gw1" + node.workeroutput = None + handle_testnodedown(session=session, node=node, error=None) + assert session.is_incomplete is True + + def test_records_incomplete_on_missing_rampart_key(self) -> None: + session = RampartSession() + node = MagicMock() + node.gateway.id = "gw1" + node.workeroutput = {} + handle_testnodedown(session=session, node=node, error=None) + assert session.is_incomplete is True + + def test_records_incomplete_on_deserialization_failure(self) -> None: + session = RampartSession() + node = MagicMock() + node.gateway.id = "gw1" + node.workeroutput = {WORKEROUTPUT_KEY: {"schema": "wrong-version"}} + handle_testnodedown(session=session, node=node, error=None) + assert session.is_incomplete is True + + def test_records_incomplete_on_truncated_payload(self) -> None: + session = RampartSession() + node = MagicMock() + node.gateway.id = "gw1" + node.workeroutput = { + WORKEROUTPUT_KEY: { + "schema": SCHEMA_VERSION, + "rampart_truncated": True, + }, + } + handle_testnodedown(session=session, node=node, error=None) + assert session.is_incomplete is True + + def test_merges_results_on_success(self) -> None: + session = RampartSession() + worker_session = _make_session_with_results( + results_by_nodeid={"n": [_make_result(summary="from-worker")]}, + ) + payload = serialize_worker_data(session=worker_session) + node = MagicMock() + node.gateway.id = "gw1" + node.workeroutput = {WORKEROUTPUT_KEY: payload} + handle_testnodedown(session=session, node=node, error=None) + assert session.is_incomplete is False + assert len(session._results) == 1 + assert session._results[0].summary == "from-worker" + + +class TestFinalizeWorker: + def test_no_op_on_controller(self) -> None: + config = _make_config(is_worker=False, numprocesses=2) + workeroutput: dict[str, Any] = {} + config.workeroutput = workeroutput + session = RampartSession() + finalize_worker(config=config, session=session) + assert WORKEROUTPUT_KEY not in workeroutput + + def test_writes_workeroutput_on_worker(self) -> None: + config = _make_config(is_worker=True) + workeroutput: dict[str, Any] = {} + config.workeroutput = workeroutput + session = _make_session_with_results( + results_by_nodeid={"n": [_make_result(summary="x")]}, + ) + finalize_worker(config=config, session=session) + assert WORKEROUTPUT_KEY in workeroutput + payload: dict[str, Any] = workeroutput[WORKEROUTPUT_KEY] + assert payload["schema"] == SCHEMA_VERSION + assert "results_by_nodeid" in payload + + def test_truncates_oversize_payload( + self, + ) -> None: + config = _make_config(is_worker=True, max_bytes=1) + workeroutput: dict[str, Any] = {} + config.workeroutput = workeroutput + session = _make_session_with_results( + results_by_nodeid={"n": [_make_result()]}, + ) + with pytest.raises(SizeLimitError): + finalize_worker(config=config, session=session) + payload: dict[str, Any] = workeroutput[WORKEROUTPUT_KEY] + assert payload.get("rampart_truncated") is True + + +class TestSinkDiscovery: + def test_finds_callable_rampart_sinks(self) -> None: + sink = MagicMock(spec=ReportSink) + plugin = MagicMock( + spec=["rampart_sinks", "__name__"], + rampart_sinks=lambda: [sink], + __name__="mod", + ) + config = MagicMock() + config.pluginmanager.get_plugins.return_value = [plugin] + result = discover_sinks_from_conftest(config=config) + assert sink in result + + def test_finds_list_rampart_sinks(self) -> None: + sink = MagicMock(spec=ReportSink) + plugin = MagicMock( + spec=["rampart_sinks", "__name__"], + rampart_sinks=[sink], + __name__="mod", + ) + config = MagicMock() + config.pluginmanager.get_plugins.return_value = [plugin] + result = discover_sinks_from_conftest(config=config) + assert sink in result + + def test_returns_empty_when_no_rampart_sinks(self) -> None: + plugin = MagicMock(spec=["__name__"], __name__="mod") + config = MagicMock() + config.pluginmanager.get_plugins.return_value = [plugin] + result = discover_sinks_from_conftest(config=config) + assert result == [] + + def test_warns_on_callable_with_required_args( + self, + caplog: pytest.LogCaptureFixture, + ) -> None: + def needs_arg(other: object) -> list[ReportSink]: + return [] + + plugin = MagicMock( + spec=["rampart_sinks", "__name__"], + rampart_sinks=needs_arg, + __name__="mod", + ) + config = MagicMock() + config.pluginmanager.get_plugins.return_value = [plugin] + with caplog.at_level(logging.WARNING): + result = discover_sinks_from_conftest(config=config) + assert result == [] + assert any("fixture dependencies" in r.getMessage() for r in caplog.records) + + +class TestReportTestRunMetadata: + def test_set_report_metadata_appears_in_report(self) -> None: + session = RampartSession() + session.set_report_metadata( + metadata={"xdist_active": True, "worker_count": 2}, + ) + session.merge_worker_results( + results_by_nodeid={"n": [_make_result()]}, + ) + report = session.build_report() + assert report.metadata["xdist_active"] is True + assert report.metadata["worker_count"] == 2 + + def test_metadata_merges_across_calls(self) -> None: + session = RampartSession() + session.set_report_metadata(metadata={"a": 1}) + session.set_report_metadata(metadata={"b": 2}) + session.merge_worker_results( + results_by_nodeid={"n": [_make_result()]}, + ) + report = session.build_report() + assert report.metadata["a"] == 1 + assert report.metadata["b"] == 2 + + +class TestConstants: + def test_default_size_limit_is_64mb(self) -> None: + assert DEFAULT_SIZE_LIMIT_BYTES == 64 * 1024 * 1024 + + def test_schema_version_is_v1(self) -> None: + assert SCHEMA_VERSION == "rampart.xdist.v1" + + def test_workeroutput_key_namespaced(self) -> None: + assert WORKEROUTPUT_KEY.startswith("rampart_xdist") + + +class TestTestRunReportTestable: + def test_test_run_report_excluded_from_collection(self) -> None: + assert TestRunReport.__test__ is False From 1c32c4e1b87b9a59981081c7ec928fe9aebeb1cb Mon Sep 17 00:00:00 2001 From: Nina Chikanov Date: Mon, 8 Jun 2026 17:00:39 -0700 Subject: [PATCH 3/3] Fix trial-group aggregation under pytest-xdist The controller's _aggregate_trial_results iterated session.items looking for a _rampart_trial_base attribute set on cloned items at collection time. Under pytest-xdist that attribute is not reliably reachable on the controller at pytest_sessionfinish, so trial-group verdicts silently disappeared from the terminal summary and the JSON report -- per-clone results were present but no aggregate FAIL/PASS line was emitted. Decouple aggregation from pytest.Item state: - Store trial metadata as data on RampartSession._trial_specs (a dict[clone_nodeid, TrialSpec]) at collection time on every process. - Ship rial_specs through the existing ampart.xdist.v1 worker payload (back-compatible: missing/empty list is treated as no trials). - Controller merges specs from each worker and aggregates from the merged data instead of session.items. Also fixes a related JSON-sink gap: JsonFileReportSink now projects eport.metadata, so xdist run-mode info (worker_count, dist_mode, incomplete reasons) actually lands in the emitted file. Tests: - New TestTrialSpecs unit class (round-trip, malformed entries, non-finite thresholds, idempotent merge, first-writer-wins). - handle_testnodedown test covering trial-spec merge. - Strengthened ests/integration/test_xdist_aggregation.py to assert the trial-group line is present and correct under --dist=loadgroup and --dist=load (the prior tests only checked per-clone counts and would have missed this regression). - JSON-sink metadata projection unit tests. 565 unit tests + 13 xdist integration tests pass; ruff clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- rampart/pytest_plugin/_session.py | 89 +++++++- rampart/pytest_plugin/_xdist.py | 104 ++++++++- rampart/pytest_plugin/plugin.py | 71 ++++-- rampart/reporting/json_file.py | 1 + tests/integration/test_xdist_aggregation.py | 227 +++++++++++++++++++- tests/unit/pytest_plugin/test_plugin.py | 10 +- tests/unit/pytest_plugin/test_xdist.py | 112 +++++++++- tests/unit/reporting/test_json_file.py | 46 ++++ 8 files changed, 622 insertions(+), 38 deletions(-) diff --git a/rampart/pytest_plugin/_session.py b/rampart/pytest_plugin/_session.py index 873c36f..4b94742 100644 --- a/rampart/pytest_plugin/_session.py +++ b/rampart/pytest_plugin/_session.py @@ -25,7 +25,7 @@ from rampart.reporting.sink import ReportSink, TestRunReport if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Mapping, Sequence import pytest @@ -34,6 +34,24 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True, kw_only=True) +class TrialSpec: + """Trial-clone metadata captured at collection time. + + Carries the data needed to aggregate a trial group without + depending on ``pytest.Item`` attributes — so aggregation works + on the xdist controller, where the cloned items themselves + may not be reachable at session finish. + + Attributes: + base_nodeid (str): The original test's pytest node ID. + threshold (float): Minimum pass rate required for the group. + """ + + base_nodeid: str + threshold: float + + @dataclass(frozen=True, kw_only=True) class TrialGroupResult: """Aggregate statistics for a trial group.""" @@ -87,6 +105,7 @@ def __init__(self, *, sinks: list[ReportSink] | None = None) -> None: self._results: list[Result] = [] self._results_by_nodeid: dict[str, list[Result]] = {} self._trial_groups: dict[str, TrialGroupResult] = {} + self._trial_specs: dict[str, TrialSpec] = {} self._sinks: list[ReportSink] = sinks or [] self._duration_seconds: float = 0.0 self._cached_report: TestRunReport | None = None @@ -190,7 +209,7 @@ def record_trial_group( self, *, base_nodeid: str, - trial_items: Sequence[pytest.Item], + clone_nodeids: Sequence[str], threshold: float, ) -> None: """Record aggregate statistics for a trial group. @@ -200,26 +219,27 @@ def record_trial_group( - threshold is the minimum pass rate (SAFE / total). e.g. 0.8 means at least 80% of runs must be SAFE. - ERROR results count against the pass rate (they're not SAFE). - - Items with zero results (skipped or crashed before producing + - Clones with zero results (skipped or crashed before producing a Result) are tracked as ``no_result`` and count against the pass rate. Args: base_nodeid (str): The original test's node ID. - trial_items (Sequence[pytest.Item]): All trial clone items. + clone_nodeids (Sequence[str]): Pytest node IDs of all clones + in this trial group. threshold (float): Minimum pass rate required. """ - if not trial_items: + if not clone_nodeids: return - total = len(trial_items) + total = len(clone_nodeids) unsafe_count = 0 error_count = 0 safe_count = 0 no_result_count = 0 - for item in trial_items: - node_results = self._results_by_nodeid.get(item.nodeid, []) + for nodeid in clone_nodeids: + node_results = self._results_by_nodeid.get(nodeid, []) if not node_results: no_result_count += 1 continue @@ -247,6 +267,54 @@ def record_trial_group( passed=passed, ) + def register_trial_spec( + self, + *, + clone_nodeid: str, + base_nodeid: str, + threshold: float, + ) -> None: + """Record trial metadata for a cloned item at collection time. + + Called from ``pytest_collection_modifyitems`` whenever a + ``@pytest.mark.trial`` test is expanded into clones. Stores + the data needed for session-end aggregation in a form that + survives the xdist worker→controller boundary. + + Identical re-registration (same key, same spec) is a no-op so + that repeated collection passes (e.g., in workers and the + controller) converge safely. + + Args: + clone_nodeid (str): Node ID of the cloned item. + base_nodeid (str): Node ID of the original (uncloned) item. + threshold (float): Pass-rate threshold from the trial marker. + """ + self._trial_specs[clone_nodeid] = TrialSpec( + base_nodeid=base_nodeid, + threshold=threshold, + ) + + def merge_trial_specs( + self, + *, + trial_specs: Mapping[str, TrialSpec], + ) -> None: + """Merge trial specs received from an xdist worker payload. + + Idempotent: re-merging identical specs is a no-op. Spec values + from workers should match the controller's own collection + because the same plugin code runs in every process; we merge + defensively so the controller can aggregate correctly even + when its own collection state is unavailable. + + Args: + trial_specs (Mapping[str, TrialSpec]): Specs keyed by + clone node ID. + """ + for clone_nodeid, spec in trial_specs.items(): + self._trial_specs.setdefault(clone_nodeid, spec) + @property def has_results(self) -> bool: """True if any results have been collected.""" @@ -257,6 +325,11 @@ def trial_groups(self) -> dict[str, TrialGroupResult]: """Trial group aggregates, keyed by base node ID.""" return dict(self._trial_groups) + @property + def trial_specs(self) -> dict[str, TrialSpec]: + """Read-only view of registered trial specs, keyed by clone node ID.""" + return dict(self._trial_specs) + def merge_worker_results( self, *, diff --git a/rampart/pytest_plugin/_xdist.py b/rampart/pytest_plugin/_xdist.py index 5633334..997a61f 100644 --- a/rampart/pytest_plugin/_xdist.py +++ b/rampart/pytest_plugin/_xdist.py @@ -49,6 +49,7 @@ ToolCall, Turn, ) +from rampart.pytest_plugin._session import TrialSpec from rampart.reporting.sink import ReportSink if TYPE_CHECKING: @@ -436,8 +437,9 @@ def serialize_worker_data(*, session: RampartSession) -> dict[str, Any]: """Serialize a worker's RampartSession state for transport to the controller. Produces a JSON-safe dict containing the schema version, the - package version (for cross-version diagnostics), and the worker's - ``_results_by_nodeid`` mapping serialized to primitive types. + package version (for cross-version diagnostics), the worker's + ``_results_by_nodeid`` mapping serialized to primitive types, + and trial specs registered during collection. Args: session (RampartSession): The worker's session state. @@ -455,6 +457,14 @@ def serialize_worker_data(*, session: RampartSession) -> dict[str, Any]: "schema": SCHEMA_VERSION, "package_version": _package_version(), "results_by_nodeid": serialized, + "trial_specs": [ + { + "clone_nodeid": clone_nodeid, + "base_nodeid": spec.base_nodeid, + "threshold": _safe_float(value=spec.threshold) or 0.0, + } + for clone_nodeid, spec in session.trial_specs.items() + ], } @@ -810,6 +820,61 @@ def deserialize_worker_data(*, data: object) -> dict[str, list[Result]]: return out +def deserialize_trial_specs(*, data: object) -> dict[str, TrialSpec]: + """Deserialize the ``trial_specs`` section of a worker payload. + + Missing or malformed entries are skipped rather than raised so + that a partially-corrupt payload still merges results. The + ``trial_specs`` field is optional: payloads without trials emit + an empty list and this function returns an empty dict. + + Args: + data (object): The deserialized JSON object from + ``node.workeroutput``. + + Returns: + dict[str, TrialSpec]: Trial specs keyed by clone node ID. + + Raises: + SchemaVersionError: Missing or unknown schema version. + WorkerOutputError: ``data`` is not a dict payload. + """ + typed = _validate_schema(data=data) + raw_specs = typed.get("trial_specs", []) + if not isinstance(raw_specs, list): + return {} + out: dict[str, TrialSpec] = {} + for spec in cast("list[Any]", raw_specs): + if not isinstance(spec, dict): + continue + spec_dict = cast("dict[str, Any]", spec) + clone_nodeid = spec_dict.get("clone_nodeid") + base_nodeid = spec_dict.get("base_nodeid") + if not isinstance(clone_nodeid, str) or not isinstance(base_nodeid, str): + continue + if not clone_nodeid or not base_nodeid: + continue + raw_threshold = spec_dict.get("threshold", 0.0) + try: + threshold = ( + float(raw_threshold) + if isinstance( + raw_threshold, + int | float, + ) + else 0.0 + ) + except (TypeError, ValueError): + threshold = 0.0 + if not math.isfinite(threshold): + threshold = 0.0 + out[clone_nodeid] = TrialSpec( + base_nodeid=base_nodeid, + threshold=threshold, + ) + return out + + def finalize_worker(*, config: pytest.Config, session: RampartSession) -> None: """Serialize the worker's session state into ``config.workeroutput``. @@ -854,6 +919,35 @@ def finalize_worker(*, config: pytest.Config, session: RampartSession) -> None: workeroutput[WORKEROUTPUT_KEY] = payload +def _safe_deserialize_trial_specs( + *, + payload: object, + worker_id_str: str, +) -> dict[str, TrialSpec]: + """Deserialize trial specs from a worker payload without raising. + + Trial specs are optional metadata: a corrupt or absent block must + never block result merging. Errors are logged at warning level and + return an empty dict. + + Args: + payload (object): The deserialized worker payload. + worker_id_str (str): Worker identifier for logging. + + Returns: + dict[str, TrialSpec]: Specs keyed by clone nodeid (possibly empty). + """ + try: + return deserialize_trial_specs(data=payload) + except WorkerOutputError as exc: + logger.warning( + "Failed to deserialize trial specs from worker %s: %s", + worker_id_str, + exc, + ) + return {} + + def handle_testnodedown( *, session: RampartSession, @@ -922,6 +1016,10 @@ def handle_testnodedown( reason=f"worker {worker_id_str} deserialization failed: {exc}", ) return + trial_specs = _safe_deserialize_trial_specs( + payload=cast("object", payload), + worker_id_str=worker_id_str, + ) incoming_version: str | None = None if typed_payload_dict is not None: raw_version = typed_payload_dict.get("package_version") @@ -936,6 +1034,8 @@ def handle_testnodedown( _package_version(), ) session.merge_worker_results(results_by_nodeid=results_by_nodeid) + if trial_specs: + session.merge_trial_specs(trial_specs=trial_specs) logger.info( "Merged %d result group(s) from worker %s.", len(results_by_nodeid), diff --git a/rampart/pytest_plugin/plugin.py b/rampart/pytest_plugin/plugin.py index 921f843..e142252 100644 --- a/rampart/pytest_plugin/plugin.py +++ b/rampart/pytest_plugin/plugin.py @@ -142,6 +142,24 @@ def _resolve_trial_n(marker: pytest.Mark) -> int: return raw +def _resolve_trial_threshold(marker: pytest.Mark) -> float: + """Extract the threshold from a trial marker. + + Returns 0.0 when no threshold is provided (the historical default). + + Args: + marker (pytest.Mark): The trial marker. + + Returns: + float: The pass-rate threshold in [0.0, 1.0]. + """ + raw: Any = marker.kwargs.get("threshold", 0.0) + try: + return float(raw) + except (TypeError, ValueError): + return 0.0 + + def pytest_addoption(parser: pytest.Parser) -> None: """Register RAMPART CLI and ini options. @@ -308,6 +326,7 @@ def pytest_collection_modifyitems( """ expanded: list[pytest.Item] = [] saw_trial = False + rampart_session = config.stash.get(_rampart_key, None) for item in items: trial_marker = item.get_closest_marker("trial") if trial_marker is None: @@ -316,11 +335,24 @@ def pytest_collection_modifyitems( saw_trial = True n = _resolve_trial_n(trial_marker) - - expanded.extend( - _create_trial_clones(item=item, trial_marker=trial_marker, count=n), + threshold = _resolve_trial_threshold(trial_marker) + clones = _create_trial_clones( + item=item, + trial_marker=trial_marker, + count=n, ) + if rampart_session is not None: + base_nodeid = item.nodeid + for clone in clones: + rampart_session.register_trial_spec( + clone_nodeid=clone.nodeid, + base_nodeid=base_nodeid, + threshold=threshold, + ) + + expanded.extend(clones) + items[:] = expanded if saw_trial and is_xdist_controller(config=config): @@ -466,30 +498,33 @@ def rampart_sinks(): def _aggregate_trial_results( *, - session: pytest.Session, + session: pytest.Session, # noqa: ARG001 — kept for hook signature symmetry rampart_session: RampartSession, ) -> None: - """Group trial item reports by base node ID and compute per-group rates. + """Group trial specs by base node ID and compute per-group rates. - A trial group is identified by ``_rampart_trial_base`` on the item. - The aggregate is stored on RampartSession for terminal summary output. + Trial specs are recorded during ``pytest_collection_modifyitems`` + on every process and shipped through the xdist worker payload so + aggregation does not depend on ``session.items`` — which is not + reliably populated with trial clones on the xdist controller at + session-finish time. Args: - session (pytest.Session): The pytest session. + session (pytest.Session): The pytest session (unused). rampart_session (RampartSession): The RAMPART session state. """ - groups: dict[str, list[pytest.Item]] = {} - for item in session.items: - base: str | None = getattr(item, "_rampart_trial_base", None) - if base is not None: - groups.setdefault(base, []).append(item) - - for base_nodeid, trial_items in groups.items(): - marker = trial_items[0].get_closest_marker("trial") - threshold = marker.kwargs.get("threshold", 0.0) if marker else 0.0 + groups: dict[str, list[tuple[str, float]]] = {} + for clone_nodeid, spec in rampart_session.trial_specs.items(): + groups.setdefault(spec.base_nodeid, []).append( + (clone_nodeid, spec.threshold), + ) + + for base_nodeid, clones in groups.items(): + # All clones of the same base share the same threshold; pick any. + threshold = clones[0][1] rampart_session.record_trial_group( base_nodeid=base_nodeid, - trial_items=trial_items, + clone_nodeids=[c[0] for c in clones], threshold=threshold, ) diff --git a/rampart/reporting/json_file.py b/rampart/reporting/json_file.py index 2294654..76916cb 100644 --- a/rampart/reporting/json_file.py +++ b/rampart/reporting/json_file.py @@ -77,6 +77,7 @@ def _serialize_report(self, report: TestRunReport) -> dict[str, Any]: "undetermined": report.undetermined, "errors": report.errors, "duration_seconds": report.duration_seconds, + "metadata": report.metadata, "population_summary": dataclasses.asdict(report.population_summary()), "by_harm_category": { category: [self._serialize_result(r) for r in results] diff --git a/tests/integration/test_xdist_aggregation.py b/tests/integration/test_xdist_aggregation.py index eb748a2..ef1590e 100644 --- a/tests/integration/test_xdist_aggregation.py +++ b/tests/integration/test_xdist_aggregation.py @@ -218,6 +218,204 @@ def test_trial_split(): assert len(reports) == 1 assert reports[0]["total_runs"] == 4 + def test_trial_group_fails_when_any_unsafe_under_loadgroup( + self, + pytester: Pytester, + ) -> None: + """An UNSAFE trial fails the whole group regardless of pass rate. + + Trial body switches on the clone name (``[trial-0]``..``[trial-3]``) + so the same outcome distribution is produced regardless of which + worker executes the clone. Three trials are SAFE and one is UNSAFE; + with threshold=0.5 the group would otherwise pass on rate alone, + so the only way the group can FAIL is if controller-side + aggregation correctly merged the worker results. + """ + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_trial_mixed=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + @pytest.mark.trial(n=4, threshold=0.5) + def test_trial_mixed(request): + # Trial-3 is UNSAFE; the rest are SAFE. With threshold=0.5 + # the group MUST FAIL on the unconditional unsafe rule. + unsafe = request.node.name.endswith("[trial-3]") + record_result(Result( + safe=not unsafe, + status=SafetyStatus.UNSAFE if unsafe else SafetyStatus.SAFE, + summary="u" if unsafe else "s", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + ) + result = pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + "--dist", + "loadgroup", + ) + # All 4 clones pass at the pytest item level — record_result + # does not fail the test; it only records a Result. + result.assert_outcomes(passed=4) + reports = _load_reports(pytester) + assert len(reports) == 1 + report = reports[0] + assert report["total_runs"] == 4 + assert report["passed"] == 3 + assert report["failed"] == 1 + # The trial-group FAIL line proves the controller correctly + # aggregated worker results. The bracketed stats uniquely + # identify the group line (the per-clone lines lack them). + summary = "\n".join(result.outlines) + assert "RAMPART Safety Summary" in summary + assert ( + "FAIL test_trial_mixed [3/4 safe, 75% pass rate, threshold: 50%]" + in summary + ) + + def test_trial_group_fails_when_any_unsafe_under_load( + self, + pytester: Pytester, + ) -> None: + """Same as above but with --dist=load so clones may split workers. + + The PR docs claim aggregation remains correct under --dist=load + because the controller merges all worker results. This test + protects that contract: an UNSAFE clone produced on any worker + must propagate into the controller's trial-group verdict. + """ + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_trial_mixed_load=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + @pytest.mark.trial(n=4, threshold=0.5) + def test_trial_mixed_load(request): + unsafe = request.node.name.endswith("[trial-3]") + record_result(Result( + safe=not unsafe, + status=SafetyStatus.UNSAFE if unsafe else SafetyStatus.SAFE, + summary="u" if unsafe else "s", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + ) + result = pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + "--dist", + "load", + ) + result.assert_outcomes(passed=4) + reports = _load_reports(pytester) + assert len(reports) == 1 + report = reports[0] + assert report["total_runs"] == 4 + assert report["failed"] == 1 + summary = "\n".join(result.outlines) + assert ( + "FAIL test_trial_mixed_load [3/4 safe, 75% pass rate, threshold: 50%]" + in summary + ) + + def test_trial_group_fails_below_threshold_under_loadgroup( + self, + pytester: Pytester, + ) -> None: + """No UNSAFE results, but pass rate below threshold => FAIL. + + 2 SAFE + 2 UNDETERMINED trials, threshold=0.75. Pass rate is 0.5 + so the group must FAIL on the threshold rule (not the unsafe rule). + """ + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_trial_threshold=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + @pytest.mark.trial(n=4, threshold=0.75) + def test_trial_threshold(request): + undetermined = request.node.name.endswith( + ("[trial-2]", "[trial-3]"), + ) + record_result(Result( + safe=True, + status=( + SafetyStatus.UNDETERMINED + if undetermined else SafetyStatus.SAFE + ), + summary="t", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + ) + result = pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + "--dist", + "loadgroup", + ) + # All 4 clones pass as pytest tests (record_result(safe=True)), + # but the trial GROUP should fail on threshold. + result.assert_outcomes(passed=4) + summary = "\n".join(result.outlines) + assert "FAIL test_trial_threshold" in summary + assert "50% pass rate" in summary + assert "threshold: 75%" in summary + + def test_trial_group_passes_when_all_safe_under_loadgroup( + self, + pytester: Pytester, + ) -> None: + """All-SAFE trial group with achievable threshold => PASS verdict.""" + pytester.makeconftest(_CONFTEST) + pytester.makepyfile( # pyright: ignore[reportUnknownMemberType] + test_trial_all_safe=""" + import pytest + from rampart import record_result + from rampart.core.result import Result, SafetyStatus + from rampart.core.types import ObservabilityLevel + + @pytest.mark.harm("test") + @pytest.mark.trial(n=3, threshold=0.5) + def test_trial_all_safe(): + record_result(Result( + safe=True, status=SafetyStatus.SAFE, summary="ok", + observability_level=ObservabilityLevel.RESPONSE_ONLY, + )) + """, + ) + result = pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + "--dist", + "loadgroup", + ) + result.assert_outcomes(passed=3) + summary = "\n".join(result.outlines) + assert "PASS test_trial_all_safe" in summary + assert "PASSED" in summary + class TestXdistMetadata: def test_report_includes_xdist_metadata(self, pytester: Pytester) -> None: @@ -225,12 +423,33 @@ def test_report_includes_xdist_metadata(self, pytester: Pytester) -> None: pytester.runpytest("-p", "no:cacheprovider", "-n", "2") reports = _load_reports(pytester) assert len(reports) == 1 - # Population summary is exposed in JSON; xdist metadata lives in - # TestRunReport.metadata which is rendered when present. - # The JsonFileReportSink does not currently project metadata, - # so we just verify the report exists with the right shape. + metadata = reports[0].get("metadata", {}) + assert metadata.get("xdist_active") is True + assert metadata.get("worker_count") == 2 + assert "dist_mode" in metadata assert "population_summary" in reports[0] + def test_size_cap_marks_run_incomplete(self, pytester: Pytester) -> None: + """Forcing a 1-byte cap surfaces incompleteness in report metadata. + + Triggers the truncation path so the controller must record + ``incomplete=True`` plus a reason in the merged report. + """ + _setup_simple_tests(pytester) + pytester.runpytest( + "-p", + "no:cacheprovider", + "-n", + "2", + "--rampart-xdist-max-bytes=1", + ) + reports = _load_reports(pytester) + assert len(reports) == 1 + metadata = reports[0].get("metadata", {}) + assert metadata.get("incomplete") is True + reasons = metadata.get("incomplete_reasons", []) + assert any("truncated" in r for r in reasons) + class TestCollectOnly: def test_collect_only_does_not_emit_reports(self, pytester: Pytester) -> None: diff --git a/tests/unit/pytest_plugin/test_plugin.py b/tests/unit/pytest_plugin/test_plugin.py index e929e2c..8e329d9 100644 --- a/tests/unit/pytest_plugin/test_plugin.py +++ b/tests/unit/pytest_plugin/test_plugin.py @@ -184,7 +184,7 @@ def test_record_trial_group(self) -> None: session.record_trial_group( base_nodeid="test_example", - trial_items=items, + clone_nodeids=[item.nodeid for item in items], threshold=0.3, ) @@ -217,7 +217,7 @@ def test_record_trial_group_all_errors(self) -> None: session.record_trial_group( base_nodeid="test_err", - trial_items=items, + clone_nodeids=[item.nodeid for item in items], threshold=0.0, ) @@ -231,7 +231,7 @@ def test_record_trial_group_empty_items_noop(self) -> None: session = RampartSession() session.record_trial_group( base_nodeid="test_empty", - trial_items=[], + clone_nodeids=[], threshold=0.0, ) assert "test_empty" not in session.trial_groups @@ -636,7 +636,7 @@ def test_writes_trial_group_line(self) -> None: session.record_trial_group( base_nodeid="test_file.py::test_stat", - trial_items=items, + clone_nodeids=[item.nodeid for item in items], threshold=0.3, ) @@ -677,7 +677,7 @@ def test_logs_when_rate_exceeds_threshold(self) -> None: session.record_trial_group( base_nodeid="test.py::test_gate", - trial_items=items, + clone_nodeids=[item.nodeid for item in items], threshold=0.1, ) diff --git a/tests/unit/pytest_plugin/test_xdist.py b/tests/unit/pytest_plugin/test_xdist.py index af95f6b..821f1ae 100644 --- a/tests/unit/pytest_plugin/test_xdist.py +++ b/tests/unit/pytest_plugin/test_xdist.py @@ -31,7 +31,7 @@ ToolCall, Turn, ) -from rampart.pytest_plugin._session import RampartSession +from rampart.pytest_plugin._session import RampartSession, TrialSpec from rampart.pytest_plugin._xdist import ( DEFAULT_SIZE_LIMIT_BYTES, MAX_METADATA_DEPTH, @@ -43,6 +43,7 @@ WorkerOutputError, _sanitize, _strip_ansi, + deserialize_trial_specs, deserialize_worker_data, discover_sinks_from_conftest, finalize_worker, @@ -674,6 +675,115 @@ def test_merges_results_on_success(self) -> None: assert len(session._results) == 1 assert session._results[0].summary == "from-worker" + def test_merges_trial_specs_on_success(self) -> None: + session = RampartSession() + worker_session = RampartSession() + worker_session.register_trial_spec( + clone_nodeid="test.py::test_x[trial-0]", + base_nodeid="test.py::test_x", + threshold=0.8, + ) + worker_session.register_trial_spec( + clone_nodeid="test.py::test_x[trial-1]", + base_nodeid="test.py::test_x", + threshold=0.8, + ) + payload = serialize_worker_data(session=worker_session) + node = MagicMock() + node.gateway.id = "gw1" + node.workeroutput = {WORKEROUTPUT_KEY: payload} + handle_testnodedown(session=session, node=node, error=None) + assert session.is_incomplete is False + assert set(session.trial_specs) == { + "test.py::test_x[trial-0]", + "test.py::test_x[trial-1]", + } + assert ( + session.trial_specs["test.py::test_x[trial-0]"].base_nodeid + == "test.py::test_x" + ) + assert session.trial_specs["test.py::test_x[trial-0]"].threshold == 0.8 + + +class TestTrialSpecs: + def test_serialize_round_trip(self) -> None: + session = RampartSession() + session.register_trial_spec( + clone_nodeid="t.py::a[trial-0]", + base_nodeid="t.py::a", + threshold=0.75, + ) + session.register_trial_spec( + clone_nodeid="t.py::a[trial-1]", + base_nodeid="t.py::a", + threshold=0.75, + ) + payload = serialize_worker_data(session=session) + + # Payload must survive a JSON round-trip (xdist transports JSON). + decoded = json.loads(json.dumps(payload)) + specs = deserialize_trial_specs(data=decoded) + + assert specs == { + "t.py::a[trial-0]": TrialSpec(base_nodeid="t.py::a", threshold=0.75), + "t.py::a[trial-1]": TrialSpec(base_nodeid="t.py::a", threshold=0.75), + } + + def test_payload_without_trials_returns_empty_dict(self) -> None: + session = RampartSession() + payload = serialize_worker_data(session=session) + assert deserialize_trial_specs(data=payload) == {} + + def test_skips_malformed_entries(self) -> None: + data: dict[str, Any] = { + "schema": SCHEMA_VERSION, + "results_by_nodeid": {}, + "trial_specs": [ + {"clone_nodeid": "ok", "base_nodeid": "b", "threshold": 0.5}, + "not-a-dict", + {"clone_nodeid": "", "base_nodeid": "b", "threshold": 0.5}, + {"clone_nodeid": "x", "base_nodeid": 123, "threshold": 0.5}, + {"clone_nodeid": "y", "base_nodeid": "b"}, + ], + } + specs = deserialize_trial_specs(data=data) + assert set(specs) == {"ok", "y"} + assert specs["y"].threshold == 0.0 + + def test_clamps_non_finite_threshold(self) -> None: + data: dict[str, Any] = { + "schema": SCHEMA_VERSION, + "results_by_nodeid": {}, + "trial_specs": [ + {"clone_nodeid": "a", "base_nodeid": "b", "threshold": float("inf")}, + {"clone_nodeid": "c", "base_nodeid": "d", "threshold": float("nan")}, + ], + } + specs = deserialize_trial_specs(data=data) + assert specs["a"].threshold == 0.0 + assert specs["c"].threshold == 0.0 + + def test_merge_is_idempotent(self) -> None: + session = RampartSession() + spec = TrialSpec(base_nodeid="b", threshold=0.5) + session.merge_trial_specs(trial_specs={"k": spec}) + session.merge_trial_specs(trial_specs={"k": spec}) + assert session.trial_specs == {"k": spec} + + def test_merge_first_writer_wins(self) -> None: + session = RampartSession() + original = TrialSpec(base_nodeid="b1", threshold=0.5) + replacement = TrialSpec(base_nodeid="b2", threshold=0.9) + session.merge_trial_specs(trial_specs={"k": original}) + session.merge_trial_specs(trial_specs={"k": replacement}) + # Defensive: the first registered spec wins so a worker can't + # silently override what the controller already saw at collection. + assert session.trial_specs["k"] == original + + def test_invalid_payload_raises(self) -> None: + with pytest.raises(WorkerOutputError): + deserialize_trial_specs(data="not a dict") + class TestFinalizeWorker: def test_no_op_on_controller(self) -> None: diff --git a/tests/unit/reporting/test_json_file.py b/tests/unit/reporting/test_json_file.py index de7a6bb..0638493 100644 --- a/tests/unit/reporting/test_json_file.py +++ b/tests/unit/reporting/test_json_file.py @@ -225,3 +225,49 @@ async def test_emitted_file_contains_metadata(self, tmp_path: Path) -> None: assert category_results[0]["turns"][0]["response_metadata"] == { "page_url": "https://example.com/chat", } + + +class TestReportMetadata: + """Run-level TestRunReport.metadata is projected into the JSON output.""" + + def test_report_metadata_appears_in_serialized_output(self) -> None: + sink = JsonFileReportSink(output_dir=Path("/tmp")) + report = TestRunReport( + metadata={ + "xdist_active": True, + "worker_count": 4, + "dist_mode": "loadgroup", + }, + ) + + data = sink._serialize_report(report) + + assert data["metadata"] == { + "xdist_active": True, + "worker_count": 4, + "dist_mode": "loadgroup", + } + + def test_incomplete_run_metadata_appears_in_serialized_output(self) -> None: + sink = JsonFileReportSink(output_dir=Path("/tmp")) + report = TestRunReport( + metadata={ + "incomplete": True, + "incomplete_reasons": ["worker gw1 payload truncated (size cap)"], + }, + ) + + data = sink._serialize_report(report) + + assert data["metadata"]["incomplete"] is True + assert data["metadata"]["incomplete_reasons"] == [ + "worker gw1 payload truncated (size cap)", + ] + + def test_empty_metadata_serializes_as_empty_dict(self) -> None: + sink = JsonFileReportSink(output_dir=Path("/tmp")) + report = TestRunReport() + + data = sink._serialize_report(report) + + assert data["metadata"] == {}