|
29 | 29 |
|
30 | 30 | from mcp import MCPError, types |
31 | 31 | from mcp.client.session import ClientSession |
32 | | -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client |
| 32 | +from mcp.client.streamable_http import StreamableHTTPTransport, _get_default_origin, streamable_http_client |
33 | 33 | from mcp.server import Server, ServerRequestContext |
34 | 34 | from mcp.server.streamable_http import ( |
35 | 35 | MCP_PROTOCOL_VERSION_HEADER, |
@@ -767,6 +767,58 @@ def test_streamable_http_transport_init_validation(): |
767 | 767 | StreamableHTTPServerTransport(mcp_session_id="test\n") |
768 | 768 |
|
769 | 769 |
|
| 770 | +def test_get_default_origin_derives_origin_from_url(): |
| 771 | + assert _get_default_origin("https://example.com:8443/mcp?token=abc") == "https://example.com:8443" |
| 772 | + assert _get_default_origin("http://user:pass@[::1]:8080/mcp") == "http://[::1]:8080" |
| 773 | + |
| 774 | + |
| 775 | +@pytest.mark.anyio |
| 776 | +async def test_streamable_http_client_sets_default_origin_on_http_client(): |
| 777 | + recorded_headers: list[httpx.Headers] = [] |
| 778 | + |
| 779 | + def handler(request: httpx.Request) -> httpx.Response: |
| 780 | + recorded_headers.append(request.headers) |
| 781 | + return httpx.Response(202, request=request) |
| 782 | + |
| 783 | + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client: |
| 784 | + async with streamable_http_client("https://mcp.example.com:8443/mcp", http_client=client) as ( |
| 785 | + _read_stream, |
| 786 | + write_stream, |
| 787 | + ): |
| 788 | + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) |
| 789 | + with anyio.fail_after(1): |
| 790 | + while not recorded_headers: |
| 791 | + await anyio.sleep(0.01) |
| 792 | + |
| 793 | + assert recorded_headers[0]["origin"] == "https://mcp.example.com:8443" |
| 794 | + assert "origin" not in client.headers |
| 795 | + |
| 796 | + |
| 797 | +@pytest.mark.anyio |
| 798 | +async def test_streamable_http_client_preserves_custom_origin_header(): |
| 799 | + recorded_headers: list[httpx.Headers] = [] |
| 800 | + |
| 801 | + def handler(request: httpx.Request) -> httpx.Response: |
| 802 | + recorded_headers.append(request.headers) |
| 803 | + return httpx.Response(202, request=request) |
| 804 | + |
| 805 | + async with httpx.AsyncClient( |
| 806 | + headers={"Origin": "https://proxy.example"}, |
| 807 | + transport=httpx.MockTransport(handler), |
| 808 | + ) as client: |
| 809 | + async with streamable_http_client("https://mcp.example.com/mcp", http_client=client) as ( |
| 810 | + _read_stream, |
| 811 | + write_stream, |
| 812 | + ): |
| 813 | + await write_stream.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) |
| 814 | + with anyio.fail_after(1): |
| 815 | + while not recorded_headers: |
| 816 | + await anyio.sleep(0.01) |
| 817 | + |
| 818 | + assert recorded_headers[0]["origin"] == "https://proxy.example" |
| 819 | + assert client.headers["origin"] == "https://proxy.example" |
| 820 | + |
| 821 | + |
770 | 822 | def test_session_termination(basic_server: None, basic_server_url: str): |
771 | 823 | """Test session termination via DELETE and subsequent request handling.""" |
772 | 824 | response = requests.post( |
|
0 commit comments