diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 80b50770a..78730536d 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -330,9 +330,9 @@ class BeforeNodeCallEvent(BaseHookEvent, _Interruptible): source: The multi-agent orchestrator instance node_id: ID of the node about to execute invocation_state: Configuration that user passes in - cancel_node: A user defined message that when set, will cancel the node execution with status FAILED. - The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the - node using a default cancel message. + cancel_node: A user defined message that when set, will skip the node and mark it as completed, allowing + downstream nodes to continue executing. The message will be emitted under a MultiAgentNodeCancel event. + If set to `True`, Strands will skip the node using a default cancel message. """ source: "MultiAgentBase" diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 8da8314ea..09e6f70f5 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -899,9 +899,25 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) cancel_message = ( before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user" ) - logger.debug("reason=<%s> | cancelling execution", cancel_message) + logger.debug("reason=<%s> | node skipped, graph continues", cancel_message) yield MultiAgentNodeCancelEvent(node.node_id, cancel_message) - raise RuntimeError(cancel_message) + node_result = NodeResult( + result=RuntimeError(cancel_message), + execution_time=0, + status=Status.COMPLETED, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=0), + execution_count=0, + ) + node.result = node_result + node.execution_time = 0 + node.execution_status = Status.COMPLETED + self.state.completed_nodes.add(node) + self.state.results[node.node_id] = node_result + self.state.execution_order.append(node) + self._accumulate_metrics(node_result) + yield MultiAgentNodeStopEvent(node_id=node.node_id, node_result=node_result) + return # Build node input from satisfied dependencies node_input = self._build_node_input(node) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index a6085627c..d63f018c7 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -2113,20 +2113,53 @@ def cancel_callback(event): graph = builder.build() graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback) - stream = graph.stream_async("test task") - tru_cancel_event = None - with pytest.raises(RuntimeError, match=cancel_message): - async for event in stream: - if event.get("type") == "multiagent_node_cancel": - tru_cancel_event = event + async for event in graph.stream_async("test task"): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_event = event exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message) assert tru_cancel_event == exp_cancel_event - tru_status = graph.state.status - exp_status = Status.FAILED - assert tru_status == exp_status + assert graph.state.status == Status.COMPLETED + assert any(n.node_id == "test_agent" for n in graph.state.completed_nodes) + assert "test_agent" in graph.state.results + agent.__call__.assert_not_called() + + +@pytest.mark.asyncio +async def test_graph_cancel_node_downstream_executes(): + """Downstream nodes must run after an upstream node is skipped via cancel_node.""" + cancelled_nodes: list[str] = [] + + def cancel_step_a(event): + if event.node_id == "step_a": + event.cancel_node = "step_a skipped" + return event + + step_a = create_mock_agent("step_a", "Should not run") + step_b = create_mock_agent("step_b", "Step B completed") + + builder = GraphBuilder() + builder.add_node(step_a, "step_a") + builder.add_node(step_b, "step_b") + builder.add_edge("step_a", "step_b") + builder.set_entry_point("step_a") + graph = builder.build() + graph.hooks.add_callback(BeforeNodeCallEvent, cancel_step_a) + + async for event in graph.stream_async("test task"): + if event.get("type") == "multiagent_node_cancel": + cancelled_nodes.append(event["node_id"]) + + assert cancelled_nodes == ["step_a"] + assert graph.state.status == Status.COMPLETED + step_a.__call__.assert_not_called() + step_b.__call__.assert_not_called() # stream_async uses stream_async on agent, not __call__ + assert any(n.node_id == "step_a" for n in graph.state.completed_nodes) + assert any(n.node_id == "step_b" for n in graph.state.completed_nodes) + assert "step_a" in graph.state.results + assert "step_b" in graph.state.results def test_graph_interrupt_on_before_node_call_event(interrupt_hook):