Skip to content

Commit edeb3e0

Browse files
committed
fix: drain completed streamable HTTP SSE responses
1 parent 616476f commit edeb3e0

3 files changed

Lines changed: 175 additions & 12 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,18 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
240240
event_source.response.raise_for_status()
241241
logger.debug("Resumption GET SSE connection established")
242242

243+
response_complete = False
243244
async for sse in event_source.aiter_sse(): # pragma: no branch
245+
if response_complete:
246+
continue
244247
is_complete = await self._handle_sse_event(
245248
sse,
246249
ctx.read_stream_writer,
247250
original_request_id,
248251
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249252
)
250253
if is_complete:
251-
await event_source.response.aclose()
252-
break
254+
response_complete = True
253255

254256
async def _handle_post_request(self, ctx: RequestContext) -> None:
255257
"""Handle a POST request with response processing."""
@@ -342,6 +344,7 @@ async def _handle_sse_response(
342344

343345
try:
344346
event_source = EventSource(response)
347+
response_complete = False
345348
async for sse in event_source.aiter_sse(): # pragma: no branch
346349
# Track last event ID for potential reconnection
347350
if sse.id:
@@ -351,6 +354,9 @@ async def _handle_sse_response(
351354
if sse.retry is not None:
352355
retry_interval_ms = sse.retry
353356

357+
if response_complete:
358+
continue
359+
354360
is_complete = await self._handle_sse_event(
355361
sse,
356362
ctx.read_stream_writer,
@@ -359,10 +365,11 @@ async def _handle_sse_response(
359365
is_initialization=is_initialization,
360366
)
361367
# If the SSE event indicates completion, like returning response/error
362-
# break the loop
368+
# keep draining the stream so the underlying HTTP connection remains reusable.
363369
if is_complete:
364-
await response.aclose()
365-
return # Normal completion, no reconnect needed
370+
response_complete = True
371+
if response_complete:
372+
return # Normal completion, no reconnect needed
366373
except Exception:
367374
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
368375

@@ -404,27 +411,32 @@ async def _handle_reconnection(
404411
# Track for potential further reconnection
405412
reconnect_last_event_id: str = last_event_id
406413
reconnect_retry_ms = retry_interval_ms
414+
response_complete = False
407415

408416
async for sse in event_source.aiter_sse():
409417
if sse.id: # pragma: no branch
410418
reconnect_last_event_id = sse.id
411419
if sse.retry is not None:
412420
reconnect_retry_ms = sse.retry
413421

422+
if response_complete:
423+
continue
424+
414425
is_complete = await self._handle_sse_event(
415426
sse,
416427
ctx.read_stream_writer,
417428
original_request_id,
418429
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
419430
)
420431
if is_complete:
421-
await event_source.response.aclose()
422-
return
432+
response_complete = True
433+
if response_complete:
434+
return
423435

424436
# Stream ended again without response - reconnect again (reset attempt counter)
425437
logger.info("SSE stream disconnected, reconnecting...")
426438
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
427-
except Exception as e: # pragma: no cover
439+
except Exception as e:
428440
logger.debug(f"Reconnection failed: {e}")
429441
# Try to reconnect again if we still have an event ID
430442
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)

tests/interaction/transports/test_hosting_resume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
357357
http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION
358358
tg.cancel_scope.cancel()
359359

360-
with anyio.fail_after(5): # pragma: no branch
361-
release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
360+
with anyio.fail_after(5): # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
361+
release.set()
362362
# init priming + init response + call priming + "first" + "second" + result = 6 stored events.
363363
await store.wait_until_stored(6)
364364
async with ( # pragma: no branch

tests/shared/test_streamable_http.py

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,16 @@
2727
from starlette.requests import Request
2828
from starlette.routing import Mount
2929

30+
import mcp.client.streamable_http as streamable_http_module
3031
from mcp import MCPError, types
3132
from mcp.client.session import ClientSession
32-
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
33+
from mcp.client.streamable_http import (
34+
RequestContext as StreamableHTTPClientRequestContext,
35+
)
36+
from mcp.client.streamable_http import (
37+
StreamableHTTPTransport,
38+
streamable_http_client,
39+
)
3340
from mcp.server import Server, ServerRequestContext
3441
from mcp.server.streamable_http import (
3542
MCP_PROTOCOL_VERSION_HEADER,
@@ -45,7 +52,7 @@
4552
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4653
from mcp.server.transport_security import TransportSecuritySettings
4754
from mcp.shared._context import RequestContext
48-
from mcp.shared._context_streams import create_context_streams
55+
from mcp.shared._context_streams import ContextSendStream, create_context_streams
4956
from mcp.shared._httpx_utils import (
5057
MCP_DEFAULT_SSE_READ_TIMEOUT,
5158
MCP_DEFAULT_TIMEOUT,
@@ -1803,6 +1810,150 @@ async def test_handle_sse_event_skips_empty_data():
18031810
await read_stream.aclose()
18041811

18051812

1813+
class _FakeStreamResponse(httpx.Response):
1814+
def __init__(self) -> None:
1815+
super().__init__(200, request=httpx.Request("POST", "http://localhost:8000/mcp"))
1816+
self.closed_by_transport = False
1817+
1818+
async def aclose(self) -> None: # pragma: no cover
1819+
self.closed_by_transport = True
1820+
await super().aclose()
1821+
1822+
1823+
def _response_sse(request_id: int | str) -> ServerSentEvent:
1824+
return ServerSentEvent(
1825+
event="message",
1826+
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
1827+
id="response-event",
1828+
)
1829+
1830+
1831+
def _make_streamable_http_request_context(
1832+
request_id: int | str,
1833+
client: httpx.AsyncClient,
1834+
write_stream: ContextSendStream[SessionMessage | Exception],
1835+
) -> StreamableHTTPClientRequestContext:
1836+
return StreamableHTTPClientRequestContext(
1837+
client=client,
1838+
session_id=None,
1839+
session_message=SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method="tools/list")),
1840+
metadata=None,
1841+
read_stream_writer=write_stream,
1842+
)
1843+
1844+
1845+
@pytest.mark.anyio
1846+
async def test_sse_response_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
1847+
"""Terminal POST SSE responses are drained instead of force-closed."""
1848+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1849+
response = _FakeStreamResponse()
1850+
1851+
class FakeEventSource:
1852+
def __init__(self, response: _FakeStreamResponse) -> None:
1853+
self.response = response
1854+
1855+
async def aiter_sse(self):
1856+
yield _response_sse(1)
1857+
yield ServerSentEvent(event="message", data="", id="drained-event")
1858+
1859+
async def fail_reconnect(*args: Any, **kwargs: Any) -> None: # pragma: no cover
1860+
raise AssertionError("terminal responses should not reconnect after draining")
1861+
1862+
monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource)
1863+
monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect)
1864+
1865+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1866+
async with httpx.AsyncClient() as client:
1867+
try:
1868+
ctx = _make_streamable_http_request_context(1, client, write_stream)
1869+
await transport._handle_sse_response(response, ctx)
1870+
1871+
assert response.closed_by_transport is False
1872+
message = await read_stream.receive()
1873+
assert isinstance(message, SessionMessage)
1874+
assert isinstance(message.message, types.JSONRPCResponse)
1875+
assert message.message.id == 1
1876+
finally:
1877+
await write_stream.aclose()
1878+
await read_stream.aclose()
1879+
1880+
1881+
@pytest.mark.anyio
1882+
async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
1883+
"""Resumed GET responses use EOF draining instead of response.aclose()."""
1884+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1885+
response = _FakeStreamResponse()
1886+
1887+
class FakeReconnectionEventSource:
1888+
def __init__(self, response: _FakeStreamResponse) -> None:
1889+
self.response = response
1890+
1891+
async def aiter_sse(self):
1892+
yield _response_sse("abc")
1893+
yield ServerSentEvent(event="message", data="", id="drained-event")
1894+
1895+
@asynccontextmanager
1896+
async def fake_aconnect_sse(*args: Any, **kwargs: Any):
1897+
yield FakeReconnectionEventSource(response)
1898+
1899+
monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse)
1900+
1901+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1902+
async with httpx.AsyncClient() as client:
1903+
try:
1904+
ctx = _make_streamable_http_request_context("abc", client, write_stream)
1905+
await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0)
1906+
1907+
assert response.closed_by_transport is False
1908+
message = await read_stream.receive()
1909+
assert isinstance(message, SessionMessage)
1910+
assert isinstance(message.message, types.JSONRPCResponse)
1911+
assert message.message.id == "abc"
1912+
finally:
1913+
await write_stream.aclose()
1914+
await read_stream.aclose()
1915+
1916+
1917+
@pytest.mark.anyio
1918+
async def test_reconnection_retries_after_failed_resume(monkeypatch: pytest.MonkeyPatch):
1919+
"""A failed resume attempt falls back to the next reconnection attempt."""
1920+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1921+
response = _FakeStreamResponse()
1922+
attempts = 0
1923+
1924+
class FakeReconnectionEventSource:
1925+
def __init__(self, response: _FakeStreamResponse) -> None:
1926+
self.response = response
1927+
1928+
async def aiter_sse(self):
1929+
yield _response_sse("abc")
1930+
1931+
@asynccontextmanager
1932+
async def fake_aconnect_sse(*args: Any, **kwargs: Any):
1933+
nonlocal attempts
1934+
attempts += 1
1935+
if attempts == 1:
1936+
raise RuntimeError("resume failed")
1937+
yield FakeReconnectionEventSource(response)
1938+
1939+
monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse)
1940+
1941+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1942+
async with httpx.AsyncClient() as client:
1943+
try:
1944+
ctx = _make_streamable_http_request_context("abc", client, write_stream)
1945+
await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0)
1946+
1947+
assert attempts == 2
1948+
message = await read_stream.receive()
1949+
assert isinstance(message, SessionMessage)
1950+
assert isinstance(message.message, types.JSONRPCResponse)
1951+
assert message.message.id == "abc"
1952+
finally:
1953+
await write_stream.aclose()
1954+
await read_stream.aclose()
1955+
1956+
18061957
@pytest.mark.anyio
18071958
async def test_priming_event_not_sent_for_old_protocol_version():
18081959
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""

0 commit comments

Comments
 (0)