Skip to content

Commit a0c5aba

Browse files
committed
Add default Origin for streamable HTTP client
1 parent 616476f commit a0c5aba

2 files changed

Lines changed: 71 additions & 3 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections.abc import AsyncGenerator, Awaitable, Callable
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass
10+
from urllib.parse import urlsplit
1011

1112
import anyio
1213
import httpx
@@ -50,6 +51,15 @@
5051
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
5152

5253

54+
def _get_default_origin(url: str) -> str | None:
55+
parsed_url = urlsplit(url)
56+
if parsed_url.scheme not in {"http", "https"} or not parsed_url.netloc:
57+
return None
58+
59+
authority = parsed_url.netloc.rsplit("@", 1)[-1]
60+
return f"{parsed_url.scheme}://{authority}"
61+
62+
5363
class StreamableHTTPError(Exception):
5464
"""Base exception for StreamableHTTP transport errors."""
5565

@@ -72,13 +82,16 @@ class RequestContext:
7282
class StreamableHTTPTransport:
7383
"""StreamableHTTP client transport implementation."""
7484

75-
def __init__(self, url: str) -> None:
85+
def __init__(self, url: str, default_origin: str | None = None) -> None:
7686
"""Initialize the StreamableHTTP transport.
7787
7888
Args:
7989
url: The endpoint URL.
90+
default_origin: Origin header to include when the caller has not
91+
configured one on the HTTP client.
8092
"""
8193
self.url = url
94+
self.default_origin = default_origin
8295
self.session_id: str | None = None
8396
self.protocol_version: str | None = None
8497

@@ -92,6 +105,8 @@ def _prepare_headers(self) -> dict[str, str]:
92105
"accept": "application/json, text/event-stream",
93106
"content-type": "application/json",
94107
}
108+
if self.default_origin:
109+
headers["Origin"] = self.default_origin
95110
# Add session headers if available
96111
if self.session_id:
97112
headers[MCP_SESSION_ID] = self.session_id
@@ -547,7 +562,8 @@ async def streamable_http_client(
547562
# Create default client with recommended MCP timeouts
548563
client = create_mcp_http_client()
549564

550-
transport = StreamableHTTPTransport(url)
565+
default_origin = None if "origin" in client.headers else _get_default_origin(url)
566+
transport = StreamableHTTPTransport(url, default_origin=default_origin)
551567

552568
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
553569

tests/shared/test_streamable_http.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from mcp import MCPError, types
3131
from mcp.client.session import ClientSession
32-
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
32+
from mcp.client.streamable_http import StreamableHTTPTransport, _get_default_origin, streamable_http_client
3333
from mcp.server import Server, ServerRequestContext
3434
from mcp.server.streamable_http import (
3535
MCP_PROTOCOL_VERSION_HEADER,
@@ -767,6 +767,58 @@ def test_streamable_http_transport_init_validation():
767767
StreamableHTTPServerTransport(mcp_session_id="test\n")
768768

769769

770+
def test_get_default_origin_derives_origin_from_url():
771+
assert _get_default_origin("https://example.com:8443/mcp?token=abc") == "https://example.com:8443"
772+
assert _get_default_origin("http://user:pass@[::1]:8080/mcp") == "http://[::1]:8080"
773+
774+
775+
@pytest.mark.anyio
776+
async def test_streamable_http_client_sets_default_origin_on_http_client():
777+
recorded_headers: list[httpx.Headers] = []
778+
779+
def handler(request: httpx.Request) -> httpx.Response:
780+
recorded_headers.append(request.headers)
781+
return httpx.Response(202, request=request)
782+
783+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
784+
async with streamable_http_client("https://mcp.example.com:8443/mcp", http_client=client) as (
785+
_read_stream,
786+
write_stream,
787+
):
788+
await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")))
789+
with anyio.fail_after(1):
790+
while not recorded_headers:
791+
await anyio.sleep(0.01)
792+
793+
assert recorded_headers[0]["origin"] == "https://mcp.example.com:8443"
794+
assert "origin" not in client.headers
795+
796+
797+
@pytest.mark.anyio
798+
async def test_streamable_http_client_preserves_custom_origin_header():
799+
recorded_headers: list[httpx.Headers] = []
800+
801+
def handler(request: httpx.Request) -> httpx.Response:
802+
recorded_headers.append(request.headers)
803+
return httpx.Response(202, request=request)
804+
805+
async with httpx.AsyncClient(
806+
headers={"Origin": "https://proxy.example"},
807+
transport=httpx.MockTransport(handler),
808+
) as client:
809+
async with streamable_http_client("https://mcp.example.com/mcp", http_client=client) as (
810+
_read_stream,
811+
write_stream,
812+
):
813+
await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")))
814+
with anyio.fail_after(1):
815+
while not recorded_headers:
816+
await anyio.sleep(0.01)
817+
818+
assert recorded_headers[0]["origin"] == "https://proxy.example"
819+
assert client.headers["origin"] == "https://proxy.example"
820+
821+
770822
def test_session_termination(basic_server: None, basic_server_url: str):
771823
"""Test session termination via DELETE and subsequent request handling."""
772824
response = requests.post(

0 commit comments

Comments
 (0)