Skip to content

Commit 27f6789

Browse files
committed
fix: drain completed streamable HTTP SSE responses
1 parent 19fe9fa commit 27f6789

3 files changed

Lines changed: 180 additions & 12 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,18 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
240240
event_source.response.raise_for_status()
241241
logger.debug("Resumption GET SSE connection established")
242242

243+
response_complete = False
243244
async for sse in event_source.aiter_sse(): # pragma: no branch
245+
if response_complete:
246+
continue
244247
is_complete = await self._handle_sse_event(
245248
sse,
246249
ctx.read_stream_writer,
247250
original_request_id,
248251
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249252
)
250253
if is_complete:
251-
await event_source.response.aclose()
252-
break
254+
response_complete = True
253255

254256
async def _handle_post_request(self, ctx: RequestContext) -> None:
255257
"""Handle a POST request with response processing."""
@@ -342,6 +344,7 @@ async def _handle_sse_response(
342344

343345
try:
344346
event_source = EventSource(response)
347+
response_complete = False
345348
async for sse in event_source.aiter_sse(): # pragma: no branch
346349
# Track last event ID for potential reconnection
347350
if sse.id:
@@ -351,6 +354,9 @@ async def _handle_sse_response(
351354
if sse.retry is not None:
352355
retry_interval_ms = sse.retry
353356

357+
if response_complete:
358+
continue
359+
354360
is_complete = await self._handle_sse_event(
355361
sse,
356362
ctx.read_stream_writer,
@@ -359,10 +365,11 @@ async def _handle_sse_response(
359365
is_initialization=is_initialization,
360366
)
361367
# If the SSE event indicates completion, like returning response/error
362-
# break the loop
368+
# keep draining the stream so the underlying HTTP connection remains reusable.
363369
if is_complete:
364-
await response.aclose()
365-
return # Normal completion, no reconnect needed
370+
response_complete = True
371+
if response_complete:
372+
return # Normal completion, no reconnect needed
366373
except Exception:
367374
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
368375

@@ -404,27 +411,32 @@ async def _handle_reconnection(
404411
# Track for potential further reconnection
405412
reconnect_last_event_id: str = last_event_id
406413
reconnect_retry_ms = retry_interval_ms
414+
response_complete = False
407415

408416
async for sse in event_source.aiter_sse():
409417
if sse.id: # pragma: no branch
410418
reconnect_last_event_id = sse.id
411419
if sse.retry is not None:
412420
reconnect_retry_ms = sse.retry
413421

422+
if response_complete:
423+
continue
424+
414425
is_complete = await self._handle_sse_event(
415426
sse,
416427
ctx.read_stream_writer,
417428
original_request_id,
418429
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
419430
)
420431
if is_complete:
421-
await event_source.response.aclose()
422-
return
432+
response_complete = True
433+
if response_complete:
434+
return
423435

424436
# Stream ended again without response - reconnect again (reset attempt counter)
425437
logger.info("SSE stream disconnected, reconnecting...")
426438
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
427-
except Exception as e: # pragma: no cover
439+
except Exception as e:
428440
logger.debug(f"Reconnection failed: {e}")
429441
# Try to reconnect again if we still have an event ID
430442
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)

tests/interaction/transports/test_hosting_resume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
357357
http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION
358358
tg.cancel_scope.cancel()
359359

360-
with anyio.fail_after(5): # pragma: no branch
361-
release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
360+
with anyio.fail_after(5): # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event
361+
release.set()
362362
# init priming + init response + call priming + "first" + "second" + result = 6 stored events.
363363
await store.wait_until_stored(6)
364364
async with ( # pragma: no branch

tests/shared/test_streamable_http.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,16 @@
2323
from starlette.requests import Request
2424
from starlette.routing import Mount
2525

26+
import mcp.client.streamable_http as streamable_http_module
2627
from mcp import MCPError, types
2728
from mcp.client.session import ClientSession
28-
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
29+
from mcp.client.streamable_http import (
30+
RequestContext as StreamableHTTPClientRequestContext,
31+
)
32+
from mcp.client.streamable_http import (
33+
StreamableHTTPTransport,
34+
streamable_http_client,
35+
)
2936
from mcp.server import Server, ServerRequestContext
3037
from mcp.server.streamable_http import (
3138
MCP_PROTOCOL_VERSION_HEADER,
@@ -41,7 +48,12 @@
4148
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4249
from mcp.server.transport_security import TransportSecuritySettings
4350
from mcp.shared._context import RequestContext
44-
from mcp.shared._context_streams import create_context_streams
51+
from mcp.shared._context_streams import ContextSendStream, create_context_streams
52+
from mcp.shared._httpx_utils import (
53+
MCP_DEFAULT_SSE_READ_TIMEOUT,
54+
MCP_DEFAULT_TIMEOUT,
55+
create_mcp_http_client,
56+
)
4557
from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage
4658
from mcp.shared.session import RequestResponder
4759
from mcp.types import (
@@ -1583,6 +1595,150 @@ async def test_handle_sse_event_skips_empty_data() -> None:
15831595
await read_stream.aclose()
15841596

15851597

1598+
class _FakeStreamResponse(httpx.Response):
1599+
def __init__(self) -> None:
1600+
super().__init__(200, request=httpx.Request("POST", "http://localhost:8000/mcp"))
1601+
self.closed_by_transport = False
1602+
1603+
async def aclose(self) -> None: # pragma: no cover
1604+
self.closed_by_transport = True
1605+
await super().aclose()
1606+
1607+
1608+
def _response_sse(request_id: int | str) -> ServerSentEvent:
1609+
return ServerSentEvent(
1610+
event="message",
1611+
data=json.dumps({"jsonrpc": "2.0", "id": request_id, "result": {}}),
1612+
id="response-event",
1613+
)
1614+
1615+
1616+
def _make_streamable_http_request_context(
1617+
request_id: int | str,
1618+
client: httpx.AsyncClient,
1619+
write_stream: ContextSendStream[SessionMessage | Exception],
1620+
) -> StreamableHTTPClientRequestContext:
1621+
return StreamableHTTPClientRequestContext(
1622+
client=client,
1623+
session_id=None,
1624+
session_message=SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=request_id, method="tools/list")),
1625+
metadata=None,
1626+
read_stream_writer=write_stream,
1627+
)
1628+
1629+
1630+
@pytest.mark.anyio
1631+
async def test_sse_response_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
1632+
"""Terminal POST SSE responses are drained instead of force-closed."""
1633+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1634+
response = _FakeStreamResponse()
1635+
1636+
class FakeEventSource:
1637+
def __init__(self, response: _FakeStreamResponse) -> None:
1638+
self.response = response
1639+
1640+
async def aiter_sse(self):
1641+
yield _response_sse(1)
1642+
yield ServerSentEvent(event="message", data="", id="drained-event")
1643+
1644+
async def fail_reconnect(*args: Any, **kwargs: Any) -> None: # pragma: no cover
1645+
raise AssertionError("terminal responses should not reconnect after draining")
1646+
1647+
monkeypatch.setattr(streamable_http_module, "EventSource", FakeEventSource)
1648+
monkeypatch.setattr(transport, "_handle_reconnection", fail_reconnect)
1649+
1650+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1651+
async with httpx.AsyncClient() as client:
1652+
try:
1653+
ctx = _make_streamable_http_request_context(1, client, write_stream)
1654+
await transport._handle_sse_response(response, ctx)
1655+
1656+
assert response.closed_by_transport is False
1657+
message = await read_stream.receive()
1658+
assert isinstance(message, SessionMessage)
1659+
assert isinstance(message.message, types.JSONRPCResponse)
1660+
assert message.message.id == 1
1661+
finally:
1662+
await write_stream.aclose()
1663+
await read_stream.aclose()
1664+
1665+
1666+
@pytest.mark.anyio
1667+
async def test_reconnection_drains_after_terminal_response(monkeypatch: pytest.MonkeyPatch):
1668+
"""Resumed GET responses use EOF draining instead of response.aclose()."""
1669+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1670+
response = _FakeStreamResponse()
1671+
1672+
class FakeReconnectionEventSource:
1673+
def __init__(self, response: _FakeStreamResponse) -> None:
1674+
self.response = response
1675+
1676+
async def aiter_sse(self):
1677+
yield _response_sse("abc")
1678+
yield ServerSentEvent(event="message", data="", id="drained-event")
1679+
1680+
@asynccontextmanager
1681+
async def fake_aconnect_sse(*args: Any, **kwargs: Any):
1682+
yield FakeReconnectionEventSource(response)
1683+
1684+
monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse)
1685+
1686+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1687+
async with httpx.AsyncClient() as client:
1688+
try:
1689+
ctx = _make_streamable_http_request_context("abc", client, write_stream)
1690+
await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0)
1691+
1692+
assert response.closed_by_transport is False
1693+
message = await read_stream.receive()
1694+
assert isinstance(message, SessionMessage)
1695+
assert isinstance(message.message, types.JSONRPCResponse)
1696+
assert message.message.id == "abc"
1697+
finally:
1698+
await write_stream.aclose()
1699+
await read_stream.aclose()
1700+
1701+
1702+
@pytest.mark.anyio
1703+
async def test_reconnection_retries_after_failed_resume(monkeypatch: pytest.MonkeyPatch):
1704+
"""A failed resume attempt falls back to the next reconnection attempt."""
1705+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
1706+
response = _FakeStreamResponse()
1707+
attempts = 0
1708+
1709+
class FakeReconnectionEventSource:
1710+
def __init__(self, response: _FakeStreamResponse) -> None:
1711+
self.response = response
1712+
1713+
async def aiter_sse(self):
1714+
yield _response_sse("abc")
1715+
1716+
@asynccontextmanager
1717+
async def fake_aconnect_sse(*args: Any, **kwargs: Any):
1718+
nonlocal attempts
1719+
attempts += 1
1720+
if attempts == 1:
1721+
raise RuntimeError("resume failed")
1722+
yield FakeReconnectionEventSource(response)
1723+
1724+
monkeypatch.setattr(streamable_http_module, "aconnect_sse", fake_aconnect_sse)
1725+
1726+
write_stream, read_stream = create_context_streams[SessionMessage | Exception](2)
1727+
async with httpx.AsyncClient() as client:
1728+
try:
1729+
ctx = _make_streamable_http_request_context("abc", client, write_stream)
1730+
await transport._handle_reconnection(ctx, "previous-event", retry_interval_ms=0)
1731+
1732+
assert attempts == 2
1733+
message = await read_stream.receive()
1734+
assert isinstance(message, SessionMessage)
1735+
assert isinstance(message.message, types.JSONRPCResponse)
1736+
assert message.message.id == "abc"
1737+
finally:
1738+
await write_stream.aclose()
1739+
await read_stream.aclose()
1740+
1741+
15861742
@pytest.mark.anyio
15871743
async def test_priming_event_not_sent_for_old_protocol_version() -> None:
15881744
"""_maybe_send_priming_event skips for old protocol versions (backwards compat)."""

0 commit comments

Comments
 (0)