diff --git a/src/utils/responses.py b/src/utils/responses.py index 4055eea6d..1dd16faa1 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -198,6 +198,37 @@ async def prepare_tools( return toolgroups +def _build_provider_data_headers( + tools: Optional[list[dict[str, Any]]], +) -> Optional[dict[str, str]]: + """Build extra HTTP headers containing MCP provider data for Llama Stack. + + Extracts per-server auth headers from MCP tool definitions and encodes + them as a JSON ``x-llamastack-provider-data`` header that Llama Stack + uses to authenticate with downstream MCP servers. + + Args: + tools: Prepared tool definitions (may include MCP and non-MCP tools). + + Returns: + Dict with a single ``x-llamastack-provider-data`` key, or None when + no MCP tools carry headers. + """ + if not tools: + return None + + mcp_headers: McpHeaders = { + tool["server_url"]: tool["headers"] + for tool in tools + if tool.get("type") == "mcp" and tool.get("headers") and tool.get("server_url") + } + + if not mcp_headers: + return None + + return {"x-llamastack-provider-data": json.dumps({"mcp_headers": mcp_headers})} + + async def prepare_responses_params( # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments client: AsyncLlamaStackClient, query_request: QueryRequest, @@ -281,6 +312,9 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma llama_stack_conv_id, ) + # Build x-llamastack-provider-data header from MCP tool headers + extra_headers = _build_provider_data_headers(tools) + return ResponsesApiParams( input=input_text, model=llama_stack_model_id, @@ -289,6 +323,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma conversation=llama_stack_conv_id, stream=stream, store=store, + extra_headers=extra_headers, ) diff --git a/src/utils/types.py b/src/utils/types.py index 75e797f07..4b7fdf2f3 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -123,6 +123,10 @@ class ResponsesApiParams(BaseModel): conversation: str = Field(description="The conversation ID in llama-stack format") stream: bool = Field(description="Whether to stream the response") store: bool = Field(description="Whether to store the response") + extra_headers: Optional[dict[str, str]] = Field( + default=None, + description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)", + ) class ToolCallSummary(BaseModel): diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index a17dca05c..23c45e3ad 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -874,6 +874,111 @@ async def test_prepare_responses_params_api_status_error_on_models( await prepare_responses_params(mock_client, query_request, None, "token") assert exc_info.value.status_code == 500 + @pytest.mark.asyncio + async def test_prepare_responses_params_includes_mcp_provider_data_headers( + self, mocker: MockerFixture + ) -> None: + """Test that extra_headers with x-llamastack-provider-data is set when MCP tools have headers.""" + mock_client = mocker.AsyncMock() + mock_model = mocker.Mock() + mock_model.id = "provider1/model1" + mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"} + mock_client.models.list = mocker.AsyncMock(return_value=[mock_model]) + + mock_conversation = mocker.Mock() + mock_conversation.id = "new_conv_id" + mock_client.conversations.create = mocker.AsyncMock( + return_value=mock_conversation + ) + + query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] + + # Simulate MCP tools with headers (as returned by prepare_tools/get_mcp_tools) + mcp_tools_with_headers = [ + { + "type": "mcp", + "server_label": "mcp::aap-controller", + "server_url": "http://aap.foo.redhat.com:8004/sse", + "require_approval": "never", + "headers": {"X-Authorization": "client-token"}, + }, + { + "type": "mcp", + "server_label": "mcp::aap-lightspeed", + "server_url": "http://aap.foo.redhat.com:8005/sse", + "require_approval": "never", + "headers": {"X-Authorization": "client-token-2"}, + }, + ] + + mocker.patch("utils.responses.configuration", mocker.Mock()) + mocker.patch( + "utils.responses.select_model_and_provider_id", + return_value=("provider1/model1", "model1", "provider1"), + ) + mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None)) + mocker.patch("utils.responses.get_system_prompt", return_value="System prompt") + mocker.patch( + "utils.responses.prepare_tools", return_value=mcp_tools_with_headers + ) + mocker.patch("utils.responses.prepare_input", return_value="test") + + result = await prepare_responses_params( + mock_client, query_request, None, "token" + ) + + # The result should contain extra_headers with x-llamastack-provider-data + dumped = result.model_dump() + assert ( + dumped["extra_headers"] is not None + ), "extra_headers should not be None when MCP tools have headers" + assert "x-llamastack-provider-data" in dumped["extra_headers"] + + provider_data = json.loads( + dumped["extra_headers"]["x-llamastack-provider-data"] + ) + assert "mcp_headers" in provider_data + assert provider_data["mcp_headers"] == { + "http://aap.foo.redhat.com:8004/sse": {"X-Authorization": "client-token"}, + "http://aap.foo.redhat.com:8005/sse": {"X-Authorization": "client-token-2"}, + } + + @pytest.mark.asyncio + async def test_prepare_responses_params_no_extra_headers_without_mcp_tools( + self, mocker: MockerFixture + ) -> None: + """Test that extra_headers is None when no MCP tools have headers.""" + mock_client = mocker.AsyncMock() + mock_model = mocker.Mock() + mock_model.id = "provider1/model1" + mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"} + mock_client.models.list = mocker.AsyncMock(return_value=[mock_model]) + + mock_conversation = mocker.Mock() + mock_conversation.id = "new_conv_id" + mock_client.conversations.create = mocker.AsyncMock( + return_value=mock_conversation + ) + + query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue] + + mocker.patch("utils.responses.configuration", mocker.Mock()) + mocker.patch( + "utils.responses.select_model_and_provider_id", + return_value=("provider1/model1", "model1", "provider1"), + ) + mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None)) + mocker.patch("utils.responses.get_system_prompt", return_value="System prompt") + mocker.patch("utils.responses.prepare_tools", return_value=None) + mocker.patch("utils.responses.prepare_input", return_value="test") + + result = await prepare_responses_params( + mock_client, query_request, None, "token" + ) + + dumped = result.model_dump() + assert dumped.get("extra_headers") is None + @pytest.mark.asyncio async def test_prepare_responses_params_api_status_error_on_conversation( self, mocker: MockerFixture