Skip to content

Commit 5332b0e

Browse files
committed
fix(server): opt-in drain on read EOF
1 parent d14452b commit 5332b0e

5 files changed

Lines changed: 43 additions & 19 deletions

File tree

src/mcp/server/lowlevel/server.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ async def run(
347347
# the initialization lifecycle, but can do so with any available node
348348
# rather than requiring initialization for each connection.
349349
stateless: bool = False,
350+
# When True, treat read EOF as a half-close and allow in-flight handlers
351+
# to drain their responses via the still-open write stream (e.g. stdio
352+
# with bash-redirected stdin).
353+
drain_on_read_close: bool = False,
350354
):
351355
async with AsyncExitStack() as stack:
352356
lifespan_context = await stack.enter_async_context(self.lifespan(self))
@@ -356,26 +360,35 @@ async def run(
356360
write_stream,
357361
initialization_options,
358362
stateless=stateless,
363+
close_write_stream_on_read_close=not drain_on_read_close,
359364
)
360365
)
361366

362367
async with anyio.create_task_group() as tg:
363-
async for message in session.incoming_messages:
364-
logger.debug("Received message: %s", message)
365-
366-
if isinstance(message, RequestResponder) and message.context is not None:
367-
context = message.context
368-
else:
369-
context = contextvars.copy_context()
370-
371-
context.run(
372-
tg.start_soon,
373-
self._handle_message,
374-
message,
375-
session,
376-
lifespan_context,
377-
raise_exceptions,
378-
)
368+
try:
369+
async for message in session.incoming_messages:
370+
logger.debug("Received message: %s", message)
371+
372+
if isinstance(message, RequestResponder) and message.context is not None:
373+
context = message.context
374+
else:
375+
context = contextvars.copy_context()
376+
377+
context.run(
378+
tg.start_soon,
379+
self._handle_message,
380+
message,
381+
session,
382+
lifespan_context,
383+
raise_exceptions,
384+
)
385+
finally:
386+
if not drain_on_read_close:
387+
# Transport closed: cancel in-flight handlers. Without this the
388+
# TG join waits for them, and when they eventually try to
389+
# respond they hit a closed write stream (the session's
390+
# _receive_loop closed it when the read stream ended).
391+
tg.cancel_scope.cancel()
379392

380393
async def _handle_message(
381394
self,

src/mcp/server/mcpserver/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None:
852852
read_stream,
853853
write_stream,
854854
self._lowlevel_server.create_initialization_options(),
855+
drain_on_read_close=True,
855856
)
856857

857858
async def run_sse_async( # pragma: no cover

src/mcp/server/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def __init__(
8080
write_stream: WriteStream[SessionMessage],
8181
init_options: InitializationOptions,
8282
stateless: bool = False,
83+
close_write_stream_on_read_close: bool = True,
8384
) -> None:
84-
super().__init__(read_stream, write_stream, close_write_stream_on_read_close=False)
85+
super().__init__(read_stream, write_stream, close_write_stream_on_read_close=close_write_stream_on_read_close)
8586
self._stateless = stateless
8687
self._initialization_state = (
8788
InitializationState.Initialized if stateless else InitializationState.NotInitialized

tests/server/test_cancel_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
120120
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)
121121

122122
async def run_server():
123-
await server.run(server_read, server_write, server.create_initialization_options())
123+
await server.run(server_read, server_write, server.create_initialization_options(), drain_on_read_close=True)
124124
server_run_returned.set()
125125

126126
init_req = JSONRPCRequest(

tests/server/test_stdio.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,16 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar
170170
):
171171
with anyio.fail_after(5):
172172
async with anyio.create_task_group() as tg: # pragma: no branch
173-
tg.start_soon(server.run, read_stream, write_stream, server.create_initialization_options())
173+
174+
async def run_server() -> None:
175+
await server.run(
176+
read_stream,
177+
write_stream,
178+
server.create_initialization_options(),
179+
drain_on_read_close=True,
180+
)
181+
182+
tg.start_soon(run_server)
174183
await both_tools_started.wait()
175184
allow_tools_to_finish.set()
176185

0 commit comments

Comments
 (0)