Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3043,7 +3043,17 @@ void BlockManager::schedulingReleaseBlocks(RequestIdType requestId)

void WindowBlockManager::schedulingReleaseBlocks(RequestIdType requestId)
{
for (auto& block : mAllocatedBlocksPerSeq.at(requestId))
auto const seqIt = mAllocatedBlocksPerSeq.find(requestId);
if (seqIt == mAllocatedBlocksPerSeq.end())
{
TLLM_LOG_WARNING(
"%s::schedulingReleaseBlocks skipped request %lu because no allocated blocks are tracked. "
"The request was likely already removed before the max-utilization scheduler simulated a pause.",
mLogPrefix.c_str(), requestId);
return;
}

for (auto& block : seqIt->second)
{
// Decrease ref count
block->decSchedulingRefCount();
Expand Down Expand Up @@ -3288,7 +3298,27 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep(LlmRequest const& req, bool tw
return 0;
}

auto const numCurrTokens = getSequence(req.mRequestId).getNumTokens();
auto const maybeNumCurrTokens = [this, requestId = req.mRequestId]() -> std::optional<SizeType32>
{
auto lck = std::scoped_lock(mSequencesMtx);
auto const seqIt = mSequences.find(requestId);
if (seqIt == mSequences.end())
{
return std::nullopt;
}
return seqIt->second.getNumTokens();
}();
if (!maybeNumCurrTokens.has_value())
{
auto const unuschedulableBlocks = mBlockManager.getWindowSizeMetadata(windowSize).maxNumBlocks + 1;
TLLM_LOG_WARNING(
"[kv cache manager] getNeededBlocksOneStep: request %lu is generation-in-progress but no KV "
"sequence exists. Returning %d required blocks so the scheduler pauses or drops the stale request.",
req.mRequestId, unuschedulableBlocks);
return unuschedulableBlocks;
}

auto const numCurrTokens = maybeNumCurrTokens.value();
auto const generatedTokens = numCurrTokens - req.getPromptLen();
auto const maxTokensToAddToKVCache = req.mMaxNewTokens - generatedTokens;
auto const tokensPerStep = req.getNumDraftTokens() + 1;
Expand Down
13 changes: 12 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
using VecTokens = std::vector<TokenIdType>;
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
using CudaStreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using CacheBlockIds = std::vector<std::vector<SizeType32>>;

Expand Down Expand Up @@ -400,7 +401,17 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("__hash__", [](tbk::BlockKey const& key) -> size_t { return tbk::BlockKeyHasher{}(key); });

nb::class_<tbk::BlockKeyHasher>(m, "BlockKeyHasher")
.def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0);
.def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0)
.def_static(
"hash_token_ids",
[](VecTokens const& tokenIds, std::size_t parentHash,
std::optional<CacheSaltIDType> cacheSaltID) -> std::size_t
{
auto blockKey = BlockKey(tokenIds);
blockKey.cacheSaltID = cacheSaltID;
return tbk::BlockKeyHasher::hash(blockKey, parentHash);
},
nb::arg("token_ids"), nb::arg("parent_hash") = 0, nb::arg("cache_salt_id") = std::nullopt);

nb::class_<tbk::KVCacheEventManager>(m, "KVCacheEventManager")
.def(nb::init<size_t, std::optional<SizeType32>, std::optional<SizeType32>, SizeType32>(),
Expand Down
44 changes: 44 additions & 0 deletions cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3387,6 +3387,50 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest)
EXPECT_EQ(blockManager.getNumFreeBlocks(), 0);
}

TEST_F(KVCacheManagerTest, SchedulingRemoveSequenceIgnoresAlreadyRemovedSequence)
{
auto constexpr numLayers = 1;
auto constexpr numKvHeads = 1;
auto constexpr sizePerHead = 8;
auto constexpr tokensPerBlock = 4;
auto constexpr totalNumBlocks = 16;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr maxNumSequences = 2;
auto constexpr maxBeamWidth = 1;
auto constexpr sinkTokenLength = 0;
auto constexpr maxSequenceLength = 16;
auto constexpr maxAttentionWindow = maxSequenceLength;
auto constexpr requestId = 17;
auto constexpr inputLength = 8;
auto constexpr maxNewTokens = 0;
auto const stream = std::make_shared<tr::CudaStream>();
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {totalNumBlocks, blocksInSecondaryPool}}};
tr::SamplingConfig const samplingConfig{maxBeamWidth};
bool constexpr isStreaming{false};

auto inputTokens = std::make_shared<VecTokens>();
for (SizeType32 token = 0; token < inputLength; ++token)
{
inputTokens->push_back(token);
}

KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
maxBeamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt,
nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, false);
kvCacheManager.allocatePools(false);

auto llmRequest = std::make_shared<LlmRequest>(
LlmRequest::RequestIdType{requestId}, maxNewTokens, inputTokens, samplingConfig, isStreaming);
EXPECT_NO_THROW(
kvCacheManager.addSequenceBatch({{{requestId, inputLength, maxBeamWidth}}}, {std::ref(*llmRequest)}));

