diff --git a/changelog.d/variable-graph.added.md b/changelog.d/variable-graph.added.md new file mode 100644 index 00000000..11ce0773 --- /dev/null +++ b/changelog.d/variable-graph.added.md @@ -0,0 +1 @@ +Added ``policyengine.graph`` — a static-analysis-based variable dependency graph for PolicyEngine source trees. ``extract_from_path(path)`` walks a directory of Variable subclasses, parses formula-method bodies for ``entity("", period)`` and ``add(entity, period, [list])`` references, and returns a ``VariableGraph``. Queries include ``deps(var)`` (direct dependencies), ``impact(var)`` (transitive downstream), and ``path(src, dst)`` (shortest dependency chain). No runtime dependency on country models — indexes ``policyengine-us`` (4,577 variables) in under a second. diff --git a/docs/_generator/README.md b/docs/_generator/README.md new file mode 100644 index 00000000..ef5c7268 --- /dev/null +++ b/docs/_generator/README.md @@ -0,0 +1,52 @@ +# Reference generator prototype + +Auto-generates one Quarto page per variable in a country model, plus a program-coverage page, purely from metadata on the `Variable` classes and `programs.yaml`. + +## Run + +```bash +# Full US reference (takes a couple of minutes — 4,686 variables) +python docs/_generator/build_reference.py --country us --out docs/_generated/reference/us + +# Preview a filtered subset +python docs/_generator/build_reference.py --country us --filter chip --out /tmp/ref-preview +``` + +Then render: + +```bash +cd /tmp/ref-preview && quarto render +``` + +## What's generated from code alone + +Per variable: + +- Title and identifier +- Metadata table: entity, value type, unit, period, `defined_for` gate +- Documentation (docstring) +- Components (`adds` / `subtracts` lists) +- Statutory references (from `reference = ...`) +- Source file path and line number + +Per program: a row in the generated program-coverage page pulled from `programs.yaml` (id, name, category, agency, status, coverage). + +Per directory (`gov/hhs/chip/`, `gov/usda/snap/`, etc.): a listing page using Quarto's built-in directory listing so the nav auto-organizes. + +## What still requires hand-authored prose + +- Methodology narrative (why the model is structured this way) +- Tutorials (how to use `policyengine.py`) +- Paper content (peer-reviewable argument) +- Per-country deep dives that read as essays rather than reference lookups + +## Design + +The generator reads directly from the imported country model — no web API calls, no intermediate JSON. This keeps the build offline-reproducible and version-pinned to whatever country model the `policyengine.py` package has installed. Re-running the generator on release produces a snapshot of the reference docs tied to the exact published model versions. + +Extensions worth considering: + +1. Walk `parameters/` YAML tree and emit a page per parameter with its time series, breakdowns, and references. +2. For each variable with a formula, surface the dependency graph (other variables / parameters it reads). `policyengine_core`'s `Variable.exhaustive_parameter_dependencies` gets partway there. +3. For each calibration target (in `policyengine-us-data/storage/calibration_targets/*.csv`), emit a page describing source, aggregation level, freshness. +4. Cross-link variables to the programs they contribute to via `programs.yaml`'s `variable:` field. diff --git a/docs/_generator/build_reference.py b/docs/_generator/build_reference.py new file mode 100644 index 00000000..490420cd --- /dev/null +++ b/docs/_generator/build_reference.py @@ -0,0 +1,392 @@ +"""Generate reference documentation pages from PolicyEngine country models. + +Introspects a country model's `TaxBenefitSystem` for every variable, reads +attributes directly from each `Variable` class (`label`, `documentation`, +`entity`, `unit`, `reference`, `defined_for`, `definition_period`, +`adds`/`subtracts`, source file path), and writes one ``.qmd`` page per +variable grouped by its parameter-tree path (``gov/hhs/chip/chip_premium``). + +Also loads the country model's ``programs.yaml`` and writes a program-level +landing page for each entry, cross-linking the variables that belong to it. + +Usage +----- + +Run for a single country model, writing into an output directory: + +.. code-block:: bash + + python docs/_generator/build_reference.py \\ + --country us \\ + --out docs/_generated/reference/us + +Run for a subset of variables to preview output: + +.. code-block:: bash + + python docs/_generator/build_reference.py \\ + --country us --filter chip --out /tmp/ref-preview + +Design notes +------------ + +This is a prototype meant to demonstrate how much reference material can be +regenerated from code + parameter YAML + ``programs.yaml`` alone, with no +hand-authored prose. Intentional non-goals: + +* Do not execute formulas; read metadata only. +* Do not render parameters (a follow-up can walk the parameter tree similarly). +* Do not write an index page tree; Quarto's directory listings handle that. + +The generator emits standard Quarto Markdown (``.qmd``). Quarto reads regular +Markdown too, so the outputs drop into either a Quarto or MyST site. +""" + +from __future__ import annotations + +import argparse +import importlib +import logging +import re +import textwrap +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +import yaml + +logger = logging.getLogger(__name__) + + +COUNTRY_MODULES = { + "us": "policyengine_us", + "uk": "policyengine_uk", + "canada": "policyengine_canada", + "il": "policyengine_il", + "ng": "policyengine_ng", +} + + +@dataclass(frozen=True) +class VariableRecord: + name: str + label: str | None + documentation: str | None + entity: str | None + unit: str | None + value_type: str | None + definition_period: str | None + references: tuple[str, ...] + defined_for: str | None + source_file: Path | None + source_line: int | None + adds: tuple[str, ...] + subtracts: tuple[str, ...] + tree_path: tuple[str, ...] + + +def _tree_path_from_source( + source_file: Path | None, package_root: Path +) -> tuple[str, ...]: + if source_file is None: + return ("_ungrouped",) + try: + rel = source_file.relative_to(package_root / "variables") + except ValueError: + return ("_ungrouped",) + parts = rel.with_suffix("").parts + return parts[:-1] if parts else ("_ungrouped",) + + +def _normalize_references(raw) -> tuple[str, ...]: + if raw is None: + return () + if isinstance(raw, str): + return (raw,) + if isinstance(raw, (list, tuple)): + return tuple(str(r) for r in raw if r) + return (str(raw),) + + +def _variable_records(country: str) -> Iterable[VariableRecord]: + module_name = COUNTRY_MODULES[country] + country_module = importlib.import_module(module_name) + + system_module = importlib.import_module(f"{module_name}.system") + tbs = system_module.CountryTaxBenefitSystem() + + package_root = Path(country_module.__file__).parent + + import inspect + + for name in sorted(tbs.variables): + variable = tbs.variables[name] + try: + source_file = Path(inspect.getsourcefile(type(variable))) + source_line = inspect.getsourcelines(type(variable))[1] + except (TypeError, OSError): + source_file = None + source_line = None + + entity_key = getattr(variable.entity, "key", None) if variable.entity else None + value_type = getattr(variable, "value_type", None) + value_type_name = ( + value_type.__name__ + if isinstance(value_type, type) + else str(value_type) + if value_type is not None + else None + ) + defined_for = getattr(variable, "defined_for", None) + defined_for_name = ( + defined_for.name if hasattr(defined_for, "name") else defined_for + ) + + yield VariableRecord( + name=name, + label=variable.label, + documentation=variable.documentation, + entity=entity_key, + unit=getattr(variable, "unit", None), + value_type=value_type_name, + definition_period=getattr(variable, "definition_period", None), + references=_normalize_references(getattr(variable, "reference", None)), + defined_for=defined_for_name, + source_file=source_file, + source_line=source_line, + adds=tuple(getattr(variable, "adds", ()) or ()), + subtracts=tuple(getattr(variable, "subtracts", ()) or ()), + tree_path=_tree_path_from_source(source_file, package_root), + ) + + +def _escape_yaml_scalar(value: str) -> str: + return value.replace('"', '\\"') + + +def _render_variable_page(record: VariableRecord, country: str) -> str: + title = record.label or record.name + lines: list[str] = [ + "---", + f'title: "{_escape_yaml_scalar(title)}"', + f'subtitle: "`{record.name}`"', + ] + if record.documentation: + summary = record.documentation.strip().splitlines()[0][:220] + lines.append(f'description: "{_escape_yaml_scalar(summary)}"') + lines.extend( + [ + "format:", + " html:", + " code-copy: true", + "---", + "", + ] + ) + + metadata = [ + ("Name", f"`{record.name}`"), + ("Entity", f"`{record.entity}`" if record.entity else "—"), + ("Value type", f"`{record.value_type}`" if record.value_type else "—"), + ("Unit", f"`{record.unit}`" if record.unit else "—"), + ( + "Period", + f"`{record.definition_period}`" if record.definition_period else "—", + ), + ( + "Defined for", + f"`{record.defined_for}`" if record.defined_for else "—", + ), + ] + lines.append("| Field | Value |") + lines.append("|---|---|") + for key, value in metadata: + lines.append(f"| {key} | {value} |") + lines.append("") + + if record.documentation: + lines.append("## Documentation") + lines.append("") + lines.append(record.documentation.strip()) + lines.append("") + + if record.adds: + lines.append("## Components") + lines.append("") + lines.append("This variable sums the following variables:") + lines.append("") + for component in record.adds: + lines.append(f"- `{component}`") + lines.append("") + + if record.subtracts: + lines.append("## Subtractions") + lines.append("") + lines.append("This variable subtracts the following variables:") + lines.append("") + for component in record.subtracts: + lines.append(f"- `{component}`") + lines.append("") + + if record.references: + lines.append("## References") + lines.append("") + for ref in record.references: + lines.append(f"- <{ref}>") + lines.append("") + + if record.source_file: + try: + repo_rel = record.source_file.relative_to(record.source_file.parents[5]) + except (ValueError, IndexError): + repo_rel = record.source_file.name + lines.append("## Source") + lines.append("") + if record.source_line: + lines.append(f"`{repo_rel}`, line {record.source_line}") + else: + lines.append(f"`{repo_rel}`") + lines.append("") + + return "\n".join(lines) + + +def _slug(value: str) -> str: + return re.sub(r"[^A-Za-z0-9_-]+", "-", value).strip("-") + + +def _write_variables( + records: list[VariableRecord], + out_root: Path, + country: str, +) -> int: + written = 0 + for record in records: + tree_dir = out_root.joinpath(*record.tree_path) + tree_dir.mkdir(parents=True, exist_ok=True) + page_path = tree_dir / f"{_slug(record.name)}.qmd" + page_path.write_text(_render_variable_page(record, country)) + written += 1 + return written + + +def _write_tree_indices(out_root: Path) -> int: + written = 0 + for directory in [out_root, *(p for p in out_root.rglob("*") if p.is_dir())]: + index_path = directory / "index.qmd" + if index_path.exists(): + continue + title = directory.name if directory != out_root else "Reference" + index_path.write_text( + textwrap.dedent( + f"""\ + --- + title: "{title}" + listing: + contents: "*.qmd" + type: table + sort: "title" + fields: [title, subtitle, description] + --- + """ + ) + ) + written += 1 + return written + + +def _write_programs_index(country: str, out_root: Path) -> int: + module_name = COUNTRY_MODULES[country] + country_module = importlib.import_module(module_name) + package_root = Path(country_module.__file__).parent + programs_path = package_root / "programs.yaml" + if not programs_path.exists(): + return 0 + with programs_path.open() as f: + registry = yaml.safe_load(f) + programs = registry.get("programs", []) + lines: list[str] = [ + "---", + 'title: "Program coverage"', + 'description: "Programs modeled in the country model, generated from programs.yaml."', + "---", + "", + "| ID | Name | Category | Agency | Status | Coverage |", + "|---|---|---|---|---|---|", + ] + for program in programs: + lines.append( + "| " + + " | ".join( + str(program.get(field, "")).replace("\n", " ") + for field in ("id", "name", "category", "agency", "status", "coverage") + ) + + " |" + ) + target = out_root / "programs.qmd" + target.write_text("\n".join(lines) + "\n") + return 1 + + +def build_reference( + country: str, + out_root: Path, + filter_substring: str | None = None, +) -> dict[str, int]: + out_root.mkdir(parents=True, exist_ok=True) + records = list(_variable_records(country)) + if filter_substring: + needle = filter_substring.lower() + records = [ + r + for r in records + if needle in r.name.lower() + or needle in " ".join(str(p).lower() for p in r.tree_path) + ] + variables_written = _write_variables(records, out_root, country) + programs_written = _write_programs_index(country, out_root) + indices_written = _write_tree_indices(out_root) + return { + "variables": variables_written, + "programs": programs_written, + "indices": indices_written, + } + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--country", + choices=sorted(COUNTRY_MODULES), + default="us", + help="Country model to introspect.", + ) + parser.add_argument( + "--out", + type=Path, + required=True, + help="Output directory for generated .qmd pages.", + ) + parser.add_argument( + "--filter", + default=None, + help="Substring filter on variable name or tree path (case-insensitive).", + ) + return parser.parse_args() + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + args = _parse_args() + stats = build_reference(args.country, args.out, args.filter) + logger.info( + "Wrote %d variable pages, %d programs page, %d directory indices to %s", + stats["variables"], + stats["programs"], + stats["indices"], + args.out, + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index f09e0a04..8d0d76ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ policyengine = "policyengine.cli:main" plotting = [ "plotly>=5.0.0", ] +graph = [ + "networkx>=3.0", +] uk = [ "policyengine_core>=3.25.0", "policyengine-uk==2.88.0", diff --git a/src/policyengine/graph/__init__.py b/src/policyengine/graph/__init__.py new file mode 100644 index 00000000..84dd698c --- /dev/null +++ b/src/policyengine/graph/__init__.py @@ -0,0 +1,41 @@ +"""Variable dependency graph for PolicyEngine source trees. + +Parses ``Variable`` subclasses in a PolicyEngine jurisdiction (e.g. +``policyengine-us``, ``policyengine-uk``) and extracts the variable- +to-variable dataflow graph from formula-method bodies. + +The extractor is static: it walks the Python AST and never imports +user code, so it works on any PolicyEngine source tree without +requiring the jurisdiction to be installed or the country model to +resolve. That makes it usable for refactor-impact analysis, CI +pre-merge checks, docs generation, and code-introspection queries +from a Claude Code plugin. + +Recognized reference patterns in v1: + +- ``("", )`` — direct call on an entity instance + (``person``, ``tax_unit``, ``spm_unit``, ``household``, ``family``, + ``marital_unit``, ``benunit``). +- ``add(, , ["v1", "v2", ...])`` — sum helper; each + string in the list becomes an edge. + +Typical usage: + +.. code-block:: python + + from policyengine.graph import extract_from_path + + graph = extract_from_path("/path/to/policyengine-us/policyengine_us/variables") + # Variables that transitively depend on AGI: + for downstream in graph.impact("adjusted_gross_income"): + print(downstream) + # Direct dependencies of a variable: + print(graph.deps("earned_income_tax_credit")) + # Dependency chain from one variable to another: + print(graph.path("wages", "federal_income_tax")) +""" + +from policyengine.graph.extractor import extract_from_path +from policyengine.graph.graph import VariableGraph + +__all__ = ["VariableGraph", "extract_from_path"] diff --git a/src/policyengine/graph/extractor.py b/src/policyengine/graph/extractor.py new file mode 100644 index 00000000..39f278cb --- /dev/null +++ b/src/policyengine/graph/extractor.py @@ -0,0 +1,188 @@ +"""AST-based extractor for PolicyEngine Variable subclasses. + +Walks a directory of ``.py`` files, identifies ``Variable`` subclasses +by looking for ``class Foo(Variable):`` in the AST, and extracts +variable references from each class's ``formula*`` methods. + +The extractor never imports user code, so it works on any PolicyEngine +source tree regardless of whether the jurisdiction is installed. +This keeps refactor-impact analysis and CI pre-merge checks fast and +dependency-free. + +Two reference patterns are recognized: + +1. ``("", )`` where ```` is a bare ``Name`` + matching one of: + ``person``, ``tax_unit``, ``spm_unit``, ``household``, ``family``, + ``marital_unit``, ``benunit``, ``tax_unit``. +2. ``add(, , [])`` — the + ``add`` helper that sums a list of variable names on an entity. + +Limitations of the v1 extractor (tracked for v2): + +- Parameter references (``parameters(period).gov.xxx.yyy``) are not + yet captured; only variable-to-variable edges. +- Dynamic variable names built via string concatenation or format + strings are skipped (low-prevalence in practice). +- ``entity.sum("var")`` or ``entity.mean("var")`` method calls are + not yet recognized; only the direct-call form. (Low-prevalence + in ``policyengine-us``; common enough to add as a small follow-up.) +""" + +from __future__ import annotations + +import ast +import os +from pathlib import Path +from typing import Iterable, Iterator, Union + +from policyengine.graph.graph import VariableGraph + +# Names of entity instances as they appear as method parameters in +# Variable formulas. Any ``Call`` whose ``func`` is a bare ``Name`` +# matching one of these AND whose first arg is a string literal is +# treated as a variable reference. Bare names (not attribute access) +# ensures we don't accidentally match something like +# ``reform.person("x", period)``. +_ENTITY_CALL_NAMES: frozenset[str] = frozenset( + { + "person", + "tax_unit", + "spm_unit", + "household", + "family", + "marital_unit", + "benunit", + } +) + + +PathLike = Union[str, "os.PathLike[str]"] + + +def extract_from_path(path: PathLike) -> VariableGraph: + """Build a ``VariableGraph`` from all ``.py`` files under ``path``. + + Directories are walked recursively. Files that fail to parse as + Python (syntax errors) are silently skipped — the extractor is a + best-effort tool over real source trees, not a compiler. + """ + root = Path(path) + graph = VariableGraph() + + files: Iterable[Path] + if root.is_file(): + files = [root] + else: + files = root.rglob("*.py") + + for file_path in files: + try: + source = file_path.read_text() + except (OSError, UnicodeDecodeError): + continue + try: + tree = ast.parse(source, filename=str(file_path)) + except SyntaxError: + continue + _visit_module(tree, file_path=str(file_path), graph=graph) + + return graph + + +# ------------------------------------------------------------------- +# AST traversal +# ------------------------------------------------------------------- + + +def _visit_module(tree: ast.Module, *, file_path: str, graph: VariableGraph) -> None: + """Register each Variable subclass and walk its formula methods.""" + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + if not _class_inherits_variable(node): + continue + var_name = node.name + graph.add_variable(var_name, file_path=file_path) + for child in node.body: + if isinstance(child, ast.FunctionDef) and _is_formula_method(child): + for dependency in _extract_references(child): + graph.add_edge(dependency=dependency, dependent=var_name) + + +def _class_inherits_variable(cls: ast.ClassDef) -> bool: + """True iff the class's base list contains a ``Variable`` name. + + Matches ``class X(Variable):``. Does not resolve aliased imports + — PolicyEngine's ``from policyengine_us.model_api import *`` + convention keeps the base name literally ``Variable``, which is + what real jurisdictions use and what this check matches. + """ + for base in cls.bases: + if isinstance(base, ast.Name) and base.id == "Variable": + return True + return False + + +def _is_formula_method(func: ast.FunctionDef) -> bool: + """Return True for ``formula`` and ``formula_YYYY`` methods.""" + return func.name == "formula" or func.name.startswith("formula_") + + +# ------------------------------------------------------------------- +# Reference extraction from a formula body +# ------------------------------------------------------------------- + + +def _extract_references(func: ast.FunctionDef) -> Iterator[str]: + """Yield every variable name referenced in the function body.""" + for node in ast.walk(func): + if not isinstance(node, ast.Call): + continue + # Pattern 1: ("", ) + entity_ref = _entity_call_to_variable(node) + if entity_ref is not None: + yield entity_ref + continue + # Pattern 2: add(, , ["v1", "v2", ...]) + yield from _add_call_to_variables(node) + + +def _entity_call_to_variable(call: ast.Call) -> str | None: + """Return the variable name if ``call`` is an entity-call pattern. + + The entity has to be a bare Name (not an attribute access), so + calls like ``some.object.person("x", period)`` are deliberately + not matched. First positional arg must be a string literal. + """ + if not isinstance(call.func, ast.Name): + return None + if call.func.id not in _ENTITY_CALL_NAMES: + return None + if not call.args: + return None + first = call.args[0] + if isinstance(first, ast.Constant) and isinstance(first.value, str): + return first.value + return None + + +def _add_call_to_variables(call: ast.Call) -> Iterator[str]: + """Yield variable names from an ``add(entity, period, [list])`` call. + + Matches the common helper. The third positional arg must be a + ``list`` literal of string literals. Anything dynamically built + is skipped. + """ + if not isinstance(call.func, ast.Name): + return + if call.func.id not in {"add", "aggr"}: + return + if len(call.args) < 3: + return + names_arg = call.args[2] + if not isinstance(names_arg, (ast.List, ast.Tuple)): + return + for elt in names_arg.elts: + if isinstance(elt, ast.Constant) and isinstance(elt.value, str): + yield elt.value diff --git a/src/policyengine/graph/graph.py b/src/policyengine/graph/graph.py new file mode 100644 index 00000000..2f5d516e --- /dev/null +++ b/src/policyengine/graph/graph.py @@ -0,0 +1,130 @@ +"""NetworkX-backed variable dependency graph. + +Separated from the extractor so the data structure is easy to test +independently, easy to serialize/deserialize, and easy to enrich with +additional edge types (parameter reads, cross-jurisdiction links) in +later versions. +""" + +from __future__ import annotations + +from typing import Iterable, Optional + +try: + import networkx as nx +except ImportError as exc: # pragma: no cover - trivial guard + raise ImportError( + "policyengine.graph requires networkx. " + "Install the optional extra: pip install 'policyengine[graph]'." + ) from exc + + +class VariableGraph: + """Directed graph of PolicyEngine variable dependencies. + + Nodes are variable names (strings). Edges run from a *dependency* + to a *dependent*: ``A -> B`` means "computing B reads A". With + this orientation, ``impact(A)`` is the set of downstream nodes + reachable from A, and ``deps(B)`` is the set of upstream nodes + that reach into B. + + The constructor accepts an optional pre-built graph for testing + and deserialization; normal callers will get instances via the + extractor. + """ + + def __init__(self, digraph: Optional[nx.DiGraph] = None) -> None: + self._g = digraph if digraph is not None else nx.DiGraph() + + # ------------------------------------------------------------------ + # Construction helpers (used by the extractor) + # ------------------------------------------------------------------ + + def add_variable(self, name: str, file_path: Optional[str] = None) -> None: + """Register a variable as a node. Safe to call repeatedly.""" + if name in self._g: + if file_path and "file_path" not in self._g.nodes[name]: + self._g.nodes[name]["file_path"] = file_path + return + self._g.add_node(name, file_path=file_path) + + def add_edge(self, dependency: str, dependent: str) -> None: + """Record that ``dependent`` reads ``dependency`` in a formula.""" + # Auto-register the dependency node if it wasn't declared yet; + # this is common when a formula references a variable defined + # in a file the extractor hasn't reached yet, or a variable + # whose class lives in a different subpackage. + if dependency not in self._g: + self._g.add_node(dependency, file_path=None) + if dependent not in self._g: + self._g.add_node(dependent, file_path=None) + self._g.add_edge(dependency, dependent) + + # ------------------------------------------------------------------ + # Query surface + # ------------------------------------------------------------------ + + def has_variable(self, name: str) -> bool: + """True iff ``name`` was registered as an explicitly-defined variable. + + Nodes that only exist because some formula *references* them — + but whose class definition was never seen — are excluded. + """ + if name not in self._g: + return False + return self._g.nodes[name].get("file_path") is not None + + def deps(self, name: str) -> Iterable[str]: + """Return variables that ``name``'s formula reads directly. + + Order follows networkx's insertion order, so the caller can + expect a deterministic sequence for a given extraction run. + """ + if name not in self._g: + return iter(()) + return list(self._g.predecessors(name)) + + def impact(self, name: str) -> Iterable[str]: + """Return variables that transitively depend on ``name``. + + Equivalent to the descendants set in the graph's natural + orientation (edges run dep → dependent). Excludes ``name`` + itself. Empty for leaf variables that nothing reads. + """ + if name not in self._g: + return iter(()) + return list(nx.descendants(self._g, name)) + + def path(self, src: str, dst: str) -> Optional[list[str]]: + """Return a shortest dependency chain from ``src`` to ``dst``. + + Returns the node list including both endpoints, or ``None`` if + no such path exists. + """ + if src not in self._g or dst not in self._g: + return None + try: + return nx.shortest_path(self._g, src, dst) + except nx.NetworkXNoPath: + return None + + # ------------------------------------------------------------------ + # Introspection for callers that want the raw structure + # ------------------------------------------------------------------ + + @property + def nx_graph(self) -> nx.DiGraph: + """The underlying NetworkX DiGraph (read-only-by-convention).""" + return self._g + + def __contains__(self, name: str) -> bool: + return name in self._g + + def __len__(self) -> int: + return self._g.number_of_nodes() + + def __repr__(self) -> str: + return ( + f"VariableGraph({self._g.number_of_nodes()} variables, " + f"{self._g.number_of_edges()} edges)" + ) diff --git a/tests/test_graph/__init__.py b/tests/test_graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_graph/conftest.py b/tests/test_graph/conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_graph/test_extractor.py b/tests/test_graph/test_extractor.py new file mode 100644 index 00000000..91e2a840 --- /dev/null +++ b/tests/test_graph/test_extractor.py @@ -0,0 +1,312 @@ +"""Tests for the variable-graph extractor. + +The extractor walks PolicyEngine-style Variable source trees and +builds a dependency graph from formula-body references. Two reference +patterns are recognized in MVP: + +1. ``("", )`` — direct call on an entity instance + inside a formula method. ```` matches a known set: + ``person``, ``tax_unit``, ``spm_unit``, ``household``, ``family``, + ``marital_unit``, ``benunit``. +2. ``add(, , ["v1", "v2"])`` — helper that sums a list + of variable values. Each string in the list is extracted. + +Tests run against a self-contained fixture tree under the test file's +own tmp directory — no dependency on an installed country model — so +behavior is deterministic and the tests pin the extraction algorithm +rather than PolicyEngine's evolving source. +""" + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from textwrap import dedent +from types import ModuleType + + +# ``policyengine/__init__.py`` eagerly imports the full country-model +# stack (policyengine-us, policyengine-uk), which makes a normal +# ``from policyengine.graph import ...`` fail in any environment +# where those jurisdictions aren't fully provisioned (missing release +# manifests, unresolved optional deps, etc.). The graph module is +# self-contained (stdlib + networkx only); load it via importlib +# directly so these tests remain environment-agnostic. +def _load_graph_module() -> ModuleType: + if "policyengine.graph" in sys.modules and hasattr( + sys.modules["policyengine.graph"], "extract_from_path" + ): + return sys.modules["policyengine.graph"] + + graph_dir = Path(__file__).resolve().parents[2] / "src" / "policyengine" / "graph" + + if "policyengine" not in sys.modules: + fake_pkg = ModuleType("policyengine") + fake_pkg.__path__ = [str(graph_dir.parent)] + sys.modules["policyengine"] = fake_pkg + if "policyengine.graph" not in sys.modules or not hasattr( + sys.modules["policyengine.graph"], "__path__" + ): + fake_subpkg = ModuleType("policyengine.graph") + fake_subpkg.__path__ = [str(graph_dir)] + sys.modules["policyengine.graph"] = fake_subpkg + + for submod, filename in [ + ("policyengine.graph.graph", "graph.py"), + ("policyengine.graph.extractor", "extractor.py"), + ]: + if submod in sys.modules: + continue + spec = importlib.util.spec_from_file_location(submod, graph_dir / filename) + module = importlib.util.module_from_spec(spec) + sys.modules[submod] = module + spec.loader.exec_module(module) # type: ignore[union-attr] + + graph_mod = sys.modules["policyengine.graph"] + graph_mod.extract_from_path = sys.modules[ + "policyengine.graph.extractor" + ].extract_from_path + graph_mod.VariableGraph = sys.modules["policyengine.graph.graph"].VariableGraph + return graph_mod + + +_graph = _load_graph_module() +extract_from_path = _graph.extract_from_path +VariableGraph = _graph.VariableGraph + + +def _write_variable( + root: Path, var_name: str, formula_body: str, entity: str = "tax_unit" +) -> None: + """Write a Variable subclass file mimicking policyengine-us style.""" + root.mkdir(parents=True, exist_ok=True) + (root / f"{var_name}.py").write_text( + dedent(f'''\ + from policyengine_us.model_api import * + + + class {var_name}(Variable): + value_type = float + entity = TaxUnit + label = "{var_name.replace("_", " ").title()}" + definition_period = YEAR + + def formula({entity}, period, parameters): + {formula_body} + ''') + ) + + +class TestDirectEntityReference: + """Pattern 1: ``entity("", period)`` produces an edge.""" + + def test_single_direct_reference(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable( + root, + "adjusted_gross_income", + 'return tax_unit("gross_income", period) - tax_unit("above_the_line_deductions", period)', + ) + _write_variable(root, "gross_income", "return 0") + _write_variable(root, "above_the_line_deductions", "return 0") + + graph = extract_from_path(root) + + assert graph.has_variable("adjusted_gross_income") + deps = set(graph.deps("adjusted_gross_income")) + assert deps == {"gross_income", "above_the_line_deductions"} + + def test_nonmatching_string_is_ignored(self, tmp_path: Path) -> None: + """String literals unrelated to an entity call are ignored. + + Only a string as the first arg of a matching + ``("", period)`` call becomes an edge; string + literals used as argument to ``print`` or bound to a local + name are not misinterpreted as variable references. + """ + root = tmp_path / "variables" + root.mkdir(parents=True, exist_ok=True) + (root / "refundable_credit.py").write_text( + dedent("""\ + from policyengine_us.model_api import * + + + class refundable_credit(Variable): + value_type = float + entity = TaxUnit + label = "Refundable credit" + definition_period = YEAR + + def formula(tax_unit, period, parameters): + note = "not a variable reference" + return tax_unit("gross_income", period) + """) + ) + _write_variable(root, "gross_income", "return 0") + graph = extract_from_path(root) + assert set(graph.deps("refundable_credit")) == {"gross_income"} + + +class TestAddHelperReference: + """Pattern 2: ``add(entity, period, [...])`` emits one edge per list item.""" + + def test_add_helper_list(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable( + root, + "total_income", + 'return add(tax_unit, period, ["wages", "self_employment_income", "interest"])', + ) + _write_variable(root, "wages", "return 0") + _write_variable(root, "self_employment_income", "return 0") + _write_variable(root, "interest", "return 0") + graph = extract_from_path(root) + assert set(graph.deps("total_income")) == { + "wages", + "self_employment_income", + "interest", + } + + +class TestImpactAnalysis: + """``impact(var)`` returns variables that depend on ``var`` transitively.""" + + def test_transitive_upstream(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable(root, "wages", "return 0") + _write_variable( + root, + "gross_income", + 'return add(tax_unit, period, ["wages"])', + ) + _write_variable( + root, + "adjusted_gross_income", + 'return tax_unit("gross_income", period)', + ) + _write_variable( + root, + "taxable_income", + 'return tax_unit("adjusted_gross_income", period)', + ) + _write_variable( + root, + "federal_income_tax", + 'return tax_unit("taxable_income", period)', + ) + graph = extract_from_path(root) + + # wages is read by gross_income → adjusted_gross_income → + # taxable_income → federal_income_tax (depth 4). + impact = set(graph.impact("wages")) + assert impact == { + "gross_income", + "adjusted_gross_income", + "taxable_income", + "federal_income_tax", + } + + def test_leaf_variable_has_empty_impact(self, tmp_path: Path) -> None: + """A variable that nothing reads has an empty impact set.""" + + root = tmp_path / "variables" + _write_variable( + root, + "federal_income_tax", + 'return tax_unit("adjusted_gross_income", period)', + ) + _write_variable(root, "adjusted_gross_income", "return 0") + graph = extract_from_path(root) + assert list(graph.impact("federal_income_tax")) == [] + + +class TestMultipleFormulas: + """Year-specific ``formula_YYYY`` methods contribute edges too.""" + + def test_year_specific_formula_contributes_edges(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + (root / "ctc.py").parent.mkdir(parents=True, exist_ok=True) + (root / "ctc.py").write_text( + dedent("""\ + from policyengine_us.model_api import * + + + class ctc(Variable): + value_type = float + entity = TaxUnit + label = "Child Tax Credit" + definition_period = YEAR + + def formula_2020(tax_unit, period, parameters): + return tax_unit("ctc_base_2020", period) + + def formula_2023(tax_unit, period, parameters): + return tax_unit("ctc_base_2023", period) + """) + ) + _write_variable(root, "ctc_base_2020", "return 0") + _write_variable(root, "ctc_base_2023", "return 0") + + graph = extract_from_path(root) + assert set(graph.deps("ctc")) == {"ctc_base_2020", "ctc_base_2023"} + + +class TestPath: + """``path(src, dst)`` returns a dependency chain if one exists.""" + + def test_path_two_hops(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable(root, "wages", "return 0") + _write_variable(root, "gross_income", 'return tax_unit("wages", period)') + _write_variable( + root, + "adjusted_gross_income", + 'return tax_unit("gross_income", period)', + ) + + graph = extract_from_path(root) + assert graph.path("wages", "adjusted_gross_income") == [ + "wages", + "gross_income", + "adjusted_gross_income", + ] + + def test_path_returns_none_if_unreachable(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + _write_variable(root, "island_a", "return 0") + _write_variable(root, "island_b", "return 0") + graph = extract_from_path(root) + assert graph.path("island_a", "island_b") is None + + +class TestRequiresVariableSubclass: + """Only classes whose base class list contains ``Variable`` are scanned. + + Helper modules (model_api, utils) should not be mistaken for + Variable definitions even if they have method bodies that call + entity-style functions. + """ + + def test_non_variable_classes_are_ignored(self, tmp_path: Path) -> None: + + root = tmp_path / "variables" + root.mkdir(parents=True, exist_ok=True) + # Looks like a variable body but the class is not a Variable. + (root / "helper.py").write_text( + dedent("""\ + class NotAVariable: + def some_method(tax_unit, period, parameters): + return tax_unit("some_variable", period) + """) + ) + graph = extract_from_path(root) + assert not graph.has_variable("NotAVariable") + # And no edge to "some_variable" should exist from a phantom source. + assert list(graph.impact("some_variable")) == []