Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def _create_error_response(
status_code: HTTPStatus,
error_code: int = INVALID_REQUEST,
headers: dict[str, str] | None = None,
request_id: RequestId | None = None,
) -> Response:
"""Create an error response with a simple string message."""
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
Expand All @@ -302,7 +303,7 @@ def _create_error_response(
# Return a properly formatted JSON error response
error_response = JSONRPCError(
jsonrpc="2.0",
id=None,
id=request_id,
error=ErrorData(code=error_code, message=error_message),
)

Expand Down Expand Up @@ -482,19 +483,16 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"

if is_initialization_request:
# Check if the server already has an established session
if self.mcp_session_id:
# Check if request has a session ID
request_session_id = self._get_session_id(request)

# If request has a session ID but doesn't match, return 404
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
request_session_id = self._get_session_id(request)
if (
self.mcp_session_id and request_session_id and request_session_id != self.mcp_session_id
): # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
return
elif not await self._validate_request_headers(request, send):
return

Expand Down Expand Up @@ -525,6 +523,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re

# Extract the request ID outside the try block for proper scope
request_id = str(message.id)
if request_id in self._request_streams:
response = self._create_error_response(
f"Conflict: Request ID {request_id!r} is already in flight on this session",
HTTPStatus.CONFLICT,
request_id=message.id,
)
await response(scope, receive, send)
return
# Register this stream for the request ID
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
request_stream_reader = self._request_streams[request_id][1]
Expand Down
105 changes: 105 additions & 0 deletions tests/issues/test_2655_streamable_http_duplicate_request_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import anyio
import httpx
import pytest

from mcp.server import Server, ServerRequestContext
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
from mcp.types import (
INVALID_REQUEST,
LATEST_PROTOCOL_VERSION,
CallToolRequestParams,
CallToolResult,
TextContent,
)


@pytest.mark.anyio
async def test_streamable_http_duplicate_request_id_returns_409_and_preserves_in_flight_request() -> None:
started = anyio.Event()
release = anyio.Event()

async def handle_call_tool(
ctx: ServerRequestContext[object],
params: CallToolRequestParams,
) -> CallToolResult:
started.set()
await release.wait()
return CallToolResult(content=[TextContent(type="text", text="ok")])

server = Server("test-duplicate-request-id", on_call_tool=handle_call_tool)
mcp_app = server.streamable_http_app(json_response=True, host="testserver")

async with (
mcp_app.router.lifespan_context(mcp_app),
httpx.ASGITransport(mcp_app) as transport,
httpx.AsyncClient(transport=transport, base_url="http://testserver", timeout=5.0) as client,
):
base_headers = {"Accept": "application/json", "Content-Type": "application/json"}

init_response = await client.post(
"/mcp",
headers=base_headers,
json={
"jsonrpc": "2.0",
"method": "initialize",
"id": "init-1",
"params": {
"clientInfo": {"name": "test-client", "version": "0"},
"protocolVersion": LATEST_PROTOCOL_VERSION,
"capabilities": {},
},
},
)
assert init_response.status_code == 200
session_id = init_response.headers.get(MCP_SESSION_ID_HEADER)
assert session_id is not None

session_headers = {**base_headers, MCP_SESSION_ID_HEADER: session_id}

initialized = await client.post(
"/mcp",
headers=session_headers,
json={"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}},
)
assert initialized.status_code == 202

request_id = "dup-id-1"
slow_response: httpx.Response | None = None

async def run_slow_request() -> None:
nonlocal slow_response
slow_response = await client.post(
"/mcp",
headers=session_headers,
json={
"jsonrpc": "2.0",
"method": "tools/call",
"id": request_id,
"params": {"name": "slow_tool", "arguments": {}},
},
)

async with anyio.create_task_group() as tg:
tg.start_soon(run_slow_request)
with anyio.fail_after(5):
await started.wait()

duplicate = await client.post(
"/mcp",
headers=session_headers,
json={"jsonrpc": "2.0", "method": "ping", "id": request_id, "params": {}},
)
assert duplicate.status_code == 409
duplicate_body = duplicate.json()
assert duplicate_body["jsonrpc"] == "2.0"
assert duplicate_body["id"] == request_id
assert duplicate_body["error"]["code"] == INVALID_REQUEST

release.set()

assert slow_response is not None
assert slow_response.status_code == 200
slow_body = slow_response.json()
assert slow_body["jsonrpc"] == "2.0"
assert slow_body["id"] == request_id
assert slow_body["result"]["content"][0]["text"] == "ok"
Loading