diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 6c35de56d..d98e18cc2 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -118,6 +118,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private KeepAliveScheduler keepAliveScheduler; + /** + * sse session timeout + */ + private final Duration sessionTimeout; + /** * Constructs a new WebMvcSseServerTransportProvider instance. * @param jsonMapper The McpJsonMapper to use for JSON serialization/deserialization @@ -135,18 +140,20 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private WebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval, - McpTransportContextExtractor contextExtractor) { + McpTransportContextExtractor contextExtractor, Duration sessionTimeout) { Assert.notNull(jsonMapper, "McpJsonMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); + Assert.notNull(sessionTimeout, "Session timeout must not be null"); this.jsonMapper = jsonMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.contextExtractor = contextExtractor; + this.sessionTimeout = sessionTimeout; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -279,7 +286,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { this.sessions.remove(sessionId); sseBuilder.error(e); } - }, Duration.ZERO); + }, this.sessionTimeout); } /** @@ -471,6 +478,8 @@ public static class Builder { private Duration keepAliveInterval; + private Duration sessionTimeout = Duration.ZERO; + private McpTransportContextExtractor contextExtractor = ( serverRequest) -> McpTransportContext.EMPTY; @@ -549,6 +558,11 @@ public Builder contextExtractor(McpTransportContextExtractor cont return this; } + public Builder sessionTimeout(Duration sessionTimeout) { + this.sessionTimeout = sessionTimeout; + return this; + } + /** * Builds a new instance of WebMvcSseServerTransportProvider with the configured * settings. @@ -560,7 +574,7 @@ public WebMvcSseServerTransportProvider build() { throw new IllegalStateException("MessageEndpoint must be set"); } return new WebMvcSseServerTransportProvider(jsonMapper == null ? McpJsonMapper.getDefault() : jsonMapper, - baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor); + baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, contextExtractor, sessionTimeout); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java index 1074e8a35..bb40297da 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProviderTests.java @@ -24,6 +24,9 @@ import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; +import java.time.Duration; +import java.util.concurrent.TimeUnit; + import static org.assertj.core.api.Assertions.assertThat; /** @@ -66,7 +69,7 @@ public void before() { } @Test - void validBaseUrl() { + void validBaseUrl() throws InterruptedException { McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build(); try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) .build()) { @@ -106,6 +109,7 @@ public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) .jsonMapper(McpJsonMapper.getDefault()) .contextExtractor(req -> McpTransportContext.EMPTY) + .sessionTimeout(Duration.ofSeconds(1)) .build(); }