Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,19 +548,33 @@ async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str:
ctx: the function invocation context used
**kwargs: only used to dynamically load the argument that is defined for this tool.
"""
stream = self.run(
str(kwargs.get(arg_name, "")),
stream=True,
session=ctx.session if propagate_session else None,
function_invocation_kwargs=dict(ctx.kwargs),
)
if stream_callback is not None:
stream.with_transform_hook(stream_callback)
final_response = await stream.get_final_response()
if final_response.user_input_requests:
raise UserInputRequiredException(contents=final_response.user_input_requests)
# TODO(Copilot): update once #4331 merges
return final_response.text
session = ctx.session if propagate_session else None

# Isolate the child agent from the parent's server-side conversation.
# service_session_id would cause the child to send previous_response_id
# referencing the parent's pending tool_call, resulting in a 400 error.
saved_service_session_id = None
if session is not None and session.service_session_id is not None:
saved_service_session_id = session.service_session_id
session.service_session_id = None

Comment on lines +556 to +560
try:
stream = self.run(
str(kwargs.get(arg_name, "")),
stream=True,
session=session,
function_invocation_kwargs=dict(ctx.kwargs),
)
if stream_callback is not None:
stream.with_transform_hook(stream_callback)
final_response = await stream.get_final_response()
if final_response.user_input_requests:
raise UserInputRequiredException(contents=final_response.user_input_requests)
# TODO(Copilot): update once #4331 merges
return final_response.text
finally:
if session is not None and saved_service_session_id is not None:
session.service_session_id = saved_service_session_id

return FunctionTool(
name=tool_name,
Expand Down
93 changes: 93 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,99 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any:
assert parent_session.state["counter"] == 1


async def test_chat_agent_as_tool_propagate_session_clears_service_session_id(client: SupportsChatGetResponse) -> None:
"""Test that propagate_session=True clears service_session_id for the child and restores it after."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool(propagate_session=True)

parent_session = AgentSession(session_id="shared-session")
parent_session.service_session_id = "resp_parent_abc123"
parent_session.state["data"] = "shared"

original_run = agent.run
captured_session = None

def capturing_run(*args: Any, **kwargs: Any) -> Any:
nonlocal captured_session
captured_session = kwargs.get("session")
# The child should see the same session object but with service_session_id cleared
assert captured_session is parent_session
assert captured_session.service_session_id is None
assert captured_session.state["data"] == "shared"
return original_run(*args, **kwargs)

agent.run = capturing_run # type: ignore[assignment, method-assign]

await tool.invoke(
context=FunctionInvocationContext(
function=tool,
arguments={"task": "Hello"},
session=parent_session,
)
)

# After the child finishes, service_session_id is restored
assert parent_session.service_session_id == "resp_parent_abc123"


async def test_chat_agent_as_tool_propagate_session_restores_service_session_id_on_error(
client: SupportsChatGetResponse,
) -> None:
"""Test that service_session_id is restored even if the child agent raises."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool(propagate_session=True)

parent_session = AgentSession(session_id="shared-session")
parent_session.service_session_id = "resp_parent_xyz789"

def failing_run(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError("Child agent failed")

agent.run = failing_run # type: ignore[assignment, method-assign]

with raises(RuntimeError, match="Child agent failed"):
await tool.invoke(
context=FunctionInvocationContext(
function=tool,
arguments={"task": "Hello"},
session=parent_session,
)
)

# service_session_id must be restored even after failure
assert parent_session.service_session_id == "resp_parent_xyz789"


async def test_chat_agent_as_tool_propagate_session_no_service_session_id(client: SupportsChatGetResponse) -> None:
"""Test that when service_session_id is None, no save/restore is needed."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool(propagate_session=True)

Comment on lines +1075 to +1079
parent_session = AgentSession(session_id="shared-session")
parent_session.service_session_id = None

original_run = agent.run
captured_session = None

def capturing_run(*args: Any, **kwargs: Any) -> Any:
nonlocal captured_session
captured_session = kwargs.get("session")
assert captured_session.service_session_id is None
return original_run(*args, **kwargs)

agent.run = capturing_run # type: ignore[assignment, method-assign]

await tool.invoke(
context=FunctionInvocationContext(
function=tool,
arguments={"task": "Hello"},
session=parent_session,
)
)

assert parent_session.service_session_id is None


async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None:
"""Test basic as_mcp_server functionality."""
agent = Agent(client=client, name="TestAgent", description="Test agent for MCP")
Expand Down
Loading