diff --git a/mcp/streamable.go b/mcp/streamable.go index d6452b28..5a6d828e 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1376,6 +1376,10 @@ func (c *streamableServerConn) Close() error { // A StreamableClientTransport is a [Transport] that can communicate with an MCP // endpoint serving the streamable HTTP transport defined by the 2025-03-26 // version of the spec. +// +// If the server terminates the session (returning 404 Not Found), subsequent +// operations on the [ClientSession] will return an error wrapping +// [ErrSessionMissing]. type StreamableClientTransport struct { Endpoint string HTTPClient *http.Client @@ -1493,16 +1497,17 @@ type streamableClientConn struct { sessionID string } -// errSessionMissing distinguishes if the session is known to not be present on -// the server (see [streamableClientConn.fail]). +// ErrSessionMissing is a sentinel error returned by a [ClientSession] using the +// [StreamableClientTransport], indicating the session is not present on the +// server. This occurs when the server returns a 404 Not Found response. // -// TODO(rfindley): should we expose this error value (and its corresponding -// API) to the user? +// According to the MCP spec ([Session Management]), clients should reestablish +// a session when encountering this error. Users can check for this error using +// [errors.Is] to implement session recovery logic by creating a new +// [ClientSession]. // -// The spec says that if the server returns 404, clients should reestablish -// a session. For now, we delegate that to the user, but do they need a way to -// differentiate a 'NotFound' error from other errors? -var errSessionMissing = errors.New("session not found") +// [Session Management]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#session-management +var ErrSessionMissing = errors.New("session not found") var _ clientConnection = (*streamableClientConn)(nil) @@ -1572,7 +1577,7 @@ func (c *streamableClientConn) connectStandaloneSSE() { // If err is non-nil, it is terminal, and subsequent (or pending) Reads will // fail. // -// If err wraps errSessionMissing, the failure indicates that the session is no +// If err wraps [ErrSessionMissing], the failure indicates that the session is no // longer present on the server, and no final DELETE will be performed when // closing the connection. func (c *streamableClientConn) fail(err error) { @@ -1816,9 +1821,9 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // which it MUST respond to requests containing that session ID with HTTP // 404 Not Found." if resp.StatusCode == http.StatusNotFound { - // Return an errSessionMissing to avoid sending a redundant DELETE when the + // Return an ErrSessionMissing to avoid sending a redundant DELETE when the // session is already gone. - return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing) + return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, ErrSessionMissing) } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode)) @@ -1970,7 +1975,7 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin // Close implements the [Connection] interface. func (c *streamableClientConn) Close() error { c.closeOnce.Do(func() { - if errors.Is(c.failure(), errSessionMissing) { + if errors.Is(c.failure(), ErrSessionMissing) { // If the session is missing, no need to delete it. } else { req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil) diff --git a/mcp/streamable_client.go b/mcp/streamable_client.go index 41a10046..c2cc25b8 100644 --- a/mcp/streamable_client.go +++ b/mcp/streamable_client.go @@ -161,7 +161,7 @@ The client must handle two response formats from POST requests: - DELETE: Terminate the session - Used by [streamableClientConn.Close] - - Skipped if session is already known to be gone ([errSessionMissing]) + - Skipped if session is already known to be gone ([ErrSessionMissing]) # Error Handling @@ -173,7 +173,7 @@ Errors are categorized and handled differently: - Triggers reconnection in [streamableClientConn.handleSSE] 2. Terminal (breaks the connection): - - 404 Not Found: Session terminated by server ([errSessionMissing]) + - 404 Not Found: Session terminated by server ([ErrSessionMissing]) - Message decode errors: Protocol violation - Context cancellation: Client closed connection - Mismatched session IDs: Protocol error @@ -183,7 +183,7 @@ Terminal errors are stored via [streamableClientConn.fail] and returned by subsequent [streamableClientConn.Read] calls. The [streamableClientConn.failed] channel signals that the connection is broken. -Special case: [errSessionMissing] indicates the server has terminated the session, +Special case: [ErrSessionMissing] indicates the server has terminated the session, so [streamableClientConn.Close] skips the DELETE request. # Protocol Version Header diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index dcdda322..53f61b4f 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "errors" "fmt" "io" "net/http" @@ -234,6 +235,9 @@ func TestStreamableClientRedundantDelete(t *testing.T) { if err == nil { t.Errorf("Listing tools: got nil error, want non-nil") } + if !errors.Is(err, ErrSessionMissing) { + t.Errorf("Listing tools: got %v, want error wrapping ErrSessionMissing", err) + } _ = session.Wait() // must not hang if missing := fake.missingRequests(); len(missing) > 0 { t.Errorf("did not receive expected requests: %v", missing) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 9eff74b8..e335ed10 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -367,6 +367,52 @@ func TestStreamableServerShutdown(t *testing.T) { } } +// TestStreamableServerSessionClose verifies that when the server closes a +// session, the client observes ErrSessionMissing on subsequent requests. +func TestStreamableServerSessionClose(t *testing.T) { + ctx := context.Background() + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "greet"}, sayHi) + + handler := NewStreamableHTTPHandler( + func(req *http.Request) *Server { return server }, + nil, + ) + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + clientSession, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + + // Verify the connection works initially. + if _, err := clientSession.ListTools(ctx, nil); err != nil { + t.Fatalf("ListTools failed: %v", err) + } + + // Close the server session. + for session := range server.Sessions() { + if err := session.Close(); err != nil { + t.Fatalf("closing server session: %v", err) + } + } + + // Subsequent requests should fail with ErrSessionMissing. + _, err = clientSession.ListTools(ctx, nil) + if err == nil { + t.Fatalf("ListTools after server session close: got nil error, want error") + } + if !errors.Is(err, ErrSessionMissing) { + t.Errorf("ListTools after server session close: got %v, want error wrapping ErrSessionMissing", err) + } +} + // TestClientReplay verifies that the client can recover from a mid-stream // network failure and receive replayed messages (if replay is configured). It // uses a proxy that is killed and restarted to simulate a recoverable network