|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | import logging |
5 | | -from typing import Any |
| 5 | +from types import SimpleNamespace |
| 6 | +from typing import Any, cast |
6 | 7 | from unittest.mock import AsyncMock, patch |
7 | 8 |
|
8 | 9 | import anyio |
@@ -64,6 +65,50 @@ async def try_run(): |
64 | 65 | assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0]) |
65 | 66 |
|
66 | 67 |
|
| 68 | +@pytest.mark.anyio |
| 69 | +async def test_run_terminates_active_streaming_session_before_shutdown(): |
| 70 | + """run() should close active SSE transports before task cancellation.""" |
| 71 | + app = Server("test-shutdown-cleanup") |
| 72 | + manager = StreamableHTTPSessionManager(app=app) |
| 73 | + transport = StreamableHTTPServerTransport(mcp_session_id="session-id") |
| 74 | + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1) |
| 75 | + |
| 76 | + try: |
| 77 | + transport._sse_stream_writers["request-id"] = sse_stream_writer |
| 78 | + |
| 79 | + async with manager.run(): |
| 80 | + manager._server_instances["session-id"] = transport |
| 81 | + |
| 82 | + assert transport.is_terminated |
| 83 | + assert transport._sse_stream_writers == {} |
| 84 | + assert manager._server_instances == {} |
| 85 | + with pytest.raises(anyio.ClosedResourceError): |
| 86 | + await sse_stream_writer.send({"data": "still-open"}) |
| 87 | + finally: |
| 88 | + await sse_stream_reader.aclose() |
| 89 | + |
| 90 | + |
| 91 | +@pytest.mark.anyio |
| 92 | +async def test_run_terminates_remaining_sessions_if_one_shutdown_fails(caplog: pytest.LogCaptureFixture): |
| 93 | + """One failed transport shutdown should not skip later active sessions.""" |
| 94 | + app = Server("test-shutdown-cleanup-error") |
| 95 | + manager = StreamableHTTPSessionManager(app=app) |
| 96 | + failing_terminate = AsyncMock(side_effect=RuntimeError("terminate failed")) |
| 97 | + healthy_terminate = AsyncMock() |
| 98 | + failing_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=failing_terminate)) |
| 99 | + healthy_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=healthy_terminate)) |
| 100 | + |
| 101 | + with caplog.at_level(logging.ERROR): |
| 102 | + async with manager.run(): |
| 103 | + manager._server_instances["bad-session"] = failing_transport |
| 104 | + manager._server_instances["healthy-session"] = healthy_transport |
| 105 | + |
| 106 | + failing_terminate.assert_awaited_once_with() |
| 107 | + healthy_terminate.assert_awaited_once_with() |
| 108 | + assert "Error terminating StreamableHTTP session during shutdown" in caplog.text |
| 109 | + assert manager._server_instances == {} |
| 110 | + |
| 111 | + |
67 | 112 | @pytest.mark.anyio |
68 | 113 | async def test_handle_request_without_run_raises_error(): |
69 | 114 | """Test that handle_request raises error if run() hasn't been called.""" |
@@ -271,6 +316,43 @@ async def mock_receive(): |
271 | 316 | assert len(transport._request_streams) == 0, "Transport should have no active request streams" |
272 | 317 |
|
273 | 318 |
|
| 319 | +@pytest.mark.anyio |
| 320 | +async def test_transport_terminate_closes_sse_stream_writers(): |
| 321 | + """terminate() should close active SSE writers so streaming responses can finish.""" |
| 322 | + transport = StreamableHTTPServerTransport(mcp_session_id="test-session") |
| 323 | + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1) |
| 324 | + |
| 325 | + try: |
| 326 | + transport._sse_stream_writers["request-id"] = sse_stream_writer |
| 327 | + |
| 328 | + await transport.terminate() |
| 329 | + |
| 330 | + assert transport._sse_stream_writers == {} |
| 331 | + with pytest.raises(anyio.ClosedResourceError): |
| 332 | + await sse_stream_writer.send({"data": "still-open"}) |
| 333 | + |
| 334 | + await transport.terminate() |
| 335 | + finally: |
| 336 | + await sse_stream_reader.aclose() |
| 337 | + |
| 338 | + |
| 339 | +@pytest.mark.anyio |
| 340 | +async def test_transport_connect_cleans_request_streams_on_exit(): |
| 341 | + """connect() should close registered request streams when the transport exits.""" |
| 342 | + transport = StreamableHTTPServerTransport(mcp_session_id="test-session") |
| 343 | + request_stream_writer, request_stream_reader = anyio.create_memory_object_stream[Any](1) |
| 344 | + |
| 345 | + transport._request_streams["request-id"] = (request_stream_writer, request_stream_reader) |
| 346 | + |
| 347 | + async with transport.connect(): |
| 348 | + assert "request-id" in transport._request_streams |
| 349 | + transport._terminated = True |
| 350 | + |
| 351 | + assert transport._request_streams == {} |
| 352 | + with pytest.raises(anyio.ClosedResourceError): |
| 353 | + await request_stream_writer.send(cast(Any, object())) |
| 354 | + |
| 355 | + |
274 | 356 | @pytest.mark.anyio |
275 | 357 | async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): |
276 | 358 | """Test that requests with unknown session IDs return HTTP 404 per MCP spec.""" |
|
0 commit comments