diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 241f7d8b5..5cfdbe7e0 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -4,19 +4,12 @@ package io.modelcontextprotocol.spec; -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; - import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpInitRequestHandler; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; -import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,6 +17,14 @@ import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + /** * Represents a Model Context Protocol (MCP) session on the server side. It manages * bidirectional JSON-RPC communication with the client. @@ -36,7 +37,9 @@ public class McpServerSession implements McpLoggableSession { private final String id; - /** Duration to wait for request responses before timing out */ + /** + * Duration to wait for request responses before timing out + */ private final Duration requestTimeout; private final AtomicLong requestCounter = new AtomicLong(0); @@ -65,6 +68,8 @@ public class McpServerSession implements McpLoggableSession { private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + private volatile AtomicBoolean closed = new AtomicBoolean(false); + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id @@ -345,14 +350,23 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { - // TODO: clear pendingResponses and emit errors? - return this.transport.closeGracefully(); + if (this.closed.compareAndSet(false, true)) { + this.pendingResponses.forEach((id, response) -> response.error(new RuntimeException("Session closed"))); + this.pendingResponses.clear(); + return this.transport.closeGracefully(); + } + else { + return Mono.empty(); + } } @Override public void close() { - // TODO: clear pendingResponses and emit errors? - this.transport.close(); + if (this.closed.compareAndSet(false, true)) { + this.pendingResponses.forEach((id, response) -> response.error(new RuntimeException("Session closed"))); + this.pendingResponses.clear(); + this.transport.close(); + } } /** diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 6c35de56d..807ccc0fe 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -263,10 +263,13 @@ private ServerResponse handleSseConnection(ServerRequest request) { logger.debug("Creating new SSE connection for session: {}", sessionId); sseBuilder.onComplete(() -> { logger.debug("SSE connection completed for session: {}", sessionId); + // explicitly close the session when the SSE connection is completed + session.close(); sessions.remove(sessionId); }); sseBuilder.onTimeout(() -> { logger.debug("SSE connection timed out for session: {}", sessionId); + session.close(); sessions.remove(sessionId); }); this.sessions.put(sessionId, session); @@ -383,6 +386,12 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { String jsonText = jsonMapper.writeValueAsString(message); sseBuilder.event(MESSAGE_EVENT_TYPE).data(jsonText); } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to send message: {}", e.getMessage()); + } + sseBuilder.error(e); + } catch (Exception e) { logger.error("Failed to send message: {}", e.getMessage()); sseBuilder.error(e);