Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ class NodeExpander(SubDAGModifier):

EXPAND_NODES = "expand_nodes"

@classmethod
def runs_before_regular_expanders(cls) -> bool:
return False

def transform_dag(
self, nodes: Collection[node.Node], config: dict[str, Any], fn: Callable
) -> Collection[node.Node]:
Expand Down Expand Up @@ -794,6 +798,35 @@ def _resolve_nodes_error(fn: Callable) -> str:
return f"Exception occurred while compiling function: {fn.__name__} to nodes"


def _parameterized_targets(expander: NodeExpander) -> set[str]:
"""Return original function parameters replaced by parameterization-style expanders."""
parameterization = getattr(expander, "parameterization", {})
targets = set()
for mapping in parameterization.values():
if isinstance(mapping, tuple):
mapping = mapping[0]
if isinstance(mapping, dict):
targets.update(mapping)
return targets


def _validate_expander_composition(
pre_expanders: list[NodeExpander], regular_expanders: list[NodeExpander]
) -> None:
injected_targets = set().union(
*(_parameterized_targets(expander) for expander in pre_expanders)
)
regular_targets = set().union(
*(_parameterized_targets(expander) for expander in regular_expanders)
)
overlap = injected_targets & regular_targets
if overlap:
raise InvalidDecoratorException(
"Cannot combine @inject and @parameterize replacements for the same parameter(s): "
f"{', '.join(sorted(overlap))}"
)


def resolve_nodes(fn: Callable, config: dict[str, Any]) -> Collection[node.Node]:
"""Gets a list of nodes from a function. This is meant to be an abstraction between the node
and the function that it implements. This will end up coordinating with the decorators we build
Expand All @@ -810,8 +843,8 @@ def resolve_nodes(fn: Callable, config: dict[str, Any]) -> Collection[node.Node]
-- this is determined in the node creator class. Apply that to get
the initial node.

3. If there is a list of node expanders, apply them. Otherwise apply the default
node expander This must be a list of length one. This gives out a list of nodes.
3. If there is a list of node expanders, apply pre-expanders before at most one regular
expander. This gives out a list of nodes.

4. If there is a node transformer, apply that. Note that the node transformer
gets applied individually to just the sink nodes in the subdag. It subclasses
Expand All @@ -837,8 +870,19 @@ def resolve_nodes(fn: Callable, config: dict[str, Any]) -> Collection[node.Node]
for node_injector in node_injectors:
nodes = node_injector.transform_dag(nodes, filter_config(config, node_injector), fn)
node_expanders = function_decorators[NodeExpander.get_lifecycle_name()]
if len(node_expanders) > 0:
(node_expander,) = node_expanders
pre_expanders = [
expander for expander in node_expanders if expander.runs_before_regular_expanders()
]
regular_expanders = [
expander for expander in node_expanders if not expander.runs_before_regular_expanders()
]
if len(regular_expanders) > 1:
raise InvalidDecoratorException(
"Cannot combine multiple regular node expanders on one function: "
f"{', '.join(expander.name for expander in regular_expanders)}"
)
_validate_expander_composition(pre_expanders, regular_expanders)
for node_expander in pre_expanders + regular_expanders:
nodes = node_expander.transform_dag(nodes, filter_config(config, node_expander), fn)
node_transformers = function_decorators[NodeTransformer.get_lifecycle_name()]
for dag_modifier in node_transformers:
Expand Down
4 changes: 4 additions & 0 deletions hamilton/function_modifiers/expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,7 @@ def __init__(self, **key_mapping: ParametrizedDependency):
This is the same as the input mapping in `@parameterize`.
"""
super(inject, self).__init__(**{parameterize.PLACEHOLDER_PARAM_NAME: key_mapping})

@classmethod
def runs_before_regular_expanders(cls) -> bool:
return True
44 changes: 44 additions & 0 deletions tests/function_modifiers/test_expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,50 @@ def contrived_function(
assert node_(int_value_not_injected=8, three_source=3, six_source=6) == 8 * (8 + 1) // 2


def test_inject_above_parameterize_resolves_nodes():
@function_modifiers.inject(params=source("my_func__params"))
@function_modifiers.parameterize(
my_func_a={"date_range": source("my_func_a_date_range")},
my_func_b={"date_range": source("my_func_b_date_range")},
)
def my_func(date_range: tuple[int, int], params: int) -> int:
return params + date_range[0] + date_range[1]

nodes = {node_.name: node_ for node_ in base.resolve_nodes(my_func, {})}

assert set(nodes) == {"my_func_a", "my_func_b"}
assert nodes["my_func_a"].callable(my_func_a_date_range=(2, 10), my_func__params=1) == 13
assert nodes["my_func_b"].callable(my_func_b_date_range=(3, 10), my_func__params=1) == 14


def test_parameterize_above_inject_resolves_nodes():
@function_modifiers.parameterize(
my_func_a={"date_range": source("my_func_a_date_range")},
my_func_b={"date_range": source("my_func_b_date_range")},
)
@function_modifiers.inject(params=source("my_func__params"))
def my_func(date_range: tuple[int, int], params: int) -> int:
return params + date_range[0] + date_range[1]

nodes = {node_.name: node_ for node_ in base.resolve_nodes(my_func, {})}

assert set(nodes) == {"my_func_a", "my_func_b"}
assert nodes["my_func_a"].callable(my_func_a_date_range=(2, 10), my_func__params=1) == 13
assert nodes["my_func_b"].callable(my_func_b_date_range=(3, 10), my_func__params=1) == 14


def test_inject_parameterize_overlap_fails():
@function_modifiers.inject(date_range=source("injected_date_range"))
@function_modifiers.parameterize(
my_func_a={"date_range": source("my_func_a_date_range")},
)
def my_func(date_range: tuple[int, int], params: int) -> int:
return params + date_range[0] + date_range[1]

with pytest.raises(base.InvalidDecoratorException, match="date_range"):
base.resolve_nodes(my_func, {})


@pytest.mark.parametrize(
("annotated_type", "cls", "expected"),
[
Expand Down