Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 98 additions & 8 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ impl<R: ServiceRole> Peer<R> {
pub struct RunningService<R: ServiceRole, S: Service<R>> {
service: Arc<S>,
peer: Peer<R>,
handle: tokio::task::JoinHandle<QuitReason>,
handle: Option<tokio::task::JoinHandle<QuitReason>>,
cancellation_token: CancellationToken,
dg: DropGuard,
}
Expand All @@ -459,14 +459,104 @@ impl<R: ServiceRole, S: Service<R>> RunningService<R, S> {
pub fn cancellation_token(&self) -> RunningServiceCancellationToken {
RunningServiceCancellationToken(self.cancellation_token.clone())
}

/// Returns true if the service has been closed or cancelled.
#[inline]
pub async fn waiting(self) -> Result<QuitReason, tokio::task::JoinError> {
self.handle.await
pub fn is_closed(&self) -> bool {
self.handle.is_none() || self.cancellation_token.is_cancelled()
}

/// Wait for the service to complete.
///
/// This will block until the service loop terminates (either due to
/// cancellation, transport closure, or an error).
#[inline]
pub async fn waiting(mut self) -> Result<QuitReason, tokio::task::JoinError> {
match self.handle.take() {
Some(handle) => handle.await,
None => Ok(QuitReason::Closed),
}
}

/// Gracefully close the connection and wait for cleanup to complete.
///
/// This method cancels the service, waits for the background task to finish
/// (which includes calling `transport.close()`), and ensures all cleanup
/// operations complete before returning.
///
/// Unlike [`cancel`](Self::cancel), this method takes `&mut self` and can be
/// called without consuming the `RunningService`. After calling this method,
/// the service is considered closed and subsequent operations will fail.
///
/// # Example
///
/// ```rust,ignore
/// let mut client = ().serve(transport).await?;
/// // ... use the client ...
/// client.close().await?;
/// ```
pub async fn close(&mut self) -> Result<QuitReason, tokio::task::JoinError> {
if let Some(handle) = self.handle.take() {
// Disarm the drop guard so it doesn't try to cancel again
// We need to cancel manually and wait for completion
self.cancellation_token.cancel();
handle.await
} else {
// Already closed
Ok(QuitReason::Closed)
}
}
pub async fn cancel(self) -> Result<QuitReason, tokio::task::JoinError> {
let RunningService { dg, handle, .. } = self;
dg.disarm().cancel();
handle.await

/// Gracefully close the connection with a timeout.
///
/// Similar to [`close`](Self::close), but returns after the specified timeout
/// if the cleanup doesn't complete in time. This is useful for ensuring
/// a bounded shutdown time.
///
/// Returns `Ok(Some(reason))` if shutdown completed within the timeout,
/// `Ok(None)` if the timeout was reached, or `Err` if there was a join error.
pub async fn close_with_timeout(
&mut self,
timeout: Duration,
) -> Result<Option<QuitReason>, tokio::task::JoinError> {
if let Some(handle) = self.handle.take() {
self.cancellation_token.cancel();
match tokio::time::timeout(timeout, handle).await {
Ok(result) => result.map(Some),
Err(_elapsed) => {
tracing::warn!(
"close_with_timeout: cleanup did not complete within {:?}",
timeout
);
Ok(None)
}
}
} else {
Ok(Some(QuitReason::Closed))
}
}

/// Cancel the service and wait for cleanup to complete.
///
/// This consumes the `RunningService` and ensures the connection is properly
/// closed. For a non-consuming alternative, see [`close`](Self::close).
pub async fn cancel(mut self) -> Result<QuitReason, tokio::task::JoinError> {
// Disarm the drop guard since we're handling cancellation explicitly
let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard());
self.close().await
}
}

impl<R: ServiceRole, S: Service<R>> Drop for RunningService<R, S> {
fn drop(&mut self) {
if self.handle.is_some() && !self.cancellation_token.is_cancelled() {
tracing::debug!(
"RunningService dropped without explicit close(). \
The connection will be closed asynchronously. \
For guaranteed cleanup, call close() or cancel() before dropping."
);
}
// The DropGuard will handle cancellation
}
}

Expand Down Expand Up @@ -847,7 +937,7 @@ where
RunningService {
service,
peer: peer_return,
handle,
handle: Some(handle),
cancellation_token: ct.clone(),
dg: ct.drop_guard(),
}
Expand Down
89 changes: 52 additions & 37 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,37 +333,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
None
};
// delete session when drop guard is dropped
if let Some(session_id) = &session_id {
let ct = transport_task_ct.clone();
let client = self.client.clone();
let session_id = session_id.clone();
let url = config.uri.clone();
let auth_header = config.auth_header.clone();
tokio::spawn(async move {
ct.cancelled().await;
let delete_session_result = client
.delete_session(url, session_id.clone(), auth_header.clone())
.await;
match delete_session_result {
Ok(_) => {
tracing::info!(session_id = session_id.as_ref(), "delete session success")
}
Err(StreamableHttpError::ServerDoesNotSupportDeleteSession) => {
tracing::info!(
session_id = session_id.as_ref(),
"server doesn't support delete session"
)
}
Err(e) => {
tracing::error!(
session_id = session_id.as_ref(),
"fail to delete session: {e}"
);
}
};
});
}
// Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
let session_cleanup_info = session_id.as_ref().map(|sid| {
(self.client.clone(), config.uri.clone(), sid.clone(), config.auth_header.clone())
});

context.send_to_handler(message).await?;
let initialized_notification = context.recv_from_handler().await?;
Expand Down Expand Up @@ -437,20 +410,23 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
});
}
loop {
// Main event loop - capture exit reason so we can do cleanup before returning
let loop_result: Result<(), WorkerQuitReason<Self::Error>> = 'main_loop: loop {
let event = tokio::select! {
_ = transport_task_ct.cancelled() => {
tracing::debug!("cancelled");
return Err(WorkerQuitReason::Cancelled);
break 'main_loop Err(WorkerQuitReason::Cancelled);
}
message = context.recv_from_handler() => {
let message = message?;
Event::ClientMessage(message)
match message {
Ok(msg) => Event::ClientMessage(msg),
Err(e) => break 'main_loop Err(e),
}
},
message = sse_worker_rx.recv() => {
let Some(message) = message else {
tracing::trace!("transport dropped, exiting");
return Err(WorkerQuitReason::HandlerTerminated);
break 'main_loop Err(WorkerQuitReason::HandlerTerminated);
};
Event::ServerMessage(message)
},
Expand Down Expand Up @@ -525,7 +501,9 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
Event::ServerMessage(json_rpc_message) => {
// send the message to the handler
context.send_to_handler(json_rpc_message).await?;
if let Err(e) = context.send_to_handler(json_rpc_message).await {
break 'main_loop Err(e);
}
}
Event::StreamResult(result) => {
if result.is_err() {
Expand All @@ -536,7 +514,44 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
}
}
};

// Cleanup session before returning (ensures close() waits for session deletion)
// Use a timeout to prevent indefinite hangs if the server is unresponsive
if let Some((client, url, session_id, auth_header)) = session_cleanup_info {
const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
match tokio::time::timeout(
SESSION_CLEANUP_TIMEOUT,
client.delete_session(url, session_id.clone(), auth_header),
)
.await
{
Ok(Ok(_)) => {
tracing::info!(session_id = session_id.as_ref(), "delete session success")
}
Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => {
tracing::info!(
session_id = session_id.as_ref(),
"server doesn't support delete session"
)
}
Ok(Err(e)) => {
tracing::error!(
session_id = session_id.as_ref(),
"fail to delete session: {e}"
);
}
Err(_elapsed) => {
tracing::warn!(
session_id = session_id.as_ref(),
"session cleanup timed out after {:?}",
SESSION_CLEANUP_TIMEOUT
);
}
}
}

loop_result
}
}

