diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index c2c6e874f1..2df86a4202 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -548,19 +548,33 @@ async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str: ctx: the function invocation context used **kwargs: only used to dynamically load the argument that is defined for this tool. """ - stream = self.run( - str(kwargs.get(arg_name, "")), - stream=True, - session=ctx.session if propagate_session else None, - function_invocation_kwargs=dict(ctx.kwargs), - ) - if stream_callback is not None: - stream.with_transform_hook(stream_callback) - final_response = await stream.get_final_response() - if final_response.user_input_requests: - raise UserInputRequiredException(contents=final_response.user_input_requests) - # TODO(Copilot): update once #4331 merges - return final_response.text + session = ctx.session if propagate_session else None + + # Isolate the child agent from the parent's server-side conversation. + # service_session_id would cause the child to send previous_response_id + # referencing the parent's pending tool_call, resulting in a 400 error. + saved_service_session_id = None + if session is not None and session.service_session_id is not None: + saved_service_session_id = session.service_session_id + session.service_session_id = None + + try: + stream = self.run( + str(kwargs.get(arg_name, "")), + stream=True, + session=session, + function_invocation_kwargs=dict(ctx.kwargs), + ) + if stream_callback is not None: + stream.with_transform_hook(stream_callback) + final_response = await stream.get_final_response() + if final_response.user_input_requests: + raise UserInputRequiredException(contents=final_response.user_input_requests) + # TODO(Copilot): update once #4331 merges + return final_response.text + finally: + if session is not None and saved_service_session_id is not None: + session.service_session_id = saved_service_session_id return FunctionTool( name=tool_name, diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index cab55196f8..9072669f50 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1009,6 +1009,99 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: assert parent_session.state["counter"] == 1 +async def test_chat_agent_as_tool_propagate_session_clears_service_session_id(client: SupportsChatGetResponse) -> None: + """Test that propagate_session=True clears service_session_id for the child and restores it after.""" + agent = Agent(client=client, name="SubAgent", description="Sub agent") + tool = agent.as_tool(propagate_session=True) + + parent_session = AgentSession(session_id="shared-session") + parent_session.service_session_id = "resp_parent_abc123" + parent_session.state["data"] = "shared" + + original_run = agent.run + captured_session = None + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_session + captured_session = kwargs.get("session") + # The child should see the same session object but with service_session_id cleared + assert captured_session is parent_session + assert captured_session.service_session_id is None + assert captured_session.state["data"] == "shared" + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) + + # After the child finishes, service_session_id is restored + assert parent_session.service_session_id == "resp_parent_abc123" + + +async def test_chat_agent_as_tool_propagate_session_restores_service_session_id_on_error( + client: SupportsChatGetResponse, +) -> None: + """Test that service_session_id is restored even if the child agent raises.""" + agent = Agent(client=client, name="SubAgent", description="Sub agent") + tool = agent.as_tool(propagate_session=True) + + parent_session = AgentSession(session_id="shared-session") + parent_session.service_session_id = "resp_parent_xyz789" + + def failing_run(*args: Any, **kwargs: Any) -> Any: + raise RuntimeError("Child agent failed") + + agent.run = failing_run # type: ignore[assignment, method-assign] + + with raises(RuntimeError, match="Child agent failed"): + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) + + # service_session_id must be restored even after failure + assert parent_session.service_session_id == "resp_parent_xyz789" + + +async def test_chat_agent_as_tool_propagate_session_no_service_session_id(client: SupportsChatGetResponse) -> None: + """Test that when service_session_id is None, no save/restore is needed.""" + agent = Agent(client=client, name="SubAgent", description="Sub agent") + tool = agent.as_tool(propagate_session=True) + + parent_session = AgentSession(session_id="shared-session") + parent_session.service_session_id = None + + original_run = agent.run + captured_session = None + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_session + captured_session = kwargs.get("session") + assert captured_session.service_session_id is None + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) + + assert parent_session.service_session_id is None + + async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None: """Test basic as_mcp_server functionality.""" agent = Agent(client=client, name="TestAgent", description="Test agent for MCP")