Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,10 @@ func (c *streamableServerConn) Close() error {
// A StreamableClientTransport is a [Transport] that can communicate with an MCP
// endpoint serving the streamable HTTP transport defined by the 2025-03-26
// version of the spec.
//
// If the server terminates the session (returning 404 Not Found), subsequent
// operations on the [ClientSession] will return an error wrapping
// [ErrSessionMissing].
type StreamableClientTransport struct {
Endpoint string
HTTPClient *http.Client
Expand Down Expand Up @@ -1493,16 +1497,17 @@ type streamableClientConn struct {
sessionID string
}

// errSessionMissing distinguishes if the session is known to not be present on
// the server (see [streamableClientConn.fail]).
// ErrSessionMissing is a sentinel error returned by a [ClientSession] using the
// [StreamableClientTransport], indicating the session is not present on the
// server. This occurs when the server returns a 404 Not Found response.
//
// TODO(rfindley): should we expose this error value (and its corresponding
// API) to the user?
// According to the MCP spec ([Session Management]), clients should reestablish
// a session when encountering this error. Users can check for this error using
// [errors.Is] to implement session recovery logic by creating a new
// [ClientSession].
//
// The spec says that if the server returns 404, clients should reestablish
// a session. For now, we delegate that to the user, but do they need a way to
// differentiate a 'NotFound' error from other errors?
var errSessionMissing = errors.New("session not found")
// [Session Management]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#session-management
var ErrSessionMissing = errors.New("session not found")

var _ clientConnection = (*streamableClientConn)(nil)

Expand Down Expand Up @@ -1572,7 +1577,7 @@ func (c *streamableClientConn) connectStandaloneSSE() {
// If err is non-nil, it is terminal, and subsequent (or pending) Reads will
// fail.
//
// If err wraps errSessionMissing, the failure indicates that the session is no
// If err wraps [ErrSessionMissing], the failure indicates that the session is no
// longer present on the server, and no final DELETE will be performed when
// closing the connection.
func (c *streamableClientConn) fail(err error) {
Expand Down Expand Up @@ -1816,9 +1821,9 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R
// which it MUST respond to requests containing that session ID with HTTP
// 404 Not Found."
if resp.StatusCode == http.StatusNotFound {
// Return an errSessionMissing to avoid sending a redundant DELETE when the
// Return an ErrSessionMissing to avoid sending a redundant DELETE when the
// session is already gone.
return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing)
return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, ErrSessionMissing)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode))
Expand Down Expand Up @@ -1970,7 +1975,7 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin
// Close implements the [Connection] interface.
func (c *streamableClientConn) Close() error {
c.closeOnce.Do(func() {
if errors.Is(c.failure(), errSessionMissing) {
if errors.Is(c.failure(), ErrSessionMissing) {
// If the session is missing, no need to delete it.
} else {
req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil)
Expand Down
6 changes: 3 additions & 3 deletions mcp/streamable_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ The client must handle two response formats from POST requests:

- DELETE: Terminate the session
- Used by [streamableClientConn.Close]
- Skipped if session is already known to be gone ([errSessionMissing])
- Skipped if session is already known to be gone ([ErrSessionMissing])

# Error Handling

Expand All @@ -173,7 +173,7 @@ Errors are categorized and handled differently:
- Triggers reconnection in [streamableClientConn.handleSSE]

2. Terminal (breaks the connection):
- 404 Not Found: Session terminated by server ([errSessionMissing])
- 404 Not Found: Session terminated by server ([ErrSessionMissing])
- Message decode errors: Protocol violation
- Context cancellation: Client closed connection
- Mismatched session IDs: Protocol error
Expand All @@ -183,7 +183,7 @@ Terminal errors are stored via [streamableClientConn.fail] and returned by
subsequent [streamableClientConn.Read] calls. The [streamableClientConn.failed]
channel signals that the connection is broken.

Special case: [errSessionMissing] indicates the server has terminated the session,
Special case: [ErrSessionMissing] indicates the server has terminated the session,
so [streamableClientConn.Close] skips the DELETE request.

# Protocol Version Header
Expand Down
4 changes: 4 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package mcp

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -234,6 +235,9 @@ func TestStreamableClientRedundantDelete(t *testing.T) {
if err == nil {
t.Errorf("Listing tools: got nil error, want non-nil")
}
if !errors.Is(err, ErrSessionMissing) {
t.Errorf("Listing tools: got %v, want error wrapping ErrSessionMissing", err)
}
_ = session.Wait() // must not hang
if missing := fake.missingRequests(); len(missing) > 0 {
t.Errorf("did not receive expected requests: %v", missing)
Expand Down
46 changes: 46 additions & 0 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,52 @@ func TestStreamableServerShutdown(t *testing.T) {
}
}

// TestStreamableServerSessionClose verifies that when the server closes a
// session, the client observes ErrSessionMissing on subsequent requests.
func TestStreamableServerSessionClose(t *testing.T) {
ctx := context.Background()

server := NewServer(testImpl, nil)
AddTool(server, &Tool{Name: "greet"}, sayHi)

handler := NewStreamableHTTPHandler(
func(req *http.Request) *Server { return server },
nil,
)

httpServer := httptest.NewServer(handler)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
clientSession, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
defer clientSession.Close()

// Verify the connection works initially.
if _, err := clientSession.ListTools(ctx, nil); err != nil {
t.Fatalf("ListTools failed: %v", err)
}

// Close the server session.
for session := range server.Sessions() {
if err := session.Close(); err != nil {
t.Fatalf("closing server session: %v", err)
}
}

// Subsequent requests should fail with ErrSessionMissing.
_, err = clientSession.ListTools(ctx, nil)
if err == nil {
t.Fatalf("ListTools after server session close: got nil error, want error")
}
if !errors.Is(err, ErrSessionMissing) {
t.Errorf("ListTools after server session close: got %v, want error wrapping ErrSessionMissing", err)
}
}

// TestClientReplay verifies that the client can recover from a mid-stream
// network failure and receive replayed messages (if replay is configured). It
// uses a proxy that is killed and restarted to simulate a recoverable network
Expand Down