diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..6df0dd0718 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -344,36 +344,45 @@ async def _aggregate_components(self, server_info: types.Implementation, session tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} + # Per the lifecycle spec, only invoke methods for capabilities the + # server advertised during initialization. If the initialize result + # is missing, fall back to the prior unconditional behavior so the + # existing MCPError handler can still cope with servers that misbehave. + capabilities = session.initialize_result.capabilities if session.initialize_result is not None else None + # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch prompts: {err}") + if capabilities is None or capabilities.prompts is not None: + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch resources: {err}") + if capabilities is None or capabilities.resources is not None: + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch tools: {err}") + if capabilities is None or capabilities.tools is not None: + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch tools: {err}") # Clean up exit stack for session if we couldn't retrieve anything # from the server. diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..49eed91e5e 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -385,3 +385,46 @@ async def test_client_session_group_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session + + +@pytest.mark.anyio +@pytest.mark.parametrize("advertised", ["tools", "prompts", "resources"]) +async def test_client_session_group_skips_unsupported_capabilities(advertised: str): + """Only the capability the server advertised is queried during aggregation.""" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_session.initialize_result = types.InitializeResult( + protocol_version=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ServerCapabilities( + tools=types.ToolsCapability() if advertised == "tools" else None, + prompts=types.PromptsCapability() if advertised == "prompts" else None, + resources=types.ResourcesCapability() if advertised == "resources" else None, + ), + server_info=types.Implementation(name="srv", version="1"), + ) + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "tool_a" + mock_resource = mock.Mock(spec=types.Resource) + mock_resource.name = "resource_b" + mock_prompt = mock.Mock(spec=types.Prompt) + mock_prompt.name = "prompt_c" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt]) + + group = ClientSessionGroup() + await group.connect_with_session(types.Implementation(name="srv", version="1"), mock_session) + + list_methods = { + "tools": mock_session.list_tools, + "prompts": mock_session.list_prompts, + "resources": mock_session.list_resources, + } + for capability, list_method in list_methods.items(): + if capability == advertised: + list_method.assert_awaited_once() + else: + list_method.assert_not_awaited() + + assert group.tools == ({"tool_a": mock_tool} if advertised == "tools" else {}) + assert group.prompts == ({"prompt_c": mock_prompt} if advertised == "prompts" else {}) + assert group.resources == ({"resource_b": mock_resource} if advertised == "resources" else {})