diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 67f4f4e1d..3b4ff6be7 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -43,6 +43,34 @@ public class HttpServerTransportOptions /// public bool Stateless { get; set; } + /// + /// Gets or sets the event store for resumability support. + /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header. + /// + /// + /// When configured, the server will: + /// + /// Generate unique event IDs for each SSE message + /// Store events for later replay + /// Replay missed events when a client reconnects with a Last-Event-ID header + /// Send priming events to establish resumability before any actual messages + /// + /// + public ISseEventStore? EventStore { get; set; } + + /// + /// Gets or sets the retry interval to suggest to clients in SSE retry field. + /// + /// + /// The retry interval. The default is 5 seconds. + /// + /// + /// When is set, the server will include a retry field in priming events. + /// This value suggests to clients how long to wait before attempting to reconnect after a connection is lost. + /// Clients may use this value to implement polling behavior during long-running operations. + /// + public TimeSpan RetryInterval { get; set; } = TimeSpan.FromSeconds(5); + /// /// Gets or sets a value that indicates whether the server uses a single execution context for the entire session. /// diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 9f4af7ea5..32faf4f9f 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -23,6 +23,7 @@ internal sealed class StreamableHttpHandler( ILoggerFactory loggerFactory) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; + private const string LastEventIdHeaderName = "Last-Event-ID"; private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); @@ -88,10 +89,15 @@ await WriteJsonRpcErrorAsync(context, return; } - if (!session.TryStartGetRequest()) + // Check for Last-Event-ID header for resumability + var lastEventId = context.Request.Headers[LastEventIdHeaderName].ToString(); + var isResumption = !string.IsNullOrEmpty(lastEventId); + + // Only check TryStartGetRequest for new connections, not resumptions + if (!isResumption && !session.TryStartGetRequest()) { await WriteJsonRpcErrorAsync(context, - "Bad Request: This server does not support multiple GET requests. Start a new session to get a new GET SSE response.", + "Bad Request: This server does not support multiple GET requests. Use Last-Event-ID header to resume or start a new session.", StatusCodes.Status400BadRequest); return; } @@ -111,7 +117,7 @@ await WriteJsonRpcErrorAsync(context, // will be sent in response to a different POST request. It might be a while before we send a message // over this response body. await context.Response.Body.FlushAsync(cancellationToken); - await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken); + await session.Transport.HandleGetRequestAsync(context.Response.Body, isResumption ? lastEventId : null, cancellationToken); } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { @@ -194,7 +200,10 @@ private async ValueTask StartNewSessionAsync(HttpContext { SessionId = sessionId, FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, + EventStore = HttpServerTransportOptions.EventStore, + RetryInterval = HttpServerTransportOptions.RetryInterval, }; + context.Response.Headers[McpSessionIdHeaderName] = sessionId; } else diff --git a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs index 624a14aa1..582a6f8e0 100644 --- a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs @@ -106,4 +106,17 @@ public required Uri Endpoint /// Gets sor sets the authorization provider to use for authentication. /// public ClientOAuthOptions? OAuth { get; set; } + + /// + /// Gets or sets the maximum number of reconnection attempts when an SSE stream is disconnected. + /// + /// + /// The maximum number of reconnection attempts. The default is 2. + /// + /// + /// When an SSE stream is disconnected (e.g., due to a network issue), the client will attempt to + /// reconnect using the Last-Event-ID header to resume from where it left off. This property controls + /// how many reconnection attempts are made before giving up. + /// + public int MaxReconnectionAttempts { get; set; } = 2; } diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 534249038..5723c2dc2 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -16,6 +16,8 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private static readonly MediaTypeWithQualityHeaderValue s_applicationJsonMediaType = new("application/json"); private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); + private static readonly TimeSpan s_defaultReconnectionDelay = TimeSpan.FromSeconds(1); + private readonly McpHttpClient _httpClient; private readonly HttpClientTransportOptions _options; private readonly CancellationTokenSource _connectionCts; @@ -106,7 +108,17 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream") { using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken); - rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false); + var sseState = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false); + rpcResponseOrError = sseState.Response; + + // Resumability: If POST SSE stream ended without a response but we have a Last-Event-ID (from priming), + // attempt to resume by sending a GET request with Last-Event-ID header. The server will replay + // events from the event store, allowing us to receive the pending response. + if (rpcResponseOrError is null && rpcRequest is not null && sseState.LastEventId is not null) + { + var resumeResult = await SendGetSseRequestWithRetriesAsync(rpcRequest, sseState, cancellationToken).ConfigureAwait(false); + rpcResponseOrError = resumeResult.Response; + } } if (rpcRequest is null) @@ -188,54 +200,135 @@ public override async ValueTask DisposeAsync() private async Task ReceiveUnsolicitedMessagesAsync() { - // Send a GET request to handle any unsolicited messages not sent over a POST response. - using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); - request.Headers.Accept.Add(s_textEventStreamMediaType); - CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); + var state = new SseStreamState(); - // Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages. - HttpResponseMessage response; - try - { - response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); - } - catch (HttpRequestException) + // Continuously receive unsolicited messages until cancelled + while (!_connectionCts.Token.IsCancellationRequested) { - return; + var result = await SendGetSseRequestWithRetriesAsync( + relatedRpcRequest: null, + state, + _connectionCts.Token).ConfigureAwait(false); + + // Update state for next reconnection attempt + state.UpdateFrom(result); + + // If we exhausted retries without receiving any events, stop trying + if (result.LastEventId is null) + { + return; + } } + } + + /// + /// Sends a GET request for SSE with retry logic and resumability support. + /// + private async Task SendGetSseRequestWithRetriesAsync( + JsonRpcRequest? relatedRpcRequest, + SseStreamState state, + CancellationToken cancellationToken) + { + int attempt = 0; + + // Delay before first attempt if we're reconnecting (have a Last-Event-ID) + bool shouldDelay = state.LastEventId is not null; - using (response) + while (attempt < _options.MaxReconnectionAttempts) { - if (!response.IsSuccessStatusCode) + cancellationToken.ThrowIfCancellationRequested(); + + if (shouldDelay) { - return; + var delay = state.RetryInterval ?? s_defaultReconnectionDelay; + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); } + shouldDelay = true; + + using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); + request.Headers.Accept.Add(s_textEventStreamMediaType); + CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion, state.LastEventId); + + HttpResponseMessage response; + try + { + response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); + } + catch (HttpRequestException) + { + attempt++; + continue; + } + + using (response) + { + if (!response.IsSuccessStatusCode) + { + attempt++; + continue; + } + + using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var result = await ProcessSseResponseAsync(responseStream, relatedRpcRequest, cancellationToken).ConfigureAwait(false); + + state.UpdateFrom(result); + + if (result.Response is not null) + { + return state; + } - using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false); - await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false); + // Stream closed without the response + if (state.LastEventId is null) + { + // No event ID means server may not support resumability - don't retry indefinitely + attempt++; + } + else + { + // We have an event ID, so reconnection should work - reset attempts + attempt = 0; + } + } } + + return state; } - private async Task ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) + private async Task ProcessSseResponseAsync( + Stream responseStream, + JsonRpcRequest? relatedRpcRequest, + CancellationToken cancellationToken) { + var state = new SseStreamState(); + await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) { - if (sseEvent.EventType != "message") + // Track event ID and retry interval for resumability + if (!string.IsNullOrEmpty(sseEvent.EventId)) + { + state.LastEventId = sseEvent.EventId; + } + if (sseEvent.ReconnectionInterval.HasValue) + { + state.RetryInterval = sseEvent.ReconnectionInterval.Value; + } + + // Skip events with empty data (priming events, keep-alives) + if (string.IsNullOrEmpty(sseEvent.Data) || sseEvent.EventType != "message") { continue; } var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false); - - // The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes - // a GET request for any notifications that might need to be sent after the completion of each POST. if (rpcResponseOrError is not null) { - return rpcResponseOrError; + state.Response = rpcResponseOrError; + return state; } } - return null; + return state; } private async Task ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) @@ -292,7 +385,8 @@ internal static void CopyAdditionalHeaders( HttpRequestHeaders headers, IDictionary? additionalHeaders, string? sessionId, - string? protocolVersion) + string? protocolVersion, + string? lastEventId = null) { if (sessionId is not null) { @@ -304,6 +398,11 @@ internal static void CopyAdditionalHeaders( headers.Add("MCP-Protocol-Version", protocolVersion); } + if (lastEventId is not null) + { + headers.Add("Last-Event-ID", lastEventId); + } + if (additionalHeaders is null) { return; @@ -317,4 +416,21 @@ internal static void CopyAdditionalHeaders( } } } + + /// + /// Tracks state across SSE stream connections. + /// + private struct SseStreamState + { + public JsonRpcMessageWithId? Response; + public string? LastEventId; + public TimeSpan? RetryInterval; + + public void UpdateFrom(SseStreamState other) + { + Response ??= other.Response; + LastEventId ??= other.LastEventId; + RetryInterval ??= other.RetryInterval; + } + } } diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index dd6814640..683e1d18d 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -29,16 +29,32 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); /// The latest version of the protocol supported by this implementation. - internal const string LatestProtocolVersion = "2025-06-18"; + internal const string LatestProtocolVersion = "2025-11-25"; /// All protocol versions supported by this implementation. internal static readonly string[] SupportedProtocolVersions = [ "2024-11-05", "2025-03-26", + "2025-06-18", LatestProtocolVersion, ]; + /// + /// Checks if the given protocol version supports priming events. + /// + /// The protocol version to check. + /// True if the protocol version supports resumability. + /// + /// Priming events are only supported in protocol version >= 2025-11-25. + /// Older clients may crash when receiving SSE events with empty data. + /// + internal static bool SupportsPrimingEvent(string? protocolVersion) + { + const string MinResumabilityProtocolVersion = "2025-11-25"; + return string.Compare(protocolVersion, MinResumabilityProtocolVersion, StringComparison.Ordinal) >= 0; + } + private readonly bool _isServer; private readonly string _transportKind; private readonly ITransport _transport; diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs index 0b9cd0416..763f62e40 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -1,6 +1,5 @@ using ModelContextProtocol.Server; using System.Security.Claims; -using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol; diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStore.cs b/src/ModelContextProtocol.Core/Server/ISseEventStore.cs new file mode 100644 index 000000000..c8686f8c6 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/ISseEventStore.cs @@ -0,0 +1,54 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Defines the contract for storing and replaying SSE events to support resumability. +/// +/// +/// +/// When a client reconnects with a Last-Event-ID header, the server uses the event store +/// to replay events that were sent after the specified event ID. This enables clients to +/// recover from connection drops without losing messages. +/// +/// +/// Events are scoped to streams, where each stream corresponds to either a specific request ID +/// (for POST SSE responses) or a special "standalone" stream ID (for unsolicited GET SSE messages). +/// +/// +/// Implementations should be thread-safe, as events may be stored and replayed concurrently. +/// +/// +public interface ISseEventStore +{ + /// + /// Stores an event for later retrieval. + /// + /// + /// The ID of the session, or null. + /// + /// + /// The ID of the stream the event belongs to. This is typically the JSON-RPC request ID + /// for POST SSE responses, or a special identifier for the standalone GET SSE stream. + /// + /// + /// The JSON-RPC message to store, or for priming events. + /// Priming events establish the event ID without carrying a message payload. + /// + /// A token to cancel the operation. + /// The generated event ID for the stored event. + ValueTask StoreEventAsync(string sessionId, string streamId, JsonRpcMessage? message, CancellationToken cancellationToken = default); + + /// + /// Replays events that occurred after the specified event ID. + /// + /// The ID of the last event the client received. + /// A token to cancel the operation. + /// + /// An containing the events to replay if the event ID was found; + /// if the event ID was not found in the store. + /// + ValueTask GetEventsAfterAsync( + string lastEventId, + CancellationToken cancellationToken = default); +} diff --git a/src/ModelContextProtocol.Core/Server/SseReplayResult.cs b/src/ModelContextProtocol.Core/Server/SseReplayResult.cs new file mode 100644 index 000000000..fd9fdd103 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/SseReplayResult.cs @@ -0,0 +1,33 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Represents the result of replaying SSE events to a client during resumption. +/// +/// +/// This class is returned by when a client +/// reconnects with a Last-Event-ID header. It contains the stream and session identifiers +/// along with an async enumerable of events to replay. +/// +public sealed class SseReplayResult +{ + /// + /// Gets the session ID that the events belong to. + /// + public required string SessionId { get; init; } + + /// + /// Gets the stream ID that the events belong to. + /// + /// + /// This is typically the JSON-RPC request ID for POST SSE responses, + /// or a special identifier for the standalone GET SSE stream. + /// + public required string StreamId { get; init; } + + /// + /// Gets the async enumerable of events to replay to the client. + /// + public required IAsyncEnumerable Events { get; init; } +} diff --git a/src/ModelContextProtocol.Core/Server/SseStreamEventStore.cs b/src/ModelContextProtocol.Core/Server/SseStreamEventStore.cs new file mode 100644 index 000000000..3f2ad78c1 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/SseStreamEventStore.cs @@ -0,0 +1,47 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Wraps an with session and stream context for a specific SSE stream. +/// +/// +/// This class simplifies event storage by binding the session ID, stream ID, and retry interval +/// so that callers only need to provide the message when storing events. +/// +internal sealed class SseStreamEventStore +{ + private readonly ISseEventStore _eventStore; + private readonly string _sessionId; + private readonly string _streamId; + private readonly TimeSpan _retryInterval; + + /// + /// Gets the retry interval to suggest to clients in SSE retry field. + /// + public TimeSpan RetryInterval => _retryInterval; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying event store to use for storage. + /// The session ID, or to generate a new one. + /// The stream ID for this SSE stream. + /// The retry interval to suggest to clients. + public SseStreamEventStore(ISseEventStore eventStore, string? sessionId, string streamId, TimeSpan retryInterval) + { + _eventStore = eventStore; + _sessionId = sessionId ?? Guid.NewGuid().ToString("N"); + _streamId = streamId; + _retryInterval = retryInterval; + } + + /// + /// Stores an event in the underlying event store with the bound session and stream context. + /// + /// The JSON-RPC message to store, or for priming events. + /// A token to cancel the operation. + /// The generated event ID for the stored event. + public ValueTask StoreEventAsync(JsonRpcMessage? message, CancellationToken cancellationToken = default) + => _eventStore.StoreEventAsync(_sessionId, _streamId, message, cancellationToken); +} diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs index a2314e623..12e80d792 100644 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs @@ -22,9 +22,11 @@ internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOp private readonly SemaphoreSlim _disposeLock = new(1, 1); private bool _disposed; + public SseStreamEventStore? EventStore { get; set; } + public Func>, CancellationToken, IAsyncEnumerable>>? MessageFilter { get; set; } - public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) + public async Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) { Throw.IfNull(sseResponseStream); @@ -44,24 +46,78 @@ public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellati } _writeTask = SseFormatter.WriteAsync(messages, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken); - return _writeTask; + await _writeTask.ConfigureAwait(false); + } + + /// + /// Sends a priming event with an event ID but no message payload. + /// This establishes resumability for the stream before any actual messages are sent. + /// + public async Task SendPrimingEventAsync(CancellationToken cancellationToken = default) + { + if (EventStore is null) + { + return null; + } + + using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false); + + if (_disposed) + { + return null; + } + + // Store a null message to get an event ID for the priming event + var eventId = await EventStore.StoreEventAsync(message: null, cancellationToken); + + // Create a priming event: empty data with an event ID + // We use a special "priming" event type that the formatter will handle + var primingItem = new SseItem(null, "priming") + { + EventId = eventId, + ReconnectionInterval = EventStore.RetryInterval, + }; + + await _messages.Writer.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false); + return eventId; } public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => await SendMessageAsync(message, eventId: null, cancellationToken).ConfigureAwait(false); + + /// + /// Sends a message with an optional pre-assigned event ID. + /// This is used for replaying stored events with their original IDs. + /// + /// + /// If resumability is enabled and no event ID is provided, the message is stored + /// in the event store before being written to the stream. This ensures messages + /// are persisted even if the stream is closed before they can be written. + /// + public async Task SendMessageAsync(JsonRpcMessage message, string? eventId, CancellationToken cancellationToken = default) { Throw.IfNull(message); using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false); + // Store the event first (even if stream is completed) so clients can retrieve it via Last-Event-ID. + // Skip if eventId is already provided (replayed events are already stored). + if (eventId is null && EventStore is not null) + { + eventId = await EventStore.StoreEventAsync(message, cancellationToken).ConfigureAwait(false); + } + if (_disposed) { - // Don't throw ObjectDisposedException here; just return false to indicate the message wasn't sent. - // The calling transport can determine what to do in this case (drop the message, or fall back to another transport). + // Message is stored but stream is closed - client can retrieve via Last-Event-ID. return false; } - // Emit redundant "event: message" lines for better compatibility with other SDKs. - await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false); + var item = new SseItem(message, SseParser.EventTypeDefault) + { + EventId = eventId, + }; + await _messages.Writer.WriteAsync(item, cancellationToken).ConfigureAwait(false); return true; } @@ -74,7 +130,8 @@ public async ValueTask DisposeAsync() return; } - _messages.Writer.Complete(); + // Signal completion if not already done (e.g., by Complete()) + _messages.Writer.TryComplete(); try { if (_writeTask is not null) @@ -101,6 +158,12 @@ private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferW return; } + // Priming events have empty data - just write nothing + if (item.EventType == "priming") + { + return; + } + JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!); } diff --git a/src/ModelContextProtocol.Core/Server/StoredSseEvent.cs b/src/ModelContextProtocol.Core/Server/StoredSseEvent.cs new file mode 100644 index 000000000..67fb3ae45 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/StoredSseEvent.cs @@ -0,0 +1,27 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Represents a stored SSE event that can be replayed to a client during resumption. +/// +/// +/// This struct is used when replaying events from an +/// after a client reconnects with a Last-Event-ID header. +/// +public readonly struct StoredSseEvent +{ + /// + /// Gets the JSON-RPC message that was stored for this event. + /// + public required JsonRpcMessage Message { get; init; } + + /// + /// Gets the unique event ID that was assigned when the event was stored. + /// + /// + /// This ID is sent to the client so it can be used in subsequent + /// Last-Event-ID headers for further resumption. + /// + public required string EventId { get; init; } +} diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 1109c2b2b..89d787ec6 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,9 +1,7 @@ using ModelContextProtocol.Protocol; using System.Diagnostics; -using System.IO.Pipelines; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; -using System.Security.Claims; using System.Text.Json; using System.Threading.Channels; @@ -35,11 +33,11 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio { _pendingRequest = request.Id; - // Invoke the initialize request callback if applicable. - if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) + // Invoke the initialize request handler if applicable. + if (request.Method == RequestMethods.Initialize) { var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - await onInitRequest(initializeRequest).ConfigureAwait(false); + await parentTransport.HandleInitRequestAsync(initializeRequest).ConfigureAwait(false); } } @@ -58,6 +56,18 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio return false; } + // Configure the SSE writer for resumability if we have an event store + if (parentTransport.EventStore is not null && McpSessionHandler.SupportsPrimingEvent(parentTransport.NegotiatedProtocolVersion)) + { + _sseWriter.EventStore = new SseStreamEventStore( + parentTransport.EventStore, + sessionId: parentTransport.SessionId, + streamId: _pendingRequest.Id.ToString()!, + retryInterval: parentTransport.RetryInterval); + + await _sseWriter.SendPrimingEventAsync(cancellationToken).ConfigureAwait(false); + } + _sseWriter.MessageFilter = StopOnFinalResponseFilter; await _sseWriter.WriteAllAsync(responseStream, cancellationToken).ConfigureAwait(false); return true; @@ -72,11 +82,13 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can throw new InvalidOperationException("Server to client requests are not supported in stateless mode."); } + // SseWriter.SendMessageAsync stores the message in the event store before checking if + // the stream is completed, so we don't need to handle storage here. bool isAccepted = await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); if (!isAccepted) { - // The underlying writer didn't accept the message because the underlying request has completed. - // Rather than drop the message, fall back to sending it via the parent transport. + // The stream is closed - fall back to sending via the parent transport's standalone SSE stream. + // The message is already stored in the event store by SseWriter. await parentTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 8a64094e4..a4b0e5377 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Protocol; -using System.IO.Pipelines; using System.Security.Claims; using System.Threading.Channels; @@ -21,6 +20,11 @@ namespace ModelContextProtocol.Server; /// public sealed class StreamableHttpServerTransport : ITransport { + /// + /// The stream ID used for unsolicited messages sent via the standalone GET SSE stream. + /// + internal const string GetStreamId = "__get__"; + // For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages. private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1) { @@ -43,7 +47,7 @@ public sealed class StreamableHttpServerTransport : ITransport /// /// Gets or initializes a value that indicates whether the transport should be in stateless mode that does not require all requests for a given session /// to arrive to the same ASP.NET Core application process. Unsolicited server-to-client messages are not supported in this mode, - /// so calling results in an . + /// so calling results in an . /// Server-to-client requests are also unsupported, because the responses might arrive at another ASP.NET Core application process. /// Client sampling and roots capabilities are also disabled in stateless mode, because the server cannot make requests. /// @@ -63,20 +67,56 @@ public sealed class StreamableHttpServerTransport : ITransport /// public Func? OnInitRequestReceived { get; set; } + /// + /// Gets or sets the event store for resumability support. + /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header. + /// + public ISseEventStore? EventStore { get; set; } + + /// + /// Gets or sets the retry interval to suggest to clients in SSE retry field. + /// When is set, the server will include a retry field in priming events. + /// + public TimeSpan RetryInterval { get; set; } = TimeSpan.FromSeconds(5); + + /// + /// Gets or sets the negotiated protocol version for this session. + /// + public string? NegotiatedProtocolVersion { get; set; } + /// public ChannelReader MessageReader => _incomingChannel.Reader; internal ChannelWriter MessageWriter => _incomingChannel.Writer; + /// + /// Handles the initialize request by capturing the protocol version and invoking the user callback. + /// + internal async ValueTask HandleInitRequestAsync(InitializeRequestParams? initParams) + { + // Capture the negotiated protocol version for resumability checks + NegotiatedProtocolVersion = initParams?.ProtocolVersion; + + // Invoke user-provided callback if specified + if (OnInitRequestReceived is { } callback) + { + await callback(initParams).ConfigureAwait(false); + } + } + /// /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by /// writing any unsolicited JSON-RPC messages sent via /// to the SSE response stream until cancellation is requested or the transport is disposed. /// /// The response stream to write MCP JSON-RPC messages as SSE events to. + /// + /// The Last-Event-ID header value from the client request for resumability. + /// When provided, the server will replay events that occurred after this ID before streaming new events. + /// /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default) + public async Task HandleGetRequestAsync(Stream sseResponseStream, string? lastEventId = null, CancellationToken cancellationToken = default) { Throw.IfNull(sseResponseStream); @@ -85,9 +125,55 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo throw new InvalidOperationException("GET requests are not supported in stateless mode."); } + var isResumingGetStream = false; + + if (!string.IsNullOrEmpty(lastEventId) && EventStore is not null) + { + var eventsToReplay = await EventStore.GetEventsAfterAsync(lastEventId!, cancellationToken); + + if (eventsToReplay is not null) + { + if (eventsToReplay.StreamId == GetStreamId) + { + isResumingGetStream = true; + + // Replay messages onto the new background GET stream + await foreach (var e in eventsToReplay.Events) + { + await _sseWriter.SendMessageAsync(e.Message, e.EventId, cancellationToken); + } + } + else + { + // Replay messages onto a new stream specific to this request + await using var newStream = new SseWriter(); + await foreach (var e in eventsToReplay.Events) + { + await newStream.SendMessageAsync(e.Message, e.EventId, cancellationToken); + } + + await newStream.WriteAllAsync(sseResponseStream, cancellationToken); + return; + } + } + } + + // New GET stream (not resumption) - only allow one per session if (Interlocked.Exchange(ref _getRequestStarted, 1) == 1) { - throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); + throw new InvalidOperationException("Only one GET SSE stream is allowed per session. Use Last-Event-ID header to resume."); + } + + // Configure the SSE writer for resumability if we have an event store and the client supports priming + if (EventStore is not null && McpSessionHandler.SupportsPrimingEvent(NegotiatedProtocolVersion)) + { + _sseWriter.EventStore = new(EventStore, SessionId, GetStreamId, RetryInterval); + + if (!isResumingGetStream) + { + // Send a priming event to establish resumability + await _sseWriter.SendPrimingEventAsync(cancellationToken).ConfigureAwait(false); + } } // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully. diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs new file mode 100644 index 000000000..f430edfde --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -0,0 +1,266 @@ +using System.ComponentModel; +using System.Net.ServerSentEvents; +using System.Text; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Integration tests for SSE resumability with full client-server flow. +/// These tests use McpClient for end-to-end testing and only use raw HTTP +/// for SSE format verification where McpClient abstracts away the details. +/// +public class ResumabilityIntegrationTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + private const string InitializeRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"TestClient","version":"1.0.0"}}} + """; + + [Fact] + public async Task Server_StoresEvents_WhenEventStoreConfigured() + { + // Arrange + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore); + await using var client = await ConnectClientAsync(); + + // Act - Make a tool call which generates events + var result = await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // Assert - Events were stored + Assert.NotNull(result); + Assert.True(eventStore.StoreEventCallCount > 0, "Expected events to be stored when EventStore is configured"); + } + + [Fact] + public async Task Server_StoresMultipleEvents_ForMultipleToolCalls() + { + // Arrange + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore); + await using var client = await ConnectClientAsync(); + + // Act - Make multiple tool calls + var initialCount = eventStore.StoreEventCallCount; + + await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test1" }, + cancellationToken: TestContext.Current.CancellationToken); + + var countAfterFirst = eventStore.StoreEventCallCount; + + await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test2" }, + cancellationToken: TestContext.Current.CancellationToken); + + var countAfterSecond = eventStore.StoreEventCallCount; + + // Assert - More events were stored for each call + Assert.True(countAfterFirst > initialCount, "Expected more events after first call"); + Assert.True(countAfterSecond > countAfterFirst, "Expected more events after second call"); + } + + [Fact] + public async Task Client_CanMakeMultipleRequests_WithResumabilityEnabled() + { + // Arrange + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore); + await using var client = await ConnectClientAsync(); + + // Act - Make many requests to verify stability + for (int i = 0; i < 5; i++) + { + var result = await client.CallToolAsync("echo", + new Dictionary { ["message"] = $"test{i}" }, + cancellationToken: TestContext.Current.CancellationToken); + + var textContent = Assert.Single(result.Content.OfType()); + Assert.Equal($"Echo: test{i}", textContent.Text); + } + + // Assert - All requests succeeded and events were stored + Assert.True(eventStore.StoreEventCallCount >= 5, "Expected events to be stored for each request"); + } + + [Fact] + public async Task Ping_WorksWithResumabilityEnabled() + { + // Arrange + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore); + await using var client = await ConnectClientAsync(); + + // Act & Assert - Ping should work + await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task ListTools_WorksWithResumabilityEnabled() + { + // Arrange + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore); + await using var client = await ConnectClientAsync(); + + // Act + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(tools); + Assert.Single(tools); + } + + [Fact] + public async Task Server_IncludesEventIdAndRetry_InSseResponse() + { + // Arrange + var expectedRetryInterval = TimeSpan.FromSeconds(5); + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore, retryInterval: expectedRetryInterval); + + // Act + var sseResponse = await SendInitializeAndReadSseResponseAsync(InitializeRequest); + + // Assert - Event IDs and retry field should be present in the response + Assert.True(sseResponse.LastEventId is not null, "Expected SSE response to contain event IDs"); + Assert.Equal(expectedRetryInterval, sseResponse.RetryInterval); + } + + [Fact] + public async Task Server_WithoutEventStore_DoesNotIncludeEventIdAndRetry() + { + // Arrange - Server without event store + await using var app = await CreateServerAsync(); + + // Act + var sseResponse = await SendInitializeAndReadSseResponseAsync(InitializeRequest); + + // Assert - No event IDs or retry field when EventStore is not configured + Assert.True(sseResponse.LastEventId is null, "Did not expect event IDs when EventStore is not configured"); + Assert.True(sseResponse.RetryInterval is null, "Did not expect retry field when EventStore is not configured"); + } + + [Fact] + public async Task Server_DoesNotSendPrimingEvents_ToOlderProtocolVersionClients() + { + // Arrange - Server with resumability enabled + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore, retryInterval: TimeSpan.FromSeconds(5)); + + // Use an older protocol version that doesn't support resumability + const string OldProtocolInitRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"OldClient","version":"1.0.0"}}} + """; + + var sseResponse = await SendInitializeAndReadSseResponseAsync(OldProtocolInitRequest); + + // Assert - Old clients should not receive event IDs or retry fields (no priming events) + Assert.True(sseResponse.LastEventId is null, "Old protocol clients should not receive event IDs"); + Assert.True(sseResponse.RetryInterval is null, "Old protocol clients should not receive retry field"); + + // Event store should not have been called for old clients + Assert.Equal(0, eventStore.StoreEventCallCount); + } + + [Fact] + public async Task Client_ReceivesRetryInterval_FromServer() + { + // Arrange - Server with specific retry interval + var expectedRetry = TimeSpan.FromMilliseconds(3000); + var eventStore = new InMemoryEventStore(); + await using var app = await CreateServerAsync(eventStore, retryInterval: expectedRetry); + + // Act - Send initialize and read the retry field + var sseItem = await SendInitializeAndReadSseResponseAsync(InitializeRequest); + + // Assert - Client receives the retry interval from server + Assert.Equal(expectedRetry, sseItem.RetryInterval); + } + + [McpServerToolType] + private class ResumabilityTestTools + { + [McpServerTool(Name = "echo"), Description("Echoes the message back")] + public static string Echo(string message) => $"Echo: {message}"; + } + + private async Task CreateServerAsync(ISseEventStore? eventStore = null, TimeSpan? retryInterval = null) + { + Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.EventStore = eventStore; + if (retryInterval.HasValue) + { + options.RetryInterval = retryInterval.Value; + } + }) + .WithTools(); + + var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private async Task ConnectClientAsync() + { + var transport = new HttpClientTransport(new HttpClientTransportOptions + { + Endpoint = new Uri("http://localhost:5000/"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + return await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + } + + private async Task SendInitializeAndReadSseResponseAsync(string initializeRequest) + { + using var requestContent = new StringContent(initializeRequest, Encoding.UTF8, "application/json"); + using var request = new HttpRequestMessage(HttpMethod.Post, "/") + { + Headers = + { + Accept = { new("application/json"), new("text/event-stream") } + }, + Content = requestContent, + }; + + var response = await HttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, + TestContext.Current.CancellationToken); + + response.EnsureSuccessStatusCode(); + + var sseResponse = new SseResponse(); + await using var stream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var sseItem in SseParser.Create(stream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + if (!string.IsNullOrEmpty(sseItem.EventId)) + { + sseResponse.LastEventId = sseItem.EventId; + } + if (sseItem.ReconnectionInterval.HasValue) + { + sseResponse.RetryInterval = sseItem.ReconnectionInterval.Value; + } + } + + return sseResponse; + } + + private struct SseResponse + { + public string? LastEventId { get; set; } + public TimeSpan? RetryInterval { get; set; } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityTests.cs new file mode 100644 index 000000000..a977eb30c --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityTests.cs @@ -0,0 +1,230 @@ +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for SSE resumability and redelivery features. +/// Tests focus on the ISseEventStore interface and unit-level behavior. +/// +public class ResumabilityTests : LoggedTest +{ + private const string TestSessionId = "test-session"; + + public ResumabilityTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) + { + } + + [Fact] + public async Task EventStore_StoresAndRetrievesEvents() + { + // Arrange + var eventStore = new InMemoryEventStore(); + var ct = TestContext.Current.CancellationToken; + + // Act + var eventId1 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test1" }, ct); + var eventId2 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test2" }, ct); + + // Assert + Assert.Equal("1", eventId1); + Assert.Equal("2", eventId2); + Assert.Equal(2, eventStore.StoredEventIds.Count); + } + + [Fact] + public async Task EventStore_TracksMultipleStreams() + { + // Arrange + var eventStore = new InMemoryEventStore(); + var ct = TestContext.Current.CancellationToken; + + // Store events for different streams + var stream1Event1 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test1" }, ct); + var stream1Event2 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test2" }, ct); + _ = await eventStore.StoreEventAsync(TestSessionId, "stream2", + new JsonRpcNotification { Method = "test3" }, ct); + var stream1Event3 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test4" }, ct); + + // Act - Get events after stream1Event1 + var result = await eventStore.GetEventsAfterAsync(stream1Event1, ct); + + // Assert + Assert.NotNull(result); + Assert.Equal("stream1", result.StreamId); + Assert.Equal(TestSessionId, result.SessionId); + + var storedEvents = new List(); + await foreach (var evt in result.Events) + { + storedEvents.Add(evt); + } + + Assert.Equal(2, storedEvents.Count); // Only stream1 events after stream1Event1 + + var notification1 = Assert.IsType(storedEvents[0].Message); + Assert.Equal("test2", notification1.Method); + Assert.Equal(stream1Event2, storedEvents[0].EventId); + + var notification2 = Assert.IsType(storedEvents[1].Message); + Assert.Equal("test4", notification2.Method); + Assert.Equal(stream1Event3, storedEvents[1].EventId); + } + + [Fact] + public async Task EventStore_ReturnsDefault_ForUnknownEventId() + { + // Arrange + var eventStore = new InMemoryEventStore(); + var ct = TestContext.Current.CancellationToken; + await eventStore.StoreEventAsync(TestSessionId, "stream1", new JsonRpcNotification { Method = "test" }, ct); + + // Act + var result = await eventStore.GetEventsAfterAsync("unknown-event-id", ct); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task EventStore_ReplaysNoEvents_WhenLastEventIsLatest() + { + // Arrange + var eventStore = new InMemoryEventStore(); + var ct = TestContext.Current.CancellationToken; + + var eventId1 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test1" }, ct); + var eventId2 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test2" }, ct); + + // Act - Get events after the last event + var result = await eventStore.GetEventsAfterAsync(eventId2, ct); + + // Assert - No events should be returned + Assert.NotNull(result); + Assert.Equal("stream1", result.StreamId); + + var storedEvents = new List(); + await foreach (var evt in result.Events) + { + storedEvents.Add(evt); + } + Assert.Empty(storedEvents); + } + + [Fact] + public async Task EventStore_HandlesPrimingEvents() + { + // Arrange - Priming events have null messages + var eventStore = new InMemoryEventStore(); + var ct = TestContext.Current.CancellationToken; + + // Store a priming event (null message) followed by real events + var primingEventId = await eventStore.StoreEventAsync(TestSessionId, "stream1", null, ct); + var eventId1 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test1" }, ct); + var eventId2 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "test2" }, ct); + + // Act - Get events after priming event + var result = await eventStore.GetEventsAfterAsync(primingEventId, ct); + + // Assert - Should return the two real events, not the priming event + Assert.NotNull(result); + Assert.Equal("stream1", result.StreamId); + + var storedEvents = new List(); + await foreach (var evt in result.Events) + { + storedEvents.Add(evt); + } + + Assert.Equal(2, storedEvents.Count); + Assert.Equal(eventId1, storedEvents[0].EventId); + Assert.Equal(eventId2, storedEvents[1].EventId); + } + + [Fact] + public async Task EventStore_ReplaysMixedMessageTypes() + { + // Arrange + var eventStore = new InMemoryEventStore(); + var ct = TestContext.Current.CancellationToken; + + var eventId1 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcNotification { Method = "notification" }, ct); + var eventId2 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcResponse { Id = new RequestId("1"), Result = null }, ct); + var eventId3 = await eventStore.StoreEventAsync(TestSessionId, "stream1", + new JsonRpcRequest { Id = new RequestId("2"), Method = "request" }, ct); + + // Act + var result = await eventStore.GetEventsAfterAsync(eventId1, ct); + + // Assert + Assert.NotNull(result); + Assert.Equal("stream1", result.StreamId); + + var storedEvents = new List(); + await foreach (var evt in result.Events) + { + storedEvents.Add(evt); + } + + Assert.Equal(2, storedEvents.Count); + + Assert.IsType(storedEvents[0].Message); + Assert.Equal(eventId2, storedEvents[0].EventId); + + Assert.IsType(storedEvents[1].Message); + Assert.Equal(eventId3, storedEvents[1].EventId); + } + + [Fact] + public void StreamableHttpServerTransport_HasEventStoreProperty() + { + // Arrange + var transport = new StreamableHttpServerTransport(); + + // Assert - EventStore property exists and is null by default + Assert.Null(transport.EventStore); + + // Act - Can set EventStore + var eventStore = new InMemoryEventStore(); + transport.EventStore = eventStore; + + // Assert + Assert.Same(eventStore, transport.EventStore); + } + + [Fact] + public void StreamableHttpServerTransport_GetStreamIdConstant_IsCorrect() + { + // The GetStreamId constant is internal, but we can test that transports + // with resumability configured behave consistently + var transport1 = new StreamableHttpServerTransport(); + var transport2 = new StreamableHttpServerTransport(); + + // Both should have null EventStore by default + Assert.Null(transport1.EventStore); + Assert.Null(transport2.EventStore); + + // Setting event stores should work independently + var store1 = new InMemoryEventStore(); + var store2 = new InMemoryEventStore(); + + transport1.EventStore = store1; + transport2.EventStore = store2; + + Assert.Same(store1, transport1.EventStore); + Assert.Same(store2, transport2.EventStore); + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/InMemoryEventStore.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/InMemoryEventStore.cs new file mode 100644 index 000000000..7b4c320c1 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/InMemoryEventStore.cs @@ -0,0 +1,90 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; + +namespace ModelContextProtocol.AspNetCore.Tests.Utils; + +/// +/// In-memory event store for testing resumability. +/// This is a simple implementation intended for testing, not for production use. +/// +public class InMemoryEventStore : ISseEventStore +{ + private readonly ConcurrentDictionary _events = new(); + private long _eventCounter; + + /// + /// Gets the list of stored event IDs in order of storage. + /// + public List StoredEventIds { get; } = []; + + /// + /// Gets the count of events that have been stored. + /// + public int StoreEventCallCount => StoredEventIds.Count; + + /// + public ValueTask StoreEventAsync(string sessionId, string streamId, JsonRpcMessage? message, CancellationToken cancellationToken = default) + { + var eventId = Interlocked.Increment(ref _eventCounter).ToString(); + _events[eventId] = (sessionId, streamId, message); + lock (StoredEventIds) + { + StoredEventIds.Add(eventId); + } + return new ValueTask(eventId); + } + + /// + public ValueTask GetEventsAfterAsync( + string lastEventId, + CancellationToken cancellationToken = default) + { + if (!_events.TryGetValue(lastEventId, out var lastEvent)) + { + return ValueTask.FromResult(null); + } + + var sessionId = lastEvent.SessionId; + var streamId = lastEvent.StreamId; + + return new ValueTask(new SseReplayResult + { + SessionId = sessionId, + StreamId = streamId, + Events = GetEventsAsync(lastEventId, sessionId, streamId, cancellationToken) + }); + } + + private async IAsyncEnumerable GetEventsAsync( + string lastEventId, + string sessionId, + string streamId, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var startReplay = false; + + foreach (var kvp in _events.OrderBy(e => long.Parse(e.Key))) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (kvp.Key == lastEventId) + { + startReplay = true; + continue; + } + + if (startReplay && kvp.Value.SessionId == sessionId && kvp.Value.StreamId == streamId && kvp.Value.Message is not null) + { + yield return new StoredSseEvent + { + Message = kvp.Value.Message, + EventId = kvp.Key + }; + } + } + + await Task.CompletedTask; + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index b400c6b0b..e5f55685b 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -107,7 +107,7 @@ public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperat { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] @@ -157,7 +157,7 @@ public async Task CreateSamplingHandler_ShouldHandleImageMessages() { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, Content = [new ImageContentBlock @@ -492,7 +492,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient() public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion) { await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = protocolVersion }); - Assert.Equal(protocolVersion ?? "2025-06-18", client.NegotiatedProtocolVersion); + Assert.Equal(protocolVersion ?? "2025-11-25", client.NegotiatedProtocolVersion); } [Fact] @@ -500,7 +500,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn { int getWeatherToolCallCount = 0; int askClientToolCallCount = 0; - + Server.ServerOptions.ToolCollection?.Add(McpServerTool.Create( async (McpServer server, string query, CancellationToken cancellationToken) => { @@ -513,14 +513,14 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn return $"Weather in {location}: sunny, 22°C"; }, "get_weather", "Gets the weather for a location"); - + var response = await server .AsSamplingChatClient() .AsBuilder() .UseFunctionInvocation() .Build() .GetResponseAsync(query, new ChatOptions { Tools = [weatherTool] }, cancellationToken); - + return response.Text ?? "No response"; }, new() { Name = "ask_client", Description = "Asks the client a question using sampling" })); @@ -530,7 +530,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn { int currentCall = samplingCallCount++; var lastMessage = messages.LastOrDefault(); - + // First call: Return a tool call request for get_weather if (currentCall == 0) { @@ -552,7 +552,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn string resultText = toolResult.Result?.ToString() ?? string.Empty; Assert.Contains("Weather in Paris: sunny", resultText); - + return Task.FromResult(new([ new ChatMessage(ChatRole.User, messages.First().Contents), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call_weather_123", "get_weather", new Dictionary { ["location"] = "Paris" })]), @@ -577,7 +577,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Null(result.IsError); - + var textContent = result.Content.OfType().FirstOrDefault(); Assert.NotNull(textContent); Assert.Contains("Weather in Paris: sunny, 22", textContent.Text); @@ -585,7 +585,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn Assert.Equal(1, askClientToolCallCount); Assert.Equal(2, samplingCallCount); } - + /// Simple test IChatClient implementation for testing. private sealed class TestChatClient(Func, ChatOptions?, CancellationToken, Task> getResponse) : IChatClient { @@ -594,7 +594,7 @@ public Task GetResponseAsync( ChatOptions? options = null, CancellationToken cancellationToken = default) => getResponse(messages, options, cancellationToken); - + async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( IEnumerable messages, ChatOptions? options, @@ -606,7 +606,7 @@ async IAsyncEnumerable IChatClient.GetStreamingResponseAsync yield return update; } } - + object? IChatClient.GetService(Type serviceType, object? serviceKey) => null; void IDisposable.Dispose() { } }