kvCacheManager.startScheduling();
EXPECT_NO_THROW(kvCacheManager.schedulingRemoveSequence(requestId));
EXPECT_NO_THROW(static_cast<void>(kvCacheManager.removeSequence(requestId)));
EXPECT_NO_THROW(kvCacheManager.schedulingRemoveSequence(requestId));
EXPECT_NO_THROW(kvCacheManager.schedulingRemoveSequence(requestId + 1));
}

TEST_P(KVCacheManagerTest, KVCacheManagerRewindTokensTest)
{
using DType = half;
Expand Down
51 changes: 40 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,7 +3084,7 @@ def _check_disagg_gen_transfer_status(self):
return

@nvtx_range("_check_kv_transfer_timeout")
def _check_kv_transfer_timeout(self):
def _check_kv_transfer_timeout(self) -> None:
if not self.kv_cache_transceiver:
return
timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms
Expand Down Expand Up @@ -3391,17 +3391,28 @@ def _check_cache_transfer_errors(self, error_msg_prefix: str):
f"Error in kv cache transfer for {error_msg_prefix}",
requests=error_requests)

def _complete_context_transfer_with_error(self, request: LlmRequest,
error_msg: str) -> None:
request.py_kv_transfer_start_time = None
request.state = LlmRequestState.DISAGG_TRANS_ERROR
self.async_transfer_manager.end_transfer(request)
if request in self.active_requests:
self._handle_errors(error_msg=error_msg, requests=[request])

@nvtx_range("_check_disagg_ctx_cache_transfer_status")
def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0):
def _check_disagg_ctx_cache_transfer_status(self,
atLeastNum: int = 0) -> None:
finished_requests, error_requests = self.kv_cache_transceiver.check_context_transfer_status(
atLeastNum)

completed_req_ids = set(finished_requests + error_requests)
finished_req_ids = set(finished_requests)
error_req_ids = set(error_requests)
completed_req_ids = finished_req_ids | error_req_ids

requests_in_transfer = self.async_transfer_manager.requests_in_transfer(
)

for request_id in completed_req_ids:
for request_id in finished_req_ids:

if request_id not in requests_in_transfer:
logger.warning(
Expand All @@ -3410,7 +3421,24 @@ def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0):

request = requests_in_transfer[request_id]

self._end_transfer_and_maybe_terminate(request)
if request.py_kv_transfer_timed_out:
self._complete_context_transfer_with_error(
request,
f"Context KV cache transfer completed after timeout for request {request_id}"
)
else:
self._end_transfer_and_maybe_terminate(request)

for request_id in error_req_ids:
if request_id not in requests_in_transfer:
logger.warning(
f"Request {request_id} not found in transfer manager")
continue

request = requests_in_transfer[request_id]
self._complete_context_transfer_with_error(
request,
f"Error in kv cache transfer for context request {request_id}")

# The set of requests in transfer may have changed since we terminated some requests.
requests_in_transfer = self.async_transfer_manager.requests_in_transfer(
Expand All @@ -3420,13 +3448,14 @@ def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0):
request = requests_in_transfer[request_id]
if request.py_kv_transfer_timed_out and request_id not in completed_req_ids:
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
# If cancel is successful, mark as complete so it can be cleaned up
# Otherwise, try at next iteration
# If cancel is successful, mark as transfer error so metrics and
# cleanup do not report this as a clean context completion.
# Otherwise, try at next iteration.
if is_cancelled:
request.py_kv_transfer_start_time = None
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE

self._end_transfer_and_maybe_terminate(request)
self._complete_context_transfer_with_error(
request,
f"Context KV cache transfer timed out for request {request_id}"
)

