diff --git a/changelog.d/v4-base-extraction.changed.md b/changelog.d/v4-base-extraction.changed.md new file mode 100644 index 00000000..572088a3 --- /dev/null +++ b/changelog.d/v4-base-extraction.changed.md @@ -0,0 +1 @@ +Extracted shared `MicrosimulationModelVersion` base class in `policyengine.tax_benefit_models.common`. Country subclasses now declare class-level metadata (`country_code`, `package_name`, `group_entities`) and implement a handful of thin hooks; `run()` stays per-country. Byte-level snapshot tests verify zero output drift. diff --git a/src/policyengine/tax_benefit_models/common/__init__.py b/src/policyengine/tax_benefit_models/common/__init__.py index 38c8a6e1..6f6efa25 100644 --- a/src/policyengine/tax_benefit_models/common/__init__.py +++ b/src/policyengine/tax_benefit_models/common/__init__.py @@ -6,6 +6,9 @@ """ from .extra_variables import dispatch_extra_variables as dispatch_extra_variables +from .model_version import ( + MicrosimulationModelVersion as MicrosimulationModelVersion, +) from .reform import compile_reform as compile_reform from .result import EntityResult as EntityResult from .result import HouseholdResult as HouseholdResult diff --git a/src/policyengine/tax_benefit_models/common/model_version.py b/src/policyengine/tax_benefit_models/common/model_version.py new file mode 100644 index 00000000..1cbe3df4 --- /dev/null +++ b/src/policyengine/tax_benefit_models/common/model_version.py @@ -0,0 +1,257 @@ +"""Base class for country ``TaxBenefitModelVersion`` implementations. + +The US and UK model-version classes share roughly 300 lines of loading logic: +manifest certification, the variable-copy loop over the country ``system``, +the parameter-copy loop, entity-relationship construction, and simple +``save`` / ``load`` passthroughs. Only ``run`` (and the country-specific +``managed_microsimulation`` helper) diverge enough to warrant per-country +implementations. + +This module extracts the shared behaviour into ``MicrosimulationModelVersion``. +Country subclasses declare class-level metadata (``country_code``, +``package_name``, ``group_entities``, ``entity_variables``) and override a +handful of thin hooks (``_load_system``, ``_load_region_registry``, +``_dataset_class``, ``run``). +""" + +from __future__ import annotations + +import datetime +import os +import warnings +from importlib import metadata +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Optional + +import pandas as pd + +from policyengine.core import ( + Parameter, + ParameterNode, + TaxBenefitModelVersion, + Variable, +) +from policyengine.provenance.manifest import ( + certify_data_release_compatibility, + get_release_manifest, +) +from policyengine.utils.entity_utils import build_entity_relationships +from policyengine.utils.parameter_labels import ( + build_scale_lookup, + generate_label_for_parameter, +) + +if TYPE_CHECKING: + from policyengine.core.simulation import Simulation + + +class MicrosimulationModelVersion(TaxBenefitModelVersion): + """Shared init / save / load logic for country microsim model versions. + + Subclasses must set the four class attributes below and implement the + country-specific hooks. ``run`` is intentionally left abstract: its + country-specific logic (reform application, simulation builder, output + post-processing) varies enough that a shared skeleton would hide real + divergences. + """ + + # --- Subclass metadata ------------------------------------------------- + country_code: ClassVar[str] = "" + """ISO-ish country identifier used by the release manifest ("us"/"uk").""" + + package_name: ClassVar[str] = "" + """Distribution name used with ``importlib.metadata.version``.""" + + group_entities: ClassVar[list[str]] = [] + """Group entities (non-person) for this country, in dataset order.""" + + entity_variables: dict[str, list[str]] = {} + """Variables to materialise per entity when writing output datasets.""" + + # --- Construction ------------------------------------------------------ + def __init__(self, **kwargs: Any) -> None: + if not self.country_code or not self.package_name: + raise RuntimeError( + f"{type(self).__name__} must declare country_code and " + "package_name class attributes" + ) + + manifest = get_release_manifest(self.country_code) + if kwargs.get("version") is None: + kwargs["version"] = manifest.model_package.version + + installed_model_version = metadata.version(self.package_name) + if installed_model_version != manifest.model_package.version: + warnings.warn( + f"Installed {self.package_name} version " + f"({installed_model_version}) does not match the bundled " + "policyengine.py manifest " + f"({manifest.model_package.version}). Calculations will " + "run against the installed version, but dataset " + "compatibility is not guaranteed. To silence this " + "warning, install the version pinned by the manifest.", + UserWarning, + stacklevel=2, + ) + + model_build_metadata = self._get_runtime_data_build_metadata() + data_certification = certify_data_release_compatibility( + self.country_code, + runtime_model_version=installed_model_version, + runtime_data_build_fingerprint=model_build_metadata.get( + "data_build_fingerprint" + ), + ) + + super().__init__(**kwargs) + self.release_manifest = manifest + self.model_package = manifest.model_package + self.data_package = manifest.data_package + self.default_dataset_uri = manifest.default_dataset_uri + self.data_certification = data_certification + self.region_registry = self._load_region_registry() + self.id = f"{self.model.id}@{self.version}" + + system = self._load_system() + self._populate_variables(system) + self._populate_parameters(system) + + # --- Hooks ------------------------------------------------------------ + @classmethod + def _get_runtime_data_build_metadata(cls) -> dict[str, Optional[str]]: + """Return build metadata from the country package, if available.""" + raise NotImplementedError + + def _load_system(self): + """Return the country package's ``system`` object.""" + raise NotImplementedError + + def _load_region_registry(self): + """Return the country's ``RegionRegistry``.""" + raise NotImplementedError + + @property + def _dataset_class(self): + """Return the country's ``PolicyEngine{Country}Dataset`` class.""" + raise NotImplementedError + + # --- Shared loading helpers ------------------------------------------ + def _populate_variables(self, system) -> None: + from policyengine_core.enums import Enum + from policyengine_core.parameters.operations.get_parameter import ( + get_parameter, + ) + + for var_obj in system.variables.values(): + default_val = var_obj.default_value + if var_obj.value_type is Enum: + default_val = default_val.name + elif var_obj.value_type is datetime.date: + default_val = default_val.isoformat() + + variable = Variable( + id=self.id + "-" + var_obj.name, + name=var_obj.name, + label=getattr(var_obj, "label", None), + tax_benefit_model_version=self, + entity=var_obj.entity.key, + description=var_obj.documentation, + data_type=( + var_obj.value_type if var_obj.value_type is not Enum else str + ), + default_value=default_val, + value_type=var_obj.value_type, + ) + if ( + hasattr(var_obj, "possible_values") + and var_obj.possible_values is not None + ): + variable.possible_values = list( + map( + lambda x: x.name, + var_obj.possible_values._value2member_map_.values(), + ) + ) + # Resolve parameter-path adds/subtracts to concrete lists so + # consumers always see list[str]. + for attr in ("adds", "subtracts"): + value = getattr(var_obj, attr, None) + if value is None: + continue + if isinstance(value, str): + try: + param = get_parameter(system.parameters, value) + setattr(variable, attr, list(param("2025-01-01"))) + except (ValueError, Exception): + setattr(variable, attr, None) + else: + setattr(variable, attr, value) + self.add_variable(variable) + + def _populate_parameters(self, system) -> None: + from policyengine_core.parameters import Parameter as CoreParameter + from policyengine_core.parameters import ParameterNode as CoreParameterNode + + scale_lookup = build_scale_lookup(system) + + for param_node in system.parameters.get_descendants(): + if isinstance(param_node, CoreParameter): + parameter = Parameter( + id=self.id + "-" + param_node.name, + name=param_node.name, + label=generate_label_for_parameter( + param_node, system, scale_lookup + ), + tax_benefit_model_version=self, + description=param_node.description, + data_type=type(param_node(2025)), + unit=param_node.metadata.get("unit"), + _core_param=param_node, + ) + self.add_parameter(parameter) + elif isinstance(param_node, CoreParameterNode): + node = ParameterNode( + id=self.id + "-" + param_node.name, + name=param_node.name, + label=param_node.metadata.get("label"), + description=param_node.description, + tax_benefit_model_version=self, + ) + self.add_parameter_node(node) + + # --- Shared run-surface helpers -------------------------------------- + def _build_entity_relationships(self, dataset) -> pd.DataFrame: + """Build a DataFrame mapping each person to their containing entities.""" + person_data = pd.DataFrame(dataset.data.person) + return build_entity_relationships(person_data, self.group_entities) + + def save(self, simulation: Simulation) -> None: + """Persist the simulation's output dataset to its bundled filepath.""" + simulation.output_dataset.save() + + def load(self, simulation: Simulation) -> None: + """Rehydrate the simulation's output dataset from disk. + + Loads timestamps from filesystem metadata when the file exists so + serialised simulations round-trip ``created_at``/``updated_at``. + """ + filepath = str( + Path(simulation.dataset.filepath).parent / (simulation.id + ".h5") + ) + + simulation.output_dataset = self._dataset_class( + id=simulation.id, + name=simulation.dataset.name, + description=simulation.dataset.description, + filepath=filepath, + year=simulation.dataset.year, + is_output_dataset=True, + ) + + if os.path.exists(filepath): + simulation.created_at = datetime.datetime.fromtimestamp( + os.path.getctime(filepath) + ) + simulation.updated_at = datetime.datetime.fromtimestamp( + os.path.getmtime(filepath) + ) diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index c03ed90d..67e7a3ae 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -1,31 +1,17 @@ import datetime -import warnings -from importlib import metadata from pathlib import Path from typing import TYPE_CHECKING, Optional import pandas as pd from microdf import MicroDataFrame -from policyengine.core import ( - Parameter, - ParameterNode, - TaxBenefitModel, - TaxBenefitModelVersion, - Variable, -) +from policyengine.core import TaxBenefitModel from policyengine.provenance.manifest import ( - certify_data_release_compatibility, dataset_logical_name, - get_release_manifest, resolve_local_managed_dataset_source, resolve_managed_dataset_reference, ) -from policyengine.utils.entity_utils import build_entity_relationships -from policyengine.utils.parameter_labels import ( - build_scale_lookup, - generate_label_for_parameter, -) +from policyengine.tax_benefit_models.common import MicrosimulationModelVersion from .datasets import PolicyEngineUKDataset, UKYearData @@ -43,18 +29,11 @@ class PolicyEngineUK(TaxBenefitModel): uk_model = PolicyEngineUK() -def _get_runtime_data_build_metadata() -> dict[str, Optional[str]]: - try: - from policyengine_uk.build_metadata import get_data_build_metadata - except ModuleNotFoundError as exc: - if exc.name != "policyengine_uk.build_metadata": - raise - return {} - - return get_data_build_metadata() or {} - +class PolicyEngineUKLatest(MicrosimulationModelVersion): + country_code = "uk" + package_name = "policyengine-uk" + group_entities = UK_GROUP_ENTITIES -class PolicyEngineUKLatest(TaxBenefitModelVersion): model: TaxBenefitModel = uk_model version: str = None created_at: datetime.datetime = None @@ -137,147 +116,32 @@ class PolicyEngineUKLatest(TaxBenefitModelVersion): ], } - def __init__(self, **kwargs: dict): - manifest = get_release_manifest("uk") - if "version" not in kwargs or kwargs.get("version") is None: - kwargs["version"] = manifest.model_package.version - - installed_model_version = metadata.version("policyengine-uk") - if installed_model_version != manifest.model_package.version: - warnings.warn( - "Installed policyengine-uk version " - f"({installed_model_version}) does not match the bundled " - "policyengine.py manifest " - f"({manifest.model_package.version}). Calculations will " - "run against the installed version, but dataset " - "compatibility is not guaranteed. To silence this " - "warning, install the version pinned by the manifest.", - UserWarning, - stacklevel=2, - ) - - model_build_metadata = _get_runtime_data_build_metadata() - data_certification = certify_data_release_compatibility( - "uk", - runtime_model_version=installed_model_version, - runtime_data_build_fingerprint=model_build_metadata.get( - "data_build_fingerprint" - ), - ) - - super().__init__(**kwargs) - self.release_manifest = manifest - self.model_package = manifest.model_package - self.data_package = manifest.data_package - self.default_dataset_uri = manifest.default_dataset_uri - self.data_certification = data_certification - from policyengine_core.enums import Enum + # --- Hooks ----------------------------------------------------------- + @classmethod + def _get_runtime_data_build_metadata(cls) -> dict[str, Optional[str]]: + try: + from policyengine_uk.build_metadata import get_data_build_metadata + except ModuleNotFoundError as exc: + if exc.name != "policyengine_uk.build_metadata": + raise + return {} + return get_data_build_metadata() or {} + + def _load_system(self): from policyengine_uk.system import system - # Attach region registry + return system + + def _load_region_registry(self): from policyengine.countries.uk.regions import uk_region_registry - self.region_registry = uk_region_registry - - self.id = f"{self.model.id}@{self.version}" - - for var_obj in system.variables.values(): - # Serialize default_value for JSON compatibility - default_val = var_obj.default_value - if var_obj.value_type is Enum: - default_val = default_val.name - elif var_obj.value_type is datetime.date: - default_val = default_val.isoformat() - - variable = Variable( - id=self.id + "-" + var_obj.name, - name=var_obj.name, - label=getattr(var_obj, "label", None), - tax_benefit_model_version=self, - entity=var_obj.entity.key, - description=var_obj.documentation, - data_type=var_obj.value_type if var_obj.value_type is not Enum else str, - default_value=default_val, - value_type=var_obj.value_type, - ) - if ( - hasattr(var_obj, "possible_values") - and var_obj.possible_values is not None - ): - variable.possible_values = list( - map( - lambda x: x.name, - var_obj.possible_values._value2member_map_.values(), - ) - ) - # Extract and resolve adds/subtracts. - # Core stores these as either list[str] or a parameter path string. - # Resolve parameter paths to lists so consumers always get list[str]. - if hasattr(var_obj, "adds") and var_obj.adds is not None: - if isinstance(var_obj.adds, str): - try: - from policyengine_core.parameters.operations.get_parameter import ( - get_parameter, - ) - - param = get_parameter(system.parameters, var_obj.adds) - variable.adds = list(param("2025-01-01")) - except (ValueError, Exception): - variable.adds = None - else: - variable.adds = var_obj.adds - if hasattr(var_obj, "subtracts") and var_obj.subtracts is not None: - if isinstance(var_obj.subtracts, str): - try: - from policyengine_core.parameters.operations.get_parameter import ( - get_parameter, - ) - - param = get_parameter(system.parameters, var_obj.subtracts) - variable.subtracts = list(param("2025-01-01")) - except (ValueError, Exception): - variable.subtracts = None - else: - variable.subtracts = var_obj.subtracts - self.add_variable(variable) - - from policyengine_core.parameters import Parameter as CoreParameter - from policyengine_core.parameters import ParameterNode as CoreParameterNode - - scale_lookup = build_scale_lookup(system) - - for param_node in system.parameters.get_descendants(): - if isinstance(param_node, CoreParameter): - parameter = Parameter( - id=self.id + "-" + param_node.name, - name=param_node.name, - label=generate_label_for_parameter( - param_node, system, scale_lookup - ), - tax_benefit_model_version=self, - description=param_node.description, - data_type=type(param_node(2025)), - unit=param_node.metadata.get("unit"), - _core_param=param_node, - ) - self.add_parameter(parameter) - elif isinstance(param_node, CoreParameterNode): - node = ParameterNode( - id=self.id + "-" + param_node.name, - name=param_node.name, - label=param_node.metadata.get("label"), - description=param_node.description, - tax_benefit_model_version=self, - ) - self.add_parameter_node(node) - - def _build_entity_relationships( - self, dataset: PolicyEngineUKDataset - ) -> pd.DataFrame: - """Build a DataFrame mapping each person to their containing entities.""" - person_data = pd.DataFrame(dataset.data.person) - return build_entity_relationships(person_data, UK_GROUP_ENTITIES) + return uk_region_registry + + @property + def _dataset_class(self): + return PolicyEngineUKDataset + # --- run ------------------------------------------------------------- def run(self, simulation: "Simulation") -> "Simulation": from policyengine_uk import Microsimulation from policyengine_uk.data import UKSingleYearDataset @@ -370,36 +234,6 @@ def run(self, simulation: "Simulation") -> "Simulation": ), ) - def save(self, simulation: "Simulation"): - """Save the simulation's output dataset.""" - simulation.output_dataset.save() - - def load(self, simulation: "Simulation"): - """Load the simulation's output dataset.""" - import os - - filepath = str( - Path(simulation.dataset.filepath).parent / (simulation.id + ".h5") - ) - - simulation.output_dataset = PolicyEngineUKDataset( - id=simulation.id, - name=simulation.dataset.name, - description=simulation.dataset.description, - filepath=filepath, - year=simulation.dataset.year, - is_output_dataset=True, - ) - - # Load timestamps from file system metadata - if os.path.exists(filepath): - simulation.created_at = datetime.datetime.fromtimestamp( - os.path.getctime(filepath) - ) - simulation.updated_at = datetime.datetime.fromtimestamp( - os.path.getmtime(filepath) - ) - def _managed_release_bundle( dataset_uri: str, @@ -423,8 +257,8 @@ def managed_microsimulation( """Construct a country-package Microsimulation pinned to this bundle. By default this enforces the dataset selection from the bundled - `policyengine.py` release manifest. Arbitrary dataset URIs require - `allow_unmanaged=True`. + ``policyengine.py`` release manifest. Arbitrary dataset URIs require + ``allow_unmanaged=True``. """ from policyengine_uk import Microsimulation diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index ec3dd9e6..51463650 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -1,31 +1,17 @@ import datetime -import warnings -from importlib import metadata from pathlib import Path from typing import TYPE_CHECKING, Optional import pandas as pd from microdf import MicroDataFrame -from policyengine.core import ( - Parameter, - ParameterNode, - TaxBenefitModel, - TaxBenefitModelVersion, - Variable, -) +from policyengine.core import TaxBenefitModel from policyengine.provenance.manifest import ( - certify_data_release_compatibility, dataset_logical_name, - get_release_manifest, resolve_local_managed_dataset_source, resolve_managed_dataset_reference, ) -from policyengine.utils.entity_utils import build_entity_relationships -from policyengine.utils.parameter_labels import ( - build_scale_lookup, - generate_label_for_parameter, -) +from policyengine.tax_benefit_models.common import MicrosimulationModelVersion from .datasets import PolicyEngineUSDataset, USYearData @@ -49,18 +35,11 @@ class PolicyEngineUS(TaxBenefitModel): us_model = PolicyEngineUS() -def _get_runtime_data_build_metadata() -> dict[str, Optional[str]]: - try: - from policyengine_us.build_metadata import get_data_build_metadata - except ModuleNotFoundError as exc: - if exc.name != "policyengine_us.build_metadata": - raise - return {} - - return get_data_build_metadata() or {} - +class PolicyEngineUSLatest(MicrosimulationModelVersion): + country_code = "us" + package_name = "policyengine-us" + group_entities = US_GROUP_ENTITIES -class PolicyEngineUSLatest(TaxBenefitModelVersion): model: TaxBenefitModel = us_model version: str = None created_at: datetime.datetime = None @@ -129,147 +108,32 @@ class PolicyEngineUSLatest(TaxBenefitModelVersion): ], } - def __init__(self, **kwargs: dict): - manifest = get_release_manifest("us") - if "version" not in kwargs or kwargs.get("version") is None: - kwargs["version"] = manifest.model_package.version - - installed_model_version = metadata.version("policyengine-us") - if installed_model_version != manifest.model_package.version: - warnings.warn( - "Installed policyengine-us version " - f"({installed_model_version}) does not match the bundled " - "policyengine.py manifest " - f"({manifest.model_package.version}). Calculations will " - "run against the installed version, but dataset " - "compatibility is not guaranteed. To silence this " - "warning, install the version pinned by the manifest.", - UserWarning, - stacklevel=2, - ) - - model_build_metadata = _get_runtime_data_build_metadata() - data_certification = certify_data_release_compatibility( - "us", - runtime_model_version=installed_model_version, - runtime_data_build_fingerprint=model_build_metadata.get( - "data_build_fingerprint" - ), - ) - - super().__init__(**kwargs) - self.release_manifest = manifest - self.model_package = manifest.model_package - self.data_package = manifest.data_package - self.default_dataset_uri = manifest.default_dataset_uri - self.data_certification = data_certification - from policyengine_core.enums import Enum + # --- Hooks ----------------------------------------------------------- + @classmethod + def _get_runtime_data_build_metadata(cls) -> dict[str, Optional[str]]: + try: + from policyengine_us.build_metadata import get_data_build_metadata + except ModuleNotFoundError as exc: + if exc.name != "policyengine_us.build_metadata": + raise + return {} + return get_data_build_metadata() or {} + + def _load_system(self): from policyengine_us.system import system - # Attach region registry + return system + + def _load_region_registry(self): from policyengine.countries.us.regions import us_region_registry - self.region_registry = us_region_registry - - self.id = f"{self.model.id}@{self.version}" - - for var_obj in system.variables.values(): - # Serialize default_value for JSON compatibility - default_val = var_obj.default_value - if var_obj.value_type is Enum: - default_val = default_val.name - elif var_obj.value_type is datetime.date: - default_val = default_val.isoformat() - - variable = Variable( - id=self.id + "-" + var_obj.name, - name=var_obj.name, - label=getattr(var_obj, "label", None), - tax_benefit_model_version=self, - entity=var_obj.entity.key, - description=var_obj.documentation, - data_type=var_obj.value_type if var_obj.value_type is not Enum else str, - default_value=default_val, - value_type=var_obj.value_type, - ) - if ( - hasattr(var_obj, "possible_values") - and var_obj.possible_values is not None - ): - variable.possible_values = list( - map( - lambda x: x.name, - var_obj.possible_values._value2member_map_.values(), - ) - ) - # Extract and resolve adds/subtracts. - # Core stores these as either list[str] or a parameter path string. - # Resolve parameter paths to lists so consumers always get list[str]. - if hasattr(var_obj, "adds") and var_obj.adds is not None: - if isinstance(var_obj.adds, str): - try: - from policyengine_core.parameters.operations.get_parameter import ( - get_parameter, - ) - - param = get_parameter(system.parameters, var_obj.adds) - variable.adds = list(param("2025-01-01")) - except (ValueError, Exception): - variable.adds = None - else: - variable.adds = var_obj.adds - if hasattr(var_obj, "subtracts") and var_obj.subtracts is not None: - if isinstance(var_obj.subtracts, str): - try: - from policyengine_core.parameters.operations.get_parameter import ( - get_parameter, - ) - - param = get_parameter(system.parameters, var_obj.subtracts) - variable.subtracts = list(param("2025-01-01")) - except (ValueError, Exception): - variable.subtracts = None - else: - variable.subtracts = var_obj.subtracts - self.add_variable(variable) - - from policyengine_core.parameters import Parameter as CoreParameter - from policyengine_core.parameters import ParameterNode as CoreParameterNode - - scale_lookup = build_scale_lookup(system) - - for param_node in system.parameters.get_descendants(): - if isinstance(param_node, CoreParameter): - parameter = Parameter( - id=self.id + "-" + param_node.name, - name=param_node.name, - label=generate_label_for_parameter( - param_node, system, scale_lookup - ), - tax_benefit_model_version=self, - description=param_node.description, - data_type=type(param_node(2025)), - unit=param_node.metadata.get("unit"), - _core_param=param_node, - ) - self.add_parameter(parameter) - elif isinstance(param_node, CoreParameterNode): - node = ParameterNode( - id=self.id + "-" + param_node.name, - name=param_node.name, - label=param_node.metadata.get("label"), - description=param_node.description, - tax_benefit_model_version=self, - ) - self.add_parameter_node(node) - - def _build_entity_relationships( - self, dataset: PolicyEngineUSDataset - ) -> pd.DataFrame: - """Build a DataFrame mapping each person to their containing entities.""" - person_data = pd.DataFrame(dataset.data.person) - return build_entity_relationships(person_data, US_GROUP_ENTITIES) + return us_region_registry + + @property + def _dataset_class(self): + return PolicyEngineUSDataset + # --- run ------------------------------------------------------------- def run(self, simulation: "Simulation") -> "Simulation": from policyengine_us import Microsimulation from policyengine_us.system import system @@ -308,14 +172,12 @@ def run(self, simulation: "Simulation") -> "Simulation": ), ) - # Build reform dict from policy and dynamic parameter values. # US requires reforms at Microsimulation construction time # (unlike UK which supports p.update() after construction). policy_reform = build_reform_dict(simulation.policy) dynamic_reform = build_reform_dict(simulation.dynamic) reform_dict = merge_reform_dicts(policy_reform, dynamic_reform) - # Create Microsimulation with reform at construction time microsim = Microsimulation(reform=reform_dict) self._build_simulation_from_dataset(microsim, dataset, system) @@ -346,7 +208,7 @@ def run(self, simulation: "Simulation") -> "Simulation": "tax_unit_weight", } - # First, copy ID and weight columns from input dataset + # Copy ID and weight columns from input dataset. for entity in data.keys(): input_df = pd.DataFrame(getattr(dataset.data, entity)) entity_id_col = f"{entity}_id" @@ -357,16 +219,16 @@ def run(self, simulation: "Simulation") -> "Simulation": if entity_weight_col in input_df.columns: data[entity][entity_weight_col] = input_df[entity_weight_col].values - # For person entity, also copy person-level group ID columns + # Person entity also needs person-level group ID columns so that + # downstream joins (e.g. person->tax_unit) work. person_input_df = pd.DataFrame(dataset.data.person) for col in person_input_df.columns: if col.startswith("person_") and col.endswith("_id"): - # Map person_household_id -> household_id, etc. target_col = col.replace("person_", "") if target_col in id_columns: data["person"][target_col] = person_input_df[col].values - # Then calculate non-ID, non-weight variables from simulation + # Calculate non-ID, non-weight variables from simulation for entity, variables in self.entity_variables.items(): for var in variables: if var not in id_columns and var not in weight_columns: @@ -404,61 +266,23 @@ def run(self, simulation: "Simulation") -> "Simulation": ), ) - def save(self, simulation: "Simulation"): - """Save the simulation's output dataset.""" - simulation.output_dataset.save() - - def load(self, simulation: "Simulation"): - """Load the simulation's output dataset.""" - import os - - filepath = str( - Path(simulation.dataset.filepath).parent / (simulation.id + ".h5") - ) - - simulation.output_dataset = PolicyEngineUSDataset( - id=simulation.id, - name=simulation.dataset.name, - description=simulation.dataset.description, - filepath=filepath, - year=simulation.dataset.year, - is_output_dataset=True, - ) - - # Load timestamps from file system metadata - if os.path.exists(filepath): - simulation.created_at = datetime.datetime.fromtimestamp( - os.path.getctime(filepath) - ) - simulation.updated_at = datetime.datetime.fromtimestamp( - os.path.getmtime(filepath) - ) - def _build_simulation_from_dataset(self, microsim, dataset, system): """Build a PolicyEngine Core simulation from dataset entity IDs. - This follows the same pattern as policyengine-uk, initializing - entities from IDs first, then using set_input() for variables. - - Args: - microsim: The Microsimulation object to populate - dataset: The dataset containing entity data - system: The tax-benefit system + Mirrors the policyengine-uk pattern of instantiating entities from + IDs first and then setting variable inputs. Handles both the legacy + ``person_X_id`` and the ``X_id`` column-naming conventions. """ import numpy as np from policyengine_core.simulations.simulation_builder import ( SimulationBuilder, ) - # Create builder and instantiate entities builder = SimulationBuilder() builder.populations = system.instantiate_entities() - # Extract entity IDs from dataset person_data = pd.DataFrame(dataset.data.person) - # Determine column naming convention - # Support both person_X_id (from create_datasets) and X_id (from custom datasets) household_id_col = ( "person_household_id" if "person_household_id" in person_data.columns @@ -485,7 +309,6 @@ def _build_simulation_from_dataset(self, microsim, dataset, system): else "tax_unit_id" ) - # Declare entities builder.declare_person_entity("person", person_data["person_id"].values) builder.declare_entity( "household", np.unique(person_data[household_id_col].values) @@ -501,7 +324,6 @@ def _build_simulation_from_dataset(self, microsim, dataset, system): "marital_unit", np.unique(person_data[marital_unit_id_col].values) ) - # Join persons to group entities builder.join_with_persons( builder.populations["household"], person_data[household_id_col].values, @@ -528,12 +350,8 @@ def _build_simulation_from_dataset(self, microsim, dataset, system): np.array(["member"] * len(person_data)), ) - # Build simulation from populations microsim.build_from_populations(builder.populations) - # Set input variables for each entity - # Skip ID columns as they're structural and already used in entity building - # Support both naming conventions id_columns = { "person_id", "household_id", @@ -558,7 +376,6 @@ def _build_simulation_from_dataset(self, microsim, dataset, system): ]: df = pd.DataFrame(entity_df) for column in df.columns: - # Skip ID columns and check if variable exists in system if column not in id_columns and column in system.variables: microsim.set_input(column, dataset.year, df[column].values) @@ -585,8 +402,8 @@ def managed_microsimulation( """Construct a country-package Microsimulation pinned to this bundle. By default this enforces the dataset selection from the bundled - `policyengine.py` release manifest. Arbitrary dataset URIs require - `allow_unmanaged=True`. + ``policyengine.py`` release manifest. Arbitrary dataset URIs require + ``allow_unmanaged=True``. """ from policyengine_us import Microsimulation diff --git a/tests/fixtures/base_extraction_snapshots/uk_couple_two_kids.json b/tests/fixtures/base_extraction_snapshots/uk_couple_two_kids.json new file mode 100644 index 00000000..49302124 --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/uk_couple_two_kids.json @@ -0,0 +1,139 @@ +{ + "benunit.benunit_id": 0.0, + "benunit.benunit_weight": 1.0, + "benunit.child_benefit": 2328.16, + "benunit.child_tax_credit": 0.0, + "benunit.family_type": "COUPLE_WITH_CHILDREN", + "benunit.income_support": 0.0, + "benunit.pension_credit": 0.0, + "benunit.universal_credit": 0.0, + "benunit.working_tax_credit": 0.0, + "household.council_tax": 0.0, + "household.equiv_hbai_household_net_income": 52503.68, + "household.hbai_household_net_income": 73505.15, + "household.household_benefits": 5880.35, + "household.household_count_people": 4.0, + "household.household_gross_income": 95880.34, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 90000.0, + "household.household_net_income": 76898.3, + "household.household_tax": 18982.05, + "household.household_wealth_decile": 10.0, + "household.household_weight": 1.0, + "household.in_poverty_ahc": 0.0, + "household.in_poverty_bhc": 0.0, + "household.in_relative_poverty_ahc": 0.0, + "household.in_relative_poverty_bhc": 0.0, + "household.rent": 0.0, + "household.tenure_type": "RENT_PRIVATELY", + "household.vat": 0.0, + "person[0].age": 42.0, + "person[0].benunit_id": 0.0, + "person[0].child_benefit": 2328.16, + "person[0].child_tax_credit": 0.0, + "person[0].dividend_income": 0.0, + "person[0].earned_income": 55000.0, + "person[0].employment_income": 55000.0, + "person[0].gender": "MALE", + "person[0].household_id": 0.0, + "person[0].income_support": 0.0, + "person[0].income_tax": 9432.0, + "person[0].is_SP_age": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].national_insurance": 3110.6, + "person[0].pension_credit": 0.0, + "person[0].pension_income": 0.0, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].private_pension_income": 0.0, + "person[0].property_income": 0.0, + "person[0].savings_interest_income": 0.0, + "person[0].self_employment_income": 0.0, + "person[0].total_income": 55000.0, + "person[0].universal_credit": 0.0, + "person[0].working_tax_credit": 0.0, + "person[1].age": 40.0, + "person[1].benunit_id": 0.0, + "person[1].child_benefit": 2328.16, + "person[1].child_tax_credit": 0.0, + "person[1].dividend_income": 0.0, + "person[1].earned_income": 35000.0, + "person[1].employment_income": 35000.0, + "person[1].gender": "MALE", + "person[1].household_id": 0.0, + "person[1].income_support": 0.0, + "person[1].income_tax": 4486.0, + "person[1].is_SP_age": 0.0, + "person[1].is_adult": 1.0, + "person[1].is_child": 0.0, + "person[1].is_male": 1.0, + "person[1].national_insurance": 1794.4, + "person[1].pension_credit": 0.0, + "person[1].pension_income": 0.0, + "person[1].person_id": 0.0, + "person[1].person_weight": 1.0, + "person[1].private_pension_income": 0.0, + "person[1].property_income": 0.0, + "person[1].savings_interest_income": 0.0, + "person[1].self_employment_income": 0.0, + "person[1].total_income": 35000.0, + "person[1].universal_credit": 0.0, + "person[1].working_tax_credit": 0.0, + "person[2].age": 8.0, + "person[2].benunit_id": 0.0, + "person[2].child_benefit": 2328.16, + "person[2].child_tax_credit": 0.0, + "person[2].dividend_income": 0.0, + "person[2].earned_income": 0.0, + "person[2].employment_income": 0.0, + "person[2].gender": "MALE", + "person[2].household_id": 0.0, + "person[2].income_support": 0.0, + "person[2].income_tax": 0.0, + "person[2].is_SP_age": 0.0, + "person[2].is_adult": 0.0, + "person[2].is_child": 1.0, + "person[2].is_male": 1.0, + "person[2].national_insurance": 0.0, + "person[2].pension_credit": 0.0, + "person[2].pension_income": 0.0, + "person[2].person_id": 0.0, + "person[2].person_weight": 1.0, + "person[2].private_pension_income": 0.0, + "person[2].property_income": 0.0, + "person[2].savings_interest_income": 0.0, + "person[2].self_employment_income": 0.0, + "person[2].total_income": 0.0, + "person[2].universal_credit": 0.0, + "person[2].working_tax_credit": 0.0, + "person[3].age": 3.0, + "person[3].benunit_id": 0.0, + "person[3].child_benefit": 2328.16, + "person[3].child_tax_credit": 0.0, + "person[3].dividend_income": 0.0, + "person[3].earned_income": 0.0, + "person[3].employment_income": 0.0, + "person[3].gender": "MALE", + "person[3].household_id": 0.0, + "person[3].income_support": 0.0, + "person[3].income_tax": 0.0, + "person[3].is_SP_age": 0.0, + "person[3].is_adult": 0.0, + "person[3].is_child": 1.0, + "person[3].is_male": 1.0, + "person[3].national_insurance": 0.0, + "person[3].pension_credit": 0.0, + "person[3].pension_income": 0.0, + "person[3].person_id": 0.0, + "person[3].person_weight": 1.0, + "person[3].private_pension_income": 0.0, + "person[3].property_income": 0.0, + "person[3].savings_interest_income": 0.0, + "person[3].self_employment_income": 0.0, + "person[3].total_income": 0.0, + "person[3].universal_credit": 0.0, + "person[3].working_tax_credit": 0.0 +} diff --git a/tests/fixtures/base_extraction_snapshots/uk_model_surface.json b/tests/fixtures/base_extraction_snapshots/uk_model_surface.json new file mode 100644 index 00000000..161ef0ec --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/uk_model_surface.json @@ -0,0 +1,11 @@ +{ + "country_id": "uk", + "data_package_name": "policyengine-uk-data", + "has_employment_income": true, + "has_income_tax": true, + "has_region_registry": true, + "model_package_name": "policyengine-uk", + "num_parameters_bucketed_100s": 20, + "num_variables_bucketed_100s": 8, + "region_registry_country": "uk" +} diff --git a/tests/fixtures/base_extraction_snapshots/uk_single_adult_employment_income.json b/tests/fixtures/base_extraction_snapshots/uk_single_adult_employment_income.json new file mode 100644 index 00000000..5ec94094 --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/uk_single_adult_employment_income.json @@ -0,0 +1,58 @@ +{ + "benunit.benunit_id": 0.0, + "benunit.benunit_weight": 1.0, + "benunit.child_benefit": 0.0, + "benunit.child_tax_credit": 0.0, + "benunit.family_type": "SINGLE", + "benunit.income_support": 0.0, + "benunit.pension_credit": 0.0, + "benunit.universal_credit": 0.0, + "benunit.working_tax_credit": 0.0, + "household.council_tax": 0.0, + "household.equiv_hbai_household_net_income": 37491.94, + "household.hbai_household_net_income": 25119.6, + "household.household_benefits": 0.0, + "household.household_count_people": 1.0, + "household.household_gross_income": 30000.0, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 30000.0, + "household.household_net_income": 24960.55, + "household.household_tax": 5039.45, + "household.household_wealth_decile": 10.0, + "household.household_weight": 1.0, + "household.in_poverty_ahc": 0.0, + "household.in_poverty_bhc": 0.0, + "household.in_relative_poverty_ahc": 0.0, + "household.in_relative_poverty_bhc": 0.0, + "household.rent": 0.0, + "household.tenure_type": "RENT_PRIVATELY", + "household.vat": 0.0, + "person[0].age": 35.0, + "person[0].benunit_id": 0.0, + "person[0].child_benefit": 0.0, + "person[0].child_tax_credit": 0.0, + "person[0].dividend_income": 0.0, + "person[0].earned_income": 30000.0, + "person[0].employment_income": 30000.0, + "person[0].gender": "MALE", + "person[0].household_id": 0.0, + "person[0].income_support": 0.0, + "person[0].income_tax": 3486.0, + "person[0].is_SP_age": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].national_insurance": 1394.4, + "person[0].pension_credit": 0.0, + "person[0].pension_income": 0.0, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].private_pension_income": 0.0, + "person[0].property_income": 0.0, + "person[0].savings_interest_income": 0.0, + "person[0].self_employment_income": 0.0, + "person[0].total_income": 30000.0, + "person[0].universal_credit": 0.0, + "person[0].working_tax_credit": 0.0 +} diff --git a/tests/fixtures/base_extraction_snapshots/uk_single_adult_no_income.json b/tests/fixtures/base_extraction_snapshots/uk_single_adult_no_income.json new file mode 100644 index 00000000..59657e2c --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/uk_single_adult_no_income.json @@ -0,0 +1,58 @@ +{ + "benunit.benunit_id": 0.0, + "benunit.benunit_weight": 1.0, + "benunit.child_benefit": 0.0, + "benunit.child_tax_credit": 0.0, + "benunit.family_type": "SINGLE", + "benunit.income_support": 0.0, + "benunit.pension_credit": 0.0, + "benunit.universal_credit": 5079.13, + "benunit.working_tax_credit": 0.0, + "household.council_tax": 0.0, + "household.equiv_hbai_household_net_income": 7580.79, + "household.hbai_household_net_income": 5079.13, + "household.household_benefits": 5079.13, + "household.household_count_people": 1.0, + "household.household_gross_income": 5079.13, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 0.0, + "household.household_net_income": 4920.09, + "household.household_tax": 159.04, + "household.household_wealth_decile": 10.0, + "household.household_weight": 1.0, + "household.in_poverty_ahc": 1.0, + "household.in_poverty_bhc": 1.0, + "household.in_relative_poverty_ahc": 0.0, + "household.in_relative_poverty_bhc": 0.0, + "household.rent": 0.0, + "household.tenure_type": "RENT_PRIVATELY", + "household.vat": 0.0, + "person[0].age": 35.0, + "person[0].benunit_id": 0.0, + "person[0].child_benefit": 0.0, + "person[0].child_tax_credit": 0.0, + "person[0].dividend_income": 0.0, + "person[0].earned_income": 0.0, + "person[0].employment_income": 0.0, + "person[0].gender": "MALE", + "person[0].household_id": 0.0, + "person[0].income_support": 0.0, + "person[0].income_tax": 0.0, + "person[0].is_SP_age": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].national_insurance": 0.0, + "person[0].pension_credit": 0.0, + "person[0].pension_income": 0.0, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].private_pension_income": 0.0, + "person[0].property_income": 0.0, + "person[0].savings_interest_income": 0.0, + "person[0].self_employment_income": 0.0, + "person[0].total_income": 0.0, + "person[0].universal_credit": 5079.13, + "person[0].working_tax_credit": 0.0 +} diff --git a/tests/fixtures/base_extraction_snapshots/uk_single_parent_one_child.json b/tests/fixtures/base_extraction_snapshots/uk_single_parent_one_child.json new file mode 100644 index 00000000..06e55db0 --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/uk_single_parent_one_child.json @@ -0,0 +1,85 @@ +{ + "benunit.benunit_id": 0.0, + "benunit.benunit_weight": 1.0, + "benunit.child_benefit": 1400.66, + "benunit.child_tax_credit": 0.0, + "benunit.family_type": "LONE_PARENT", + "benunit.income_support": 0.0, + "benunit.pension_credit": 0.0, + "benunit.universal_credit": 1544.43, + "benunit.working_tax_credit": 0.0, + "household.council_tax": 0.0, + "household.equiv_hbai_household_net_income": 28120.33, + "household.hbai_household_net_income": 24464.69, + "household.household_benefits": 2945.09, + "household.household_count_people": 2.0, + "household.household_gross_income": 27945.09, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 25000.0, + "household.household_net_income": 24305.64, + "household.household_tax": 3639.45, + "household.household_wealth_decile": 10.0, + "household.household_weight": 1.0, + "household.in_poverty_ahc": 0.0, + "household.in_poverty_bhc": 0.0, + "household.in_relative_poverty_ahc": 0.0, + "household.in_relative_poverty_bhc": 0.0, + "household.rent": 0.0, + "household.tenure_type": "RENT_PRIVATELY", + "household.vat": 0.0, + "person[0].age": 32.0, + "person[0].benunit_id": 0.0, + "person[0].child_benefit": 1400.66, + "person[0].child_tax_credit": 0.0, + "person[0].dividend_income": 0.0, + "person[0].earned_income": 25000.0, + "person[0].employment_income": 25000.0, + "person[0].gender": "MALE", + "person[0].household_id": 0.0, + "person[0].income_support": 0.0, + "person[0].income_tax": 2486.0, + "person[0].is_SP_age": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].national_insurance": 994.4, + "person[0].pension_credit": 0.0, + "person[0].pension_income": 0.0, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].private_pension_income": 0.0, + "person[0].property_income": 0.0, + "person[0].savings_interest_income": 0.0, + "person[0].self_employment_income": 0.0, + "person[0].total_income": 25000.0, + "person[0].universal_credit": 1544.43, + "person[0].working_tax_credit": 0.0, + "person[1].age": 5.0, + "person[1].benunit_id": 0.0, + "person[1].child_benefit": 1400.66, + "person[1].child_tax_credit": 0.0, + "person[1].dividend_income": 0.0, + "person[1].earned_income": 0.0, + "person[1].employment_income": 0.0, + "person[1].gender": "MALE", + "person[1].household_id": 0.0, + "person[1].income_support": 0.0, + "person[1].income_tax": 0.0, + "person[1].is_SP_age": 0.0, + "person[1].is_adult": 0.0, + "person[1].is_child": 1.0, + "person[1].is_male": 1.0, + "person[1].national_insurance": 0.0, + "person[1].pension_credit": 0.0, + "person[1].pension_income": 0.0, + "person[1].person_id": 0.0, + "person[1].person_weight": 1.0, + "person[1].private_pension_income": 0.0, + "person[1].property_income": 0.0, + "person[1].savings_interest_income": 0.0, + "person[1].self_employment_income": 0.0, + "person[1].total_income": 0.0, + "person[1].universal_credit": 1544.43, + "person[1].working_tax_credit": 0.0 +} diff --git a/tests/fixtures/base_extraction_snapshots/us_married_two_kids_high_income.json b/tests/fixtures/base_extraction_snapshots/us_married_two_kids_high_income.json new file mode 100644 index 00000000..1d5e98ca --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/us_married_two_kids_high_income.json @@ -0,0 +1,97 @@ +{ + "family.family_id": 0.0, + "family.family_weight": 0.0, + "household.congressional_district_geoid": 0.0, + "household.household_benefits": 0.0, + "household.household_count_people": 4.0, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 240000.0, + "household.household_net_income": 175089.92, + "household.household_tax": 64910.07, + "household.household_weight": 1.0, + "marital_unit.marital_unit_id": 0.0, + "marital_unit.marital_unit_weight": 1.0, + "person[0].age": 42.0, + "person[0].employment_income": 150000.0, + "person[0].family_id": 0.0, + "person[0].household_id": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].marital_unit_id": 0.0, + "person[0].medicaid": 0.0, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].race": 3.0, + "person[0].social_security": 0.0, + "person[0].spm_unit_id": 0.0, + "person[0].ssi": 0.0, + "person[0].tax_unit_id": 0.0, + "person[0].unemployment_compensation": 0.0, + "person[1].age": 40.0, + "person[1].employment_income": 90000.0, + "person[1].family_id": 0.0, + "person[1].household_id": 0.0, + "person[1].is_adult": 1.0, + "person[1].is_child": 0.0, + "person[1].is_male": 1.0, + "person[1].marital_unit_id": 0.0, + "person[1].medicaid": 0.0, + "person[1].person_id": 1.0, + "person[1].person_weight": 1.0, + "person[1].race": 3.0, + "person[1].social_security": 0.0, + "person[1].spm_unit_id": 0.0, + "person[1].ssi": 0.0, + "person[1].tax_unit_id": 0.0, + "person[1].unemployment_compensation": 0.0, + "person[2].age": 8.0, + "person[2].employment_income": 0.0, + "person[2].family_id": 0.0, + "person[2].household_id": 0.0, + "person[2].is_adult": 0.0, + "person[2].is_child": 1.0, + "person[2].is_male": 1.0, + "person[2].marital_unit_id": 0.0, + "person[2].medicaid": 0.0, + "person[2].person_id": 2.0, + "person[2].person_weight": 1.0, + "person[2].race": 3.0, + "person[2].social_security": 0.0, + "person[2].spm_unit_id": 0.0, + "person[2].ssi": 0.0, + "person[2].tax_unit_id": 0.0, + "person[2].unemployment_compensation": 0.0, + "person[3].age": 3.0, + "person[3].employment_income": 0.0, + "person[3].family_id": 0.0, + "person[3].household_id": 0.0, + "person[3].is_adult": 0.0, + "person[3].is_child": 1.0, + "person[3].is_male": 1.0, + "person[3].marital_unit_id": 0.0, + "person[3].medicaid": 0.0, + "person[3].person_id": 3.0, + "person[3].person_weight": 1.0, + "person[3].race": 3.0, + "person[3].social_security": 0.0, + "person[3].spm_unit_id": 0.0, + "person[3].ssi": 0.0, + "person[3].tax_unit_id": 0.0, + "person[3].unemployment_compensation": 0.0, + "spm_unit.snap": 0.0, + "spm_unit.spm_unit_id": 0.0, + "spm_unit.spm_unit_is_in_deep_spm_poverty": 0.0, + "spm_unit.spm_unit_is_in_spm_poverty": 0.0, + "spm_unit.spm_unit_net_income": 175089.92, + "spm_unit.spm_unit_weight": 1.0, + "spm_unit.tanf": 0.0, + "tax_unit.ctc": 4400.0, + "tax_unit.eitc": 0.0, + "tax_unit.employee_payroll_tax": 21480.0, + "tax_unit.household_state_income_tax": 12690.07, + "tax_unit.income_tax": 30740.0, + "tax_unit.tax_unit_id": 0.0, + "tax_unit.tax_unit_weight": 1.0 +} diff --git a/tests/fixtures/base_extraction_snapshots/us_model_surface.json b/tests/fixtures/base_extraction_snapshots/us_model_surface.json new file mode 100644 index 00000000..eaf4352e --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/us_model_surface.json @@ -0,0 +1,11 @@ +{ + "country_id": "us", + "data_package_name": "policyengine-us-data", + "has_employment_income": true, + "has_income_tax": true, + "has_region_registry": true, + "model_package_name": "policyengine-us", + "num_parameters_bucketed_100s": 777, + "num_variables_bucketed_100s": 46, + "region_registry_country": "us" +} diff --git a/tests/fixtures/base_extraction_snapshots/us_single_adult_employment_income.json b/tests/fixtures/base_extraction_snapshots/us_single_adult_employment_income.json new file mode 100644 index 00000000..d94660a9 --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/us_single_adult_employment_income.json @@ -0,0 +1,46 @@ +{ + "family.family_id": 0.0, + "family.family_weight": 0.0, + "household.congressional_district_geoid": 0.0, + "household.household_benefits": 0.0, + "household.household_count_people": 1.0, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 60000.0, + "household.household_net_income": 48007.14, + "household.household_tax": 11992.86, + "household.household_weight": 1.0, + "marital_unit.marital_unit_id": 0.0, + "marital_unit.marital_unit_weight": 1.0, + "person[0].age": 35.0, + "person[0].employment_income": 60000.0, + "person[0].family_id": 0.0, + "person[0].household_id": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].marital_unit_id": 0.0, + "person[0].medicaid": 0.0, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].race": 3.0, + "person[0].social_security": 0.0, + "person[0].spm_unit_id": 0.0, + "person[0].ssi": 0.0, + "person[0].tax_unit_id": 0.0, + "person[0].unemployment_compensation": 0.0, + "spm_unit.snap": 0.0, + "spm_unit.spm_unit_id": 0.0, + "spm_unit.spm_unit_is_in_deep_spm_poverty": 0.0, + "spm_unit.spm_unit_is_in_spm_poverty": 0.0, + "spm_unit.spm_unit_net_income": 48007.14, + "spm_unit.spm_unit_weight": 1.0, + "spm_unit.tanf": 0.0, + "tax_unit.ctc": 0.0, + "tax_unit.eitc": 0.0, + "tax_unit.employee_payroll_tax": 5370.0, + "tax_unit.household_state_income_tax": 1602.86, + "tax_unit.income_tax": 5020.0, + "tax_unit.tax_unit_id": 0.0, + "tax_unit.tax_unit_weight": 1.0 +} diff --git a/tests/fixtures/base_extraction_snapshots/us_single_adult_no_income.json b/tests/fixtures/base_extraction_snapshots/us_single_adult_no_income.json new file mode 100644 index 00000000..258db6f1 --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/us_single_adult_no_income.json @@ -0,0 +1,46 @@ +{ + "family.family_id": 0.0, + "family.family_weight": 0.0, + "household.congressional_district_geoid": 0.0, + "household.household_benefits": 3596.04, + "household.household_count_people": 1.0, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 0.0, + "household.household_net_income": 3596.04, + "household.household_tax": 0.0, + "household.household_weight": 1.0, + "marital_unit.marital_unit_id": 0.0, + "marital_unit.marital_unit_weight": 1.0, + "person[0].age": 35.0, + "person[0].employment_income": 0.0, + "person[0].family_id": 0.0, + "person[0].household_id": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].marital_unit_id": 0.0, + "person[0].medicaid": 6439.11, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].race": 3.0, + "person[0].social_security": 0.0, + "person[0].spm_unit_id": 0.0, + "person[0].ssi": 0.0, + "person[0].tax_unit_id": 0.0, + "person[0].unemployment_compensation": 0.0, + "spm_unit.snap": 3596.04, + "spm_unit.spm_unit_id": 0.0, + "spm_unit.spm_unit_is_in_deep_spm_poverty": 0.0, + "spm_unit.spm_unit_is_in_spm_poverty": 0.0, + "spm_unit.spm_unit_net_income": 3596.04, + "spm_unit.spm_unit_weight": 1.0, + "spm_unit.tanf": 0.0, + "tax_unit.ctc": 0.0, + "tax_unit.eitc": 0.0, + "tax_unit.employee_payroll_tax": 0.0, + "tax_unit.household_state_income_tax": 0.0, + "tax_unit.income_tax": 0.0, + "tax_unit.tax_unit_id": 0.0, + "tax_unit.tax_unit_weight": 1.0 +} diff --git a/tests/fixtures/base_extraction_snapshots/us_single_parent_one_child.json b/tests/fixtures/base_extraction_snapshots/us_single_parent_one_child.json new file mode 100644 index 00000000..78ba7237 --- /dev/null +++ b/tests/fixtures/base_extraction_snapshots/us_single_parent_one_child.json @@ -0,0 +1,63 @@ +{ + "family.family_id": 0.0, + "family.family_weight": 0.0, + "household.congressional_district_geoid": 0.0, + "household.household_benefits": 1003.27, + "household.household_count_people": 2.0, + "household.household_id": 0.0, + "household.household_income_decile": 10.0, + "household.household_market_income": 40000.0, + "household.household_net_income": 39890.89, + "household.household_tax": 1112.38, + "household.household_weight": 1.0, + "marital_unit.marital_unit_id": 0.0, + "marital_unit.marital_unit_weight": 1.0, + "person[0].age": 32.0, + "person[0].employment_income": 40000.0, + "person[0].family_id": 0.0, + "person[0].household_id": 0.0, + "person[0].is_adult": 1.0, + "person[0].is_child": 0.0, + "person[0].is_male": 1.0, + "person[0].marital_unit_id": 0.0, + "person[0].medicaid": 0.0, + "person[0].person_id": 0.0, + "person[0].person_weight": 1.0, + "person[0].race": 3.0, + "person[0].social_security": 0.0, + "person[0].spm_unit_id": 0.0, + "person[0].ssi": 0.0, + "person[0].tax_unit_id": 0.0, + "person[0].unemployment_compensation": 0.0, + "person[1].age": 5.0, + "person[1].employment_income": 0.0, + "person[1].family_id": 0.0, + "person[1].household_id": 0.0, + "person[1].is_adult": 0.0, + "person[1].is_child": 1.0, + "person[1].is_male": 1.0, + "person[1].marital_unit_id": 0.0, + "person[1].medicaid": 3258.31, + "person[1].person_id": 1.0, + "person[1].person_weight": 1.0, + "person[1].race": 3.0, + "person[1].social_security": 0.0, + "person[1].spm_unit_id": 0.0, + "person[1].ssi": 0.0, + "person[1].tax_unit_id": 0.0, + "person[1].unemployment_compensation": 0.0, + "spm_unit.snap": 0.0, + "spm_unit.spm_unit_id": 0.0, + "spm_unit.spm_unit_is_in_deep_spm_poverty": 0.0, + "spm_unit.spm_unit_is_in_spm_poverty": 0.0, + "spm_unit.spm_unit_net_income": 39890.89, + "spm_unit.spm_unit_weight": 1.0, + "spm_unit.tanf": 0.0, + "tax_unit.ctc": 2200.0, + "tax_unit.eitc": 1852.62, + "tax_unit.employee_payroll_tax": 3580.0, + "tax_unit.household_state_income_tax": 0.0, + "tax_unit.income_tax": -2467.62, + "tax_unit.tax_unit_id": 0.0, + "tax_unit.tax_unit_weight": 1.0 +} diff --git a/tests/test_base_extraction_snapshot.py b/tests/test_base_extraction_snapshot.py new file mode 100644 index 00000000..aa2a9c51 --- /dev/null +++ b/tests/test_base_extraction_snapshot.py @@ -0,0 +1,217 @@ +"""Byte-level snapshot regression test for MicrosimulationModelVersion extraction. + +These tests freeze the exact numeric outputs of both the US and UK household +calculators across a representative set of cases. The intent is to make the +base-class extraction (PR F) fail loudly if any country-specific behaviour +drifts during the refactor. + +Snapshots live in ``tests/fixtures/base_extraction_snapshots/``. To refresh +them, run with ``PE_UPDATE_SNAPSHOTS=1`` set. Do **not** refresh them as part +of a refactor meant to be behaviour-preserving. +""" + +from __future__ import annotations + +import json +import math +import os +from pathlib import Path + +import pytest + +SNAPSHOT_DIR = Path(__file__).parent / "fixtures" / "base_extraction_snapshots" +UPDATE = os.environ.get("PE_UPDATE_SNAPSHOTS") == "1" + + +def _flatten(prefix: str, value, out: dict[str, float]) -> None: + """Flatten a nested ``HouseholdResult`` into ``"path.name" -> scalar``.""" + if isinstance(value, list): + for idx, item in enumerate(value): + _flatten(f"{prefix}[{idx}]", item, out) + return + if isinstance(value, dict): + for key, sub in value.items(): + new_prefix = f"{prefix}.{key}" if prefix else str(key) + _flatten(new_prefix, sub, out) + return + if isinstance(value, bool): + out[prefix] = float(value) + elif isinstance(value, (int, float)): + out[prefix] = float(value) + else: + out[prefix] = str(value) + + +def _round(value, places: int = 2): + if isinstance(value, float): + if math.isnan(value): + return "nan" + if math.isinf(value): + return "inf" if value > 0 else "-inf" + return round(value, places) + return value + + +def _check_snapshot(name: str, data: dict) -> None: + path = SNAPSHOT_DIR / f"{name}.json" + rounded = {k: _round(v) for k, v in sorted(data.items())} + + if UPDATE or not path.exists(): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(rounded, indent=2, sort_keys=True) + "\n") + if not UPDATE: + pytest.skip(f"Created missing snapshot {path.name}; re-run to verify") + return + + expected = json.loads(path.read_text()) + diffs = [] + all_keys = set(expected) | set(rounded) + for key in sorted(all_keys): + if key not in expected: + diffs.append(f" new key: {key}={rounded[key]!r}") + elif key not in rounded: + diffs.append(f" removed key: {key}={expected[key]!r}") + elif expected[key] != rounded[key]: + diffs.append(f" {key}: expected {expected[key]!r}, got {rounded[key]!r}") + assert not diffs, f"Snapshot {name} drift:\n" + "\n".join(diffs[:40]) + + +# US cases ------------------------------------------------------------------- + + +US_CASES = { + "us_single_adult_no_income": dict( + people=[{"age": 35}], + tax_unit={"filing_status": "SINGLE"}, + year=2026, + ), + "us_single_adult_employment_income": dict( + people=[{"age": 35, "employment_income": 60_000}], + tax_unit={"filing_status": "SINGLE"}, + year=2026, + ), + "us_single_parent_one_child": dict( + people=[ + {"age": 32, "employment_income": 40_000}, + {"age": 5}, + ], + tax_unit={"filing_status": "HEAD_OF_HOUSEHOLD"}, + year=2026, + ), + "us_married_two_kids_high_income": dict( + people=[ + {"age": 42, "employment_income": 150_000}, + {"age": 40, "employment_income": 90_000}, + {"age": 8}, + {"age": 3}, + ], + tax_unit={"filing_status": "JOINT"}, + year=2026, + ), +} + + +@pytest.mark.parametrize("case_name", sorted(US_CASES)) +def test_us_household_snapshot(case_name: str) -> None: + pytest.importorskip("policyengine_us") + import policyengine as pe + + kwargs = US_CASES[case_name] + result = pe.us.calculate_household(**kwargs) + out: dict[str, float] = {} + _flatten("", result.to_dict(), out) + _check_snapshot(case_name, out) + + +# UK cases ------------------------------------------------------------------- + + +UK_CASES = { + "uk_single_adult_no_income": dict( + people=[{"age": 35}], + year=2026, + ), + "uk_single_adult_employment_income": dict( + people=[{"age": 35, "employment_income": 30_000}], + year=2026, + ), + "uk_single_parent_one_child": dict( + people=[ + {"age": 32, "employment_income": 25_000}, + {"age": 5}, + ], + year=2026, + ), + "uk_couple_two_kids": dict( + people=[ + {"age": 42, "employment_income": 55_000}, + {"age": 40, "employment_income": 35_000}, + {"age": 8}, + {"age": 3}, + ], + year=2026, + ), +} + + +@pytest.mark.parametrize("case_name", sorted(UK_CASES)) +def test_uk_household_snapshot(case_name: str) -> None: + pytest.importorskip("policyengine_uk") + import policyengine as pe + + kwargs = UK_CASES[case_name] + result = pe.uk.calculate_household(**kwargs) + out: dict[str, float] = {} + _flatten("", result.to_dict(), out) + _check_snapshot(case_name, out) + + +# Model-version metadata snapshots ------------------------------------------- + + +def test_us_model_version_surface() -> None: + """Freeze the exposed surface of ``us_latest`` (variables, parameters). + + If the base-class extraction accidentally changes how variables or + parameters are loaded from ``policyengine_us.system``, these counts will + drift. The snapshot intentionally rounds to stable aggregates rather than + dumping the full variable list so that unrelated upstream releases don't + churn the snapshot file. + """ + pytest.importorskip("policyengine_us") + from policyengine.tax_benefit_models.us import us_latest + + surface = { + "country_id": us_latest.release_manifest.country_id, + "model_package_name": us_latest.model_package.name, + "data_package_name": us_latest.data_package.name, + "has_region_registry": us_latest.region_registry is not None, + "region_registry_country": us_latest.region_registry.country_id, + "num_variables_bucketed_100s": len(us_latest.variables) // 100, + "num_parameters_bucketed_100s": len(us_latest.parameters) // 100, + "has_employment_income": any( + v.name == "employment_income" for v in us_latest.variables + ), + "has_income_tax": any(v.name == "income_tax" for v in us_latest.variables), + } + _check_snapshot("us_model_surface", surface) + + +def test_uk_model_version_surface() -> None: + pytest.importorskip("policyengine_uk") + from policyengine.tax_benefit_models.uk import uk_latest + + surface = { + "country_id": uk_latest.release_manifest.country_id, + "model_package_name": uk_latest.model_package.name, + "data_package_name": uk_latest.data_package.name, + "has_region_registry": uk_latest.region_registry is not None, + "region_registry_country": uk_latest.region_registry.country_id, + "num_variables_bucketed_100s": len(uk_latest.variables) // 100, + "num_parameters_bucketed_100s": len(uk_latest.parameters) // 100, + "has_employment_income": any( + v.name == "employment_income" for v in uk_latest.variables + ), + "has_income_tax": any(v.name == "income_tax" for v in uk_latest.variables), + } + _check_snapshot("uk_model_surface", surface) diff --git a/tests/test_manifest_version_mismatch.py b/tests/test_manifest_version_mismatch.py index 1c65230e..f5fd431a 100644 --- a/tests/test_manifest_version_mismatch.py +++ b/tests/test_manifest_version_mismatch.py @@ -34,6 +34,9 @@ def _pick_mismatched_version(manifest_version: str) -> str: return manifest_version + ".drift" +BASE_PATH = "policyengine.tax_benefit_models.common.model_version" + + def _run_init_version_check_branch( module_path: str, class_name: str, @@ -41,39 +44,35 @@ def _run_init_version_check_branch( ) -> list[warnings.WarningMessage]: """Exercise only the manifest-vs-installed version check in ``__init__``. - Patches ``metadata.version`` to return ``installed_version``, and - stubs everything the ``__init__`` calls after the version check so - we don't hit the network or do heavy work. Returns the list of - warnings emitted during the check. + The version-check logic lives on the shared + ``MicrosimulationModelVersion`` base; we patch names on that module + (not on the per-country ``model`` module) and stub everything the + ``__init__`` calls after the version check so we don't hit the + network or do heavy work. """ - with patch(f"{module_path}.metadata.version", return_value=installed_version): + with patch(f"{BASE_PATH}.metadata.version", return_value=installed_version): with patch( - f"{module_path}.certify_data_release_compatibility", + f"{BASE_PATH}.certify_data_release_compatibility", return_value=None, ): + # Prevent super().__init__ from actually running the + # parameter-loading pipeline — we only care that the + # version branch in __init__ emits a warning, not raises. with patch( - f"{module_path}._get_runtime_data_build_metadata", - return_value={}, + f"{BASE_PATH}.TaxBenefitModelVersion.__init__", + return_value=None, ): - # Prevent super().__init__ from actually running the - # parameter-loading pipeline — we only care that the - # version branch in our override emits a warning, not - # an exception. - with patch( - f"{module_path}.TaxBenefitModelVersion.__init__", - return_value=None, + import importlib + + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + # Stub the country-specific runtime-metadata hook so + # the version-check path doesn't import the country pkg. + with patch.object( + cls, "_get_runtime_data_build_metadata", return_value={} ): - # Import late so the patches above apply to the - # module-level names used by __init__. - import importlib - - module = importlib.import_module(module_path) - cls = getattr(module, class_name) with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") - # The class is a TaxBenefitModelVersion subclass - # that normally takes kwargs for the parameter - # tree. We're not exercising the parameter tree. try: cls() except Exception: