diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index 8f8330603893..5afec8fc7167 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -204,13 +204,15 @@ class BaseCacheTransceiver { public: virtual ~BaseCacheTransceiver() = default; - virtual void respondAndSendAsync(LlmRequest* llmRequest) = 0; + // Transfers are asynchronous. Pass shared_ptr so the transceiver and its + // workers keep the request alive until the corresponding future resolves. + virtual void respondAndSendAsync(std::shared_ptr llmRequest) = 0; virtual void respondAndSendLayerWise( RequestVector const& requests, std::shared_ptr const& progress) = 0; - virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0; - virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0; + virtual void requestAndReceiveSync(std::shared_ptr llmRequest) = 0; + virtual void requestAndReceiveAsync(std::shared_ptr llmRequest) = 0; /// Check all requests transferring context, and return the requests that have completed or encountered an error. virtual RequestStatuses checkContextTransferStatus( @@ -221,7 +223,7 @@ class BaseCacheTransceiver [[nodiscard]] virtual bool checkGenTransferComplete() const = 0; - virtual bool cancelRequest(LlmRequest* llmRequest) = 0; + virtual bool cancelRequest(std::shared_ptr llmRequest) = 0; }; class CacheTransceiver : public BaseCacheTransceiver @@ -252,13 +254,13 @@ class CacheTransceiver : public BaseCacheTransceiver virtual ~CacheTransceiver(); - void respondAndSendAsync(LlmRequest* llmRequest) override; + void respondAndSendAsync(std::shared_ptr llmRequest) override; void respondAndSendLayerWise( RequestVector const& requests, std::shared_ptr const& progress) override; - void requestAndReceiveSync(LlmRequest* llmRequest) override; - void requestAndReceiveAsync(LlmRequest* llmRequest) override; + void requestAndReceiveSync(std::shared_ptr llmRequest) override; + void requestAndReceiveAsync(std::shared_ptr llmRequest) override; RequestStatuses checkContextTransferStatus( std::optional const& atLeastRequestNum = std::nullopt, bool markComplete = false) override; @@ -267,7 +269,7 @@ class CacheTransceiver : public BaseCacheTransceiver [[nodiscard]] bool checkGenTransferComplete() const override; - virtual bool cancelRequest(LlmRequest* llmRequest) override; + virtual bool cancelRequest(std::shared_ptr llmRequest) override; private: void initializeCommState(); @@ -276,8 +278,10 @@ class CacheTransceiver : public BaseCacheTransceiver std::unique_ptr mCacheSender; std::unique_ptr mCacheReceiver; - std::vector>> mSenderFutures; - std::vector>> mRequesterFutures; + // Hold strong references while futures are outstanding so Python-side + // cleanup cannot leave C++ with dangling LlmRequest pointers. + std::vector, std::future>> mSenderFutures; + std::vector, std::future>> mRequesterFutures; mpi::MpiComm const* mMpiWorldComm{nullptr}; std::shared_ptr mGroupComm; diff --git a/cpp/include/tensorrt_llm/executor/transferAgent.h b/cpp/include/tensorrt_llm/executor/transferAgent.h index 8d6a46107675..1d229663cc67 100644 --- a/cpp/include/tensorrt_llm/executor/transferAgent.h +++ b/cpp/include/tensorrt_llm/executor/transferAgent.h @@ -288,6 +288,14 @@ class TransferStatus virtual ~TransferStatus() = default; [[nodiscard]] virtual bool isCompleted() const = 0; virtual TransferState wait(int64_t timeout_ms = -1) const = 0; + /// Release the backend transfer request. If the request is still active, + /// backends may attempt to cancel it. A true return only means the backend + /// accepted release of the transfer handle; callers must still treat remote + /// memory quiescence as backend-specific. + [[nodiscard]] virtual bool release() + { + return false; + } }; struct BaseAgentConfig diff --git a/cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp index 58092897ebbe..60d345275871 100644 --- a/cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp +++ b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp @@ -21,11 +21,112 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/opUtils.h" +#include #include namespace tensorrt_llm::batch_manager { +namespace +{ + +char const* bufferKindName(BufferKind kind) +{ + switch (kind) + { + case BufferKind::kKV: return "kv"; + case BufferKind::kKV_INDEXER: return "kv_indexer"; + case BufferKind::kRNN: return "rnn"; + } + return "unknown"; +} + +} // namespace + +BufferIndexHolder::BufferIndexHolder( + BaseTransBufferManager* manager, Direction direction, std::optional bufferId) + : mManager{manager} + , mDirection{direction} + , mBufferId{bufferId} + , mOwns{manager != nullptr} +{ +} + +BufferIndexHolder::~BufferIndexHolder() +{ + reset(); +} + +BufferIndexHolder::BufferIndexHolder(BufferIndexHolder&& other) noexcept + : mManager{other.mManager} + , mDirection{other.mDirection} + , mBufferId{other.mBufferId} + , mOwns{other.mOwns} +{ + other.mManager = nullptr; + other.mBufferId = std::nullopt; + other.mOwns = false; +} + +BufferIndexHolder& BufferIndexHolder::operator=(BufferIndexHolder&& other) noexcept +{ + if (this != &other) + { + reset(); + mManager = other.mManager; + mDirection = other.mDirection; + mBufferId = other.mBufferId; + mOwns = other.mOwns; + other.mManager = nullptr; + other.mBufferId = std::nullopt; + other.mOwns = false; + } + return *this; +} + +BufferIndexHolder BufferIndexHolder::acquireSend(BaseTransBufferManager& manager) +{ + return BufferIndexHolder{&manager, Direction::kSend, manager.assignBufferIndexForSend()}; +} + +BufferIndexHolder BufferIndexHolder::acquireRecv(BaseTransBufferManager& manager) +{ + return BufferIndexHolder{&manager, Direction::kRecv, manager.assignBufferIndexForRecv()}; +} + +void BufferIndexHolder::reset() noexcept +{ + if (!mOwns || mManager == nullptr) + { + return; + } + + try + { + if (mDirection == Direction::kSend) + { + mManager->freeBufferIndexForSend(mBufferId); + } + else + { + mManager->freeBufferIndexForRecv(mBufferId); + } + } + catch (std::exception const& e) + { + TLLM_LOG_ERROR( + "Exception while releasing cache transfer buffer index %d: %s", mBufferId.value_or(-1), e.what()); + } + catch (...) + { + TLLM_LOG_ERROR("Unknown exception while releasing cache transfer buffer index %d", mBufferId.value_or(-1)); + } + + mManager = nullptr; + mBufferId = std::nullopt; + mOwns = false; +} + BaseTransBufferManager::BaseTransBufferManager( size_t transferBufferSize, nvinfer1::DataType dataType, std::optional maxNumTokens) : mDataType{dataType} @@ -56,22 +157,48 @@ BaseTransBufferManager::BaseTransBufferManager( std::optional BaseTransBufferManager::assignBufferIndexForSend() { - return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer); + auto bufferId = assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer); + if (bufferId.has_value()) + { + TLLM_LOG_DEBUG("Assigned send cache transfer buffer kind=%s index=%d outstanding=%d/%zu", + bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceSendResource.mConcurrence.load(), + mSendBufferCount); + } + return bufferId; } void BaseTransBufferManager::freeBufferIndexForSend(std::optional bufferId) { freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer); + if (bufferId.has_value()) + { + TLLM_LOG_DEBUG("Freed send cache transfer buffer kind=%s index=%d outstanding=%d/%zu", + bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceSendResource.mConcurrence.load(), + mSendBufferCount); + } } std::optional BaseTransBufferManager::assignBufferIndexForRecv() { - return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer); + auto bufferId = assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer); + if (bufferId.has_value()) + { + TLLM_LOG_DEBUG("Assigned recv cache transfer buffer kind=%s index=%d outstanding=%d/%zu", + bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceRecvResource.mConcurrence.load(), + mRecvBufferCount); + } + return bufferId; } void BaseTransBufferManager::freeBufferIndexForRecv(std::optional bufferId) { freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer); + if (bufferId.has_value()) + { + TLLM_LOG_DEBUG("Freed recv cache transfer buffer kind=%s index=%d outstanding=%d/%zu", + bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceRecvResource.mConcurrence.load(), + mRecvBufferCount); + } } std::tuple, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers( diff --git a/cpp/tensorrt_llm/batch_manager/baseTransBuffer.h b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.h index 1efeb89ccc04..e9c192224a4c 100644 --- a/cpp/tensorrt_llm/batch_manager/baseTransBuffer.h +++ b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.h @@ -39,6 +39,8 @@ class FabricMemory; namespace tensorrt_llm::batch_manager { +class BaseTransBufferManager; + enum class BufferKind : uint8_t { kKV = 0, @@ -46,6 +48,46 @@ enum class BufferKind : uint8_t kRNN = 2 }; +class BufferIndexHolder +{ +public: + enum class Direction : uint8_t + { + kSend = 0, + kRecv = 1 + }; + + BufferIndexHolder() = default; + BufferIndexHolder(BaseTransBufferManager* manager, Direction direction, std::optional bufferId); + ~BufferIndexHolder(); + + BufferIndexHolder(BufferIndexHolder const&) = delete; + BufferIndexHolder& operator=(BufferIndexHolder const&) = delete; + BufferIndexHolder(BufferIndexHolder&& other) noexcept; + BufferIndexHolder& operator=(BufferIndexHolder&& other) noexcept; + + [[nodiscard]] static BufferIndexHolder acquireSend(BaseTransBufferManager& manager); + [[nodiscard]] static BufferIndexHolder acquireRecv(BaseTransBufferManager& manager); + + [[nodiscard]] std::optional get() const noexcept + { + return mBufferId; + } + + [[nodiscard]] bool owns() const noexcept + { + return mOwns; + } + + void reset() noexcept; + +private: + BaseTransBufferManager* mManager{nullptr}; + Direction mDirection{Direction::kSend}; + std::optional mBufferId{std::nullopt}; + bool mOwns{false}; +}; + /// @brief Base class for cache transfer buffer management. /// Handles buffer pool allocation, index assignment, and slicing. /// Derived classes provide cache-specific size calculations. diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index e05e8d6f76fc..e5d71f316318 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -494,7 +494,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio // cache blocks to the corresponding buffer. // 5. send the buffer to the corresponding target. Ideally, we send only once (one buffer) for each target. - auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend(); + auto cacheBufferHolder = BufferIndexHolder::acquireSend(*mCacheTransBufferManager); + auto cacheBufferId = cacheBufferHolder.get(); int peerDuplicateHeadFactor = targetInfo.mPeerDupHeadFactor; auto bufferTargetNum = targetNum / peerDuplicateHeadFactor; auto ppRank = selfIdx @@ -578,7 +579,6 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio session.setTime(TransferSession::kTimeTransmissions); - mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId); session.setTime(TransferSession::kTimePostprocess); } TLLM_LOG_DEBUG( @@ -800,6 +800,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess size_t remainNoCoverTargetNum = 0; size_t bufferCoverTargetNum = 0; std::optional cacheBufferId = std::nullopt; + BufferIndexHolder cacheBufferHolder; { NVTX3_SCOPED_RANGE(formatInputAllocBuffer); @@ -813,7 +814,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess } else { - cacheBufferId = mCacheTransBufferManager->assignBufferIndexForRecv(); + cacheBufferHolder = BufferIndexHolder::acquireRecv(*mCacheTransBufferManager); + cacheBufferId = cacheBufferHolder.get(); } auto [recvSplitCachestmp, bufferCoverTargetNumtmp, onlyUseDynamicBuffer] = mCacheTransBufferManager->getOrAllocateRecvBuffers( @@ -948,10 +950,6 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess recvSplitCaches, outputBuffersPerWindow, destConfig, selfConfig, selfIdx, bufferManager); bufferManager.getStream().synchronize(); - if (cacheBufferId.has_value()) - { - mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId); - } } session.setTime(TransferSession::kTimePostprocess); } diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 2e4bf1f06667..e79035db6f86 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -54,7 +54,9 @@ #include "tensorrt_llm/runtime/utils/pgUtils.h" #include #include +#include #include +#include #include namespace tensorrt_llm::batch_manager @@ -323,7 +325,7 @@ void CacheTransceiver::setContextState(LlmRequest* llmRequest) } } -void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest) +void CacheTransceiver::respondAndSendAsync(std::shared_ptr llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isContextOnlyRequest()); llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS); @@ -337,9 +339,9 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest) } return; } - setContextState(llmRequest); - auto future = mCacheSender->sendAsync(*llmRequest); - mSenderFutures.emplace_back(llmRequest, std::move(future)); + setContextState(llmRequest.get()); + auto future = mCacheSender->sendAsync(llmRequest); + mSenderFutures.emplace_back(std::move(llmRequest), std::move(future)); } void CacheTransceiver::respondAndSendLayerWise( @@ -354,22 +356,22 @@ void CacheTransceiver::respondAndSendLayerWise( llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); setContextState(llmRequest.get()); - auto future = mCacheSender->sendAsync(*llmRequest); - mSenderFutures.emplace_back(llmRequest.get(), std::move(future)); + auto future = mCacheSender->sendAsync(llmRequest); + mSenderFutures.emplace_back(llmRequest, std::move(future)); } } -void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest) +void CacheTransceiver::requestAndReceiveSync(std::shared_ptr llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest()); { - auto future = mCacheReceiver->receiveAsync(*llmRequest); + auto future = mCacheReceiver->receiveAsync(llmRequest); future.get(); } llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); } -void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest) +void CacheTransceiver::requestAndReceiveAsync(std::shared_ptr llmRequest) { TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest()); @@ -381,9 +383,9 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest) return; } - auto future = mCacheReceiver->receiveAsync(*llmRequest); - mRequesterFutures.emplace_back(llmRequest, std::move(future)); + auto future = mCacheReceiver->receiveAsync(llmRequest); llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS); + mRequesterFutures.emplace_back(std::move(llmRequest), std::move(future)); } std::vector gatherRequestIds( @@ -484,6 +486,8 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus( std::optional const& atLeastRequestNum, bool markComplete) { bool blockAll = !atLeastRequestNum.has_value(); + TLLM_LOG_DEBUG("Checking context KV transfer futures: pending=%zu atLeast=%d blockAll=%d markComplete=%d", + mSenderFutures.size(), atLeastRequestNum.value_or(0), blockAll, markComplete); std::optional senderFutureTimeoutMs = std::nullopt; // If blockAll is true, we want to block and not use a timeout if (!blockAll && mCacheTransceiverConfig.has_value()) @@ -564,8 +568,9 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus( } else if (status == std::future_status::timeout) { - TLLM_LOG_WARNING("Timed out waiting for context KV cache transfer after %d milliseconds.", - senderFutureTimeoutMs.value()); + TLLM_LOG_WARNING( + "Timed out waiting for context KV cache transfer after %d milliseconds for request %ld.", + senderFutureTimeoutMs.value(), request->mRequestId); ++it; } else @@ -586,6 +591,13 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus( requestsStatus.errorRequestIds.insert(request->mRequestId); it = mSenderFutures.erase(it); } + catch (...) + { + TLLM_LOG_ERROR("Unknown error occurred during context transfer for request %ld", request->mRequestId); + request->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + requestsStatus.errorRequestIds.insert(request->mRequestId); + it = mSenderFutures.erase(it); + } } else { @@ -599,6 +611,8 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus( void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastRequestNum) { bool blockAll = !atLeastRequestNum.has_value(); + TLLM_LOG_DEBUG("Checking generation KV transfer futures: pending=%zu atLeast=%d blockAll=%d", + mRequesterFutures.size(), atLeastRequestNum.value_or(0), blockAll); std::vector genTransferReadyRequestIds; for (auto&& [request, future] : mRequesterFutures) { @@ -722,7 +736,7 @@ void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastR if (!common::getEnvKVCacheTimeOutputPath().empty()) { auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; - updateKVCacheTransferBW(syncComm, it->first); + updateKVCacheTransferBW(syncComm, it->first.get()); } } catch (std::exception const& e) @@ -731,6 +745,12 @@ void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastR "Error occurred during generation transfer for request %ld: %s", it->first->mRequestId, e.what()); it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR); } + catch (...) + { + TLLM_LOG_ERROR( + "Unknown error occurred during generation transfer for request %ld", it->first->mRequestId); + it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + } if (useMPI()) { TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), @@ -757,7 +777,7 @@ bool CacheTransceiver::checkGenTransferComplete() const return mRequesterFutures.empty(); } -bool CacheTransceiver::cancelRequest(LlmRequest* llmRequest) +bool CacheTransceiver::cancelRequest(std::shared_ptr llmRequest) { if (llmRequest->isContextOnlyRequest()) { diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 3ecceb9f3f2c..83418cafab07 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -30,9 +30,11 @@ #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include +#include #include #include #include +#include #include namespace tensorrt_llm::batch_manager @@ -78,8 +80,15 @@ void TransferSession::send(size_t idx, void const* data, size_t size) } catch (std::exception const& e) { + auto const requestId = mRequest != nullptr ? mRequest->mRequestId : 0; throw common::RequestSpecificException( - __FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR); + __FILE__, __LINE__, e.what(), requestId, common::RequestErrorCode::kNETWORK_ERROR); + } + catch (...) + { + auto const requestId = mRequest != nullptr ? mRequest->mRequestId : 0; + throw common::RequestSpecificException(__FILE__, __LINE__, "Unknown exception in cache transfer send", + requestId, common::RequestErrorCode::kNETWORK_ERROR); } } @@ -91,8 +100,15 @@ void TransferSession::recv(size_t idx, void* data, size_t size) } catch (std::exception const& e) { + auto const requestId = mRequest != nullptr ? mRequest->mRequestId : 0; throw common::RequestSpecificException( - __FILE__, __LINE__, e.what(), mRequest->mRequestId, common::RequestErrorCode::kNETWORK_ERROR); + __FILE__, __LINE__, e.what(), requestId, common::RequestErrorCode::kNETWORK_ERROR); + } + catch (...) + { + auto const requestId = mRequest != nullptr ? mRequest->mRequestId : 0; + throw common::RequestSpecificException(__FILE__, __LINE__, "Unknown exception in cache transfer recv", + requestId, common::RequestErrorCode::kNETWORK_ERROR); } } @@ -296,16 +312,16 @@ class CacheSender::Impl } } - [[nodiscard]] std::future sendAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future sendAsync(std::shared_ptr const& llmRequest) { std::promise promise; auto future = promise.get_future(); - llmRequest.setKvCacheTransferStart(LlmRequest::getSteadyClockNow()); + llmRequest->setKvCacheTransferStart(LlmRequest::getSteadyClockNow()); { { std::scoped_lock lkResp(mSenderMutex); - mReadyResponses.emplace( - llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)}); + // Keep the request alive until the async-send worker finishes. + mReadyResponses.emplace(llmRequest->mRequestId, Response{llmRequest, std::move(promise)}); } std::unique_lock lkCond(mCondMutex); mAnyReady = true; @@ -477,7 +493,8 @@ class CacheSender::Impl private: struct Response { - LlmRequest* mRequest; + // Keep the request alive until the async-send worker finishes. + std::shared_ptr mRequest; std::promise mPromise; }; @@ -535,6 +552,12 @@ class CacheSender::Impl TLLM_LOG_ERROR("Exception in sendAndRemoveResponse: %s request id: %ld", e.what(), id); resp.mPromise.set_exception(std::current_exception()); } + catch (...) + { + TLLM_LOG_ERROR("Unknown exception in sendAndRemoveResponse for request id: %ld", id); + resp.mPromise.set_exception( + std::make_exception_ptr(std::runtime_error("Unknown exception in sendAndRemoveResponse"))); + } } void asyncSendAndRemoveResponse(RequestIdType id, Response resp) noexcept @@ -584,13 +607,16 @@ class CacheSender::Impl { // TODO: if the generation does not require the kv cache, the request will // not be removed from mCancelledRequests. This should be handled by timeout. - auto it = mReadyResponses.find(mCurrentRequest.value()); - TLLM_CHECK(it != mReadyResponses.end()); + auto cancelledRequestId = mCurrentRequest.value(); + Response cancelledResponse; { std::scoped_lock lkResp(mSenderMutex); + auto it = mReadyResponses.find(cancelledRequestId); + TLLM_CHECK(it != mReadyResponses.end()); + cancelledResponse = std::move(it->second); mReadyResponses.erase(it); - mCancelledRequests.erase(mCurrentRequest.value()); - mRemainSendCount.erase(mCurrentRequest.value()); + mCancelledRequests.erase(cancelledRequestId); + mRemainSendCount.erase(cancelledRequestId); } mCurrentRequest = std::nullopt; @@ -599,6 +625,9 @@ class CacheSender::Impl std::unique_lock lk(mCondMutex); mAnyReady = false; } + cancelledResponse.mPromise.set_exception(std::make_exception_ptr(TLLM_REQUEST_EXCEPTION( + cancelledRequestId, common::RequestErrorCode::kNETWORK_ERROR, + "KV cache transfer for request %zu was cancelled", cancelledRequestId))); } } mCurrentRequest = std::nullopt; @@ -670,6 +699,16 @@ class CacheSender::Impl it.second.mPromise.set_exception(std::current_exception()); } } + catch (...) + { + TLLM_LOG_ERROR("Unknown exception in CacheSender response"); + auto unknownException + = std::make_exception_ptr(std::runtime_error("Unknown exception in CacheSender response")); + for (auto& it : mReadyResponses) + { + it.second.mPromise.set_exception(unknownException); + } + } } void terminate() @@ -753,23 +792,26 @@ class CacheReceiver::Impl TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); } - [[nodiscard]] std::future receiveAsync(LlmRequest& llmRequest) + [[nodiscard]] std::future receiveAsync(std::shared_ptr const& llmRequest) { // TODO: Modify the implementation here to avoid frequent thread creation. - return std::async(std::launch::async, &CacheReceiver::Impl::requestSync, this, std::ref(llmRequest)); + // Keep the request alive until the async task completes. + auto llmRequestCopy = llmRequest; + return std::async(std::launch::async, + [this, llmRequestCopy]() { requestSync(*llmRequestCopy); }); } - [[nodiscard]] std::future requestAndReceiveAsyncMultiThreads(LlmRequest& llmRequest) + [[nodiscard]] std::future requestAndReceiveAsyncMultiThreads(std::shared_ptr const& llmRequest) { try { auto promise = std::make_unique>(); auto future = promise->get_future(); - TLLM_CHECK(llmRequest.getDataTransceiverState().getCommState().has_value()); + TLLM_CHECK(llmRequest->getDataTransceiverState().getCommState().has_value()); std::string processInfo = kDefaultProcessInfo; if (common::getEnvRequestKVCacheConcurrent()) { - processInfo = llmRequest.getDataTransceiverState().getCommState()->toString(); + processInfo = llmRequest->getDataTransceiverState().getCommState()->toString(); } if (mInstanceToAsyncResource.find(processInfo) == mInstanceToAsyncResource.end()) { @@ -782,7 +824,8 @@ class CacheReceiver::Impl auto& asyncResource = mInstanceToAsyncResource.at(processInfo); { std::unique_lock lck(asyncResource->mMtxForQueue); - asyncResource->mRequestsQueue.emplace_back(std::addressof(llmRequest), std::move(promise)); + // Keep the request alive until the worker finishes. + asyncResource->mRequestsQueue.emplace_back(llmRequest, std::move(promise)); } asyncResource->mCVforQueue.notify_all(); return future; @@ -791,6 +834,10 @@ class CacheReceiver::Impl { TLLM_THROW("%s", e.what()); } + catch (...) + { + TLLM_THROW("Unknown exception in requestAndReceiveAsyncMultiThreads"); + } } void receiveSync(TransferSession& session) @@ -847,11 +894,20 @@ class CacheReceiver::Impl auto* agentConnectionManager = dynamic_cast(mManager); std::vector> cacheBufferIds; + std::vector cacheBufferHolders; if (agentConnectionManager) { - for (auto& cacheTransBufferManager : agentConnectionManager->getCacheTransBufferManagers()) + auto& cacheTransBufferManagers = agentConnectionManager->getCacheTransBufferManagers(); + cacheBufferIds.reserve(cacheTransBufferManagers.size()); + cacheBufferHolders.reserve(cacheTransBufferManagers.size()); + for (auto& cacheTransBufferManager : cacheTransBufferManagers) { - cacheBufferIds.push_back(cacheTransBufferManager->assignBufferIndexForRecv()); + auto holder = BufferIndexHolder::acquireRecv(*cacheTransBufferManager); + auto bufferId = holder.get(); + cacheBufferIds.push_back( + bufferId.has_value() ? std::optional{static_cast(bufferId.value())} + : std::nullopt); + cacheBufferHolders.emplace_back(std::move(holder)); } TLLM_CHECK(!cacheBufferIds.empty()); } @@ -940,10 +996,15 @@ class CacheReceiver::Impl } } auto const& resource = getReceiveCacheResource(llmRequest); - return TransferSession(std::move(allConnections), DataContext{tagFromRequestId(requestId), mTerminate}, + auto session = TransferSession(std::move(allConnections), DataContext{tagFromRequestId(requestId), mTerminate}, std::move(allCounterparts), mSelfState, contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(), &llmRequest, !common::getEnvKVCacheTimeOutputPath().empty()); + for (auto& holder : cacheBufferHolders) + { + session.addBufferIndexHolder(std::move(holder)); + } + return session; } std::unique_ptr const& getReceiveCacheResource(LlmRequest const& llmRequest) @@ -1074,7 +1135,8 @@ class CacheReceiver::Impl struct RequestAndPromise { - LlmRequest* mRequest; + // Keep the request alive while it is queued or owned by a worker. + std::shared_ptr mRequest; std::unique_ptr> mPromise; RequestAndPromise() @@ -1083,38 +1145,16 @@ class CacheReceiver::Impl { } - RequestAndPromise(LlmRequest* request, std::unique_ptr>&& promise) - : mRequest(request) + RequestAndPromise(std::shared_ptr request, std::unique_ptr>&& promise) + : mRequest(std::move(request)) , mPromise(std::move(promise)) { } RequestAndPromise(RequestAndPromise const&) = delete; - RequestAndPromise(RequestAndPromise&& other) noexcept - : mRequest(other.mRequest) - , mPromise(std::move(other.mPromise)) - { - other.mRequest = nullptr; - } - - RequestAndPromise& operator=(RequestAndPromise&& other) noexcept - { - if (this != &other) - { - mRequest = nullptr; - if (mPromise) - { - mPromise.reset(); - } - - mRequest = other.mRequest; - mPromise = std::move(other.mPromise); - - other.mRequest = nullptr; - } - return *this; - } + RequestAndPromise(RequestAndPromise&& other) noexcept = default; + RequestAndPromise& operator=(RequestAndPromise&& other) noexcept = default; }; struct AsyncResource @@ -1152,6 +1192,17 @@ class CacheReceiver::Impl resource.mRequestsQueue.pop_front(); } { + auto requestId = [&requestAndPromise]() -> size_t + { return requestAndPromise.mRequest != nullptr ? requestAndPromise.mRequest->mRequestId : 0; }; + auto contextRequestId = [&requestAndPromise]() -> size_t + { + if (requestAndPromise.mRequest == nullptr + || !requestAndPromise.mRequest->getContextPhaseParams().has_value()) + { + return 0; + } + return requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(); + }; try { TLLM_CHECK_WITH_INFO(requestAndPromise.mRequest != nullptr, "requestAndPromise.mRequest is null"); @@ -1161,19 +1212,24 @@ class CacheReceiver::Impl catch (tensorrt_llm::common::RequestSpecificException const& err) { TLLM_LOG_ERROR("Exception in DataRequester request(): request id:%zu , request context id:%zu : %s", - requestAndPromise.mRequest->mRequestId, - requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what()); - auto new_exception = TLLM_REQUEST_EXCEPTION( - requestAndPromise.mRequest->mRequestId, err.getErrorCode(), "%s", err.what()); + requestId(), contextRequestId(), err.what()); + auto new_exception = TLLM_REQUEST_EXCEPTION(requestId(), err.getErrorCode(), "%s", err.what()); requestAndPromise.mPromise->set_exception(std::make_exception_ptr(new_exception)); } catch (std::exception const& err) { - TLLM_LOG_ERROR("Exception in CacheReceiver request(): request id:%ld , request context id:%ld : %s", - requestAndPromise.mRequest->mRequestId, - requestAndPromise.mRequest->getContextPhaseParams().value().getReqId(), err.what()); + TLLM_LOG_ERROR("Exception in CacheReceiver request(): request id:%zu , request context id:%zu : %s", + requestId(), contextRequestId(), err.what()); requestAndPromise.mPromise->set_exception(std::current_exception()); } + catch (...) + { + TLLM_LOG_ERROR( + "Unknown exception in CacheReceiver request(): request id:%zu , request context id:%zu", + requestId(), contextRequestId()); + requestAndPromise.mPromise->set_exception( + std::make_exception_ptr(std::runtime_error("Unknown exception in CacheReceiver request"))); + } } } } @@ -1209,7 +1265,7 @@ CacheSender::CacheSender( { } -std::future CacheSender::sendAsync(LlmRequest& llmRequest) const +std::future CacheSender::sendAsync(std::shared_ptr const& llmRequest) const { return mImpl->sendAsync(llmRequest); } @@ -1252,7 +1308,7 @@ CacheReceiver::CacheReceiver( { } -std::future CacheReceiver::receiveAsync(LlmRequest& llmRequest) const +std::future CacheReceiver::receiveAsync(std::shared_ptr const& llmRequest) const { return mImpl->requestAndReceiveAsyncMultiThreads(llmRequest); } diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 4d072a04521b..80d80e206537 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -20,8 +20,10 @@ #include #include #include +#include #include +#include "tensorrt_llm/batch_manager/baseTransBuffer.h" #include "tensorrt_llm/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/batch_manager/cacheTransferLayer.h" #include "tensorrt_llm/batch_manager/llmRequest.h" @@ -102,6 +104,11 @@ class TransferSession } } + TransferSession(TransferSession const&) = delete; + TransferSession& operator=(TransferSession const&) = delete; + TransferSession(TransferSession&&) noexcept = default; + TransferSession& operator=(TransferSession&&) noexcept = default; + [[nodiscard]] std::vector const& getConnections() const; // should be called only during the initialization of the TransferSession @@ -151,6 +158,11 @@ class TransferSession mCounterPartRanks = std::move(ranks); } + void addBufferIndexHolder(BufferIndexHolder&& holder) + { + mBufferIndexHolders.emplace_back(std::move(holder)); + } + private: std::vector mConnections; std::vector mCounterPartRanks; // Ranks corresponding to mConnections indices @@ -162,6 +174,7 @@ class TransferSession std::unique_ptr mTimes; int32_t mIndexFromEnd{0}; BlockKey mLastBlockKey{}; + std::vector mBufferIndexHolders; }; using UniqueToken = tensorrt_llm::runtime::UniqueToken; @@ -257,9 +270,10 @@ class CacheSender /// @brief Asynchronously respond to the request and send data. /// @param llmRequest Request object. Its data should be ready when called, and the data for this request - /// should remain valid until future synchronization. + /// should remain valid until future synchronization. Passed as shared_ptr so + /// the async worker keeps the LlmRequest alive until the future resolves. /// @return Once the data is fully sent, the future object will become valid. - [[nodiscard]] virtual std::future sendAsync(LlmRequest& llmRequest) const; + [[nodiscard]] virtual std::future sendAsync(std::shared_ptr const& llmRequest) const; /// @brief Return the internal communicator status. /// @return The communicator status. @@ -314,9 +328,10 @@ class CacheReceiver /// @brief Asynchronously send a request to receive data. /// @param llmRequest Request object. Its data should be in an allocated but unwritten state when called, and the - /// data for this request should remain intact only after future synchronization. + /// data for this request should remain intact only after future synchronization. Passed as shared_ptr so + /// the async worker keeps the LlmRequest alive until the future resolves. /// @return Once the data is fully received, the future object will become valid. - [[nodiscard]] virtual std::future receiveAsync(LlmRequest& llmRequest) const; + [[nodiscard]] virtual std::future receiveAsync(std::shared_ptr const& llmRequest) const; virtual TransferSession sendRequestInfo(LlmRequest const& llmRequest); diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index c72090867f29..4ad480a63227 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -253,7 +253,8 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses return bufferSizeForTarget; }; auto bufferEleSizes = getBufferSizeForTarget(); - auto cacheBufferId = mCacheTransBufferManagers[transferIndexerKCache]->assignBufferIndexForSend(); + auto cacheBufferHolder = BufferIndexHolder::acquireSend(*mCacheTransBufferManagers[transferIndexerKCache]); + auto cacheBufferId = cacheBufferHolder.get(); auto result = mCacheTransBufferManagers[transferIndexerKCache]->getOrAllocateSendBuffers( cacheBufferId, static_cast(pPDomainSize * cPDomainSize), bufferEleSizes, bufferManager); auto& outputSplitCaches = std::get<0>(result); @@ -380,7 +381,6 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses { sendBufferFun(deviceId, pickUpConnections[0]); } - mCacheTransBufferManagers[transferIndexerKCache]->freeBufferIndexForSend(cacheBufferId); } session.setTime(TransferSession::kTimeTransmissions); session.setTime(TransferSession::kTimePostprocess); @@ -487,13 +487,15 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s auto bufferKind = transferIndexerKCache ? static_cast(BufferKind::kKV_INDEXER) : static_cast(BufferKind::kKV); auto preAssignedId = connections[pickUpConnections[0]]->getPreAssignedBufferId(bufferKind); + BufferIndexHolder cacheBufferHolder; if (preAssignedId.has_value()) { cacheBufferId = static_cast(*preAssignedId); } else { - cacheBufferId = mCacheTransBufferManagers[transferIndexerKCache]->assignBufferIndexForRecv(); + cacheBufferHolder = BufferIndexHolder::acquireRecv(*mCacheTransBufferManagers[transferIndexerKCache]); + cacheBufferId = cacheBufferHolder.get(); } auto targetNum = pickUpConnections.size(); @@ -642,10 +644,6 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s bufferManager.getStream().synchronize(); } - if (cacheBufferId.has_value()) - { - mCacheTransBufferManagers[transferIndexerKCache]->freeBufferIndexForRecv(cacheBufferId); - } } session.setTime(TransferSession::kTimePostprocess); diff --git a/cpp/tensorrt_llm/batch_manager/rnnCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/rnnCacheFormatter.cpp index 1fd1cbdc253f..2df131ef4b74 100644 --- a/cpp/tensorrt_llm/batch_manager/rnnCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/rnnCacheFormatter.cpp @@ -148,7 +148,8 @@ void RnnCacheFormatter::format(TransferSession& session) / targetInfo.mDomainTPSize; } - auto cacheBufferId = mRnnCacheTransBufferManager->assignBufferIndexForSend(); + auto cacheBufferHolder = BufferIndexHolder::acquireSend(*mRnnCacheTransBufferManager); + auto cacheBufferId = cacheBufferHolder.get(); auto allocationResult = mRnnCacheTransBufferManager->getOrAllocateSendBuffers( cacheBufferId, static_cast(bufferTargetNum), bufferSizesPerTarget, bufferManager); auto& outputBuffers = std::get<0>(allocationResult); @@ -191,7 +192,6 @@ void RnnCacheFormatter::format(TransferSession& session) session.setTime(TransferSession::kTimeTransmissions); - mRnnCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId); session.setTime(TransferSession::kTimePostprocess); TLLM_LOG_DEBUG( @@ -312,6 +312,7 @@ void RnnCacheFormatter::unformat(TransferSession& session) size_t remainNoCoverSourceNum = 0; size_t bufferCoverSourceNum = 0; std::optional cacheBufferId = std::nullopt; + BufferIndexHolder cacheBufferHolder; auto preAssignedRnnId = connections[pickUpConnections[0]]->getPreAssignedBufferId(static_cast(BufferKind::kRNN)); @@ -321,7 +322,8 @@ void RnnCacheFormatter::unformat(TransferSession& session) } else { - cacheBufferId = mRnnCacheTransBufferManager->assignBufferIndexForRecv(); + cacheBufferHolder = BufferIndexHolder::acquireRecv(*mRnnCacheTransBufferManager); + cacheBufferId = cacheBufferHolder.get(); } auto allocationResult = mRnnCacheTransBufferManager->getOrAllocateRecvBuffers( @@ -469,10 +471,6 @@ void RnnCacheFormatter::unformat(TransferSession& session) bufferManager.getStream().synchronize(); - if (cacheBufferId.has_value()) - { - mRnnCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId); - } session.setTime(TransferSession::kTimePostprocess); TLLM_LOG_DEBUG( diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 8bb2c0e2ba88..94679a54f022 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -915,7 +915,7 @@ void TrtGptModelInflightBatching::forwardSync() TLLM_CHECK_WITH_INFO(mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration of " "cacheTransceiverConfig."); - mCacheTransceiver->respondAndSendAsync(llmReq.get()); + mCacheTransceiver->respondAndSendAsync(llmReq); } mSeqSlotManager->freeSequenceSlot(llmReq->mRequestId); } @@ -1596,11 +1596,11 @@ void TrtGptModelInflightBatching::prepareDisaggGenInitRequests( mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration."); if (common::getEnvDisableKVCacheTransferOverlap()) { - mCacheTransceiver->requestAndReceiveSync(newGenReq.get()); + mCacheTransceiver->requestAndReceiveSync(newGenReq); } else { - mCacheTransceiver->requestAndReceiveAsync(newGenReq.get()); + mCacheTransceiver->requestAndReceiveAsync(newGenReq); } } if (!common::getEnvDisableKVCacheTransferOverlap()) diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp index d46defdf50ad..f4f25c15a6c2 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp @@ -143,8 +143,36 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size NotificationInfo notificationInfo{syncInfo}; std::stringstream ss; NotificationInfo::serialize(notificationInfo, ss); - TransferState transferState = status->wait(); + + static constexpr int64_t kCancelPollTimeoutMs = 100; + TransferState transferState = TransferState::kIN_PROGRESS; + while (transferState == TransferState::kIN_PROGRESS) + { + transferState = status->wait(kCancelPollTimeoutMs); + if (transferState == TransferState::kIN_PROGRESS && ctx.getTransferTerminate().load(std::memory_order_relaxed)) + { + bool const released = status->release(); + TLLM_LOG_WARNING( + "AgentConnection::send cancelled while transfer was in progress (ctx tag=%d, remote=%s, " + "releaseAccepted=%d)", + ctx.getTag(), mRemoteAgentName.c_str(), released); + TLLM_CHECK_WITH_INFO( + released, "AgentConnection::send cancel could not release the backend transfer handle"); + TLLM_THROW("AgentConnection::send cancelled mid-transfer"); + } + } TLLM_CHECK_WITH_INFO(transferState == TransferState::kSUCCESS, "AgentConnection::send failed"); + if (ctx.getTransferTerminate().load(std::memory_order_relaxed)) + { + bool const released = status->release(); + TLLM_LOG_WARNING( + "AgentConnection::send cancelled after transfer completed but before notify (ctx tag=%d, remote=%s, " + "releaseAccepted=%d)", + ctx.getTag(), mRemoteAgentName.c_str(), released); + TLLM_CHECK_WITH_INFO( + released, "AgentConnection::send pre-notify cancel could not release the backend transfer handle"); + TLLM_THROW("AgentConnection::send cancelled pre-notify"); + } // TODO: there is a bug in request_with_notify https://github.com/ai-dynamo/nixl/pull/252 mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str()); } @@ -153,7 +181,11 @@ void AgentConnection::recv(DataContext const& ctx, void* data, size_t size) cons { NotificationSyncInfo syncInfo{mAgentName, ctx}; - mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo, ctx.getTransferTerminate()); + bool const received + = mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo, ctx.getTransferTerminate()); + TLLM_CHECK_WITH_INFO(received, + "AgentConnection::recv ended before receiving sync notification (ctx tag=%d, remote=%s)", + ctx.getTag(), mRemoteAgentName.c_str()); } void AgentConnection::sendRequestAndBufferInfo(batch_manager::RequestInfo& requestInfo, @@ -247,7 +279,10 @@ void AgentConnection::sendReadySignal(DataContext const& ctx, bool isReady) cons bool AgentConnection::recvReadySignal(DataContext const& ctx) const { ReadySignalInfo readySignalInfo{mAgentName, ctx, false}; - mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo, ctx.getTransferTerminate()); + if (!mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo, ctx.getTransferTerminate())) + { + return false; + } return readySignalInfo.mIsReady; } @@ -582,7 +617,7 @@ int AgentConnectionManager::getDeviceId() const } template -void AgentConnectionManager::waitForNotification( +bool AgentConnectionManager::waitForNotification( std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic const& terminateFlag) { while (!terminateFlag.load()) @@ -590,7 +625,7 @@ void AgentConnectionManager::waitForNotification( if (!mIsRunning) { - return; + return false; } updateUnhandledNotifications(); std::scoped_lock lock(mNotificationMutex); @@ -623,7 +658,7 @@ void AgentConnectionManager::waitForNotification( { it = mUnhandledNotifications.erase(it); } - return; + return true; } } } @@ -643,7 +678,7 @@ void AgentConnectionManager::waitForNotification( { it = mUnhandledNotifications.erase(it); } - return; + return true; } } } @@ -663,24 +698,25 @@ void AgentConnectionManager::waitForNotification( } } } + return false; } // Explicit template instantiations -template void AgentConnectionManager::waitForNotification( +template bool AgentConnectionManager::waitForNotification( std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo, std::atomic const& terminateFlag); -template void AgentConnectionManager::waitForNotification( +template bool AgentConnectionManager::waitForNotification( std::string const& remoteAgentName, ReadySignalInfo& expectedInfo, std::atomic const& terminateFlag); -void AgentConnectionManager::waitForSyncInfo( +bool AgentConnectionManager::waitForSyncInfo( std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic const& terminateFlag) { - waitForNotification(remoteAgentName, syncInfo, terminateFlag); + return waitForNotification(remoteAgentName, syncInfo, terminateFlag); } -void AgentConnectionManager::waitForReadySignal( +bool AgentConnectionManager::waitForReadySignal( std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic const& terminateFlag) { - waitForNotification(remoteAgentName, readySignalInfo, terminateFlag); + return waitForNotification(remoteAgentName, readySignalInfo, terminateFlag); } std::string const& AgentConnectionManager::getAgentName() const diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h index 8ec948cfafeb..1afa787cba2b 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h @@ -320,11 +320,11 @@ class AgentConnectionManager : public ConnectionManager [[nodiscard]] std::string const& getAgentName() const; template - void waitForNotification( + [[nodiscard]] bool waitForNotification( std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic const& terminateFlag); - void waitForSyncInfo( + [[nodiscard]] bool waitForSyncInfo( std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic const& terminateFlag); - void waitForReadySignal( + [[nodiscard]] bool waitForReadySignal( std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic const& terminateFlag); [[nodiscard]] bool isRunning() const override; diff --git a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindings.cpp b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindings.cpp index 2f5eb9342bf6..f9157d59426e 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindings.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/agentBindings.cpp @@ -175,7 +175,8 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m) // subclass type is not directly registered (e.g., agents created via factory). nb::class_(m, "TransferStatus") .def("is_completed", &kvc::TransferStatus::isCompleted, nb::call_guard()) - .def("wait", &kvc::TransferStatus::wait, nb::arg("timeout_ms") = -1, nb::call_guard()); + .def("wait", &kvc::TransferStatus::wait, nb::arg("timeout_ms") = -1, nb::call_guard()) + .def("release", &kvc::TransferStatus::release, nb::call_guard()); // BaseAgentConfig struct nb::class_(m, "BaseAgentConfig") @@ -228,7 +229,8 @@ NB_MODULE(tensorrt_llm_transfer_agent_binding, m) nb::class_(m, "NixlTransferStatus") .def("is_completed", &kvc::NixlTransferStatus::isCompleted, nb::call_guard()) .def("wait", &kvc::NixlTransferStatus::wait, nb::arg("timeout_ms") = -1, - nb::call_guard()); + nb::call_guard()) + .def("release", &kvc::NixlTransferStatus::release, nb::call_guard()); // NixlTransferAgent class nb::class_(m, "NixlTransferAgent") diff --git a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp index bad3e184f983..23f061037353 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp @@ -324,6 +324,14 @@ NixlTransferStatus::NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle) TLLM_CHECK(mHandle); } +NixlTransferStatus::~NixlTransferStatus() +{ + if (!release()) + { + TLLM_LOG_WARNING("NIXL transfer handle release failed during destruction; backend handle may remain active"); + } +} + [[nodiscard]] MemoryDescs NixlHelper::coalesceMemoryDescs(MemoryDescs const& descs) { auto const& descVec = descs.getDescs(); @@ -484,6 +492,11 @@ NixlTransferStatus::NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle) TransferState NixlTransferStatus::wait(int64_t timeout_ms) const { + if (mHandle == nullptr) + { + return TransferState::kFAILURE; + } + auto startTime = std::chrono::steady_clock::now(); while (true) @@ -520,9 +533,31 @@ TransferState NixlTransferStatus::wait(int64_t timeout_ms) const [[nodiscard]] bool NixlTransferStatus::isCompleted() const { + if (mHandle == nullptr) + { + return false; + } return mRawAgent->getXferStatus(mHandle) == NIXL_SUCCESS; } +[[nodiscard]] bool NixlTransferStatus::release() +{ + if (mHandle == nullptr) + { + return true; + } + + auto status = mRawAgent->releaseXferReq(mHandle); + if (status == NIXL_SUCCESS) + { + mHandle = nullptr; + return true; + } + + TLLM_LOG_WARNING("NIXL releaseXferReq failed with status: %s", nixlEnumStrings::statusStr(status).c_str()); + return false; +} + [[nodiscard]] MemoryDescs NixlHelper::splitVmmDescs(MemoryDescs const& descs, size_t& detectedChunkSize) { detectedChunkSize = 0; diff --git a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.h b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.h index b35c5deb182b..26504e1ecf11 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.h +++ b/cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.h @@ -63,11 +63,14 @@ class NixlTransferStatus final : public TransferStatus { public: NixlTransferStatus(nixlAgent* agent, nixlXferReqH* handle); + ~NixlTransferStatus() override; [[nodiscard]] bool isCompleted() const override; [[nodiscard]] TransferState wait(int64_t timeout_ms = -1) const override; + [[nodiscard]] bool release() override; + private: nixlAgent* mRawAgent{}; nixlXferReqH* mHandle{}; diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp index 61cac5df4c7e..b0df67ba965a 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -46,17 +46,17 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver // using BaseCacheTransceiver::BaseCacheTransceiver; // Inherit constructors NB_TRAMPOLINE(tb::BaseCacheTransceiver, 6); - void respondAndSendAsync(tb::LlmRequest* llmRequest) override + void respondAndSendAsync(std::shared_ptr llmRequest) override { NB_OVERRIDE_PURE(respondAndSendAsync, llmRequest); } - void requestAndReceiveSync(tb::LlmRequest* llmRequest) override + void requestAndReceiveSync(std::shared_ptr llmRequest) override { NB_OVERRIDE_PURE(requestAndReceiveSync, llmRequest); } - void requestAndReceiveAsync(tb::LlmRequest* llmRequest) override + void requestAndReceiveAsync(std::shared_ptr llmRequest) override { NB_OVERRIDE_PURE(requestAndReceiveAsync, llmRequest); } @@ -77,7 +77,7 @@ class PyCacheTransceiver : public tb::BaseCacheTransceiver NB_OVERRIDE_PURE(checkGenTransferComplete); } - bool cancelRequest(tb::LlmRequest* llmRequest) override + bool cancelRequest(std::shared_ptr llmRequest) override { NB_OVERRIDE_PURE(cancelRequest, llmRequest); } diff --git a/docs/source/developer-guide/disagg-kv-transfer-hardening-plan.md b/docs/source/developer-guide/disagg-kv-transfer-hardening-plan.md new file mode 100644 index 000000000000..ebeeef9fe0fd --- /dev/null +++ b/docs/source/developer-guide/disagg-kv-transfer-hardening-plan.md @@ -0,0 +1,512 @@ +# Disaggregated KV Transfer Hardening Plan + +This note describes follow-up hardening work for disaggregated KV cache transfer. +It assumes the conservative request-lifetime patch is already in place: async +transceiver APIs and worker/future tracking structures carry +`std::shared_ptr` instead of raw `LlmRequest*`. + +That patch fixes one class of bug: stale access to the `LlmRequest` object +itself. It does not, by itself, prove that all resources referenced by a live +`LlmRequest` are still valid. KV cache blocks, sequence slots, transfer buffer +pool entries, and cache-manager request mappings have separate lifetimes and +need their own ownership rules. + +## Goals + +1. Avoid transfer buffer-index leaks on exceptions and early returns. +2. Avoid double termination after context transfer completion. +3. Improve diagnostics for futures, buffer pools, request IDs, and worker drain. +4. Harden unknown exception handling without pretending the transport is healthy. +5. Prevent `_terminate_request()` while context KV send is still in progress. +6. Define a policy for broad `_handle_errors()` when generation KV receive is in + flight. + +## Non-Goals + +- Do not make stuck transport operations magically interruptible. +- Do not erase a C++ future before the worker has reached quiescence. +- Do not report `cancel_request()` success for an already in-flight transfer + unless the worker is known to have stopped touching request and KV resources. +- Do not keep serving after an unknown transport/backend exception unless the + backend provides a clear health guarantee. + +## Current Implementation Status + +This branch has started the plan with the pieces that are local and low risk: + +- `BufferIndexHolder` now owns send/recv transfer buffer slots and releases them + on scope exit. +- `TransferSession` owns AgentConnection pre-assigned receive buffer slots from + pre-assignment through formatter `unformat()`. +- `CacheFormatter`, `MLACacheFormatter`, and `RnnCacheFormatter` use RAII + holders instead of manual free calls. +- `_handle_responses()` now checks + `request.is_disagg_context_transmission_state` before the partial-reuse + cleanup branch. +- `_terminate_request()` now defers termination while context KV send is in + progress, including the disaggregated PP termination handler path. +- Broad `_handle_errors(requests=None)` now fails closed when generation KV + receive is in flight: it emits client error responses where possible, clears + local scheduling queues, marks the executor shut down, and intentionally does + not free active request resources in-process. +- C++ worker/future paths now have catch-all handling for unknown exceptions and + convert them into promise/future errors. +- `TransferStatus` now has a `release()` hook. The NIXL implementation calls + `nixlAgent::releaseXferReq()`, so sender-side cancellation and object + destruction use NIXL's intended transfer-handle release path instead of + leaking backend request handles. +- Agent notification waits now report whether they ended because the expected + notification arrived or because the transfer was terminated. A terminated + receive no longer looks like a successful data/sync notification. + +Still remaining: + +- Add focused unit/fault-injection tests. +- Add an explicit process/transceiver unhealthy flag if we want C++ to expose + health directly rather than surfacing errors through Python futures. +- If graceful in-process recovery is required later, add real generation receive + transfer tracking instead of relying on fail-closed restart. +- NIXL/UCX cancellation is still backend-specific: `releaseXferReq()` releases + the handle and may request cancellation, but this branch does not treat it as + proof that a peer can immediately recycle remote KV memory after an ambiguous + one-sided RMA transfer. + +## 1. RAII BufferIndexHolder + +### Problem + +Several formatter paths manually pair: + +- `assignBufferIndexForSend()` with `freeBufferIndexForSend(...)` +- `assignBufferIndexForRecv()` with `freeBufferIndexForRecv(...)` + +Important sites: + +- `cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp` + - send path around `assignBufferIndexForSend()` + - receive path around `assignBufferIndexForRecv()` +- `cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp` + - send path for MLA/indexer cache + - receive path for MLA/indexer cache +- `cpp/tensorrt_llm/batch_manager/rnnCacheFormatter.cpp` + - send path + - receive path +- `cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp` + - Agent receive pre-assignment path that calls `assignBufferIndexForRecv()` + before sending request/buffer info to peers. + +Manual free is fragile. Any exception, `TLLM_CHECK`, or early return between +assignment and free can leave a pool slot marked in use. That can wedge later +transfers even if the request-level logic is otherwise correct. + +### Proposed API + +Add a small move-only RAII holder near `BaseTransBufferManager`, for example in +`cpp/tensorrt_llm/batch_manager/baseTransBuffer.h`. + +```cpp +class BufferIndexHolder +{ +public: + enum class Direction + { + kSend, + kRecv, + }; + + static BufferIndexHolder assignSend(BaseTransBufferManager& manager); + static BufferIndexHolder assignRecv(BaseTransBufferManager& manager); + + // For pre-assigned ids that are owned by a caller but consumed later, allow + // construction with an existing id and explicit ownership. + BufferIndexHolder(BaseTransBufferManager& manager, Direction direction, + std::optional id, bool owns) noexcept; + + BufferIndexHolder(BufferIndexHolder const&) = delete; + BufferIndexHolder& operator=(BufferIndexHolder const&) = delete; + BufferIndexHolder(BufferIndexHolder&& other) noexcept; + BufferIndexHolder& operator=(BufferIndexHolder&& other) noexcept; + + ~BufferIndexHolder() noexcept; + + std::optional get() const noexcept; + std::optional release() noexcept; + void reset() noexcept; + +private: + BaseTransBufferManager* mManager{nullptr}; + Direction mDirection{Direction::kSend}; + std::optional mId{std::nullopt}; + bool mOwns{false}; +}; +``` + +Destructor behavior: + +- If `mOwns == false`, do nothing. +- If `mId == std::nullopt`, do nothing. +- If direction is send, call `freeBufferIndexForSend(mId)`. +- If direction is recv, call `freeBufferIndexForRecv(mId)`. +- Destructors must be `noexcept`. If the free path can throw today, normalize it + first or catch/log in the holder destructor. + +### Usage Pattern + +For a locally assigned send buffer index: + +```cpp +auto bufferId = BufferIndexHolder::assignSend(*mCacheTransBufferManager); +auto result = mCacheTransBufferManager->getOrAllocateSendBuffers( + bufferId.get(), targetNum, bufferSizes, bufferManager); + +// Any exception before function exit frees the slot. +// Normal exit also frees the slot automatically. +``` + +For a locally assigned receive buffer index: + +```cpp +auto bufferId = BufferIndexHolder::assignRecv(*mCacheTransBufferManager); +auto result = mCacheTransBufferManager->getOrAllocateRecvBuffers( + bufferId.get(), targetNum, bufferSizes, bufferManager); +``` + +For pre-assigned AgentConnection ids, ownership is trickier: + +- If the id is assigned in the same formatter scope, use a holder directly. +- If `dataTransceiver.cpp` assigns ids and passes them through connection + metadata for a later formatter to consume, do not let a stack holder free the + id at the end of `sendRequestInfo()`. That would be too early. +- Prefer storing recv buffer-index holders in `TransferSession`, or an + equivalent per-request transfer object, so ownership spans from pre-assignment + through `CacheFormatter::unformat()` / `MLACacheFormatter::unformat()` / + `RnnCacheFormatter::unformat()`. +- Once the formatter takes ownership, avoid a second manual free. + +### Implementation Steps + +1. Add `BufferIndexHolder` and unit tests for move, release, reset, send free, + and recv free. +2. Replace manual send/recv buffer-index frees in `cacheFormatter.cpp`. +3. Replace manual send/recv buffer-index frees in `mlaCacheFormatter.cpp`. +4. Replace manual send/recv buffer-index frees in `rnnCacheFormatter.cpp`. +5. For AgentConnection pre-assigned recv ids in `dataTransceiver.cpp`, decide + and implement the ownership handoff: + - either store holders in `TransferSession`, or + - store them in a request-scoped object owned by the receive worker until + `unformat()` completes. +6. Add debug logs on assign/free with request id, direction, buffer id, and + manager/buffer kind. +7. Add a stress test that injects exceptions after assignment and verifies the + next transfer can still acquire a slot. + +## 2. `is_disagg_context_complete_state` Guard + +This guard is already present in the conservative PR. The intended shape in +`_handle_responses()` is: + +```python +if request.is_disagg_context_complete_state: + # Transfer completion already handled cleanup for this request. + pass +elif request.is_disagg_context_transmission_state: + # Do not terminate while KV send is still in flight. + pass +elif self.enable_partial_reuse_for_disagg and not self.kv_cache_manager.is_vswa and self.dist.pp_size == 1: + requests_to_terminate.append(request) +else: + requests_to_terminate.append(request) +``` + +The current PR includes the complete-state guard. The next hardening pass should +also make the in-transmission guard explicit and place it before the partial +reuse branch, so partial reuse cannot bypass the transmission check. + +## 3. Prevent `_terminate_request()` During Context Transmission + +### Problem + +For context/prefill send, `AsyncTransferManager.start_transfer()` is the object +that tracks in-flight context transfers. It stores the request in +`_requests_in_transfer`, optionally pins blocks with +`kv_cache_manager.store_blocks_for_reuse(request, True)`, and only unpins in +`end_transfer()`. + +Calling `_terminate_request()` while `request.is_disagg_context_transmission_state` +is true can fight that ownership model. Even if `shared_ptr` keeps +the request object alive, `_terminate_request()` may call +`resource_manager.free_resources(request)`, and resource managers can remove +sequence/block mappings that the sender still needs. + +### Proposed Python Guard + +Add a helper: + +```python +def _can_terminate_request_now(self, request: LlmRequest) -> bool: + if self.kv_cache_transceiver is None: + return True + if request.is_disagg_context_transmission_state: + return False + return True +``` + +Use it in all normal termination paths that are not fatal broad-error cleanup: + +- `_handle_responses()` +- `_end_transfer_and_maybe_terminate()` +- request-scoped `_handle_errors(..., requests=[...])` +- cancellation handling after `finish_by_reason(...)` + +For `_handle_responses()`, prefer the explicit ordering: + +```python +if request.is_disagg_context_complete_state: + pass +elif request.is_disagg_context_transmission_state: + pass +elif partial_reuse_cleanup_condition: + requests_to_terminate.append(request) +else: + requests_to_terminate.append(request) +``` + +This means a context request can leave `active_requests` but remain owned by +`AsyncTransferManager.requests_in_transfer()`. When +`_check_disagg_ctx_cache_transfer_status()` later sees completion, it calls +`_end_transfer_and_maybe_terminate()`, which calls +`async_transfer_manager.end_transfer(request)` and only then terminates. + +### Tests + +Add a Python unit test with a fake transceiver / fake transfer manager: + +1. Put a context-only request in `DISAGG_CONTEXT_TRANS_IN_PROGRESS`. +2. Force `_handle_responses()` to see `request_done`. +3. Assert `_terminate_request()` is not called. +4. Simulate `_check_disagg_ctx_cache_transfer_status()` completion. +5. Assert termination happens after `end_transfer()`. + +## 4. Broad `_handle_errors()` Policy for In-Flight Generation Transfers + +### Problem + +Generation receive is different from context send: + +- It is tracked by C++ `mRequesterFutures`. +- The receiver worker can be actively writing destination KV blocks. +- There is no Python `AsyncTransferManager` equivalent that pins/unpins the + destination KV resources for receive. + +The broad error path: + +```python +def _handle_errors(self, error_msg=None, *, requests=None): + failed_requests = requests if requests is not None else self.active_requests + if requests is None: + self.active_requests.clear() + ... + for request in failed_requests: + self._terminate_request(request) +``` + +does not filter out `DISAGG_GENERATION_TRANS_IN_PROGRESS`. If it frees all +active requests while a generation receive worker is still in flight, +`shared_ptr` keeps the request object alive, but it does not pin the +destination KV blocks. The worker may still be writing to resources that the +resource manager has removed or reused. + +### Recommended Policy + +Treat broad `_handle_errors(requests=None)` as fatal if any active request is in +disaggregated transfer: + +```python +inflight_transfer_requests = [ + request for request in self.active_requests + if request.is_disagg_context_transmission_state + or request.is_disagg_generation_transmission_in_progress +] + +if requests is None and inflight_transfer_requests: + self._mark_unhealthy_for_restart(...) + self.should_stop_processing = True + self.shutdown_event.set() + # Do not call _terminate_request() on in-flight transfer requests. + # Let process restart release GPU and transport resources. +``` + +Reasoning: + +- Broad `_handle_errors()` callers include hang detector, decode exceptions, + forward exceptions, sampling/setup/update exceptions. +- These are not narrow request-local cleanup paths. +- If a worker is inside UCX/NIXL/MPI or writing KV blocks, we do not have a + proven safe in-process recovery protocol. +- The deployment already has canary/health checking that can restart an + unhealthy pod. + +The fatal path can still enqueue client error responses if that is safe and +non-blocking, but it should not free in-flight KV transfer resources and then +continue serving. + +### Alternative: Graceful Deferred Termination + +If we later need graceful recovery instead of process restart, add real +generation-transfer tracking: + +1. Have `check_gen_transfer_status()` return completed/error request ids, like + context status does. +2. Add a Python `generation_transfer_manager` that records in-flight generation + receives and blocks termination while they are active. +3. Add `_pending_termination_after_transfer[request_id]`. +4. On request-local error/cancel while generation transfer is in flight: + - mark pending error/cancel, + - keep resources alive, + - do not free the request, + - let `check_gen_transfer_status()` complete the future, + - then terminate and free resources. +5. If the future never becomes ready, rely on unhealthy-process restart. + +Do not erase the C++ future merely because Python wants cleanup. Erasing the +future is only safe after the worker has completed or the worker has a proven +cancel/quiescence handshake. + +## 5. Better Diagnostics + +Add diagnostics that are actionable but not noisy at normal INFO level. + +### Future Tracking + +In `CacheTransceiver`: + +- On insertion into `mSenderFutures` / `mRequesterFutures`, log request id, + pointer address, state, vector size, and whether overlap is enabled. +- On status poll, log at DEBUG: + - number of tracked sender/requester futures, + - number ready, + - number selected for completion, + - oldest transfer age. +- On timeout or repeated non-ready status, log at WARNING with request id, + elapsed time, transfer start time, and vector size. +- On erase, log request id, final state, and vector size after erase. + +### Buffer Pools + +In `BaseTransBufferManager`: + +- Log assign/free with direction, buffer id, buffer kind, configured count, and + dynamic-buffer mode. +- Expose a debug-only method to count outstanding send/recv slots. +- On destructor or shutdown, warn if any non-dynamic buffer slot is still + marked in use. + +### Worker Drain + +In `CacheSender::Impl` and `CacheReceiver::Impl`: + +- Log worker start/stop and request ids currently queued. +- On destructor/drain, log queue sizes and number of futures being joined. +- When a worker sets a promise exception, include request id and context request + id when available. + +### Python Error Paths + +In `_handle_errors()`: + +- Log whether `requests` is `None` or request-scoped. +- Log request ids grouped by state. +- If broad error cleanup sees in-flight transfer requests, emit a single + high-severity log explaining that the process is being marked unhealthy and + in-flight transfer resources are intentionally not freed in-process. + +## 6. `catch (...)` Hardening + +### Current Gap + +Several worker paths catch `std::exception`, set promise exceptions, and keep +going. Unknown non-`std::exception` failures may bypass those handlers. Also, +continuing to serve after unknown transport/backend exceptions is risky because +the connection manager, CUDA stream, or transport backend may be in an unknown +state. + +### Proposed C++ Pattern + +Add catch-all blocks immediately after existing `catch (std::exception const&)` +blocks in worker/future completion paths: + +```cpp +catch (...) +{ + auto error = std::runtime_error("Unknown exception in KV cache transfer worker"); + TLLM_LOG_ERROR("%s request id: %ld", error.what(), requestId); + markUnhealthy(error.what()); + promise.set_exception(std::make_exception_ptr(error)); +} +``` + +Candidate locations: + +- `CacheSender::Impl::sendAndRemoveResponse` +- `CacheSender::Impl::response` +- `CacheReceiver::Impl::request` +- `CacheTransceiver::checkContextTransferStatus` +- `CacheTransceiver::checkGenTransferStatus` +- async send/receive helper lambdas launched with `std::async` + +### Health Semantics + +Prefer fail-closed: + +- Convert the unknown exception to `std::runtime_error` for the future/promise. +- Mark the cache transceiver or process unhealthy. +- Surface the failure to Python. +- Let Python stop accepting work and rely on canary/pod restart. + +Avoid catching unknown exceptions and continuing to serve unless the specific +transport backend guarantees it is still valid after that exception class. + +## 7. Validation Plan + +### Unit Tests + +- RAII holder frees send/recv ids on normal scope exit. +- RAII holder frees send/recv ids on exception. +- Move construction and move assignment transfer ownership exactly once. +- Pre-assigned/borrowed ids are not double-freed. +- `_handle_responses()` does not terminate a context request in + `DISAGG_CONTEXT_TRANS_IN_PROGRESS`. +- `_handle_errors(requests=None)` with in-flight generation transfer marks + unhealthy and avoids `_terminate_request()` for those requests. + +### Integration / Fault Injection + +- Inject exception after send buffer assignment before manual free location. + Verify subsequent transfer can acquire the slot. +- Inject exception after recv buffer assignment before concat/free. Verify slot + is released. +- Simulate a generation receive future that never becomes ready, then trigger + broad `_handle_errors()`. Verify process is marked unhealthy and no + in-process free happens for the in-flight receive request. +- Simulate context send in progress and force `_handle_responses()` to emit the + logical `LlmResponse`. Verify resource free waits until + `AsyncTransferManager.end_transfer()`. + +### Observability Checks + +- Logs include request id for future insertion, poll, completion, erase, and + exception. +- Buffer pool logs show balanced assign/free counts under fault injection. +- Broad-error fatal path emits one clear health log, not per-iteration spam. + +## Suggested Implementation Order + +1. Add `_can_terminate_request_now()` and context transmission guard in Python. +2. Decide and implement broad `_handle_errors()` policy for in-flight generation + transfer. Prefer fail-closed/unhealthy restart first. +3. Add `BufferIndexHolder` for straightforward local send/recv formatter sites. +4. Handle AgentConnection pre-assigned recv ids with a request/session-scoped + holder. +5. Add diagnostics. +6. Add `catch (...)` hardening and unhealthy marking. +7. Add fault-injection tests. diff --git a/tensorrt_llm/_torch/disaggregation/base/agent.py b/tensorrt_llm/_torch/disaggregation/base/agent.py index 1aac24d98e75..4c5548bd2472 100644 --- a/tensorrt_llm/_torch/disaggregation/base/agent.py +++ b/tensorrt_llm/_torch/disaggregation/base/agent.py @@ -89,6 +89,9 @@ def is_completed(self) -> bool: ... @abstractmethod def wait(self, timeout_ms: int | None = None) -> bool: ... + def release(self) -> bool: + return False + class BaseTransferAgent(ABC): @abstractmethod diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py index 84f85d41b325..1dd85e177a13 100644 --- a/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py @@ -30,6 +30,10 @@ def wait(self, timeout_ms=None) -> bool: timeout_ms = -1 return self._cpp_status.wait(timeout_ms) == TransferState.SUCCESS + def release(self) -> bool: + """Release the transfer handle, requesting backend cancel if still active.""" + return self._cpp_status.release() + class BindingsNixlTransferAgent(BaseTransferAgent): """NixlTransferAgent using C++ bindings with GIL release support. diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py index dcfa28210f81..703dc42f33a1 100644 --- a/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py @@ -26,6 +26,14 @@ def is_completed(self): status = TransferState(self.agent.check_xfer_state(self.handle)) return status == TransferState.DONE + def release(self): + try: + self.handle.release() + return True + except Exception: + logger.exception("Failed to release NIXL transfer handle (agent=%s).", self.agent.name) + return False + def wait(self, timeout_ms=None): start_time = time.time() status = TransferState.PENDING diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py b/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py index 2d3583f5ab9e..21a2283b9720 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py @@ -77,6 +77,11 @@ def remove_by_ids(self, request_ids: set[int]) -> None: """Remove requests with the given IDs.""" pass + @abstractmethod + def clear(self) -> None: + """Remove all requests from the queue.""" + pass + @abstractmethod def __bool__(self) -> bool: """Check if queue has any requests.""" @@ -145,6 +150,10 @@ def remove_by_ids(self, request_ids: set[int]) -> None: self.clear() self.extend(filtered_requests) + def clear(self) -> None: + """Remove all requests from the queue.""" + super().clear() + def __bool__(self) -> bool: """Check if queue has any requests.""" return len(self) > 0 @@ -248,6 +257,10 @@ def remove_by_ids(self, request_ids: set[int]) -> None: self._heap = [e for e in self._heap if e[2].id not in request_ids] heapq.heapify(self._heap) + def clear(self) -> None: + """Remove all requests from the queue.""" + self._heap.clear() + def __bool__(self) -> bool: return len(self._heap) > 0