diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 35ccb1d58a..d1c9ea8dfb 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -1592,7 +1592,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: Returns: An async context manager for the streamable HTTP client transport. """ - from httpx import AsyncClient, Request, Timeout + from httpx import URL, AsyncClient, Request, Timeout http_client = self._httpx_client if self._header_provider is not None: @@ -1604,8 +1604,17 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: self._httpx_client = http_client if not hasattr(self, "_inject_headers_hook"): + mcp_origin = URL(self.url) + + def _origin(url: URL) -> tuple[str, str, int | None]: + port = url.port + if port is None: + port = {"http": 80, "https": 443}.get(url.scheme) + return (url.scheme, url.host or "", port) async def _inject_headers(request: Request) -> None: # noqa: RUF029 + if _origin(request.url) != _origin(mcp_origin): + return headers = _mcp_call_headers.get({}) for key, value in headers.items(): request.headers[key] = value diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 0fc5867d79..b973645844 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -4497,6 +4497,49 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook(): await tool._httpx_client.aclose() +async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect(): + """header_provider headers should not be re-applied to redirected origins.""" + import httpx + + from agent_framework._mcp import _mcp_call_headers + + captured: list[tuple[str, httpx.Headers]] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + captured.append((str(request.url), request.headers.copy())) + if request.url.host == "example.com": + return httpx.Response(302, headers={"Location": "http://attacker.example/capture"}) + return httpx.Response(204) + + user_client = httpx.AsyncClient(transport=httpx.MockTransport(handler), follow_redirects=True) + tool = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + http_client=user_client, + header_provider=lambda kw: {"Authorization": kw.get("auth", "")}, + ) + + try: + with patch("agent_framework._mcp.streamable_http_client"): + tool.get_mcp_client() + + token = _mcp_call_headers.set({"Authorization": "Bearer test-token"}) + try: + response = await user_client.post("http://example.com/mcp") + finally: + _mcp_call_headers.reset(token) + + assert response.status_code == 204 + assert [url for url, _ in captured] == [ + "http://example.com/mcp", + "http://attacker.example/capture", + ] + assert captured[0][1].get("Authorization") == "Bearer test-token" + assert captured[1][1].get("Authorization") is None + finally: + await user_client.aclose() + + async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client(): """Test that header_provider works when the user provides their own httpx client.""" import httpx