Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,37 @@ async def prepare_tools(
return toolgroups


def _build_provider_data_headers(
tools: Optional[list[dict[str, Any]]],
) -> Optional[dict[str, str]]:
"""Build extra HTTP headers containing MCP provider data for Llama Stack.

Extracts per-server auth headers from MCP tool definitions and encodes
them as a JSON ``x-llamastack-provider-data`` header that Llama Stack
uses to authenticate with downstream MCP servers.

Args:
tools: Prepared tool definitions (may include MCP and non-MCP tools).

Returns:
Dict with a single ``x-llamastack-provider-data`` key, or None when
no MCP tools carry headers.
"""
if not tools:
return None

mcp_headers: McpHeaders = {
tool["server_url"]: tool["headers"]
for tool in tools
if tool.get("type") == "mcp" and tool.get("headers") and tool.get("server_url")
}

if not mcp_headers:
return None

return {"x-llamastack-provider-data": json.dumps({"mcp_headers": mcp_headers})}


async def prepare_responses_params( # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
client: AsyncLlamaStackClient,
query_request: QueryRequest,
Expand Down Expand Up @@ -281,6 +312,9 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
llama_stack_conv_id,
)

# Build x-llamastack-provider-data header from MCP tool headers
extra_headers = _build_provider_data_headers(tools)

return ResponsesApiParams(
input=input_text,
model=llama_stack_model_id,
Expand All @@ -289,6 +323,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
conversation=llama_stack_conv_id,
stream=stream,
store=store,
extra_headers=extra_headers,
)


Expand Down
4 changes: 4 additions & 0 deletions src/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class ResponsesApiParams(BaseModel):
conversation: str = Field(description="The conversation ID in llama-stack format")
stream: bool = Field(description="Whether to stream the response")
store: bool = Field(description="Whether to store the response")
extra_headers: Optional[dict[str, str]] = Field(
default=None,
description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)",
)


class ToolCallSummary(BaseModel):
Expand Down
105 changes: 105 additions & 0 deletions tests/unit/utils/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,111 @@ async def test_prepare_responses_params_api_status_error_on_models(
await prepare_responses_params(mock_client, query_request, None, "token")
assert exc_info.value.status_code == 500

@pytest.mark.asyncio
async def test_prepare_responses_params_includes_mcp_provider_data_headers(
self, mocker: MockerFixture
) -> None:
"""Test that extra_headers with x-llamastack-provider-data is set when MCP tools have headers."""
mock_client = mocker.AsyncMock()
mock_model = mocker.Mock()
mock_model.id = "provider1/model1"
mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"}
mock_client.models.list = mocker.AsyncMock(return_value=[mock_model])

mock_conversation = mocker.Mock()
mock_conversation.id = "new_conv_id"
mock_client.conversations.create = mocker.AsyncMock(
return_value=mock_conversation
)

query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]

# Simulate MCP tools with headers (as returned by prepare_tools/get_mcp_tools)
mcp_tools_with_headers = [
{
"type": "mcp",
"server_label": "mcp::aap-controller",
"server_url": "http://aap.foo.redhat.com:8004/sse",
"require_approval": "never",
"headers": {"X-Authorization": "client-token"},
},
{
"type": "mcp",
"server_label": "mcp::aap-lightspeed",
"server_url": "http://aap.foo.redhat.com:8005/sse",
"require_approval": "never",
"headers": {"X-Authorization": "client-token-2"},
},
]

mocker.patch("utils.responses.configuration", mocker.Mock())
mocker.patch(
"utils.responses.select_model_and_provider_id",
return_value=("provider1/model1", "model1", "provider1"),
)
mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None))
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch(
"utils.responses.prepare_tools", return_value=mcp_tools_with_headers
)
mocker.patch("utils.responses.prepare_input", return_value="test")

result = await prepare_responses_params(
mock_client, query_request, None, "token"
)

# The result should contain extra_headers with x-llamastack-provider-data
dumped = result.model_dump()
assert (
dumped["extra_headers"] is not None
), "extra_headers should not be None when MCP tools have headers"
assert "x-llamastack-provider-data" in dumped["extra_headers"]

provider_data = json.loads(
dumped["extra_headers"]["x-llamastack-provider-data"]
)
assert "mcp_headers" in provider_data
assert provider_data["mcp_headers"] == {
"http://aap.foo.redhat.com:8004/sse": {"X-Authorization": "client-token"},
"http://aap.foo.redhat.com:8005/sse": {"X-Authorization": "client-token-2"},
}

@pytest.mark.asyncio
async def test_prepare_responses_params_no_extra_headers_without_mcp_tools(
self, mocker: MockerFixture
) -> None:
"""Test that extra_headers is None when no MCP tools have headers."""
mock_client = mocker.AsyncMock()
mock_model = mocker.Mock()
mock_model.id = "provider1/model1"
mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"}
mock_client.models.list = mocker.AsyncMock(return_value=[mock_model])

mock_conversation = mocker.Mock()
mock_conversation.id = "new_conv_id"
mock_client.conversations.create = mocker.AsyncMock(
return_value=mock_conversation
)

query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]

mocker.patch("utils.responses.configuration", mocker.Mock())
mocker.patch(
"utils.responses.select_model_and_provider_id",
return_value=("provider1/model1", "model1", "provider1"),
)
mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None))
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch("utils.responses.prepare_tools", return_value=None)
mocker.patch("utils.responses.prepare_input", return_value="test")

result = await prepare_responses_params(
mock_client, query_request, None, "token"
)

dumped = result.model_dump()
assert dumped.get("extra_headers") is None

@pytest.mark.asyncio
async def test_prepare_responses_params_api_status_error_on_conversation(
self, mocker: MockerFixture
Expand Down
Loading