|
49 | 49 | from mcp.server.transport_security import TransportSecuritySettings |
50 | 50 | from mcp.shared._context import RequestContext |
51 | 51 | from mcp.shared._context_streams import ContextSendStream, create_context_streams |
52 | | -from mcp.shared._httpx_utils import ( |
53 | | - MCP_DEFAULT_SSE_READ_TIMEOUT, |
54 | | - MCP_DEFAULT_TIMEOUT, |
55 | | - create_mcp_http_client, |
56 | | -) |
57 | 52 | from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage |
58 | 53 | from mcp.shared.session import RequestResponder |
59 | 54 | from mcp.types import ( |
@@ -1663,6 +1658,76 @@ async def fail_reconnect(*args: Any, **kwargs: Any) -> None: # pragma: no cover |
1663 | 1658 | await read_stream.aclose() |
1664 | 1659 |
|
1665 | 1660 |
|
| 1661 | +@pytest.mark.anyio |
| 1662 | +async def test_sse_response_does_not_reconnect_after_terminal_then_drain_error(monkeypatch: pytest.MonkeyPatch): |
| 1663 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1664 | + response = _FakeStreamResponse() |
| 1665 | + |
| 1666 | + class FakeEventSource: |
| 1667 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1668 | + self.response = response |
| 1669 | + |
| 1670 | + async def aiter_sse(self): |
| 1671 | + yield _response_sse(1) |
| 1672 | + raise RuntimeError("drain failed after terminal response") |
| 1673 | + |
| 1674 | + async def fail_reconnect(*args: Any, **kwargs: Any) -> None: # pragma: no cover |
| 1675 | + raise AssertionError("completed responses should not reconnect after drain errors") |
| 1676 | + |
| 1677 | + monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource) |
| 1678 | + monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect) |
| 1679 | + |
| 1680 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1681 | + async with httpx.AsyncClient() as client: |
| 1682 | + try: |
| 1683 | + ctx = _make_streamable_http_request_context(1, client, write_stream) |
| 1684 | + await transport._handle_sse_response(response, ctx) |
| 1685 | + |
| 1686 | + message = await read_stream.receive() |
| 1687 | + assert isinstance(message, SessionMessage) |
| 1688 | + assert isinstance(message.message, types.JSONRPCResponse) |
| 1689 | + assert message.message.id == 1 |
| 1690 | + finally: |
| 1691 | + await write_stream.aclose() |
| 1692 | + await read_stream.aclose() |
| 1693 | + |
| 1694 | + |
| 1695 | +@pytest.mark.anyio |
| 1696 | +async def test_sse_response_reconnects_after_pre_terminal_drain_error(monkeypatch: pytest.MonkeyPatch): |
| 1697 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1698 | + response = _FakeStreamResponse() |
| 1699 | + reconnects: list[tuple[str, int | None]] = [] |
| 1700 | + |
| 1701 | + class FakeEventSource: |
| 1702 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1703 | + self.response = response |
| 1704 | + |
| 1705 | + async def aiter_sse(self): |
| 1706 | + yield ServerSentEvent(event="message", data="", id="resume-from-here") |
| 1707 | + raise RuntimeError("stream failed before terminal response") |
| 1708 | + |
| 1709 | + async def record_reconnect( |
| 1710 | + ctx: StreamableHTTPClientRequestContext, |
| 1711 | + last_event_id: str, |
| 1712 | + retry_interval_ms: int | None = None, |
| 1713 | + ) -> None: |
| 1714 | + reconnects.append((last_event_id, retry_interval_ms)) |
| 1715 | + |
| 1716 | + monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource) |
| 1717 | + monkeypatch.setattr(transport, "_handle_reconnection", record_reconnect) |
| 1718 | + |
| 1719 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1720 | + async with httpx.AsyncClient() as client: |
| 1721 | + try: |
| 1722 | + ctx = _make_streamable_http_request_context(1, client, write_stream) |
| 1723 | + await transport._handle_sse_response(response, ctx) |
| 1724 | + |
| 1725 | + assert reconnects == [("resume-from-here", None)] |
| 1726 | + finally: |
| 1727 | + await write_stream.aclose() |
| 1728 | + await read_stream.aclose() |
| 1729 | + |
| 1730 | + |
1666 | 1731 | @pytest.mark.anyio |
1667 | 1732 | async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch): |
1668 | 1733 | """Resumed GET responses use EOF draining instead of response.aclose().""" |
@@ -1699,6 +1764,44 @@ async def fake_aconnect_sse(*args: Any, **kwargs: Any): |
1699 | 1764 | await read_stream.aclose() |
1700 | 1765 |
|
1701 | 1766 |
|
| 1767 | +@pytest.mark.anyio |
| 1768 | +async def test_reconnection_does_not_retry_after_terminal_then_drain_error(monkeypatch: pytest.MonkeyPatch): |
| 1769 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 1770 | + response = _FakeStreamResponse() |
| 1771 | + attempts = 0 |
| 1772 | + |
| 1773 | + class FakeReconnectionEventSource: |
| 1774 | + def __init__(self, response: _FakeStreamResponse) -> None: |
| 1775 | + self.response = response |
| 1776 | + |
| 1777 | + async def aiter_sse(self): |
| 1778 | + yield _response_sse("abc") |
| 1779 | + raise RuntimeError("drain failed after terminal response") |
| 1780 | + |
| 1781 | + @asynccontextmanager |
| 1782 | + async def fake_aconnect_sse(*args: Any, **kwargs: Any): |
| 1783 | + nonlocal attempts |
| 1784 | + attempts += 1 |
| 1785 | + yield FakeReconnectionEventSource(response) |
| 1786 | + |
| 1787 | + monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse) |
| 1788 | + |
| 1789 | + write_stream, read_stream = create_context_streams[SessionMessage | Exception](2) |
| 1790 | + async with httpx.AsyncClient() as client: |
| 1791 | + try: |
| 1792 | + ctx = _make_streamable_http_request_context("abc", client, write_stream) |
| 1793 | + await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0) |
| 1794 | + |
| 1795 | + assert attempts == 1 |
| 1796 | + message = await read_stream.receive() |
| 1797 | + assert isinstance(message, SessionMessage) |
| 1798 | + assert isinstance(message.message, types.JSONRPCResponse) |
| 1799 | + assert message.message.id == "abc" |
| 1800 | + finally: |
| 1801 | + await write_stream.aclose() |
| 1802 | + await read_stream.aclose() |
| 1803 | + |
| 1804 | + |
1702 | 1805 | @pytest.mark.anyio |
1703 | 1806 | async def test_reconnection_retries_after_failed_resume(monkeypatch: pytest.MonkeyPatch): |
1704 | 1807 | """A failed resume attempt falls back to the next reconnection attempt.""" |
|
0 commit comments