Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/fix-eternity-input-cache-preservation.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preserve user-provided ETERNITY inputs across cache invalidation when they were set for an ordinary period.
38 changes: 29 additions & 9 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import tempfile
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Type, Union

import numpy as np
Expand Down Expand Up @@ -57,6 +58,14 @@ def _stable_hash_to_seed(value: str) -> int:
)


@dataclass(frozen=True)
class PreservedUserInput:
variable_name: str
branch_name: str
period: Period
value: object


class Simulation:
"""
Represents a simulation, and handles the calculation logic
Expand Down Expand Up @@ -273,16 +282,23 @@ def _invalidate_all_caches(self) -> None:
self._fast_cache = {}
self.invalidated_caches = set()
# Snapshot user-provided inputs before wiping so they can be
# replayed into the fresh storage. Storage keys each entry as
# f"{branch_name}:{period}"; preserve exactly those keys.
preserved: dict[str, dict[str, object]] = {}
# replayed into the fresh storage. Use the storage API instead of
# hand-building keys, since ETERNITY variables canonicalize every
# period to the single ETERNITY storage key.
preserved: list[PreservedUserInput] = []
user_input_keys = getattr(self, "_user_input_keys", None) or set()
for variable_name, branch_name, period in user_input_keys:
holder = self.get_holder(variable_name)
storage_key = f"{branch_name}:{period}"
stored_value = holder._memory_storage._arrays.get(storage_key)
stored_value = holder._memory_storage.get(period, branch_name)
if stored_value is not None:
preserved.setdefault(variable_name, {})[storage_key] = stored_value
preserved.append(
PreservedUserInput(
variable_name=variable_name,
branch_name=branch_name,
period=period,
value=stored_value,
)
)
# Iterate only over holders that already exist on each population —
# lazy-creating a holder for every variable in the tax-benefit
# system (thousands in policyengine-us) inflated the cost of
Expand All @@ -295,9 +311,13 @@ def _invalidate_all_caches(self) -> None:
if holder._disk_storage is not None:
holder._disk_storage._files = {}
# Replay preserved user inputs so ``calculate`` still sees them.
for variable_name, key_to_array in preserved.items():
holder = self.get_holder(variable_name)
holder._memory_storage._arrays.update(key_to_array)
for user_input in preserved:
holder = self.get_holder(user_input.variable_name)
holder._memory_storage.put(
user_input.value,
user_input.period,
user_input.branch_name,
)
for branch in self.branches.values():
branch._invalidate_all_caches()

Expand Down
54 changes: 54 additions & 0 deletions tests/core/test_apply_reform_preserves_user_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,60 @@ def apply(self):
assert sim.calculate("age", period=period)[0] == 27


def test_apply_reform_preserves_eternity_inputs_set_for_a_period(
tax_benefit_system,
):
"""ETERNITY inputs set for ordinary periods must survive reform apply."""
sim = SimulationBuilder().build_from_entities(
tax_benefit_system, situation_examples.single
)
period = "2017"
expected_person_id = np.array([123], dtype=np.int32)

sim.set_input("person_id", period, expected_person_id)

class NoOpReform(Reform):
def apply(self):
pass

sim.apply_reform(NoOpReform)

result = sim.calculate("person_id", period=period)
assert np.array_equal(result, expected_person_id), (
"apply_reform lost an ETERNITY input set through Simulation.set_input "
f"for {period}; got {result} instead of {expected_person_id}."
)


def test_apply_reform_preserves_eternity_inputs_set_through_holder(
tax_benefit_system,
):
"""ETERNITY preservation must also cover direct ``holder.set_input``."""
sim = SimulationBuilder().build_from_entities(
tax_benefit_system, situation_examples.single
)
period = "2017"
expected_person_id = np.array([456], dtype=np.int32)

sim.get_holder("person_id").set_input(
period,
expected_person_id,
sim.branch_name,
)

class NoOpReform(Reform):
def apply(self):
pass

sim.apply_reform(NoOpReform)

result = sim.calculate("person_id", period=period)
assert np.array_equal(result, expected_person_id), (
"apply_reform lost an ETERNITY input set through Holder.set_input "
f"for {period}; got {result} instead of {expected_person_id}."
)


def test_apply_reform_preserves_situation_dict_inputs(tax_benefit_system):
"""Situation-dict inputs must survive ``apply_reform`` too.

Expand Down