diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 5fc8934f..e0fd7642 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -434,7 +434,7 @@ impl Peer { pub struct RunningService> { service: Arc, peer: Peer, - handle: tokio::task::JoinHandle, + handle: Option>, cancellation_token: CancellationToken, dg: DropGuard, } @@ -459,14 +459,104 @@ impl> RunningService { pub fn cancellation_token(&self) -> RunningServiceCancellationToken { RunningServiceCancellationToken(self.cancellation_token.clone()) } + + /// Returns true if the service has been closed or cancelled. #[inline] - pub async fn waiting(self) -> Result { - self.handle.await + pub fn is_closed(&self) -> bool { + self.handle.is_none() || self.cancellation_token.is_cancelled() + } + + /// Wait for the service to complete. + /// + /// This will block until the service loop terminates (either due to + /// cancellation, transport closure, or an error). + #[inline] + pub async fn waiting(mut self) -> Result { + match self.handle.take() { + Some(handle) => handle.await, + None => Ok(QuitReason::Closed), + } + } + + /// Gracefully close the connection and wait for cleanup to complete. + /// + /// This method cancels the service, waits for the background task to finish + /// (which includes calling `transport.close()`), and ensures all cleanup + /// operations complete before returning. + /// + /// Unlike [`cancel`](Self::cancel), this method takes `&mut self` and can be + /// called without consuming the `RunningService`. After calling this method, + /// the service is considered closed and subsequent operations will fail. + /// + /// # Example + /// + /// ```rust,ignore + /// let mut client = ().serve(transport).await?; + /// // ... use the client ... + /// client.close().await?; + /// ``` + pub async fn close(&mut self) -> Result { + if let Some(handle) = self.handle.take() { + // Disarm the drop guard so it doesn't try to cancel again + // We need to cancel manually and wait for completion + self.cancellation_token.cancel(); + handle.await + } else { + // Already closed + Ok(QuitReason::Closed) + } } - pub async fn cancel(self) -> Result { - let RunningService { dg, handle, .. } = self; - dg.disarm().cancel(); - handle.await + + /// Gracefully close the connection with a timeout. + /// + /// Similar to [`close`](Self::close), but returns after the specified timeout + /// if the cleanup doesn't complete in time. This is useful for ensuring + /// a bounded shutdown time. + /// + /// Returns `Ok(Some(reason))` if shutdown completed within the timeout, + /// `Ok(None)` if the timeout was reached, or `Err` if there was a join error. + pub async fn close_with_timeout( + &mut self, + timeout: Duration, + ) -> Result, tokio::task::JoinError> { + if let Some(handle) = self.handle.take() { + self.cancellation_token.cancel(); + match tokio::time::timeout(timeout, handle).await { + Ok(result) => result.map(Some), + Err(_elapsed) => { + tracing::warn!( + "close_with_timeout: cleanup did not complete within {:?}", + timeout + ); + Ok(None) + } + } + } else { + Ok(Some(QuitReason::Closed)) + } + } + + /// Cancel the service and wait for cleanup to complete. + /// + /// This consumes the `RunningService` and ensures the connection is properly + /// closed. For a non-consuming alternative, see [`close`](Self::close). + pub async fn cancel(mut self) -> Result { + // Disarm the drop guard since we're handling cancellation explicitly + let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard()); + self.close().await + } +} + +impl> Drop for RunningService { + fn drop(&mut self) { + if self.handle.is_some() && !self.cancellation_token.is_cancelled() { + tracing::debug!( + "RunningService dropped without explicit close(). \ + The connection will be closed asynchronously. \ + For guaranteed cleanup, call close() or cancel() before dropping." + ); + } + // The DropGuard will handle cancellation } } @@ -847,7 +937,7 @@ where RunningService { service, peer: peer_return, - handle, + handle: Some(handle), cancellation_token: ct.clone(), dg: ct.drop_guard(), } diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 4db461a4..6756d58a 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -333,37 +333,10 @@ impl Worker for StreamableHttpClientWorker { } None }; - // delete session when drop guard is dropped - if let Some(session_id) = &session_id { - let ct = transport_task_ct.clone(); - let client = self.client.clone(); - let session_id = session_id.clone(); - let url = config.uri.clone(); - let auth_header = config.auth_header.clone(); - tokio::spawn(async move { - ct.cancelled().await; - let delete_session_result = client - .delete_session(url, session_id.clone(), auth_header.clone()) - .await; - match delete_session_result { - Ok(_) => { - tracing::info!(session_id = session_id.as_ref(), "delete session success") - } - Err(StreamableHttpError::ServerDoesNotSupportDeleteSession) => { - tracing::info!( - session_id = session_id.as_ref(), - "server doesn't support delete session" - ) - } - Err(e) => { - tracing::error!( - session_id = session_id.as_ref(), - "fail to delete session: {e}" - ); - } - }; - }); - } + // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns) + let session_cleanup_info = session_id.as_ref().map(|sid| { + (self.client.clone(), config.uri.clone(), sid.clone(), config.auth_header.clone()) + }); context.send_to_handler(message).await?; let initialized_notification = context.recv_from_handler().await?; @@ -437,20 +410,23 @@ impl Worker for StreamableHttpClientWorker { } }); } - loop { + // Main event loop - capture exit reason so we can do cleanup before returning + let loop_result: Result<(), WorkerQuitReason> = 'main_loop: loop { let event = tokio::select! { _ = transport_task_ct.cancelled() => { tracing::debug!("cancelled"); - return Err(WorkerQuitReason::Cancelled); + break 'main_loop Err(WorkerQuitReason::Cancelled); } message = context.recv_from_handler() => { - let message = message?; - Event::ClientMessage(message) + match message { + Ok(msg) => Event::ClientMessage(msg), + Err(e) => break 'main_loop Err(e), + } }, message = sse_worker_rx.recv() => { let Some(message) = message else { tracing::trace!("transport dropped, exiting"); - return Err(WorkerQuitReason::HandlerTerminated); + break 'main_loop Err(WorkerQuitReason::HandlerTerminated); }; Event::ServerMessage(message) }, @@ -525,7 +501,9 @@ impl Worker for StreamableHttpClientWorker { } Event::ServerMessage(json_rpc_message) => { // send the message to the handler - context.send_to_handler(json_rpc_message).await?; + if let Err(e) = context.send_to_handler(json_rpc_message).await { + break 'main_loop Err(e); + } } Event::StreamResult(result) => { if result.is_err() { @@ -536,7 +514,44 @@ impl Worker for StreamableHttpClientWorker { } } } + }; + + // Cleanup session before returning (ensures close() waits for session deletion) + // Use a timeout to prevent indefinite hangs if the server is unresponsive + if let Some((client, url, session_id, auth_header)) = session_cleanup_info { + const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); + match tokio::time::timeout( + SESSION_CLEANUP_TIMEOUT, + client.delete_session(url, session_id.clone(), auth_header), + ) + .await + { + Ok(Ok(_)) => { + tracing::info!(session_id = session_id.as_ref(), "delete session success") + } + Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => { + tracing::info!( + session_id = session_id.as_ref(), + "server doesn't support delete session" + ) + } + Ok(Err(e)) => { + tracing::error!( + session_id = session_id.as_ref(), + "fail to delete session: {e}" + ); + } + Err(_elapsed) => { + tracing::warn!( + session_id = session_id.as_ref(), + "session cleanup timed out after {:?}", + SESSION_CLEANUP_TIMEOUT + ); + } + } } + + loop_result } } diff --git a/crates/rmcp/tests/test_close_connection.rs b/crates/rmcp/tests/test_close_connection.rs new file mode 100644 index 00000000..903c8d55 --- /dev/null +++ b/crates/rmcp/tests/test_close_connection.rs @@ -0,0 +1,127 @@ +//cargo test --test test_close_connection --features "client server" + +mod common; +use std::time::Duration; + +use common::handlers::{TestClientHandler, TestServer}; +use rmcp::{service::QuitReason, ServiceExt}; + +/// Test that close() properly shuts down the connection +#[tokio::test] +async fn test_close_method() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client + let handler = TestClientHandler::new(true, true); + let mut client = handler.serve(client_transport).await?; + + // Verify client is not closed + assert!(!client.is_closed()); + + // Call close() and verify it returns + let result = client.close().await?; + assert!(matches!(result, QuitReason::Cancelled)); + + // Verify client is now closed + assert!(client.is_closed()); + + // Calling close() again should return Closed immediately + let result = client.close().await?; + assert!(matches!(result, QuitReason::Closed)); + + // Wait for server to finish + server_handle.await??; + Ok(()) +} + +/// Test that close_with_timeout() respects the timeout +#[tokio::test] +async fn test_close_with_timeout() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client + let handler = TestClientHandler::new(true, true); + let mut client = handler.serve(client_transport).await?; + + // Close with a reasonable timeout + let result = client.close_with_timeout(Duration::from_secs(5)).await?; + assert!(result.is_some()); + assert!(matches!(result.unwrap(), QuitReason::Cancelled)); + + // Verify client is now closed + assert!(client.is_closed()); + + // Wait for server to finish + server_handle.await??; + Ok(()) +} + +/// Test that cancel() still works and consumes self +#[tokio::test] +async fn test_cancel_method() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + + // Start client + let handler = TestClientHandler::new(true, true); + let client = handler.serve(client_transport).await?; + + // Cancel should work as before + let result = client.cancel().await?; + assert!(matches!(result, QuitReason::Cancelled)); + + // Wait for server to finish + server_handle.await??; + Ok(()) +} + +/// Test that dropping without close() logs a debug message (we can't easily test +/// the log output, but we can verify the drop doesn't panic) +#[tokio::test] +async fn test_drop_without_close() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(4096); + + // Start server that will handle the drop + let server_handle = tokio::spawn(async move { + let server = TestServer::new().serve(server_transport).await?; + // The server should close when the client drops + let result = server.waiting().await?; + // Server should detect closure + assert!(matches!(result, QuitReason::Closed | QuitReason::Cancelled)); + anyhow::Ok(()) + }); + + // Create and immediately drop the client + { + let handler = TestClientHandler::new(true, true); + let _client = handler.serve(client_transport).await?; + // Client dropped here without calling close() + } + + // Give the async cleanup a moment to run + tokio::time::sleep(Duration::from_millis(100)).await; + + // Wait for server to finish (it should detect the closure) + server_handle.await??; + Ok(()) +}