self._check_cache_transfer_errors("context requests")

Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def add_dummy_requests(
def update_resources(self,
scheduled_batch: ScheduledRequests,
attn_metadata: "AttentionMetadata" = None,
kv_cache_dtype_byte_size: float = None):
kv_cache_dtype_byte_size: float = None) -> None:
if not self.is_draft:
_update_kv_cache_draft_token_location(self, scheduled_batch,
attn_metadata,
Expand All @@ -862,6 +862,10 @@ def update_resources(self,
# storing, so that SWA windows are safe to store — blocks won't go out-of-window
# and be evicted while the context is still in-flight.
for request in scheduled_batch.context_requests:
if (request.is_disagg_context_transmission_state
and self.enable_partial_reuse and not self.is_vswa
and self.mapping.pp_size == 1):
continue
self.impl.store_context_blocks(request)

def free_resources(self, request: LlmRequest, pin_on_release: bool = False):
Expand Down
124 changes: 109 additions & 15 deletions tensorrt_llm/serve/harmony_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,90 @@ def _safe_decode_utf8(self,

return self.encoding.decode(tokens)

def _strip_harmony_stream_control_text(self, text: str) -> str:
"""Strip Harmony framing tokens from best-effort streaming fallback text."""
if not text:
return ""
if not any(marker in text for marker in self._harmony_special_tokens):
return text

cleaned = text
for marker in self._harmony_special_tokens:
cleaned = cleaned.replace(marker, "\n")

# Remove Harmony header-only lines while preserving real content lines.
cleaned = re.sub(
r"(?m)^\s*(assistant|analysis|commentary|final|system|developer|user)(?:\s+to=[^\n]*)?\s*$\n?",
"", cleaned)
cleaned = re.sub(r"(?m)^\s*(json|code)\s*$\n?", "", cleaned)
cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
cleaned = cleaned.strip("\n")
if not cleaned.strip():
return ""
return cleaned

def _decode_streaming_fallback_text(self, request_id: str,
tokens: list[int]) -> str:
"""Decode a failed streaming batch as visible text when Harmony parsing fails."""
try:
decoded_text = self._safe_decode_utf8(tokens)
except (HarmonyError, UnicodeDecodeError, ValueError) as decode_error:
logger.error(
f"Streaming: Failed to decode fallback text for request {request_id}: "
f"{type(decode_error).__name__}: {decode_error}")
logger.debug(
f"Problematic fallback streaming tokens for request {request_id}: {tokens}"
)
return ""

return self._strip_harmony_stream_control_text(decoded_text)

def _streaming_error_state_summary(
self, stream_state: HarmonyStreamState | None) -> dict[str, Any]:
"""Return compact parser state for streaming parse-failure logs."""
if stream_state is None:
return {}

debug_info = stream_state.get_debug_info()
return {
"tokens_processed": debug_info.get("tokens_processed"),
"current_channel": debug_info.get("current_channel"),
"current_recipient": debug_info.get("current_recipient"),
"current_channel_state": debug_info.get("current_channel_state"),
"generated_channels": debug_info.get("generated_channels"),
"channel_started": debug_info.get("channel_started"),
"has_preamble_content": debug_info.get("has_preamble_content"),
"should_filter_tools": debug_info.get("should_filter_tools"),
}

def _handle_streaming_parse_error(
self, request_id: str, tokens: list[int],
available_tools: list[dict[str, Any]] | None,
tool_choice: str | None, parse_error: Exception,
stream_state: HarmonyStreamState | None) -> str:
"""Log, decode, and reset parser state after a streaming Harmony parse error."""
token_sample = {
"count": len(tokens),
"first": tokens[:8],
"last": tokens[-8:],
}
fallback_text = self._decode_streaming_fallback_text(request_id, tokens)
logger.error(
f"Streaming: Failed to process token batch for request {request_id}: "
f"{type(parse_error).__name__}: {parse_error}; "
f"fallback_chars={len(fallback_text)}; "
f"token_sample={token_sample}; "
f"state={self._streaming_error_state_summary(stream_state)}")
logger.debug(
f"Problematic streaming tokens for request {request_id}: {tokens}")

# A StreamableParser can remain poisoned after a parse exception. Reset it so
# later batches still have a chance to parse normally.
self.cleanup_stream_state(request_id)
self.create_stream_state(request_id, available_tools, tool_choice)

return fallback_text

def harmony_system_message(self,
reasoning_effort: ReasoningEffort | None = None,
system_instructions: list[str] = []) -> Message:
Expand Down Expand Up @@ -1376,14 +1460,17 @@ def stateful_stream_harmony_tokens_to_openai_deltas(
deltas = stream_state.process_token_batch(tokens)
# logger.info(">> GENERATED DELTAS: %s", deltas)
return deltas
except (HarmonyError, UnicodeDecodeError, ValueError):
logger.error(
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}"
)
logger.debug(f"Problematic streaming tokens: {tokens}")

# Return empty deltas to continue processing
return []
except (HarmonyError, UnicodeDecodeError, ValueError) as parse_error:
fallback_text = self._handle_streaming_parse_error(
request_id=request_id,
tokens=tokens,
available_tools=available_tools,
tool_choice=tool_choice,
parse_error=parse_error,
stream_state=stream_state)
if not fallback_text:
return []
return [{"content": fallback_text}]

def stateful_stream_harmony_tokens_to_openai_messages(
self,
Expand Down Expand Up @@ -1413,13 +1500,20 @@ def stateful_stream_harmony_tokens_to_openai_messages(
try:
messages = stream_state.process_token_batch_to_messages(tokens)
return messages
except (HarmonyError, UnicodeDecodeError, ValueError):
logger.error(
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}",
)
logger.debug(f"Problematic streaming tokens: {tokens}")

return []
except (HarmonyError, UnicodeDecodeError, ValueError) as parse_error:
fallback_text = self._handle_streaming_parse_error(
request_id=request_id,
tokens=tokens,
available_tools=available_tools,
tool_choice=tool_choice,
parse_error=parse_error,
stream_state=stream_state)
if not fallback_text:
return []
return [
Message.from_role_and_content(
Role.ASSISTANT, fallback_text).with_channel("final")
]

def create_openai_streaming_response(
self,
Expand Down
Loading
Loading