Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ class BeforeNodeCallEvent(BaseHookEvent, _Interruptible):
source: The multi-agent orchestrator instance
node_id: ID of the node about to execute
invocation_state: Configuration that user passes in
cancel_node: A user defined message that when set, will cancel the node execution with status FAILED.
The message will be emitted under a MultiAgentNodeCancel event. If set to `True`, Strands will cancel the
node using a default cancel message.
cancel_node: A user defined message that when set, will skip the node and mark it as completed, allowing
downstream nodes to continue executing. The message will be emitted under a MultiAgentNodeCancel event.
If set to `True`, Strands will skip the node using a default cancel message.
"""

source: "MultiAgentBase"
Expand Down
20 changes: 18 additions & 2 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,9 +899,25 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
cancel_message = (
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
)
logger.debug("reason=<%s> | cancelling execution", cancel_message)
logger.debug("reason=<%s> | node skipped, graph continues", cancel_message)
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
raise RuntimeError(cancel_message)
node_result = NodeResult(
result=RuntimeError(cancel_message),
execution_time=0,
status=Status.COMPLETED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
accumulated_metrics=Metrics(latencyMs=0),
execution_count=0,
)
node.result = node_result
node.execution_time = 0
node.execution_status = Status.COMPLETED
self.state.completed_nodes.add(node)
self.state.results[node.node_id] = node_result
self.state.execution_order.append(node)
self._accumulate_metrics(node_result)
yield MultiAgentNodeStopEvent(node_id=node.node_id, node_result=node_result)
return

# Build node input from satisfied dependencies
node_input = self._build_node_input(node)
Expand Down
51 changes: 42 additions & 9 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,20 +2113,53 @@ def cancel_callback(event):
graph = builder.build()
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_callback)

stream = graph.stream_async("test task")

tru_cancel_event = None
with pytest.raises(RuntimeError, match=cancel_message):
async for event in stream:
if event.get("type") == "multiagent_node_cancel":
tru_cancel_event = event
async for event in graph.stream_async("test task"):
if event.get("type") == "multiagent_node_cancel":
tru_cancel_event = event

exp_cancel_event = MultiAgentNodeCancelEvent(node_id="test_agent", message=cancel_message)
assert tru_cancel_event == exp_cancel_event

tru_status = graph.state.status
exp_status = Status.FAILED
assert tru_status == exp_status
assert graph.state.status == Status.COMPLETED
assert any(n.node_id == "test_agent" for n in graph.state.completed_nodes)
assert "test_agent" in graph.state.results
agent.__call__.assert_not_called()


@pytest.mark.asyncio
async def test_graph_cancel_node_downstream_executes():
"""Downstream nodes must run after an upstream node is skipped via cancel_node."""
cancelled_nodes: list[str] = []

def cancel_step_a(event):
if event.node_id == "step_a":
event.cancel_node = "step_a skipped"
return event

step_a = create_mock_agent("step_a", "Should not run")
step_b = create_mock_agent("step_b", "Step B completed")

builder = GraphBuilder()
builder.add_node(step_a, "step_a")
builder.add_node(step_b, "step_b")
builder.add_edge("step_a", "step_b")
builder.set_entry_point("step_a")
graph = builder.build()
graph.hooks.add_callback(BeforeNodeCallEvent, cancel_step_a)

async for event in graph.stream_async("test task"):
if event.get("type") == "multiagent_node_cancel":
cancelled_nodes.append(event["node_id"])

assert cancelled_nodes == ["step_a"]
assert graph.state.status == Status.COMPLETED
step_a.__call__.assert_not_called()
step_b.__call__.assert_not_called() # stream_async uses stream_async on agent, not __call__
assert any(n.node_id == "step_a" for n in graph.state.completed_nodes)
assert any(n.node_id == "step_b" for n in graph.state.completed_nodes)
assert "step_a" in graph.state.results
assert "step_b" in graph.state.results


def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
Expand Down