|
23 | 23 | from starlette.requests import Request |
24 | 24 | from starlette.routing import Mount |
25 | 25 |
|
| 26 | +import mcp.client.streamable_http as streamable_http_module |
26 | 27 | from mcp import MCPError, types |
27 | 28 | from mcp.client.session import ClientSession |
28 | | -from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client |
| 29 | +from mcp.client.streamable_http import ( |
| 30 | + RequestContext as StreamableHTTPClientRequestContext, |
| 31 | +) |
| 32 | +from mcp.client.streamable_http import ( |
| 33 | + StreamableHTTPTransport, |
| 34 | + streamable_http_client, |
| 35 | +) |
29 | 36 | from mcp.server import Server, ServerRequestContext |
30 | 37 | from mcp.server.streamable_http import ( |
31 | 38 | MCP_PROTOCOL_VERSION_HEADER, |
|
41 | 48 | from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
42 | 49 | from mcp.server.transport_security import TransportSecuritySettings |
43 | 50 | from mcp.shared._context import RequestContext |
44 | | -from mcp.shared._context_streams import create_context_streams |
| 51 | +from mcp.shared._context_streams import ContextSendStream, create_context_streams |
| 52 | +from mcp.shared._httpx_utils import ( |
| 53 | + MCP_DEFAULT_SSE_READ_TIMEOUT, |
| 54 | + MCP_DEFAULT_TIMEOUT, |
| 55 | + create_mcp_http_client, |
| 56 | +) |
45 | 57 | from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage |
46 | 58 | from mcp.shared.session import RequestResponder |
47 | 59 | from mcp.types import ( |
@@ -1583,6 +1595,150 @@ async def test_handle_sse_event_skips_empty_data() -> None: |
1583 | 1595 | await read_stream.aclose() |
1584 | 1596 |
|
1585 | 1597 |
|
| 1598 | +class _FakeStreamResponse(httpx.Response): |
| 1599 | + def __init__(self) -> None: |
| 1600 | + super().__init__(200, request=httpx.Request("POST", "http://localhost:8000/mcp")) |
| 1601 | + self.closed_by_transport = False |
| 1602 | + |
| 1603 | + async def aclose(self) -> None: # pragma: no cover |
| 1604 | + self.closed_by_transport = True |
| 1605 | + await super().aclose() |
| 1606 | + |
| 1607 | + |
| 1608 | +def _response_sse(request_id: int | str) -> ServerSentEvent: |
| 1609 | + return ServerSentEvent( |
| 1610 | + event="message", |
| 1611 | + data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}), |
| 1612 | + id="response-event", |
| 1613 | + ) |
| 1614 | + |
| 1615 | + |
| 1616 | +def _make_streamable_http_request_context( |
| 1617 | + request_id: int | str, |
| 1618 | + client: httpx.AsyncClient, |
| 1619 | + write_stream: ContextSendStream[SessionMessage | Exception], |
| 1620 | +) -> StreamableHTTPClientRequestContext: |
| 1621 | + return StreamableHTTPClientRequestContext( |
| 1622 | + client=client, |
| 1623 | + session_id=None, |
| 1624 | + session_message=SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method="tools/list")), |
| 1625 | + metadata=None, |
| 1626 | + read_stream_writer=write_stream, |
| 1627 | + ) |
| 1628 | + |
| 1629 | + |
| 1630 | +@pytest.mark.anyio |
| 1631 | +async def test_sse_response_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch): |
| 1632 | + """Terminal POST SSE responses are drained instead of force-closed.""" |
| 1633 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1634 | + response = _FakeStreamResponse() |
| 1635 | + |
| 1636 | + class FakeEventSource: |
| 1637 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1638 | + self.response = response |
| 1639 | + |
| 1640 | + async def aiter_sse(self): |
| 1641 | + yield _response_sse(1) |
| 1642 | + yield ServerSentEvent(event="message", data="", id="drained-event") |
| 1643 | + |
| 1644 | + async def fail_reconnect(*args: Any, **kwargs: Any) -> None: # pragma: no cover |
| 1645 | + raise AssertionError("terminal responses should not reconnect after draining") |
| 1646 | + |
| 1647 | + monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource) |
| 1648 | + monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect) |
| 1649 | + |
| 1650 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1651 | + async with httpx.AsyncClient() as client: |
| 1652 | + try: |
| 1653 | + ctx = _make_streamable_http_request_context(1, client, write_stream) |
| 1654 | + await transport._handle_sse_response(response, ctx) |
| 1655 | + |
| 1656 | + assert response.closed_by_transport is False |
| 1657 | + message = await read_stream.receive() |
| 1658 | + assert isinstance(message, SessionMessage) |
| 1659 | + assert isinstance(message.message, types.JSONRPCResponse) |
| 1660 | + assert message.message.id == 1 |
| 1661 | + finally: |
| 1662 | + await write_stream.aclose() |
| 1663 | + await read_stream.aclose() |
| 1664 | + |
| 1665 | + |
| 1666 | +@pytest.mark.anyio |
| 1667 | +async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch): |
| 1668 | + """Resumed GET responses use EOF draining instead of response.aclose().""" |
| 1669 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1670 | + response = _FakeStreamResponse() |
| 1671 | + |
| 1672 | + class FakeReconnectionEventSource: |
| 1673 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1674 | + self.response = response |
| 1675 | + |
| 1676 | + async def aiter_sse(self): |
| 1677 | + yield _response_sse("abc") |
| 1678 | + yield ServerSentEvent(event="message", data="", id="drained-event") |
| 1679 | + |
| 1680 | + @asynccontextmanager |
| 1681 | + async def fake_aconnect_sse(*args: Any, **kwargs: Any): |
| 1682 | + yield FakeReconnectionEventSource(response) |
| 1683 | + |
| 1684 | + monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse) |
| 1685 | + |
| 1686 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1687 | + async with httpx.AsyncClient() as client: |
| 1688 | + try: |
| 1689 | + ctx = _make_streamable_http_request_context("abc", client, write_stream) |
| 1690 | + await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0) |
| 1691 | + |
| 1692 | + assert response.closed_by_transport is False |
| 1693 | + message = await read_stream.receive() |
| 1694 | + assert isinstance(message, SessionMessage) |
| 1695 | + assert isinstance(message.message, types.JSONRPCResponse) |
| 1696 | + assert message.message.id == "abc" |
| 1697 | + finally: |
| 1698 | + await write_stream.aclose() |
| 1699 | + await read_stream.aclose() |
| 1700 | + |
| 1701 | + |
| 1702 | +@pytest.mark.anyio |
| 1703 | +async def test_reconnection_retries_after_failed_resume(monkeypatch: pytest.MonkeyPatch): |
| 1704 | + """A failed resume attempt falls back to the next reconnection attempt.""" |
| 1705 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1706 | + response = _FakeStreamResponse() |
| 1707 | + attempts = 0 |
| 1708 | + |
| 1709 | + class FakeReconnectionEventSource: |
| 1710 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1711 | + self.response = response |
| 1712 | + |
| 1713 | + async def aiter_sse(self): |
| 1714 | + yield _response_sse("abc") |
| 1715 | + |
| 1716 | + @asynccontextmanager |
| 1717 | + async def fake_aconnect_sse(*args: Any, **kwargs: Any): |
| 1718 | + nonlocal attempts |
| 1719 | + attempts += 1 |
| 1720 | + if attempts == 1: |
| 1721 | + raise RuntimeError("resume failed") |
| 1722 | + yield FakeReconnectionEventSource(response) |
| 1723 | + |
| 1724 | + monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse) |
| 1725 | + |
| 1726 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1727 | + async with httpx.AsyncClient() as client: |
| 1728 | + try: |
| 1729 | + ctx = _make_streamable_http_request_context("abc", client, write_stream) |
| 1730 | + await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0) |
| 1731 | + |
| 1732 | + assert attempts == 2 |
| 1733 | + message = await read_stream.receive() |
| 1734 | + assert isinstance(message, SessionMessage) |
| 1735 | + assert isinstance(message.message, types.JSONRPCResponse) |
| 1736 | + assert message.message.id == "abc" |
| 1737 | + finally: |
| 1738 | + await write_stream.aclose() |
| 1739 | + await read_stream.aclose() |
| 1740 | + |
| 1741 | + |
1586 | 1742 | @pytest.mark.anyio |
1587 | 1743 | async def test_priming_event_not_sent_for_old_protocol_version() -> None: |
1588 | 1744 | """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" |
|
0 commit comments