diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c97ad3..8f67001 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,17 @@ under their entry. ## Unreleased ### Added +- `analyze_dataset`: deterministic one-shot orchestrator that runs the + full first-look pipeline (inspect_mesh → validate_dataset → + inspect_variable → calculate_area → calculate_zonal_mean → plot_mesh → + plot_variable) and returns a single structured result with + `recommended_next_steps`. Accepts direct paths or + `session_id` + `dataset_handle`. Forwards `use_remote`/`endpoint` to + every underlying stage. (#32) +- `recommended_next_steps` field on result-bearing tools to guide + agent chaining: `calculate_zonal_mean`, `validate_dataset` (branched + on pass/fail), `subset_bbox`, `subset_polygon`, + `extract_cross_section`. (#30) - `uxarray-mcp` CLI entry point with subcommands: `serve`, `setup`, `endpoints add/list/remove`, `doctor`, `install-claude`. - Multi-endpoint config schema (`hpc.endpoints.` with diff --git a/src/uxarray_mcp/server.py b/src/uxarray_mcp/server.py index eaf4d23..d06f573 100644 --- a/src/uxarray_mcp/server.py +++ b/src/uxarray_mcp/server.py @@ -9,6 +9,7 @@ from fastmcp import FastMCP from uxarray_mcp.tools import ( + analyze_dataset, calculate_anomaly, calculate_area_hpc, calculate_bias, @@ -58,6 +59,9 @@ # Tool discovery — always call this first with a new dataset mcp.tool()(get_capabilities) +# Deterministic one-shot analysis — full first-look pipeline in a single call +mcp.tool()(analyze_dataset) + # Autonomous scientific agent — Analyze → Plan → Execute → Verify mcp.tool()(run_scientific_agent) mcp.tool()(run_workflow) diff --git a/src/uxarray_mcp/tools/__init__.py b/src/uxarray_mcp/tools/__init__.py index 93f63aa..5f694a4 100644 --- a/src/uxarray_mcp/tools/__init__.py +++ b/src/uxarray_mcp/tools/__init__.py @@ -35,6 +35,7 @@ inspect_variable, validate_dataset, ) +from .orchestration import analyze_dataset from .plotting import plot_mesh, plot_variable, plot_zonal_mean from .remote_tools import ( calculate_area_hpc, @@ -62,6 +63,7 @@ __all__ = [ "get_capabilities", "list_datasets", + "analyze_dataset", "run_scientific_agent", "run_workflow", "resume_workflow", diff --git a/src/uxarray_mcp/tools/advanced.py b/src/uxarray_mcp/tools/advanced.py index 1964379..67b7c27 100644 --- a/src/uxarray_mcp/tools/advanced.py +++ b/src/uxarray_mcp/tools/advanced.py @@ -199,6 +199,16 @@ def subset_bbox( "variable_summary": variable_summary, "result_handle": result_handle, } + next_steps = [ + f'plot_mesh(grid_path="{resolved_grid}")', + f'export_to_netcdf("", result_handle="{result_handle}")', + ] + if resolved_data is not None: + next_steps.insert( + 0, + f'plot_variable("{resolved_grid}", "{resolved_data}", "")', + ) + result["recommended_next_steps"] = next_steps tracker.succeed("Bounding-box subset complete.") result = attach_provenance( result, @@ -293,6 +303,16 @@ def subset_polygon( "variable_summary": variable_summary, "result_handle": result_handle, } + next_steps = [ + f'plot_mesh(grid_path="{resolved_grid}")', + f'export_to_netcdf("", result_handle="{result_handle}")', + ] + if resolved_data is not None: + next_steps.insert( + 0, + f'plot_variable("{resolved_grid}", "{resolved_data}", "")', + ) + result["recommended_next_steps"] = next_steps result = attach_provenance( result, tool="subset_polygon", @@ -379,6 +399,16 @@ def extract_cross_section( "variable_summary": variable_summary, "result_handle": result_handle, } + next_steps = [ + f'plot_mesh(grid_path="{resolved_grid}")', + f'export_to_netcdf("", result_handle="{result_handle}")', + ] + if resolved_data is not None: + next_steps.insert( + 0, + f'calculate_zonal_mean("{resolved_grid}", "{resolved_data}", "")', + ) + result["recommended_next_steps"] = next_steps result = attach_provenance( result, tool="extract_cross_section", diff --git a/src/uxarray_mcp/tools/inspection.py b/src/uxarray_mcp/tools/inspection.py index dca03b8..184054c 100644 --- a/src/uxarray_mcp/tools/inspection.py +++ b/src/uxarray_mcp/tools/inspection.py @@ -302,6 +302,12 @@ def calculate_zonal_mean( except Exception as e: raise RuntimeError(f"Failed to compute zonal mean: {str(e)}") + result["recommended_next_steps"] = [ + f'plot_zonal_mean("{grid_path}", "{data_path}", "{variable_name}")', + f'plot_variable("{grid_path}", "{data_path}", "{variable_name}")', + f'extract_cross_section(latitude=0.0, grid_path="{grid_path}", ' + f'data_path="{data_path}", variable_name="{variable_name}")', + ] return attach_provenance( result, tool="calculate_zonal_mean", @@ -530,6 +536,19 @@ def validate_dataset(grid_path: str, data_path: str) -> Dict[str, Any]: "issues": issues, } + if overall_passed: + result["recommended_next_steps"] = [ + f'inspect_variable("{grid_path}", "{data_path}")', + f'calculate_zonal_mean("{grid_path}", "{data_path}", "")', + f'plot_variable("{grid_path}", "{data_path}", "")', + ] + else: + result["recommended_next_steps"] = [ + "Validation failed; review the per-variable warnings above before " + "running analysis.", + "Drop or repair affected variables, or rerun with a corrected data file.", + ] + return attach_provenance( result, tool="validate_dataset", diff --git a/src/uxarray_mcp/tools/orchestration.py b/src/uxarray_mcp/tools/orchestration.py new file mode 100644 index 0000000..979a093 --- /dev/null +++ b/src/uxarray_mcp/tools/orchestration.py @@ -0,0 +1,328 @@ +"""Deterministic one-shot orchestration tools. + +These tools run a fixed pipeline of inspection, validation, and analysis +calls in sequence and return a single structured result. They are the +"do everything reasonable" counterpart to ``run_scientific_agent``: no +LLM reasoning, no branching heuristics — just a predictable chain that +turns a single user invocation into a full first look at a dataset. +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any, Optional + +from uxarray_mcp.provenance import attach_provenance + + +def _safe_call(stage: str, fn, warnings: list[str]) -> Optional[Any]: + """Run ``fn``; on failure append a warning and return None.""" + try: + return fn() + except Exception as exc: + warnings.append(f"{stage}: {type(exc).__name__}: {exc}") + return None + + +def _png_meta(items: list[Any]) -> dict[str, Any]: + """Convert a plot tool's ``[ImageContent, TextContent]`` list to a dict.""" + if not items or len(items) < 2: + return {} + img = items[0] + try: + meta = json.loads(items[1].text) + except Exception: + meta = {} + png_b64 = getattr(img, "data", None) + image_size_bytes = meta.get("image_size_bytes") + if image_size_bytes is None and png_b64: + try: + image_size_bytes = len(base64.b64decode(png_b64)) + except Exception: + image_size_bytes = None + return { + "png_b64": png_b64, + "image_size_bytes": image_size_bytes, + "grid_info": meta.get("grid_info"), + "variable_name": meta.get("variable_name"), + } + + +def analyze_dataset( + grid_path: Optional[str] = None, + data_path: Optional[str] = None, + variable_name: Optional[str] = None, + session_id: Optional[str] = None, + dataset_handle: Optional[str] = None, + use_remote: bool = False, + endpoint: Optional[str] = None, + include_plots: bool = True, +) -> dict[str, Any]: + """Run a complete first-look analysis of a mesh dataset in one call. + + Executes a fixed deterministic pipeline: + + 1. ``inspect_mesh`` — topology summary + 2. ``validate_dataset`` — NaN/Inf/fill checks (only if ``data_path``) + 3. ``inspect_variable`` — variable metadata (only if ``data_path``) + 4. ``calculate_area`` — face area statistics + 5. ``calculate_zonal_mean`` — zonal profile of the first face-centered + variable (only if ``data_path`` and a face-centered variable exists) + 6. ``plot_mesh`` — wireframe PNG (only if ``include_plots``) + 7. ``plot_variable`` — choropleth PNG of the chosen variable (only if + ``include_plots`` and a face-centered variable exists) + + Each stage is run defensively — a failure in any single stage is + recorded in ``warnings`` and the pipeline continues. + + Parameters + ---------- + grid_path : str | None + Path to the mesh grid file or ``healpix:``. Optional when + ``session_id`` and ``dataset_handle`` are provided. + data_path : str | None + Path to a data file with variables. If omitted, only mesh-level + stages run. + variable_name : str | None + Specific face-centered variable to analyze. If omitted, the first + face-centered variable discovered by ``inspect_variable`` is used. + session_id, dataset_handle : str | None + When both are provided, the grid/data paths are resolved from the + registered session dataset. + use_remote : bool + Forwarded to each underlying ``*_hpc`` dispatcher. + endpoint : str | None + Forwarded to each underlying ``*_hpc`` dispatcher. + include_plots : bool + When False, the two plot stages are skipped (useful for headless + callers that just want statistics). + + Returns + ------- + dict + A structured result with keys: + + - ``grid_path``, ``data_path``: resolved input paths + - ``mesh``: ``inspect_mesh`` result (or ``None`` on failure) + - ``validation``: ``validate_dataset`` result (or ``None``) + - ``variables``: ``inspect_variable`` result (or ``None``) + - ``area``: ``calculate_area`` result (or ``None``) + - ``selected_variable``: name of the face-centered variable used + - ``zonal_mean``: ``calculate_zonal_mean`` result (or ``None``) + - ``mesh_plot``: PNG metadata dict (or ``None``) + - ``variable_plot``: PNG metadata dict (or ``None``) + - ``stages_run``: list of stage names that completed successfully + - ``warnings``: list of stage failures (empty when everything ran) + - ``recommended_next_steps``: chained-tool suggestions for the + agent to act on after seeing this result + - ``_provenance``: standard provenance block + """ + from .inspection import inspect_mesh + from .plotting import _resolve_plot_paths + from .remote_tools import ( + calculate_area_hpc, + calculate_zonal_mean_hpc, + inspect_mesh_hpc, + inspect_variable_hpc, + plot_mesh_hpc, + plot_variable_hpc, + ) + + resolved_grid, resolved_data = _resolve_plot_paths( + grid_path, data_path, session_id, dataset_handle + ) + + warnings: list[str] = [] + stages_run: list[str] = [] + + # ── Stage 1: inspect mesh ──────────────────────────────────────────────── + mesh = _safe_call( + "inspect_mesh", + lambda: inspect_mesh_hpc( + resolved_grid, + use_remote=use_remote, + endpoint=endpoint, + session_id=session_id, + ), + warnings, + ) + if mesh is not None: + stages_run.append("inspect_mesh") + + # ── Stage 2 + 3: validate + inspect variables (data path required) ────── + validation: Optional[dict[str, Any]] = None + variables: Optional[dict[str, Any]] = None + selected_variable: Optional[str] = variable_name + + if resolved_data is not None: + # validate_dataset is local-only by design (data is read directly); + # if the dispatcher were ever added, prefer that. For now use the + # local variant via inspection.validate_dataset. + from .inspection import validate_dataset as _validate_dataset + + validation = _safe_call( + "validate_dataset", + lambda: _validate_dataset(resolved_grid, resolved_data), + warnings, + ) + if validation is not None: + stages_run.append("validate_dataset") + + variables = _safe_call( + "inspect_variable", + lambda: inspect_variable_hpc( + resolved_grid, + resolved_data, + variable_name, + use_remote=use_remote, + endpoint=endpoint, + session_id=session_id, + ), + warnings, + ) + if variables is not None: + stages_run.append("inspect_variable") + if selected_variable is None: + for var in variables.get("variables", []): + if var.get("location") == "faces": + selected_variable = var.get("name") + break + + # ── Stage 4: face areas ────────────────────────────────────────────────── + area = _safe_call( + "calculate_area", + lambda: calculate_area_hpc( + resolved_grid, + use_remote=use_remote, + endpoint=endpoint, + session_id=session_id, + ), + warnings, + ) + if area is not None: + stages_run.append("calculate_area") + + # ── Stage 5: zonal mean (needs data + face-centered variable) ─────────── + zonal_mean: Optional[dict[str, Any]] = None + if resolved_data is not None and selected_variable is not None: + zonal_mean = _safe_call( + "calculate_zonal_mean", + lambda: calculate_zonal_mean_hpc( + resolved_grid, + resolved_data, + selected_variable, + use_remote=use_remote, + endpoint=endpoint, + session_id=session_id, + ), + warnings, + ) + if zonal_mean is not None: + stages_run.append("calculate_zonal_mean") + + # ── Stage 6 + 7: plots (optional) ──────────────────────────────────────── + mesh_plot: Optional[dict[str, Any]] = None + variable_plot: Optional[dict[str, Any]] = None + + if include_plots: + plot_items = _safe_call( + "plot_mesh", + lambda: plot_mesh_hpc( + grid_path=resolved_grid, + use_remote=use_remote, + endpoint=endpoint, + session_id=session_id, + ), + warnings, + ) + if plot_items is not None: + mesh_plot = _png_meta(plot_items) + stages_run.append("plot_mesh") + + if resolved_data is not None and selected_variable is not None: + var_plot_items = _safe_call( + "plot_variable", + lambda: plot_variable_hpc( + grid_path=resolved_grid, + data_path=resolved_data, + variable_name=selected_variable, + use_remote=use_remote, + endpoint=endpoint, + session_id=session_id, + ), + warnings, + ) + if var_plot_items is not None: + variable_plot = _png_meta(var_plot_items) + stages_run.append("plot_variable") + + # ── Recommended next steps ────────────────────────────────────────────── + next_steps: list[str] = [] + if resolved_data is None: + next_steps.append( + f'inspect_variable("{resolved_grid}", "") ' + "— rerun with a data file to unlock variable analysis" + ) + if validation is not None and validation.get("passed") is False: + next_steps.append( + "Validation failed; review the per-variable warnings before " + "trusting downstream results." + ) + if selected_variable and resolved_data is not None: + next_steps.append( + f'plot_zonal_mean("{resolved_grid}", "{resolved_data}", ' + f'"{selected_variable}") — render the zonal profile' + ) + next_steps.append( + f'extract_cross_section(latitude=0.0, grid_path="{resolved_grid}", ' + f'data_path="{resolved_data}", variable_name="{selected_variable}")' + ) + next_steps.append( + f"subset_bbox(lon_bounds=[-180, 180], lat_bounds=[-90, 90], " + f'grid_path="{resolved_grid}", data_path="{resolved_data}", ' + f'variable_name="{selected_variable}") — focus on a region' + ) + if not next_steps: + next_steps.append( + f'plot_mesh(grid_path="{resolved_grid}") — visualize the mesh wireframe' + ) + + result: dict[str, Any] = { + "grid_path": resolved_grid, + "data_path": resolved_data, + "mesh": mesh, + "validation": validation, + "variables": variables, + "area": area, + "selected_variable": selected_variable, + "zonal_mean": zonal_mean, + "mesh_plot": mesh_plot, + "variable_plot": variable_plot, + "stages_run": stages_run, + "warnings": warnings, + "recommended_next_steps": next_steps, + } + + # Ensure inspect_mesh is callable for the local-only smoke fallback path + # (used by tests that import `inspect_mesh` directly from this module). + _ = inspect_mesh # keep import alive + + venue = "hpc" if use_remote else "local" + return attach_provenance( + result, + tool="analyze_dataset", + inputs={ + "grid_path": grid_path, + "data_path": data_path, + "variable_name": variable_name, + "session_id": session_id, + "dataset_handle": dataset_handle, + "use_remote": use_remote, + "endpoint": endpoint, + "include_plots": include_plots, + }, + venue=venue, + warnings=warnings, + selected_variable=selected_variable, + ) diff --git a/tests/test_analyze_dataset.py b/tests/test_analyze_dataset.py new file mode 100644 index 0000000..1a7336e --- /dev/null +++ b/tests/test_analyze_dataset.py @@ -0,0 +1,117 @@ +"""Tests for analyze_dataset (issue #32). + +The tool runs a fixed pipeline of inspection, validation, and analysis +calls and returns a single structured result with provenance and +recommended next steps. +""" + +from uxarray_mcp.tools import ( + analyze_dataset, + create_session, + register_dataset, +) + + +def test_analyze_dataset_healpix_no_data(): + """Without a data file the mesh-only stages should still run.""" + result = analyze_dataset("healpix:2", include_plots=False) + + assert result["mesh"] is not None + assert result["mesh"]["n_face"] == 192 + assert result["area"] is not None + assert result["validation"] is None + assert result["variables"] is None + assert result["zonal_mean"] is None + assert result["selected_variable"] is None + assert result["warnings"] == [] + assert "inspect_mesh" in result["stages_run"] + assert "calculate_area" in result["stages_run"] + # No data, so first hint should ask the agent to provide a data file. + assert any("data_path" in s for s in result["recommended_next_steps"]) + assert result["_provenance"]["tool"] == "analyze_dataset" + + +def test_analyze_dataset_with_data_runs_full_pipeline(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + result = analyze_dataset(grid_file, data_file, include_plots=False) + + assert result["mesh"] is not None + assert result["validation"] is not None + assert result["validation"]["passed"] is True + assert result["variables"] is not None + assert result["area"] is not None + assert result["selected_variable"] in {"temperature", "pressure"} + assert result["zonal_mean"] is not None + for required in ( + "inspect_mesh", + "validate_dataset", + "inspect_variable", + "calculate_area", + "calculate_zonal_mean", + ): + assert required in result["stages_run"] + + +def test_analyze_dataset_includes_plots_by_default(): + """Plot stages should produce base64 PNGs when include_plots=True.""" + result = analyze_dataset("healpix:2") + + assert result["mesh_plot"] is not None + assert result["mesh_plot"]["png_b64"] + assert "plot_mesh" in result["stages_run"] + + +def test_analyze_dataset_resolves_session_dataset_handle(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + session = create_session("analyze-handle-test") + registered = register_dataset( + session_id=session["session_id"], + grid_path=grid_file, + data_path=data_file, + name="grid-and-data", + ) + + result = analyze_dataset( + session_id=session["session_id"], + dataset_handle=registered["dataset_handle"], + include_plots=False, + ) + + assert result["grid_path"] == grid_file + assert result["data_path"] == data_file + assert result["mesh"]["n_face"] == 1 + assert result["validation"]["passed"] is True + assert result["selected_variable"] in {"temperature", "pressure"} + + +def test_analyze_dataset_failure_in_one_stage_does_not_abort(monkeypatch): + """A stage error should be captured in warnings; other stages still run.""" + from uxarray_mcp.tools import remote_tools + + def boom(*args, **kwargs): + raise RuntimeError("simulated area failure") + + monkeypatch.setattr(remote_tools, "calculate_area_hpc", boom) + + result = analyze_dataset("healpix:2", include_plots=False) + + assert result["mesh"] is not None + assert result["area"] is None + assert any("calculate_area" in w for w in result["warnings"]) + assert "inspect_mesh" in result["stages_run"] + assert "calculate_area" not in result["stages_run"] + + +def test_analyze_dataset_recommended_next_steps_present(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + result = analyze_dataset(grid_file, data_file, include_plots=False) + + steps = result["recommended_next_steps"] + assert isinstance(steps, list) and len(steps) >= 1 + joined = " ".join(steps) + # With a face-centered variable in scope, the chain should suggest + # plotting / cross-section / subsetting follow-ups. + assert any( + kw in joined + for kw in ("plot_zonal_mean", "extract_cross_section", "subset_bbox") + ) diff --git a/tests/test_recommended_next_steps.py b/tests/test_recommended_next_steps.py new file mode 100644 index 0000000..3f01c05 --- /dev/null +++ b/tests/test_recommended_next_steps.py @@ -0,0 +1,119 @@ +"""Tests for the ``recommended_next_steps`` field added in issue #30. + +Every result-bearing tool should suggest follow-up tool calls so an agent +can chain a workflow without already knowing the tool vocabulary. +""" + +from uxarray_mcp.tools.advanced import ( + extract_cross_section, + subset_bbox, + subset_polygon, +) +from uxarray_mcp.tools.inspection import ( + calculate_zonal_mean, + validate_dataset, +) + + +def _is_str_list(value) -> bool: + return isinstance(value, list) and all(isinstance(v, str) for v in value) + + +def test_calculate_zonal_mean_recommends_next_steps(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + result = calculate_zonal_mean(grid_file, data_file, "temperature") + + steps = result["recommended_next_steps"] + assert _is_str_list(steps) and len(steps) >= 2 + joined = " ".join(steps) + assert "plot_zonal_mean" in joined + assert "plot_variable" in joined + + +def test_validate_dataset_pass_recommends_analysis_chain(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + result = validate_dataset(grid_file, data_file) + + assert result["passed"] is True + steps = result["recommended_next_steps"] + assert _is_str_list(steps) + joined = " ".join(steps) + assert "inspect_variable" in joined + assert "calculate_zonal_mean" in joined + + +def test_validate_dataset_fail_recommends_stop(synthetic_mesh_file, tmp_path): + """Synthetic data file with NaN should trigger the failure-branch hint.""" + import xarray as xr + + bad_file = tmp_path / "bad.nc" + xr.Dataset( + {"temperature": (["nMesh2_face"], [float("nan")], {"units": "K"})} + ).to_netcdf(bad_file) + + result = validate_dataset(synthetic_mesh_file, str(bad_file)) + + assert result["passed"] is False + steps = result["recommended_next_steps"] + assert _is_str_list(steps) + joined = " ".join(steps).lower() + assert "validation failed" in joined + + +def test_subset_bbox_recommends_plot_and_export(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + result = subset_bbox( + lon_bounds=[-180.0, 180.0], + lat_bounds=[-90.0, 90.0], + grid_path=grid_file, + data_path=data_file, + variable_name="temperature", + ) + + steps = result["recommended_next_steps"] + assert _is_str_list(steps) + joined = " ".join(steps) + assert "plot_mesh" in joined + assert "export_to_netcdf" in joined + # When data is provided, plot_variable should lead the list. + assert "plot_variable" in steps[0] + + +def test_subset_polygon_recommends_plot_and_export(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + result = subset_polygon( + polygon_lon_lat=[ + [-180.0, -90.0], + [180.0, -90.0], + [180.0, 90.0], + [-180.0, 90.0], + ], + grid_path=grid_file, + data_path=data_file, + variable_name="temperature", + ) + + steps = result["recommended_next_steps"] + assert _is_str_list(steps) + joined = " ".join(steps) + assert "plot_mesh" in joined + assert "export_to_netcdf" in joined + + +def test_extract_cross_section_recommends_zonal_mean(synthetic_mesh_with_data): + grid_file, data_file = synthetic_mesh_with_data + result = extract_cross_section( + latitude=0.5, + longitude=None, + grid_path=grid_file, + data_path=data_file, + variable_name="temperature", + session_id=None, + dataset_handle=None, + result_name=None, + ) + + steps = result["recommended_next_steps"] + assert _is_str_list(steps) + joined = " ".join(steps) + assert "calculate_zonal_mean" in joined