Expand Down
127 changes: 127 additions & 0 deletions crates/rmcp/tests/test_close_connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
//cargo test --test test_close_connection --features "client server"

mod common;
use std::time::Duration;

use common::handlers::{TestClientHandler, TestServer};
use rmcp::{service::QuitReason, ServiceExt};

/// Test that close() properly shuts down the connection
#[tokio::test]
async fn test_close_method() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);

// Start server
let server_handle = tokio::spawn(async move {
let server = TestServer::new().serve(server_transport).await?;
server.waiting().await?;
anyhow::Ok(())
});

// Start client
let handler = TestClientHandler::new(true, true);
let mut client = handler.serve(client_transport).await?;

// Verify client is not closed
assert!(!client.is_closed());

// Call close() and verify it returns
let result = client.close().await?;
assert!(matches!(result, QuitReason::Cancelled));

// Verify client is now closed
assert!(client.is_closed());

// Calling close() again should return Closed immediately
let result = client.close().await?;
assert!(matches!(result, QuitReason::Closed));

// Wait for server to finish
server_handle.await??;
Ok(())
}

/// Test that close_with_timeout() respects the timeout
#[tokio::test]
async fn test_close_with_timeout() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);

// Start server
let server_handle = tokio::spawn(async move {
let server = TestServer::new().serve(server_transport).await?;
server.waiting().await?;
anyhow::Ok(())
});

// Start client
let handler = TestClientHandler::new(true, true);
let mut client = handler.serve(client_transport).await?;

// Close with a reasonable timeout
let result = client.close_with_timeout(Duration::from_secs(5)).await?;
assert!(result.is_some());
assert!(matches!(result.unwrap(), QuitReason::Cancelled));

// Verify client is now closed
assert!(client.is_closed());

// Wait for server to finish
server_handle.await??;
Ok(())
}

/// Test that cancel() still works and consumes self
#[tokio::test]
async fn test_cancel_method() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);

// Start server
let server_handle = tokio::spawn(async move {
let server = TestServer::new().serve(server_transport).await?;
server.waiting().await?;
anyhow::Ok(())
});

// Start client
let handler = TestClientHandler::new(true, true);
let client = handler.serve(client_transport).await?;

// Cancel should work as before
let result = client.cancel().await?;
assert!(matches!(result, QuitReason::Cancelled));

// Wait for server to finish
server_handle.await??;
Ok(())
}

/// Test that dropping without close() logs a debug message (we can't easily test
/// the log output, but we can verify the drop doesn't panic)
#[tokio::test]
async fn test_drop_without_close() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);

// Start server that will handle the drop
let server_handle = tokio::spawn(async move {
let server = TestServer::new().serve(server_transport).await?;
// The server should close when the client drops
let result = server.waiting().await?;
// Server should detect closure
assert!(matches!(result, QuitReason::Closed | QuitReason::Cancelled));
anyhow::Ok(())
});

// Create and immediately drop the client
{
let handler = TestClientHandler::new(true, true);
let _client = handler.serve(client_transport).await?;
// Client dropped here without calling close()
}

// Give the async cleanup a moment to run
tokio::time::sleep(Duration::from_millis(100)).await;

// Wait for server to finish (it should detect the closure)
server_handle.await??;
Ok(())
}