Skip to content

Commit 74dfacd

Browse files
committed
fix(streamable_http): reject duplicate request IDs
1 parent 616476f commit 74dfacd

2 files changed

Lines changed: 132 additions & 1 deletion

File tree

src/mcp/server/streamable_http.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def _create_error_response(
290290
status_code: HTTPStatus,
291291
error_code: int = INVALID_REQUEST,
292292
headers: dict[str, str] | None = None,
293+
request_id: RequestId | None = None,
293294
) -> Response:
294295
"""Create an error response with a simple string message."""
295296
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
@@ -302,7 +303,7 @@ def _create_error_response(
302303
# Return a properly formatted JSON error response
303304
error_response = JSONRPCError(
304305
jsonrpc="2.0",
305-
id=None,
306+
id=request_id,
306307
error=ErrorData(code=error_code, message=error_message),
307308
)
308309

@@ -525,6 +526,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
525526

526527
# Extract the request ID outside the try block for proper scope
527528
request_id = str(message.id)
529+
if request_id in self._request_streams:
530+
response = self._create_error_response(
531+
f"Conflict: Request ID {request_id!r} is already in flight on this session",
532+
HTTPStatus.CONFLICT,
533+
request_id=message.id,
534+
)
535+
await response(scope, receive, send)
536+
return
528537
# Register this stream for the request ID
529538
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
530539
request_stream_reader = self._request_streams[request_id][1]
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import anyio
2+
import httpx
3+
import pytest
4+
5+
from mcp.server import Server, ServerRequestContext
6+
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
7+
from mcp.types import (
8+
INVALID_REQUEST,
9+
LATEST_PROTOCOL_VERSION,
10+
CallToolRequestParams,
11+
CallToolResult,
12+
ListToolsResult,
13+
PaginatedRequestParams,
14+
TextContent,
15+
Tool,
16+
)
17+
18+
19+
@pytest.mark.anyio
20+
async def test_streamable_http_duplicate_request_id_returns_409_and_preserves_in_flight_request() -> None:
21+
started = anyio.Event()
22+
release = anyio.Event()
23+
24+
async def handle_list_tools(
25+
ctx: ServerRequestContext[object],
26+
params: PaginatedRequestParams | None,
27+
) -> ListToolsResult:
28+
return ListToolsResult(
29+
tools=[
30+
Tool(
31+
name="slow_tool",
32+
description="Blocks until released by the test",
33+
input_schema={"type": "object", "properties": {}},
34+
)
35+
]
36+
)
37+
38+
async def handle_call_tool(
39+
ctx: ServerRequestContext[object],
40+
params: CallToolRequestParams,
41+
) -> CallToolResult:
42+
started.set()
43+
await release.wait()
44+
return CallToolResult(content=[TextContent(type="text", text="ok")])
45+
46+
server = Server("test-duplicate-request-id", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool)
47+
mcp_app = server.streamable_http_app(json_response=True, host="testserver")
48+
49+
async with (
50+
mcp_app.router.lifespan_context(mcp_app),
51+
httpx.ASGITransport(mcp_app) as transport,
52+
httpx.AsyncClient(transport=transport, base_url="http://testserver", timeout=5.0) as client,
53+
):
54+
base_headers = {"Accept": "application/json", "Content-Type": "application/json"}
55+
56+
init_response = await client.post(
57+
"/mcp",
58+
headers=base_headers,
59+
json={
60+
"jsonrpc": "2.0",
61+
"method": "initialize",
62+
"id": "init-1",
63+
"params": {
64+
"clientInfo": {"name": "test-client", "version": "0"},
65+
"protocolVersion": LATEST_PROTOCOL_VERSION,
66+
"capabilities": {},
67+
},
68+
},
69+
)
70+
assert init_response.status_code == 200
71+
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
72+
assert session_id is not None
73+
74+
session_headers = {**base_headers, MCP_SESSION_ID_HEADER: session_id}
75+
76+
initialized = await client.post(
77+
"/mcp",
78+
headers=session_headers,
79+
json={"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}},
80+
)
81+
assert initialized.status_code == 202
82+
83+
request_id = "dup-id-1"
84+
slow_response: httpx.Response | None = None
85+
86+
async def run_slow_request() -> None:
87+
nonlocal slow_response
88+
slow_response = await client.post(
89+
"/mcp",
90+
headers=session_headers,
91+
json={
92+
"jsonrpc": "2.0",
93+
"method": "tools/call",
94+
"id": request_id,
95+
"params": {"name": "slow_tool", "arguments": {}},
96+
},
97+
)
98+
99+
async with anyio.create_task_group() as tg:
100+
tg.start_soon(run_slow_request)
101+
with anyio.fail_after(5):
102+
await started.wait()
103+
104+
duplicate = await client.post(
105+
"/mcp",
106+
headers=session_headers,
107+
json={"jsonrpc": "2.0", "method": "ping", "id": request_id, "params": {}},
108+
)
109+
assert duplicate.status_code == 409
110+
duplicate_body = duplicate.json()
111+
assert duplicate_body["jsonrpc"] == "2.0"
112+
assert duplicate_body["id"] == request_id
113+
assert duplicate_body["error"]["code"] == INVALID_REQUEST
114+
115+
release.set()
116+
117+
assert slow_response is not None
118+
assert slow_response.status_code == 200
119+
slow_body = slow_response.json()
120+
assert slow_body["jsonrpc"] == "2.0"
121+
assert slow_body["id"] == request_id
122+
assert slow_body["result"]["content"][0]["text"] == "ok"

0 commit comments

Comments
 (0)