Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,15 @@ class BaseCacheTransceiver
{
public:
virtual ~BaseCacheTransceiver() = default;
virtual void respondAndSendAsync(LlmRequest* llmRequest) = 0;
// Transfers are asynchronous. Pass shared_ptr so the transceiver and its
// workers keep the request alive until the corresponding future resolves.
virtual void respondAndSendAsync(std::shared_ptr<LlmRequest> llmRequest) = 0;
virtual void respondAndSendLayerWise(
RequestVector const& requests, std::shared_ptr<ContextProgress> const& progress)
= 0;

virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveSync(std::shared_ptr<LlmRequest> llmRequest) = 0;
virtual void requestAndReceiveAsync(std::shared_ptr<LlmRequest> llmRequest) = 0;

/// Check all requests transferring context, and return the requests that have completed or encountered an error.
virtual RequestStatuses checkContextTransferStatus(
Expand All @@ -221,7 +223,7 @@ class BaseCacheTransceiver

[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;

virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
virtual bool cancelRequest(std::shared_ptr<LlmRequest> llmRequest) = 0;
};

class CacheTransceiver : public BaseCacheTransceiver
Expand Down Expand Up @@ -252,13 +254,13 @@ class CacheTransceiver : public BaseCacheTransceiver

virtual ~CacheTransceiver();

void respondAndSendAsync(LlmRequest* llmRequest) override;
void respondAndSendAsync(std::shared_ptr<LlmRequest> llmRequest) override;

void respondAndSendLayerWise(
RequestVector const& requests, std::shared_ptr<ContextProgress> const& progress) override;

void requestAndReceiveSync(LlmRequest* llmRequest) override;
void requestAndReceiveAsync(LlmRequest* llmRequest) override;
void requestAndReceiveSync(std::shared_ptr<LlmRequest> llmRequest) override;
void requestAndReceiveAsync(std::shared_ptr<LlmRequest> llmRequest) override;

RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;
Expand All @@ -267,7 +269,7 @@ class CacheTransceiver : public BaseCacheTransceiver

[[nodiscard]] bool checkGenTransferComplete() const override;

virtual bool cancelRequest(LlmRequest* llmRequest) override;
virtual bool cancelRequest(std::shared_ptr<LlmRequest> llmRequest) override;

private:
void initializeCommState();
Expand All @@ -276,8 +278,10 @@ class CacheTransceiver : public BaseCacheTransceiver

std::unique_ptr<CacheSender> mCacheSender;
std::unique_ptr<CacheReceiver> mCacheReceiver;
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
// Hold strong references while futures are outstanding so Python-side
// cleanup cannot leave C++ with dangling LlmRequest pointers.
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mSenderFutures;
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mRequesterFutures;
mpi::MpiComm const* mMpiWorldComm{nullptr};

std::shared_ptr<CacheTransceiverComm> mGroupComm;
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/tensorrt_llm/executor/transferAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,14 @@ class TransferStatus
virtual ~TransferStatus() = default;
[[nodiscard]] virtual bool isCompleted() const = 0;
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;
/// Release the backend transfer request. If the request is still active,
/// backends may attempt to cancel it. A true return only means the backend
/// accepted release of the transfer handle; callers must still treat remote
/// memory quiescence as backend-specific.
[[nodiscard]] virtual bool release()
{
return false;
}
};

struct BaseAgentConfig
Expand Down
131 changes: 129 additions & 2 deletions cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,112 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/opUtils.h"

#include <exception>
#include <mutex>

namespace tensorrt_llm::batch_manager
{

namespace
{

char const* bufferKindName(BufferKind kind)
{
switch (kind)
{
case BufferKind::kKV: return "kv";
case BufferKind::kKV_INDEXER: return "kv_indexer";
case BufferKind::kRNN: return "rnn";
}
return "unknown";
}

} // namespace

BufferIndexHolder::BufferIndexHolder(
BaseTransBufferManager* manager, Direction direction, std::optional<int> bufferId)
: mManager{manager}
, mDirection{direction}
, mBufferId{bufferId}
, mOwns{manager != nullptr}
{
}

BufferIndexHolder::~BufferIndexHolder()
{
reset();
}

BufferIndexHolder::BufferIndexHolder(BufferIndexHolder&& other) noexcept
: mManager{other.mManager}
, mDirection{other.mDirection}
, mBufferId{other.mBufferId}
, mOwns{other.mOwns}
{
other.mManager = nullptr;
other.mBufferId = std::nullopt;
other.mOwns = false;
}

BufferIndexHolder& BufferIndexHolder::operator=(BufferIndexHolder&& other) noexcept
{
if (this != &other)
{
reset();
mManager = other.mManager;
mDirection = other.mDirection;
mBufferId = other.mBufferId;
mOwns = other.mOwns;
other.mManager = nullptr;
other.mBufferId = std::nullopt;
other.mOwns = false;
}
return *this;
}

BufferIndexHolder BufferIndexHolder::acquireSend(BaseTransBufferManager& manager)
{
return BufferIndexHolder{&manager, Direction::kSend, manager.assignBufferIndexForSend()};
}

BufferIndexHolder BufferIndexHolder::acquireRecv(BaseTransBufferManager& manager)
{
return BufferIndexHolder{&manager, Direction::kRecv, manager.assignBufferIndexForRecv()};
}

void BufferIndexHolder::reset() noexcept
{
if (!mOwns || mManager == nullptr)
{
return;
}

try
{
if (mDirection == Direction::kSend)
{
mManager->freeBufferIndexForSend(mBufferId);
}
else
{
mManager->freeBufferIndexForRecv(mBufferId);
}
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR(
"Exception while releasing cache transfer buffer index %d: %s", mBufferId.value_or(-1), e.what());
}
catch (...)
{
TLLM_LOG_ERROR("Unknown exception while releasing cache transfer buffer index %d", mBufferId.value_or(-1));
}

mManager = nullptr;
mBufferId = std::nullopt;
mOwns = false;
}

BaseTransBufferManager::BaseTransBufferManager(
size_t transferBufferSize, nvinfer1::DataType dataType, std::optional<size_t> maxNumTokens)
: mDataType{dataType}
Expand Down Expand Up @@ -56,22 +157,48 @@ BaseTransBufferManager::BaseTransBufferManager(

std::optional<int> BaseTransBufferManager::assignBufferIndexForSend()
{
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
auto bufferId = assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
if (bufferId.has_value())
{
TLLM_LOG_DEBUG("Assigned send cache transfer buffer kind=%s index=%d outstanding=%d/%zu",
bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceSendResource.mConcurrence.load(),
mSendBufferCount);
}
return bufferId;
}

void BaseTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
{
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
if (bufferId.has_value())
{
TLLM_LOG_DEBUG("Freed send cache transfer buffer kind=%s index=%d outstanding=%d/%zu",
bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceSendResource.mConcurrence.load(),
mSendBufferCount);
}
}

std::optional<int> BaseTransBufferManager::assignBufferIndexForRecv()
{
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
auto bufferId = assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
if (bufferId.has_value())
{
TLLM_LOG_DEBUG("Assigned recv cache transfer buffer kind=%s index=%d outstanding=%d/%zu",
bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceRecvResource.mConcurrence.load(),
mRecvBufferCount);
}
return bufferId;
}

void BaseTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
{
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
if (bufferId.has_value())
{
TLLM_LOG_DEBUG("Freed recv cache transfer buffer kind=%s index=%d outstanding=%d/%zu",
bufferKindName(getBufferKind()), bufferId.value(), mConcurrenceRecvResource.mConcurrence.load(),
mRecvBufferCount);
}
}

std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers(
Expand Down
42 changes: 42 additions & 0 deletions cpp/tensorrt_llm/batch_manager/baseTransBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,55 @@ class FabricMemory;
namespace tensorrt_llm::batch_manager
{

class BaseTransBufferManager;

enum class BufferKind : uint8_t
{
kKV = 0,
kKV_INDEXER = 1,
kRNN = 2
};

class BufferIndexHolder
{
public:
enum class Direction : uint8_t
{
kSend = 0,
kRecv = 1
};

BufferIndexHolder() = default;
BufferIndexHolder(BaseTransBufferManager* manager, Direction direction, std::optional<int> bufferId);
~BufferIndexHolder();

BufferIndexHolder(BufferIndexHolder const&) = delete;
BufferIndexHolder& operator=(BufferIndexHolder const&) = delete;
BufferIndexHolder(BufferIndexHolder&& other) noexcept;
BufferIndexHolder& operator=(BufferIndexHolder&& other) noexcept;

[[nodiscard]] static BufferIndexHolder acquireSend(BaseTransBufferManager& manager);
[[nodiscard]] static BufferIndexHolder acquireRecv(BaseTransBufferManager& manager);

[[nodiscard]] std::optional<int> get() const noexcept
{
return mBufferId;
}

[[nodiscard]] bool owns() const noexcept
{
return mOwns;
}

void reset() noexcept;

private:
BaseTransBufferManager* mManager{nullptr};
Direction mDirection{Direction::kSend};
std::optional<int> mBufferId{std::nullopt};
bool mOwns{false};
};

/// @brief Base class for cache transfer buffer management.
/// Handles buffer pool allocation, index assignment, and slicing.
/// Derived classes provide cache-specific size calculations.
Expand Down
12 changes: 5 additions & 7 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
// cache blocks to the corresponding buffer.
// 5. send the buffer to the corresponding target. Ideally, we send only once (one buffer) for each target.

auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
auto cacheBufferHolder = BufferIndexHolder::acquireSend(*mCacheTransBufferManager);
auto cacheBufferId = cacheBufferHolder.get();
int peerDuplicateHeadFactor = targetInfo.mPeerDupHeadFactor;
auto bufferTargetNum = targetNum / peerDuplicateHeadFactor;
auto ppRank = selfIdx
Expand Down Expand Up @@ -578,7 +579,6 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio

session.setTime(TransferSession::kTimeTransmissions);

mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId);
session.setTime(TransferSession::kTimePostprocess);
}
TLLM_LOG_DEBUG(
Expand Down Expand Up @@ -800,6 +800,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
size_t remainNoCoverTargetNum = 0;
size_t bufferCoverTargetNum = 0;
std::optional<int> cacheBufferId = std::nullopt;
BufferIndexHolder cacheBufferHolder;
{
NVTX3_SCOPED_RANGE(formatInputAllocBuffer);

Expand All @@ -813,7 +814,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
}
else
{
cacheBufferId = mCacheTransBufferManager->assignBufferIndexForRecv();
cacheBufferHolder = BufferIndexHolder::acquireRecv(*mCacheTransBufferManager);
cacheBufferId = cacheBufferHolder.get();
}
auto [recvSplitCachestmp, bufferCoverTargetNumtmp, onlyUseDynamicBuffer]
= mCacheTransBufferManager->getOrAllocateRecvBuffers(
Expand Down Expand Up @@ -948,10 +950,6 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
recvSplitCaches, outputBuffersPerWindow, destConfig, selfConfig, selfIdx, bufferManager);

bufferManager.getStream().synchronize();
if (cacheBufferId.has_value())
{
mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId);
}
}
session.setTime(TransferSession::kTimePostprocess);
}
Expand Down
Loading
Loading