Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def __init__(
# Create internal cancel signal for graceful cancellation using threading.Event
self._cancel_signal = threading.Event()

self._invocations = 0
self._invocations_lock = threading.Lock()

self.tool_registry = ToolRegistry()

# Process tool list if provided
Expand Down Expand Up @@ -368,6 +371,14 @@ def __init__(

self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))

def _track_invocation(self) -> None:
with self._invocations_lock:
self._invocations += 1

def _untrack_invocation(self) -> None:
with self._invocations_lock:
self._invocations -= 1

def cancel(self) -> None:
"""Cancel the currently running agent invocation.

Expand Down Expand Up @@ -397,7 +408,9 @@ def cancel(self) -> None:
Note:
Multiple calls to cancel() are safe and idempotent.
"""
self._cancel_signal.set()
with self._invocations_lock:
if self._invocations > 0:
self._cancel_signal.set()

@property
def system_prompt(self) -> str | None:
Expand Down Expand Up @@ -982,6 +995,7 @@ async def _execute_event_loop_cycle(
structured_output_context.register_tool(self.tool_registry)

try:
self._track_invocation()
events = event_loop_cycle(
agent=self,
invocation_state=invocation_state,
Expand All @@ -1003,6 +1017,7 @@ async def _execute_event_loop_cycle(
yield event

finally:
self._untrack_invocation()
if structured_output_context:
structured_output_context.cleanup(self.tool_registry)

Expand Down
69 changes: 36 additions & 33 deletions tests/strands/agent/test_agent_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@
}


class DelayedModelProvider(MockedModelProvider):
"""Model provider that blocks streaming until signaled, for cancel timing tests."""

def __init__(self, responses: list) -> None:
super().__init__(responses)
self.streaming_started = asyncio.Event()
self.cancel_ready = asyncio.Event()

async def stream(self, *args, **kwargs):
self.streaming_started.set()
await self.cancel_ready.wait()
async for event in super().stream(*args, **kwargs):
yield event


@pytest.mark.asyncio
async def test_agent_cancel_before_invocation():
"""Test agent.cancel() before invocation starts.

Verifies that calling cancel() before invoke_async() results in
immediate cancellation without any model calls.
Verifies that calling cancel() before invoke_async() (agent runs) has no effect.
"""
agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE]))

Expand All @@ -31,8 +45,8 @@ async def test_agent_cancel_before_invocation():

result = await agent.invoke_async("Hello")

assert result.stop_reason == "cancelled"
assert result.message == {"role": "assistant", "content": [{"text": "Cancelled by user"}], "metadata": ANY}
assert result.stop_reason == "end_turn"
assert result.message == {**DEFAULT_RESPONSE, "metadata": ANY}


@pytest.mark.asyncio
Expand All @@ -42,23 +56,13 @@ async def test_agent_cancel_during_execution():
Verifies that calling cancel() while the agent is running
stops execution at the next checkpoint.
"""
streaming_started = asyncio.Event()
cancel_ready = asyncio.Event()

class DelayedModelProvider(MockedModelProvider):
async def stream(self, *args, **kwargs):
streaming_started.set()
# Block until cancel has been called
await cancel_ready.wait()
async for event in super().stream(*args, **kwargs):
yield event

agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE]))
model = DelayedModelProvider([DEFAULT_RESPONSE])
agent = Agent(model=model)

async def cancel_when_ready():
await streaming_started.wait()
await model.streaming_started.wait()
agent.cancel()
cancel_ready.set()
model.cancel_ready.set()

cancel_task = asyncio.create_task(cancel_when_ready())
result = await agent.invoke_async("Hello")
Expand Down Expand Up @@ -128,7 +132,7 @@ async def test_agent_cancel_idempotent():

result = await agent.invoke_async("Hello")

assert result.stop_reason == "cancelled"
assert result.stop_reason == "end_turn"


@pytest.mark.asyncio
Expand All @@ -138,24 +142,16 @@ async def test_agent_cancel_from_thread():
Verifies thread-safety of the cancel() method when called
from a background thread.
"""
streaming_started = asyncio.Event()
cancel_ready = asyncio.Event()
loop = asyncio.get_running_loop()

class DelayedModelProvider(MockedModelProvider):
async def stream(self, *args, **kwargs):
streaming_started.set()
await cancel_ready.wait()
async for event in super().stream(*args, **kwargs):
yield event

agent = Agent(model=DelayedModelProvider([DEFAULT_RESPONSE]))
model = DelayedModelProvider([DEFAULT_RESPONSE])
agent = Agent(model=model)

def cancel_from_thread():
# Wait for streaming to start before cancelling
asyncio.run_coroutine_threadsafe(streaming_started.wait(), loop).result()
asyncio.run_coroutine_threadsafe(model.streaming_started.wait(), loop).result()
agent.cancel()
loop.call_soon_threadsafe(cancel_ready.set)
loop.call_soon_threadsafe(model.cancel_ready.set)

thread = threading.Thread(target=cancel_from_thread)
thread.start()
Expand Down Expand Up @@ -279,10 +275,17 @@ async def test_agent_cancel_continue_after():
Verifies that the cancel signal is cleared after an invocation completes,
allowing subsequent invocations to run normally.
"""
agent = Agent(model=MockedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE]))
model = DelayedModelProvider([DEFAULT_RESPONSE, DEFAULT_RESPONSE])
agent = Agent(model=model)

agent.cancel()
async def cancel_when_ready():
await model.streaming_started.wait()
agent.cancel()
model.cancel_ready.set()

cancel_task = asyncio.create_task(cancel_when_ready())
result1 = await agent.invoke_async("Hello")
await cancel_task
assert result1.stop_reason == "cancelled"

# Second invocation should work normally
Expand Down