From 36e71a406b089b22759033883bf0bbe2e31fdd9c Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Thu, 30 Apr 2026 11:03:41 -0700 Subject: [PATCH 1/5] [None][fix] Fix GPT-OSS KV-aware router hashing Route GPT-OSS chat requests through Harmony tokenization before computing KV-aware router block hashes, and make the Harmony serving path honor pre-tokenized prompt_token_ids from the router. Also forward cache_salt from chat_harmony into generation. Add regression coverage for cache-salt hash partitioning, zero-block prompt routing, and partial final-block matched-token accounting. Signed-off-by: Simeng Liu --- tensorrt_llm/serve/openai_server.py | 26 +- tensorrt_llm/serve/router.py | 318 ++++++++++++++++---- tests/unittest/disaggregated/test_router.py | 234 +++++++++++++- 3 files changed, 496 insertions(+), 82 deletions(-) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index c8b4d2bfee9d..bc234f4bbdf2 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -1490,17 +1490,20 @@ async def create_streaming_generator(promise: RequestOutput, # Get tool_choice from request tool_choice = getattr(request, 'tool_choice', None) - try: - harmony_tokens = self.harmony_adapter.openai_to_harmony_tokens( - request.messages, - tools_dict, - reasoning_effort=reasoning_effort, - tool_choice=tool_choice) - except Exception as e: - logger.error(f"messages_dict: {request.messages}") - logger.error(f"tools_dict: {tools_dict}") - logger.error(f"request: {request}") - raise e + if request.prompt_token_ids is not None: + harmony_tokens = request.prompt_token_ids + else: + try: + harmony_tokens = self.harmony_adapter.openai_to_harmony_tokens( + request.messages, + tools_dict, + reasoning_effort=reasoning_effort, + tool_choice=tool_choice) + except Exception as e: + logger.error(f"messages_dict: {request.messages}") + logger.error(f"tools_dict: {tools_dict}") + logger.error(f"request: {request}") + raise e # Get harmony stop tokens harmony_stop_tokens = self.harmony_adapter.get_stop_tokens() @@ -1534,6 +1537,7 @@ async def create_streaming_generator(promise: RequestOutput, streaming=bool(request.stream), lora_request=request.lora_request, disaggregated_params=disaggregated_params, + cache_salt=request.cache_salt, trace_headers=trace_headers, ) if not self.postproc_worker_enabled: diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 42370825098f..f645308ba0a4 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -1,12 +1,14 @@ import asyncio +import json import os import time from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Union +from typing import (Awaitable, Callable, Dict, Iterable, List, Mapping, + Optional, TypeAlias, Union) import aiohttp -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from tensorrt_llm.bindings.internal.batch_manager import (BlockKey, BlockKeyHasher) @@ -18,6 +20,80 @@ CompletionRequest) OpenAIRequest = Union[CompletionRequest, ChatCompletionRequest] +KvCacheEventRecord: TypeAlias = Mapping[str, object] +UINT32_MASK = (1 << 32) - 1 +UINT64_MASK = (1 << 64) - 1 + + +def _hash32_mix(value: int, seed: int) -> int: + value &= UINT32_MASK + value = (((value >> 16) ^ value) * 0x45d9f3b) & UINT32_MASK + value = (((value >> 16) ^ value) * 0x45d9f3b) & UINT32_MASK + value = ((value >> 16) ^ value) & UINT32_MASK + value = (value + 0x9e3779b9) & UINT32_MASK + return (seed ^ (value + ((seed << 6) & UINT64_MASK) + + (seed >> 2))) & UINT64_MASK + + +def _hash64_mix(value: int, seed: int) -> int: + value &= UINT64_MASK + value = ((value ^ (value >> 30)) * 0xbf58476d1ce4e5b9) & UINT64_MASK + value = ((value ^ (value >> 27)) * 0x94d049bb133111eb) & UINT64_MASK + value = (value ^ (value >> 31)) & UINT64_MASK + return (seed ^ (value + 0x9e3779b9 + ((seed << 6) & UINT64_MASK) + + (seed >> 2))) & UINT64_MASK + + +def _get_cache_salt_id(cache_salt: str) -> int: + from blake3 import blake3 + + cache_salt_id = int.from_bytes(blake3( + cache_salt.encode("utf-8")).digest(length=8), + "little", + signed=False) + if cache_salt_id < 0 or cache_salt_id >= (1 << 64): + raise ValueError( + f"cache_salt_id must be in [0, 2**64 - 1], got {cache_salt_id}.") + return cache_salt_id + + +def _as_event_records(event_raw: object) -> list[KvCacheEventRecord]: + if isinstance(event_raw, str): + try: + decoded = json.loads(event_raw) + except json.JSONDecodeError: + return [] + return _as_event_records(decoded) + if isinstance(event_raw, Mapping): + return [event_raw] + if isinstance(event_raw, list): + records: list[KvCacheEventRecord] = [] + for item in event_raw: + records.extend(_as_event_records(item)) + return records + return [] + + +def _extract_block_hashes_from_blocks(blocks: object) -> list[int]: + if not isinstance(blocks, list): + return [] + block_hashes: list[int] = [] + for block in blocks: + if not isinstance(block, Mapping): + continue + block_hash = block.get("block_hash") + if isinstance(block_hash, int): + block_hashes.append(block_hash) + return block_hashes + + +def _extract_block_hashes(block_hashes_raw: object) -> list[int]: + if not isinstance(block_hashes_raw, list): + return [] + return [ + block_hash for block_hash in block_hashes_raw + if isinstance(block_hash, int) + ] def get_request_num_tokens(request: OpenAIRequest) -> int: @@ -101,53 +177,71 @@ def __init__( self._kv_cache_block_table: set[int] = set() self._tokens_per_block = tokens_per_block - def add_blocks(self, block_hashes: Iterable[int]): - for hash in block_hashes: - self._kv_cache_block_table.add(hash) + def add_blocks(self, block_hashes: Iterable[int]) -> None: + for block_hash in block_hashes: + self._kv_cache_block_table.add(block_hash) - def remove_blocks(self, block_hashes: Iterable[int]): - for hash in block_hashes: - self._kv_cache_block_table.discard(hash) + def remove_blocks(self, block_hashes: Iterable[int]) -> None: + for block_hash in block_hashes: + self._kv_cache_block_table.discard(block_hash) - def update_with_events(self, events: Iterable[dict]): - # event_raw: {"id": , "data": } + def update_with_events(self, events: Iterable[object]) -> None: for event_raw in events: - if "data" in event_raw: - event = event_raw["data"] - else: - event = event_raw - - if event["type"] == "stored": - self.add_blocks(block["block_hash"] - for block in event["blocks"]) - elif event["type"] == "removed": - self.remove_blocks(event["block_hashes"]) - - async def poll_events(self, session: aiohttp.ClientSession): + event_records = _as_event_records(event_raw) + + for event_record in event_records: + event = event_record.get("data", event_record) + if not isinstance(event, Mapping): + continue + + event_type = event.get("type") + if event_type == "stored": + block_hashes = _extract_block_hashes_from_blocks( + event.get("blocks")) + self.add_blocks(block_hashes) + elif event_type == "removed": + block_hashes = _extract_block_hashes( + event.get("block_hashes")) + self.remove_blocks(block_hashes) + + async def poll_events( + self, session: aiohttp.ClientSession) -> list[object] | None: async with session.post( f"{self._base_url}/kv_cache_events") as response: events_raw = await response.json() - return events_raw + if events_raw is None: + return None + if isinstance(events_raw, list): + return events_raw + return [events_raw] - async def matched_tokens(self, block_hashes: list[list[int]]) -> int: + async def matched_tokens( + self, + block_hashes: list[list[int]], + block_lengths: Optional[list[list[int]]] = None) -> int: match_count = 0 async with self._lock: - for hash_list in block_hashes: - for hash in hash_list: - # TODO: 1) parent hash verification, 2) partial matching - if hash in self._kv_cache_block_table: - match_count += self._tokens_per_block + for prompt_index, hash_list in enumerate(block_hashes): + lengths = None if block_lengths is None else block_lengths[ + prompt_index] + for block_index, block_hash in enumerate(hash_list): + # TODO: parent hash verification + if block_hash in self._kv_cache_block_table: + if lengths is None: + match_count += self._tokens_per_block + else: + match_count += lengths[block_index] else: break return match_count - async def decrement_load(self, request: OpenAIRequest): + async def decrement_load(self, request: OpenAIRequest) -> None: num_tokens = get_request_num_tokens(request) if self._use_tokens else 0 async with self._lock: self._num_active_requests -= 1 self._num_active_tokens -= num_tokens - async def poll_and_update(self): + async def poll_and_update(self) -> None: """Poll KV cache events and update block table. Called outside the critical path.""" try: assert self._session is not None, "session must be set on KvCacheAwareServerState" @@ -159,10 +253,10 @@ async def poll_and_update(self): logger.warning( f"Failed to poll KV cache events from {self._server}: {e}") - def num_active_tokens(self): + def num_active_tokens(self) -> int: return self._num_active_tokens - def num_active_requests(self): + def num_active_requests(self) -> int: return self._num_active_requests @@ -644,11 +738,27 @@ async def finish_request(self, request: OpenAIRequest): await self._unregister_request(request) +def _python_block_key_hash(token_ids: list[int], parent_hash: int, + cache_salt_id: Optional[int]) -> int: + seed = (len(token_ids) ^ ((parent_hash * 0xbf58476d1ce4e5b9) & UINT64_MASK)) + seed &= UINT64_MASK + if parent_hash == 0 and cache_salt_id is not None: + seed = _hash64_mix(cache_salt_id, seed) + for token_id in token_ids: + seed = _hash32_mix(token_id, seed) + return seed + + def block_key_hasher(token_ids: list[int], - parent_hash: Optional[int] = None) -> int: + parent_hash: Optional[int] = None, + cache_salt_id: Optional[int] = None) -> int: + normalized_parent_hash = 0 if parent_hash is None else parent_hash + if cache_salt_id is not None: + return _python_block_key_hash(token_ids, normalized_parent_hash, + cache_salt_id) + block_key = BlockKey(token_ids) - return BlockKeyHasher.hash(block_key, - 0 if parent_hash is None else parent_hash) + return BlockKeyHasher.hash(block_key, normalized_parent_hash) class BlockHashMixin: @@ -659,13 +769,14 @@ class BlockHashMixin: def _init_block_hashing(self, tokens_per_block: int = 32, - custom_tokenizer: Optional[str] = None): + custom_tokenizer: Optional[str] = None) -> None: env_tokens_per_block = os.environ.get( "TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK") if env_tokens_per_block is not None: tokens_per_block = int(env_tokens_per_block) self._tokens_per_block = tokens_per_block self._tokenizers: dict = {} + self._model_types: dict[str, Optional[str]] = {} self._custom_tokenizer = custom_tokenizer logger.info(f"BlockHashMixin: tokens_per_block={self._tokens_per_block}" f", custom_tokenizer={self._custom_tokenizer}") @@ -681,21 +792,62 @@ def _get_tokenizer(self, model: str): model, trust_remote_code=True) return self._tokenizers[model] + def _get_model_type(self, model: str) -> Optional[str]: + if model not in self._model_types: + model_type = None + normalized_model = model.lower().replace("_", "-") + if "gpt-oss" in normalized_model or "gptoss" in normalized_model: + model_type = "gpt_oss" + elif os.path.exists(model): + try: + model_config = AutoConfig.from_pretrained( + model, trust_remote_code=True, local_files_only=True) + model_type = getattr(model_config, "model_type", None) + except Exception as e: + logger.debug( + "BlockHashMixin: failed to read model config for " + f"{model}: {e}") + self._model_types[model] = model_type + return self._model_types[model] + + def _uses_harmony_tokenization(self, + request: ChatCompletionRequest) -> bool: + return self._get_model_type(request.model) == "gpt_oss" + + @staticmethod + def _tool_dicts(request: ChatCompletionRequest) -> Optional[list[dict]]: + return (None if getattr(request, "tools", None) is None else [ + tool.model_dump() if hasattr(tool, "model_dump") else tool + for tool in request.tools + ]) + + def _tokenize_harmony_chat( + self, request: ChatCompletionRequest) -> list[list[int]]: + from tensorrt_llm.serve import harmony_adapter + + result = harmony_adapter.get_harmony_adapter().openai_to_harmony_tokens( + request.messages, + self._tool_dicts(request), + reasoning_effort=harmony_adapter.maybe_transform_reasoning_effort( + request.reasoning_effort), + tool_choice=getattr(request, "tool_choice", None), + ) + request.prompt_token_ids = result + return [result] + def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: # Handle ChatCompletionRequest (has messages, not prompt) if isinstance(request, ChatCompletionRequest): if request.prompt_token_ids is not None: return [request.prompt_token_ids] + if self._uses_harmony_tokenization(request): + return self._tokenize_harmony_chat(request) tokenizer = self._get_tokenizer(request.model) # Forward tools and chat_template_kwargs so custom tokenizers # (e.g. DeepseekV32Tokenizer) render tool schemas and respect # template flags like `thinking=true` when computing the prompt # token ids used for cache-aware routing AND passed downstream # (prompt_token_ids makes the worker skip re-tokenization). - tool_dicts = (None if getattr(request, "tools", None) is None else [ - tool.model_dump() if hasattr(tool, "model_dump") else tool - for tool in request.tools - ]) chat_template_kwargs = (request.chat_template_kwargs if getattr( request, "chat_template_kwargs", None) else {}) result = tokenizer.apply_chat_template( @@ -705,7 +857,7 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: ], add_generation_prompt=request.add_generation_prompt, tokenize=True, - tools=tool_dicts, + tools=self._tool_dicts(request), **chat_template_kwargs, ) # Some custom tokenizers (e.g. DeepseekV32Tokenizer) return a @@ -736,8 +888,29 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: if len(token_lists) > 1 else token_lists[0]) return token_lists - def _compute_block_hashes(self, - token_lists: list[list[int]]) -> list[list[int]]: + @staticmethod + def _request_cache_salt_id(request: OpenAIRequest) -> Optional[int]: + cache_salt = getattr(request, "cache_salt", None) + if cache_salt is None: + return None + + return _get_cache_salt_id(cache_salt) + + def _compute_block_lengths(self, + token_lists: list[list[int]]) -> list[list[int]]: + block_lengths: list[list[int]] = [] + for token_list in token_lists: + lengths = [] + for t in range(0, len(token_list) - 1, self._tokens_per_block): + t_end = min(t + self._tokens_per_block, len(token_list) - 1) + lengths.append(t_end - t) + block_lengths.append(lengths) + return block_lengths + + def _compute_block_hashes( + self, + token_lists: list[list[int]], + cache_salt_id: Optional[int] = None) -> list[list[int]]: block_hashes: list[list[int]] = [] for token_list in token_lists: hash_list = [] @@ -747,13 +920,14 @@ def _compute_block_hashes(self, t_end = min(t + self._tokens_per_block, len(token_list) - 1) hash_list.append( block_key_hasher(token_list[t:t_end], - None if t == 0 else hash_list[-1])) + None if t == 0 else hash_list[-1], + cache_salt_id)) block_hashes.append(hash_list) return block_hashes def _tokenize_and_compute_block_hashes( - self, - request: OpenAIRequest) -> tuple[list[list[int]], list[list[int]]]: + self, request: OpenAIRequest + ) -> tuple[list[list[int]], list[list[int]], list[list[int]]]: """Synchronous tokenize + block-hash, combined for thread offload. Factored into one method so ``get_next_server`` can offload the whole @@ -762,8 +936,10 @@ def _tokenize_and_compute_block_hashes( requests in parallel. """ token_lists = self._tokenize(request) - block_hashes = self._compute_block_hashes(token_lists) - return token_lists, block_hashes + block_lengths = self._compute_block_lengths(token_lists) + block_hashes = self._compute_block_hashes( + token_lists, self._request_cache_salt_id(request)) + return token_lists, block_hashes, block_lengths @staticmethod def _text_to_int_sequences(texts: list[str]) -> list[list[int]]: @@ -809,6 +985,10 @@ async def get_next_server( server for server in self._server_state.keys() if server != exclude_server ] + if not servers: + raise ValueError( + f"No available servers after excluding {exclude_server}") + # Tokenize + block-hash is CPU-bound (~50 ms p50 for a 40 k-token # chat request with a Rust-backed tokenizer). Running it directly # inside the async handler blocks the orchestrator's event loop and @@ -816,26 +996,32 @@ async def get_next_server( # tokenizers releasing the GIL, offloading to a thread lets multiple # tokenize calls run in parallel and frees the event loop to # dispatch HTTP traffic to the CTX/GEN workers meanwhile. - token_lists, block_hashes = await asyncio.to_thread( + token_lists, block_hashes, block_lengths = await asyncio.to_thread( self._tokenize_and_compute_block_hashes, request) - padded_tokens = sum( - len(hash_list) - for hash_list in block_hashes) * self._tokens_per_block + hashable_tokens = sum( + sum(prompt_block_lengths) for prompt_block_lengths in block_lengths) # select the server by (KV match - load) # TODO: more options - workloads = [ - state.num_active_requests() - for state in self._server_state.values() - ] + workloads_by_server = { + server: self._server_state[server].num_active_requests() + for server in servers + } + active_tokens_by_server = { + server: self._server_state[server].num_active_tokens() + for server in servers + } + workloads = [workloads_by_server[server] for server in servers] + active_tokens = [active_tokens_by_server[server] for server in servers] scores = [] matches = [] - for i in range(len(servers)): - server = servers[i] + for server in servers: # https://github.com/ai-dynamo/dynamo/blob/main/docs/kv_cache_routing.md#kv-cache-routing-and-load-balancing - matches.append( - await self._server_state[server].matched_tokens(block_hashes)) - score = matches[-1] / padded_tokens - workloads[ - i] / self._max_batch_size + matched_tokens = await self._server_state[server].matched_tokens( + block_hashes, block_lengths) + matches.append(matched_tokens) + match_ratio = matched_tokens / hashable_tokens if hashable_tokens else 0.0 + score = (match_ratio - + workloads_by_server[server] / self._max_batch_size) scores.append(score) max_score = max(scores) tied = [i for i, s in enumerate(scores) if s == max_score] @@ -846,8 +1032,14 @@ async def get_next_server( await self._register_request(server, request) return server, { "block_hashes": block_hashes, # list[list[int]] + "block_lengths": block_lengths, # list[list[int]] "token_lists": token_lists, # list[list[int]] + "hashable_tokens": hashable_tokens, # int "matches": matches, # list[int] + "scores": scores, # list[float] + "workloads": workloads, # list[int] + "active_tokens": active_tokens, # list[int] + "candidate_servers": servers, # list[str] "server_info": self._server_info.get(server, {}), } diff --git a/tests/unittest/disaggregated/test_router.py b/tests/unittest/disaggregated/test_router.py index cb2856df4cd8..7f7722f45be3 100644 --- a/tests/unittest/disaggregated/test_router.py +++ b/tests/unittest/disaggregated/test_router.py @@ -14,7 +14,7 @@ FunctionDefinition) from tensorrt_llm.serve.router import (ConversationRouter, KvCacheAwareRouter, LoadBalancingRouter, RoundRobinRouter, - create_router) + block_key_hasher, create_router) def _make_mock_aiohttp_session(return_value=None): @@ -301,20 +301,22 @@ def matches_by_server(info): hit_servers, hit_infos = zip(*results) assert hit_servers == (server_of[2], server_of[1], server_of[0]) - # matched partial block will be counted as a whole block - # req2 ([1002]*300): only matches server_of[2] → 320 tokens + # matched partial blocks are counted by their actual token lengths + # req2 ([1002]*300): only matches server_of[2] -> 299 hashable tokens m0 = matches_by_server(hit_infos[0]) - assert m0[server_of[2]] == 320 + assert m0[server_of[2]] == 299 assert m0[server_of[0]] == 0 assert m0[server_of[1]] == 0 - # req1 ([1000]*50+[1001]*150): full match server_of[1] → 224, partial server_of[0] → 32 + # req1 ([1000]*50+[1001]*150): full match server_of[1] -> 199, + # partial server_of[0] -> 32 m1 = matches_by_server(hit_infos[1]) - assert m1[server_of[1]] == 224 + assert m1[server_of[1]] == 199 assert m1[server_of[0]] == 32 assert m1[server_of[2]] == 0 - # req0 ([1000]*100): full match server_of[0] → 128, partial server_of[1] → 32 + # req0 ([1000]*100): full match server_of[0] -> 99, + # partial server_of[1] -> 32 m2 = matches_by_server(hit_infos[2]) - assert m2[server_of[0]] == 128 + assert m2[server_of[0]] == 99 assert m2[server_of[1]] == 32 assert m2[server_of[2]] == 0 for request in requests: @@ -357,6 +359,129 @@ def matches_by_server(info): assert len(set(final_servers)) == 3 +@pytest.mark.asyncio +async def test_kv_cache_aware_router_exclude_server_uses_candidate_workloads( + servers: list[str]) -> None: + router = KvCacheAwareRouter(server_role=None, + servers=servers, + use_tokens=False, + max_batch_size=10, + tokens_per_block=32) + + router._server_state["server1"]._num_active_requests = 100 + router._server_state["server2"]._num_active_requests = 0 + router._server_state["server3"]._num_active_requests = 1 + + request = CompletionRequest(model="TinyLlama", prompt=[[1234] * 65]) + server, info = await router.get_next_server(request, + exclude_server="server1") + try: + assert server == "server2" + assert info["candidate_servers"] == ["server2", "server3"] + assert info["workloads"] == [0, 1] + finally: + await router.finish_request(request) + + +@pytest.mark.asyncio +async def test_kv_cache_aware_router_short_prompt_without_blocks_uses_load( + servers: list[str]) -> None: + router = KvCacheAwareRouter(server_role=None, + servers=servers, + use_tokens=False, + max_batch_size=10, + tokens_per_block=32) + + router._server_state["server1"]._num_active_requests = 9 + router._server_state["server2"]._num_active_requests = 0 + router._server_state["server3"]._num_active_requests = 1 + + request = CompletionRequest(model="TinyLlama", prompt=[[1234]]) + server, info = await router.get_next_server(request) + try: + assert server == "server2" + assert info["block_hashes"] == [[]] + assert info["block_lengths"] == [[]] + assert info["hashable_tokens"] == 0 + assert info["workloads"] == [9, 0, 1] + assert info["scores"] == [-0.9, 0.0, -0.1] + finally: + await router.finish_request(request) + + +@pytest.mark.asyncio +async def test_kv_cache_aware_router_counts_partial_block_tokens() -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1"], + use_tokens=False, + max_batch_size=10, + tokens_per_block=32) + token_ids = list(range(50)) + block_hashes = router._compute_block_hashes([token_ids]) + block_lengths = router._compute_block_lengths([token_ids]) + + assert block_lengths == [[32, 17]] + router._server_state["server1"].add_blocks(block_hashes[0]) + assert await router._server_state["server1"].matched_tokens( + block_hashes, block_lengths) == 49 + + request = CompletionRequest(model="TinyLlama", prompt=[token_ids]) + server, info = await router.get_next_server(request) + try: + assert server == "server1" + assert info["matches"] == [49] + assert info["hashable_tokens"] == 49 + assert info["scores"] == [1.0] + finally: + await router.finish_request(request) + + +@pytest.mark.asyncio +async def test_kv_cache_aware_router_cache_salt_partitions_hashes() -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1", "server2"], + use_tokens=False, + max_batch_size=10, + tokens_per_block=32) + token_ids = list(range(65)) + tenant_a = ChatCompletionRequest(model="TinyLlama", + messages=[{ + "role": "user", + "content": "unused" + }], + prompt_token_ids=token_ids, + cache_salt="tenant-a") + token_lists, tenant_a_hashes, block_lengths = ( + router._tokenize_and_compute_block_hashes(tenant_a)) + salt_id = router._request_cache_salt_id(tenant_a) + + assert token_lists == [token_ids] + assert block_lengths == [[32, 32]] + assert salt_id is not None + assert tenant_a_hashes != router._compute_block_hashes([token_ids]) + assert tenant_a_hashes[0][0] != block_key_hasher(token_ids[:32]) + assert block_key_hasher(token_ids[32:64], tenant_a_hashes[0][0], + salt_id) == block_key_hasher( + token_ids[32:64], tenant_a_hashes[0][0]) + + router._server_state["server1"].add_blocks(tenant_a_hashes[0]) + tenant_b = ChatCompletionRequest(model="TinyLlama", + messages=[{ + "role": "user", + "content": "unused" + }], + prompt_token_ids=token_ids, + cache_salt="tenant-b") + server, info = await router.get_next_server(tenant_b) + try: + server1_index = info["candidate_servers"].index("server1") + assert info["block_hashes"][0] != tenant_a_hashes[0] + assert info["matches"][server1_index] == 0 + assert server == "server1" + finally: + await router.finish_request(tenant_b) + + @pytest.mark.asyncio @pytest.mark.parametrize("api_type", ["completion", "chat"]) async def test_kv_cache_aware_router_multi_turn_conversation( @@ -1000,6 +1125,59 @@ def test_tokenize_preserves_empty_tools_list(): assert kwargs["tools"] == [] +def test_gpt_oss_tokenize_uses_harmony_tokens_for_router_hashes() -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1"], + use_tokens=False, + max_batch_size=32, + tokens_per_block=32) + + tokenizer = _mock_tokenizer(token_ids=[900, 901, 902, 903]) + harmony_tokens = [100, 101, 102, 103, 104] + harmony = mock.MagicMock() + harmony.openai_to_harmony_tokens.return_value = harmony_tokens + + with mock.patch("tensorrt_llm.serve.harmony_adapter.get_harmony_adapter", + return_value=harmony), mock.patch( + "tensorrt_llm.serve.harmony_adapter." + "maybe_transform_reasoning_effort", + return_value="medium"), mock.patch.object( + router, "_get_tokenizer", return_value=tokenizer): + req = ChatCompletionRequest( + model="openai/gpt-oss-20b", + messages=[{ + "role": "developer", + "content": "Use tools when useful." + }, { + "role": "user", + "content": "what's the weather in Paris?" + }], + tools=[_get_weather_tool()], + tool_choice="none", + reasoning_effort="medium", + ) + token_lists, block_hashes, block_lengths = ( + router._tokenize_and_compute_block_hashes(req)) + + tokenizer.apply_chat_template.assert_not_called() + harmony.openai_to_harmony_tokens.assert_called_once() + assert token_lists == [harmony_tokens] + assert block_lengths == router._compute_block_lengths([harmony_tokens]) + assert router._request_cache_salt_id(req) is None + assert req.prompt_token_ids == harmony_tokens + + call_args = harmony.openai_to_harmony_tokens.call_args + assert call_args.args[0] == req.messages + tool_dicts = call_args.args[1] + assert isinstance(tool_dicts, list) + assert tool_dicts[0]["function"]["name"] == "get_current_weather" + assert call_args.kwargs["reasoning_effort"] == "medium" + assert call_args.kwargs["tool_choice"] == "none" + + expected_hashes = router._compute_block_hashes([harmony_tokens]) + assert block_hashes == expected_hashes + + def test_tokenize_skipped_when_prompt_token_ids_already_set(): """Skip tokenization when ``prompt_token_ids`` is already populated. @@ -1031,3 +1209,43 @@ def test_tokenize_skipped_when_prompt_token_ids_already_set(): assert out == [[10, 20, 30]] get_tok.assert_not_called() tok.apply_chat_template.assert_not_called() + + +@pytest.mark.asyncio +async def test_chat_harmony_uses_prompt_token_ids_and_cache_salt() -> None: + from tensorrt_llm.serve.openai_server import OpenAIServer + + server = OpenAIServer.__new__(OpenAIServer) + server.harmony_adapter = mock.MagicMock() + server.harmony_adapter.openai_to_harmony_tokens.return_value = [900, 901] + server.harmony_adapter.get_stop_tokens.return_value = [42] + server.tokenizer = mock.MagicMock() + server.tokenizer.tokenizer.vocab_size = 1000 + server.await_disconnected = mock.AsyncMock() + + promise = mock.MagicMock() + promise.prompt_token_ids = [10, 11, 12, 13] + server.generator = mock.MagicMock() + server.generator.args.num_postprocess_workers = 1 + server.generator.generate_async.return_value = promise + + req = ChatCompletionRequest( + model="openai/gpt-oss-20b", + messages=[{ + "role": "user", + "content": "hello" + }], + prompt_token_ids=[10, 11, 12, 13], + stream=True, + max_completion_tokens=1, + cache_salt="tenant-a", + ) + + await server.chat_harmony(req, raw_request=None) + await asyncio.sleep(0) + + server.harmony_adapter.openai_to_harmony_tokens.assert_not_called() + server.generator.generate_async.assert_called_once() + generate_kwargs = server.generator.generate_async.call_args.kwargs + assert generate_kwargs["inputs"] == [10, 11, 12, 13] + assert generate_kwargs["cache_salt"] == "tenant-a" From 035da99a662d86a975de63a10e595b521642b3c6 Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Thu, 30 Apr 2026 18:24:19 -0700 Subject: [PATCH 2/5] [None][fix] Align router hashing with C++ Bind the C++ BlockKeyHasher path for token IDs with optional cache salt and route all Python block hashing through it. Add use_harmony as an explicit router override while preserving GPT-OSS name/config detection when unset. Keep router tokenization side-effect free so generation workers can retokenize requests, and cover the behavior with router tests. Signed-off-by: Simeng Liu --- .../nanobind/batch_manager/kvCacheManager.cpp | 13 ++- tensorrt_llm/serve/router.py | 76 ++++-------- tests/unittest/disaggregated/test_router.py | 110 +++++++++++++++++- 3 files changed, 136 insertions(+), 63 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 295114b00711..f165ddf4ec0d 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -49,6 +49,7 @@ using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using SizeType32 = tensorrt_llm::runtime::SizeType32; using TokenIdType = tensorrt_llm::runtime::TokenIdType; using VecTokens = std::vector; +using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType; using CudaStreamPtr = std::shared_ptr; using CacheBlockIds = std::vector>; @@ -400,7 +401,17 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("__hash__", [](tbk::BlockKey const& key) -> size_t { return tbk::BlockKeyHasher{}(key); }); nb::class_(m, "BlockKeyHasher") - .def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0); + .def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0) + .def_static( + "hash_token_ids", + [](VecTokens const& tokenIds, std::size_t parentHash, + std::optional cacheSaltID) -> std::size_t + { + auto blockKey = BlockKey(tokenIds); + blockKey.cacheSaltID = cacheSaltID; + return tbk::BlockKeyHasher::hash(blockKey, parentHash); + }, + nb::arg("token_ids"), nb::arg("parent_hash") = 0, nb::arg("cache_salt_id") = std::nullopt); nb::class_(m, "KVCacheEventManager") .def(nb::init, std::optional, SizeType32>(), diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index f645308ba0a4..c22d82709e4f 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -10,8 +10,7 @@ import aiohttp from transformers import AutoConfig, AutoTokenizer -from tensorrt_llm.bindings.internal.batch_manager import (BlockKey, - BlockKeyHasher) +from tensorrt_llm.bindings.internal.batch_manager import BlockKeyHasher from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig, RouterConfig, ServerRole) from tensorrt_llm.logger import logger @@ -21,27 +20,7 @@ OpenAIRequest = Union[CompletionRequest, ChatCompletionRequest] KvCacheEventRecord: TypeAlias = Mapping[str, object] -UINT32_MASK = (1 << 32) - 1 -UINT64_MASK = (1 << 64) - 1 - - -def _hash32_mix(value: int, seed: int) -> int: - value &= UINT32_MASK - value = (((value >> 16) ^ value) * 0x45d9f3b) & UINT32_MASK - value = (((value >> 16) ^ value) * 0x45d9f3b) & UINT32_MASK - value = ((value >> 16) ^ value) & UINT32_MASK - value = (value + 0x9e3779b9) & UINT32_MASK - return (seed ^ (value + ((seed << 6) & UINT64_MASK) + - (seed >> 2))) & UINT64_MASK - - -def _hash64_mix(value: int, seed: int) -> int: - value &= UINT64_MASK - value = ((value ^ (value >> 30)) * 0xbf58476d1ce4e5b9) & UINT64_MASK - value = ((value ^ (value >> 27)) * 0x94d049bb133111eb) & UINT64_MASK - value = (value ^ (value >> 31)) & UINT64_MASK - return (seed ^ (value + 0x9e3779b9 + ((seed << 6) & UINT64_MASK) + - (seed >> 2))) & UINT64_MASK +CACHE_SALT_ID_UPPER_BOUND = 1 << 64 def _get_cache_salt_id(cache_salt: str) -> int: @@ -51,7 +30,7 @@ def _get_cache_salt_id(cache_salt: str) -> int: cache_salt.encode("utf-8")).digest(length=8), "little", signed=False) - if cache_salt_id < 0 or cache_salt_id >= (1 << 64): + if cache_salt_id < 0 or cache_salt_id >= CACHE_SALT_ID_UPPER_BOUND: raise ValueError( f"cache_salt_id must be in [0, 2**64 - 1], got {cache_salt_id}.") return cache_salt_id @@ -738,27 +717,12 @@ async def finish_request(self, request: OpenAIRequest): await self._unregister_request(request) -def _python_block_key_hash(token_ids: list[int], parent_hash: int, - cache_salt_id: Optional[int]) -> int: - seed = (len(token_ids) ^ ((parent_hash * 0xbf58476d1ce4e5b9) & UINT64_MASK)) - seed &= UINT64_MASK - if parent_hash == 0 and cache_salt_id is not None: - seed = _hash64_mix(cache_salt_id, seed) - for token_id in token_ids: - seed = _hash32_mix(token_id, seed) - return seed - - def block_key_hasher(token_ids: list[int], parent_hash: Optional[int] = None, cache_salt_id: Optional[int] = None) -> int: normalized_parent_hash = 0 if parent_hash is None else parent_hash - if cache_salt_id is not None: - return _python_block_key_hash(token_ids, normalized_parent_hash, - cache_salt_id) - - block_key = BlockKey(token_ids) - return BlockKeyHasher.hash(block_key, normalized_parent_hash) + return BlockKeyHasher.hash_token_ids(token_ids, normalized_parent_hash, + cache_salt_id) class BlockHashMixin: @@ -769,7 +733,8 @@ class BlockHashMixin: def _init_block_hashing(self, tokens_per_block: int = 32, - custom_tokenizer: Optional[str] = None) -> None: + custom_tokenizer: Optional[str] = None, + use_harmony: Optional[bool] = None) -> None: env_tokens_per_block = os.environ.get( "TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK") if env_tokens_per_block is not None: @@ -778,8 +743,10 @@ def _init_block_hashing(self, self._tokenizers: dict = {} self._model_types: dict[str, Optional[str]] = {} self._custom_tokenizer = custom_tokenizer + self._use_harmony = use_harmony logger.info(f"BlockHashMixin: tokens_per_block={self._tokens_per_block}" - f", custom_tokenizer={self._custom_tokenizer}") + f", custom_tokenizer={self._custom_tokenizer}" + f", use_harmony={self._use_harmony}") def _get_tokenizer(self, model: str): if model not in self._tokenizers: @@ -812,6 +779,8 @@ def _get_model_type(self, model: str) -> Optional[str]: def _uses_harmony_tokenization(self, request: ChatCompletionRequest) -> bool: + if self._use_harmony is not None: + return self._use_harmony return self._get_model_type(request.model) == "gpt_oss" @staticmethod @@ -832,7 +801,6 @@ def _tokenize_harmony_chat( request.reasoning_effort), tool_choice=getattr(request, "tool_choice", None), ) - request.prompt_token_ids = result return [result] def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: @@ -846,8 +814,7 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: # Forward tools and chat_template_kwargs so custom tokenizers # (e.g. DeepseekV32Tokenizer) render tool schemas and respect # template flags like `thinking=true` when computing the prompt - # token ids used for cache-aware routing AND passed downstream - # (prompt_token_ids makes the worker skip re-tokenization). + # token ids used for cache-aware routing. chat_template_kwargs = (request.chat_template_kwargs if getattr( request, "chat_template_kwargs", None) else {}) result = tokenizer.apply_chat_template( @@ -865,8 +832,6 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: # Encode to token IDs if needed. if isinstance(result, str): result = tokenizer.encode(result, add_special_tokens=False) - # Set prompt_token_ids so the worker server skips re-tokenization - request.prompt_token_ids = result return [result] # Handle CompletionRequest (has prompt) @@ -882,10 +847,6 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: tokenizer = self._get_tokenizer(request.model) token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts] - # Replace string prompts with token IDs so the worker server - # skips re-tokenization - request.prompt = (token_lists - if len(token_lists) > 1 else token_lists[0]) return token_lists @staticmethod @@ -963,10 +924,12 @@ def __init__(self, max_batch_size: int = 64, tokens_per_block: int = 32, custom_tokenizer: Optional[str] = None, - **kwargs): + use_harmony: Optional[bool] = None, + **kwargs) -> None: super().__init__(server_role, servers, metadata_server_cfg, metadata_server, **kwargs) - self._init_block_hashing(tokens_per_block, custom_tokenizer) + self._init_block_hashing(tokens_per_block, custom_tokenizer, + use_harmony) self._init_load_balancing(servers, use_tokens) # TODO: use max_num_tokens? per server? self._max_batch_size = max_batch_size @@ -1162,13 +1125,14 @@ def __init__(self, match_threshold: float = 0.75, tokens_per_block: int = 128, use_token_ids: bool = False, + use_harmony: Optional[bool] = None, hash_skip_count: int = 0, max_sessions: int = 100000, - **kwargs): + **kwargs) -> None: super().__init__(server_role, servers, metadata_server_cfg, metadata_server, **kwargs) self._init_load_balancing(servers) - self._init_block_hashing(tokens_per_block) + self._init_block_hashing(tokens_per_block, use_harmony=use_harmony) self._match_threshold = match_threshold self._use_token_ids = use_token_ids diff --git a/tests/unittest/disaggregated/test_router.py b/tests/unittest/disaggregated/test_router.py index 7f7722f45be3..f847a4de930e 100644 --- a/tests/unittest/disaggregated/test_router.py +++ b/tests/unittest/disaggregated/test_router.py @@ -6,6 +6,8 @@ import aiohttp import pytest +from tensorrt_llm.bindings.internal.batch_manager import (BlockKey, + BlockKeyHasher) from tensorrt_llm.llmapi.disagg_utils import RouterConfig from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, ChatCompletionToolsParam, @@ -482,6 +484,27 @@ async def test_kv_cache_aware_router_cache_salt_partitions_hashes() -> None: await router.finish_request(tenant_b) +def test_block_key_hasher_matches_bound_cpp_hasher_with_cache_salt() -> None: + token_ids = list(range(65)) + salt_id = 1234567890123456789 + first_block = token_ids[:32] + second_block = token_ids[32:64] + + assert block_key_hasher(first_block) == BlockKeyHasher.hash( + BlockKey(first_block)) + + salted_first_hash = BlockKeyHasher.hash_token_ids(first_block, 0, salt_id) + assert block_key_hasher(first_block, + cache_salt_id=salt_id) == salted_first_hash + + salted_second_hash = BlockKeyHasher.hash_token_ids( + second_block, salted_first_hash, salt_id) + assert block_key_hasher(second_block, salted_first_hash, + salt_id) == salted_second_hash + assert salted_second_hash == BlockKeyHasher.hash_token_ids( + second_block, salted_first_hash, None) + + @pytest.mark.asyncio @pytest.mark.parametrize("api_type", ["completion", "chat"]) async def test_kv_cache_aware_router_multi_turn_conversation( @@ -639,7 +662,7 @@ def make_request(token_ids: list[int]): f"but got {server_b1}") -def test_create_router(servers): +def test_create_router(servers: list[str]) -> None: default_router = create_router(None, servers) assert isinstance(default_router, RoundRobinRouter) @@ -659,8 +682,10 @@ def test_create_router(servers): assert tokens_load_balancing_router._use_tokens router_config.type = "kv_cache_aware" + router_config.args["use_harmony"] = True kv_cache_aware_router = create_router(router_config, servers) assert isinstance(kv_cache_aware_router, KvCacheAwareRouter) + assert kv_cache_aware_router._use_harmony is True with pytest.raises(ValueError): create_router(RouterConfig(type="unsupported_router"), servers) @@ -1023,15 +1048,15 @@ def _mock_tokenizer(token_ids=None): @pytest.mark.parametrize("router_class", [KvCacheAwareRouter, ConversationRouter]) -def test_tokenize_forwards_tools_and_chat_template_kwargs(router_class): +def test_tokenize_forwards_tools_and_chat_template_kwargs( + router_class) -> None: """Regression test for PR #13232. ``BlockHashMixin._tokenize`` must forward the request's ``tools`` (as a list of dicts) and ``chat_template_kwargs`` to ``tokenizer.apply_chat_template``. Without this, custom tokenizers that render tool schemas into the prompt (e.g. DeepSeek-V3.2) produce - truncated token ids, breaking cache-aware routing decisions and the - ``prompt_token_ids`` handed to the worker downstream. + truncated token ids, breaking cache-aware routing decisions. """ router = router_class(server_role=None, servers=["server1"], @@ -1053,6 +1078,7 @@ def test_tokenize_forwards_tools_and_chat_template_kwargs(router_class): router._tokenize(req) tok.apply_chat_template.assert_called_once() + assert req.prompt_token_ids is None kwargs = tok.apply_chat_template.call_args.kwargs # tools must be forwarded as a list of dicts (model_dump), not the # Pydantic objects themselves. @@ -1068,7 +1094,7 @@ def test_tokenize_forwards_tools_and_chat_template_kwargs(router_class): @pytest.mark.parametrize("router_class", [KvCacheAwareRouter, ConversationRouter]) -def test_tokenize_without_tools_passes_none(router_class): +def test_tokenize_without_tools_passes_none(router_class) -> None: """Bare chat request: no tools, no chat_template_kwargs. ``apply_chat_template`` still runs but receives ``tools=None`` and no @@ -1090,6 +1116,7 @@ def test_tokenize_without_tools_passes_none(router_class): router._tokenize(req) tok.apply_chat_template.assert_called_once() + assert req.prompt_token_ids is None kwargs = tok.apply_chat_template.call_args.kwargs assert kwargs["tools"] is None assert "thinking" not in kwargs @@ -1164,7 +1191,7 @@ def test_gpt_oss_tokenize_uses_harmony_tokens_for_router_hashes() -> None: assert token_lists == [harmony_tokens] assert block_lengths == router._compute_block_lengths([harmony_tokens]) assert router._request_cache_salt_id(req) is None - assert req.prompt_token_ids == harmony_tokens + assert req.prompt_token_ids is None call_args = harmony.openai_to_harmony_tokens.call_args assert call_args.args[0] == req.messages @@ -1178,6 +1205,77 @@ def test_gpt_oss_tokenize_uses_harmony_tokens_for_router_hashes() -> None: assert block_hashes == expected_hashes +def test_use_harmony_flag_for_alias_model() -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1"], + use_tokens=False, + max_batch_size=32, + tokens_per_block=32, + use_harmony=True) + + tokenizer = _mock_tokenizer(token_ids=[900, 901, 902, 903]) + harmony_tokens = [100, 101, 102, 103, 104] + harmony = mock.MagicMock() + harmony.openai_to_harmony_tokens.return_value = harmony_tokens + + with mock.patch("tensorrt_llm.serve.harmony_adapter.get_harmony_adapter", + return_value=harmony), mock.patch( + "tensorrt_llm.serve.harmony_adapter." + "maybe_transform_reasoning_effort", + return_value="medium"), mock.patch.object( + router, "_get_tokenizer", return_value=tokenizer): + req = ChatCompletionRequest(model="served-gptoss-alias", + messages=[{ + "role": "user", + "content": "hello" + }]) + token_lists, _, _ = router._tokenize_and_compute_block_hashes(req) + + tokenizer.apply_chat_template.assert_not_called() + harmony.openai_to_harmony_tokens.assert_called_once() + assert token_lists == [harmony_tokens] + assert req.prompt_token_ids is None + + +def test_use_harmony_false_prefers_tokenizer_for_gpt_oss_name() -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1"], + use_tokens=False, + max_batch_size=32, + tokens_per_block=32, + use_harmony=False) + + tokenizer = _mock_tokenizer(token_ids=[900, 901, 902, 903]) + with mock.patch.object(router, "_get_tokenizer", return_value=tokenizer): + req = ChatCompletionRequest(model="openai/gpt-oss-20b", + messages=[{ + "role": "user", + "content": "hello" + }]) + token_lists = router._tokenize(req) + + tokenizer.apply_chat_template.assert_called_once() + assert token_lists == [[900, 901, 902, 903]] + assert req.prompt_token_ids is None + + +def test_tokenize_does_not_rewrite_completion_prompt() -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1"], + use_tokens=False, + max_batch_size=32, + tokens_per_block=32) + + tokenizer = mock.MagicMock() + tokenizer.return_value = {"input_ids": [10, 20, 30]} + with mock.patch.object(router, "_get_tokenizer", return_value=tokenizer): + req = CompletionRequest(model="TinyLlama", prompt="hello") + token_lists = router._tokenize(req) + + assert token_lists == [[10, 20, 30]] + assert req.prompt == "hello" + + def test_tokenize_skipped_when_prompt_token_ids_already_set(): """Skip tokenization when ``prompt_token_ids`` is already populated. From a98270c04ac0e3d95fc4a15219e9d0f54d019a19 Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Fri, 1 May 2026 09:56:20 -0700 Subject: [PATCH 3/5] [None][fix] Guard KV-aware routing edge cases Add stable round-robin tie breaking for KV-aware router candidate sets and ignore weak prefix-cache matches so low-affinity requests fall back to load balancing. Make KV-cache scheduling tolerate stale or already removed sequences instead of throwing during simulated pause accounting. Signed-off-by: Simeng Liu --- .../batch_manager/kvCacheManager.cpp | 34 +++++++++++- .../batch_manager/kvCacheManagerTest.cpp | 44 ++++++++++++++++ tensorrt_llm/serve/router.py | 39 +++++++++++--- tests/unittest/disaggregated/test_router.py | 52 +++++++++++++++++++ 4 files changed, 160 insertions(+), 9 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index f402a2c0d4ae..2355e7764af9 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -3043,7 +3043,17 @@ void BlockManager::schedulingReleaseBlocks(RequestIdType requestId) void WindowBlockManager::schedulingReleaseBlocks(RequestIdType requestId) { - for (auto& block : mAllocatedBlocksPerSeq.at(requestId)) + auto const seqIt = mAllocatedBlocksPerSeq.find(requestId); + if (seqIt == mAllocatedBlocksPerSeq.end()) + { + TLLM_LOG_WARNING( + "%s::schedulingReleaseBlocks skipped request %lu because no allocated blocks are tracked. " + "The request was likely already removed before the max-utilization scheduler simulated a pause.", + mLogPrefix.c_str(), requestId); + return; + } + + for (auto& block : seqIt->second) { // Decrease ref count block->decSchedulingRefCount(); @@ -3288,7 +3298,27 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep(LlmRequest const& req, bool tw return 0; } - auto const numCurrTokens = getSequence(req.mRequestId).getNumTokens(); + auto const maybeNumCurrTokens = [this, requestId = req.mRequestId]() -> std::optional + { + auto lck = std::scoped_lock(mSequencesMtx); + auto const seqIt = mSequences.find(requestId); + if (seqIt == mSequences.end()) + { + return std::nullopt; + } + return seqIt->second.getNumTokens(); + }(); + if (!maybeNumCurrTokens.has_value()) + { + auto const unuschedulableBlocks = mBlockManager.getWindowSizeMetadata(windowSize).maxNumBlocks + 1; + TLLM_LOG_WARNING( + "[kv cache manager] getNeededBlocksOneStep: request %lu is generation-in-progress but no KV " + "sequence exists. Returning %d required blocks so the scheduler pauses or drops the stale request.", + req.mRequestId, unuschedulableBlocks); + return unuschedulableBlocks; + } + + auto const numCurrTokens = maybeNumCurrTokens.value(); auto const generatedTokens = numCurrTokens - req.getPromptLen(); auto const maxTokensToAddToKVCache = req.mMaxNewTokens - generatedTokens; auto const tokensPerStep = req.getNumDraftTokens() + 1; diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index d1cdeb3c16bb..9da9b7cb923d 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -3387,6 +3387,50 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), 0); } +TEST_F(KVCacheManagerTest, SchedulingRemoveSequenceIgnoresAlreadyRemovedSequence) +{ + auto constexpr numLayers = 1; + auto constexpr numKvHeads = 1; + auto constexpr sizePerHead = 8; + auto constexpr tokensPerBlock = 4; + auto constexpr totalNumBlocks = 16; + auto constexpr blocksInSecondaryPool = 0; + auto constexpr maxNumSequences = 2; + auto constexpr maxBeamWidth = 1; + auto constexpr sinkTokenLength = 0; + auto constexpr maxSequenceLength = 16; + auto constexpr maxAttentionWindow = maxSequenceLength; + auto constexpr requestId = 17; + auto constexpr inputLength = 8; + auto constexpr maxNewTokens = 0; + auto const stream = std::make_shared(); + auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {totalNumBlocks, blocksInSecondaryPool}}}; + tr::SamplingConfig const samplingConfig{maxBeamWidth}; + bool constexpr isStreaming{false}; + + auto inputTokens = std::make_shared(); + for (SizeType32 token = 0; token < inputLength; ++token) + { + inputTokens->push_back(token); + } + + KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, + maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, false); + kvCacheManager.allocatePools(false); + + auto llmRequest = std::make_shared( + LlmRequest::RequestIdType{requestId}, maxNewTokens, inputTokens, samplingConfig, isStreaming); + EXPECT_NO_THROW( + kvCacheManager.addSequenceBatch({{{requestId, inputLength, maxBeamWidth}}}, {std::ref(*llmRequest)})); + + kvCacheManager.startScheduling(); + EXPECT_NO_THROW(kvCacheManager.schedulingRemoveSequence(requestId)); + EXPECT_NO_THROW(static_cast(kvCacheManager.removeSequence(requestId))); + EXPECT_NO_THROW(kvCacheManager.schedulingRemoveSequence(requestId)); + EXPECT_NO_THROW(kvCacheManager.schedulingRemoveSequence(requestId + 1)); +} + TEST_P(KVCacheManagerTest, KVCacheManagerRewindTokensTest) { using DType = half; diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index c22d82709e4f..22fe0cc5023f 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -267,6 +267,23 @@ def _get_server_load(self, server: str) -> int: return state._num_active_tokens if self._use_tokens \ else state._num_active_requests + def _select_round_robin_tied(self, candidates: list[str]) -> str: + """Select the next candidate using a stable server-ring pointer.""" + if not candidates: + raise ValueError("No tied candidates available") + + ordered_servers = list(self._server_state) + candidate_set = set(candidates) + start = self._rr_counter % len(ordered_servers) + for offset in range(len(ordered_servers)): + index = (start + offset) % len(ordered_servers) + server = ordered_servers[index] + if server in candidate_set: + self._rr_counter = index + 1 + return server + + raise ValueError("No tied candidate is present in server state") + def _validate_servers_available(self): if not self._servers: if self._metadata_server: @@ -298,10 +315,9 @@ def _select_least_loaded(self, loads = {s: self._get_server_load(s) for s in candidates} min_load = min(loads.values()) tied = [s for s in candidates if loads[s] == min_load] - server = tied[self._rr_counter % len(tied)] - self._rr_counter += 1 + server = self._select_round_robin_tied(tied) logger.debug(f"LoadBalancingMixin: selected={server}, " - f"loads={loads}, tied={tied}, rr={self._rr_counter - 1}") + f"loads={loads}, tied={tied}, rr={self._rr_counter}") return server @@ -923,6 +939,7 @@ def __init__(self, use_tokens: bool = False, max_batch_size: int = 64, tokens_per_block: int = 32, + match_rate_threshold: float = 0.1, custom_tokenizer: Optional[str] = None, use_harmony: Optional[bool] = None, **kwargs) -> None: @@ -933,6 +950,7 @@ def __init__(self, self._init_load_balancing(servers, use_tokens) # TODO: use max_num_tokens? per server? self._max_batch_size = max_batch_size + self._match_rate_threshold = match_rate_threshold def _create_server_state(self, server: str) -> KvCacheAwareServerState: return KvCacheAwareServerState(server, self._use_tokens, @@ -982,16 +1000,20 @@ async def get_next_server( matched_tokens = await self._server_state[server].matched_tokens( block_hashes, block_lengths) matches.append(matched_tokens) - match_ratio = matched_tokens / hashable_tokens if hashable_tokens else 0.0 + max_match_rate = (max(matches) / + hashable_tokens) if hashable_tokens else 0.0 + cache_affinity_active = max_match_rate > self._match_rate_threshold + for matched_tokens, server in zip(matches, servers): + effective_match_tokens = matched_tokens if cache_affinity_active else 0 + match_ratio = effective_match_tokens / hashable_tokens if hashable_tokens else 0.0 score = (match_ratio - workloads_by_server[server] / self._max_batch_size) scores.append(score) max_score = max(scores) tied = [i for i, s in enumerate(scores) if s == max_score] - winner = tied[self._rr_counter % len(tied)] - self._rr_counter += 1 - server = servers[winner] async with self._lock: + tied_servers = [servers[i] for i in tied] + server = self._select_round_robin_tied(tied_servers) await self._register_request(server, request) return server, { "block_hashes": block_hashes, # list[list[int]] @@ -1003,6 +1025,9 @@ async def get_next_server( "workloads": workloads, # list[int] "active_tokens": active_tokens, # list[int] "candidate_servers": servers, # list[str] + "match_rate_threshold": self._match_rate_threshold, + "cache_affinity_active": cache_affinity_active, + "max_match_rate": max_match_rate, "server_info": self._server_info.get(server, {}), } diff --git a/tests/unittest/disaggregated/test_router.py b/tests/unittest/disaggregated/test_router.py index f847a4de930e..457ac3677d1c 100644 --- a/tests/unittest/disaggregated/test_router.py +++ b/tests/unittest/disaggregated/test_router.py @@ -411,6 +411,58 @@ async def test_kv_cache_aware_router_short_prompt_without_blocks_uses_load( await router.finish_request(request) +@pytest.mark.asyncio +async def test_kv_cache_aware_router_tie_breaks_changing_candidate_sets( +) -> None: + router = KvCacheAwareRouter(server_role=None, + servers=[ + "server1", "server2", "server3", "server4" + ], + use_tokens=False, + max_batch_size=10, + tokens_per_block=32) + + requests = [ + CompletionRequest(model="TinyLlama", prompt=[[index] * 65]) + for index in range(2) + ] + + server0, _ = await router.get_next_server(requests[0]) + server1, _ = await router.get_next_server(requests[1]) + try: + assert server0 == "server1" + assert server1 == "server2" + finally: + for request in requests: + await router.finish_request(request) + + +@pytest.mark.asyncio +async def test_kv_cache_aware_router_gates_low_match_rate_to_load_balance( +) -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1", "server2"], + use_tokens=False, + max_batch_size=10, + tokens_per_block=32, + match_rate_threshold=0.5) + token_ids = list(range(101)) + block_hashes = router._compute_block_hashes([token_ids]) + router._server_state["server1"].add_blocks(block_hashes[0][:1]) + router._server_state["server1"]._num_active_requests = 1 + + request = CompletionRequest(model="TinyLlama", prompt=[token_ids]) + server, info = await router.get_next_server(request) + try: + assert info["matches"] == [32, 0] + assert info["max_match_rate"] == pytest.approx(0.32) + assert info["cache_affinity_active"] is False + assert info["scores"] == pytest.approx([-0.1, 0.0]) + assert server == "server2" + finally: + await router.finish_request(request) + + @pytest.mark.asyncio async def test_kv_cache_aware_router_counts_partial_block_tokens() -> None: router = KvCacheAwareRouter(server_role=None, From 8994ed74668e193223a836998e243dd8e051fd5b Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Fri, 1 May 2026 10:00:51 -0700 Subject: [PATCH 4/5] [None][fix] Handle disagg transfer and Harmony parse failures Preserve visible streaming output when Harmony parsing fails by decoding a fallback batch, stripping only Harmony control framing, and resetting parser state for later chunks. Treat failed or timed-out context KV transfers as transfer errors instead of clean context completions, and avoid re-storing context blocks that were already stored for asynchronous disaggregated transfer. Signed-off-by: Simeng Liu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 51 +++++-- .../_torch/pyexecutor/resource_manager.py | 6 +- tensorrt_llm/serve/harmony_adapter.py | 127 +++++++++++++++--- .../llmapi/apps/test_harmony_parsing.py | 86 ++++++++++++ 4 files changed, 243 insertions(+), 27 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 48bd51fa06c1..50b79700aab3 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -3084,7 +3084,7 @@ def _check_disagg_gen_transfer_status(self): return @nvtx_range("_check_kv_transfer_timeout") - def _check_kv_transfer_timeout(self): + def _check_kv_transfer_timeout(self) -> None: if not self.kv_cache_transceiver: return timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms @@ -3391,17 +3391,28 @@ def _check_cache_transfer_errors(self, error_msg_prefix: str): f"Error in kv cache transfer for {error_msg_prefix}", requests=error_requests) + def _complete_context_transfer_with_error(self, request: LlmRequest, + error_msg: str) -> None: + request.py_kv_transfer_start_time = None + request.state = LlmRequestState.DISAGG_TRANS_ERROR + self.async_transfer_manager.end_transfer(request) + if request in self.active_requests: + self._handle_errors(error_msg=error_msg, requests=[request]) + @nvtx_range("_check_disagg_ctx_cache_transfer_status") - def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0): + def _check_disagg_ctx_cache_transfer_status(self, + atLeastNum: int = 0) -> None: finished_requests, error_requests = self.kv_cache_transceiver.check_context_transfer_status( atLeastNum) - completed_req_ids = set(finished_requests + error_requests) + finished_req_ids = set(finished_requests) + error_req_ids = set(error_requests) + completed_req_ids = finished_req_ids | error_req_ids requests_in_transfer = self.async_transfer_manager.requests_in_transfer( ) - for request_id in completed_req_ids: + for request_id in finished_req_ids: if request_id not in requests_in_transfer: logger.warning( @@ -3410,7 +3421,24 @@ def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0): request = requests_in_transfer[request_id] - self._end_transfer_and_maybe_terminate(request) + if request.py_kv_transfer_timed_out: + self._complete_context_transfer_with_error( + request, + f"Context KV cache transfer completed after timeout for request {request_id}" + ) + else: + self._end_transfer_and_maybe_terminate(request) + + for request_id in error_req_ids: + if request_id not in requests_in_transfer: + logger.warning( + f"Request {request_id} not found in transfer manager") + continue + + request = requests_in_transfer[request_id] + self._complete_context_transfer_with_error( + request, + f"Error in kv cache transfer for context request {request_id}") # The set of requests in transfer may have changed since we terminated some requests. requests_in_transfer = self.async_transfer_manager.requests_in_transfer( @@ -3420,13 +3448,14 @@ def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0): request = requests_in_transfer[request_id] if request.py_kv_transfer_timed_out and request_id not in completed_req_ids: is_cancelled = self.kv_cache_transceiver.cancel_request(request) - # If cancel is successful, mark as complete so it can be cleaned up - # Otherwise, try at next iteration + # If cancel is successful, mark as transfer error so metrics and + # cleanup do not report this as a clean context completion. + # Otherwise, try at next iteration. if is_cancelled: - request.py_kv_transfer_start_time = None - request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE - - self._end_transfer_and_maybe_terminate(request) + self._complete_context_transfer_with_error( + request, + f"Context KV cache transfer timed out for request {request_id}" + ) self._check_cache_transfer_errors("context requests") diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 5a71a122fc66..939827ab5781 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -838,7 +838,7 @@ def add_dummy_requests( def update_resources(self, scheduled_batch: ScheduledRequests, attn_metadata: "AttentionMetadata" = None, - kv_cache_dtype_byte_size: float = None): + kv_cache_dtype_byte_size: float = None) -> None: if not self.is_draft: _update_kv_cache_draft_token_location(self, scheduled_batch, attn_metadata, @@ -862,6 +862,10 @@ def update_resources(self, # storing, so that SWA windows are safe to store — blocks won't go out-of-window # and be evicted while the context is still in-flight. for request in scheduled_batch.context_requests: + if (request.is_disagg_context_transmission_state + and self.enable_partial_reuse and not self.is_vswa + and self.mapping.pp_size == 1): + continue self.impl.store_context_blocks(request) def free_resources(self, request: LlmRequest, pin_on_release: bool = False): diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index 6e9949802c45..cfd971787df0 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -519,6 +519,93 @@ def _safe_decode_utf8(self, return self.encoding.decode(tokens) + def _strip_harmony_stream_control_text(self, text: str) -> str: + """Strip Harmony framing tokens from best-effort streaming fallback text.""" + if not text: + return "" + if not any(marker in text for marker in self._harmony_special_tokens): + return text + + cleaned = text + for marker in self._harmony_special_tokens: + cleaned = cleaned.replace(marker, "\n") + + # Remove Harmony header-only lines while preserving real content lines. + cleaned = re.sub( + r"(?m)^\s*(assistant|analysis|commentary|final|system|developer|user)(?:\s+to=[^\n]*)?\s*$\n?", + "", + cleaned) + cleaned = re.sub(r"(?m)^\s*(json|code)\s*$\n?", "", cleaned) + cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) + cleaned = cleaned.strip("\n") + if not cleaned.strip(): + return "" + return cleaned + + def _decode_streaming_fallback_text(self, request_id: str, + tokens: list[int]) -> str: + """Decode a failed streaming batch as visible text when Harmony parsing fails.""" + try: + decoded_text = self._safe_decode_utf8(tokens) + except (HarmonyError, UnicodeDecodeError, ValueError) as decode_error: + logger.error( + f"Streaming: Failed to decode fallback text for request {request_id}: " + f"{type(decode_error).__name__}: {decode_error}") + logger.debug( + f"Problematic fallback streaming tokens for request {request_id}: {tokens}" + ) + return "" + + return self._strip_harmony_stream_control_text(decoded_text) + + def _streaming_error_state_summary( + self, stream_state: HarmonyStreamState | None) -> dict[str, Any]: + """Return compact parser state for streaming parse-failure logs.""" + if stream_state is None: + return {} + + debug_info = stream_state.get_debug_info() + return { + "tokens_processed": debug_info.get("tokens_processed"), + "current_channel": debug_info.get("current_channel"), + "current_recipient": debug_info.get("current_recipient"), + "current_channel_state": debug_info.get("current_channel_state"), + "generated_channels": debug_info.get("generated_channels"), + "channel_started": debug_info.get("channel_started"), + "has_preamble_content": debug_info.get("has_preamble_content"), + "should_filter_tools": debug_info.get("should_filter_tools"), + } + + def _handle_streaming_parse_error( + self, + request_id: str, + tokens: list[int], + available_tools: list[dict[str, Any]] | None, + tool_choice: str | None, + parse_error: Exception, + stream_state: HarmonyStreamState | None) -> str: + """Log, decode, and reset parser state after a streaming Harmony parse error.""" + token_sample = { + "count": len(tokens), + "first": tokens[:8], + "last": tokens[-8:], + } + fallback_text = self._decode_streaming_fallback_text(request_id, tokens) + logger.error( + f"Streaming: Failed to process token batch for request {request_id}: " + f"{type(parse_error).__name__}: {parse_error}; " + f"fallback_chars={len(fallback_text)}; " + f"token_sample={token_sample}; " + f"state={self._streaming_error_state_summary(stream_state)}") + logger.debug(f"Problematic streaming tokens for request {request_id}: {tokens}") + + # A StreamableParser can remain poisoned after a parse exception. Reset it so + # later batches still have a chance to parse normally. + self.cleanup_stream_state(request_id) + self.create_stream_state(request_id, available_tools, tool_choice) + + return fallback_text + def harmony_system_message(self, reasoning_effort: ReasoningEffort | None = None, system_instructions: list[str] = []) -> Message: @@ -1376,14 +1463,17 @@ def stateful_stream_harmony_tokens_to_openai_deltas( deltas = stream_state.process_token_batch(tokens) # logger.info(">> GENERATED DELTAS: %s", deltas) return deltas - except (HarmonyError, UnicodeDecodeError, ValueError): - logger.error( - f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}" - ) - logger.debug(f"Problematic streaming tokens: {tokens}") - - # Return empty deltas to continue processing - return [] + except (HarmonyError, UnicodeDecodeError, ValueError) as parse_error: + fallback_text = self._handle_streaming_parse_error( + request_id=request_id, + tokens=tokens, + available_tools=available_tools, + tool_choice=tool_choice, + parse_error=parse_error, + stream_state=stream_state) + if not fallback_text: + return [] + return [{"content": fallback_text}] def stateful_stream_harmony_tokens_to_openai_messages( self, @@ -1413,13 +1503,20 @@ def stateful_stream_harmony_tokens_to_openai_messages( try: messages = stream_state.process_token_batch_to_messages(tokens) return messages - except (HarmonyError, UnicodeDecodeError, ValueError): - logger.error( - f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}", - ) - logger.debug(f"Problematic streaming tokens: {tokens}") - - return [] + except (HarmonyError, UnicodeDecodeError, ValueError) as parse_error: + fallback_text = self._handle_streaming_parse_error( + request_id=request_id, + tokens=tokens, + available_tools=available_tools, + tool_choice=tool_choice, + parse_error=parse_error, + stream_state=stream_state) + if not fallback_text: + return [] + return [ + Message.from_role_and_content( + Role.ASSISTANT, fallback_text).with_channel("final") + ] def create_openai_streaming_response( self, diff --git a/tests/unittest/llmapi/apps/test_harmony_parsing.py b/tests/unittest/llmapi/apps/test_harmony_parsing.py index acf51f73c1d6..f19df40c6896 100644 --- a/tests/unittest/llmapi/apps/test_harmony_parsing.py +++ b/tests/unittest/llmapi/apps/test_harmony_parsing.py @@ -617,6 +617,92 @@ def test_create_response_message_without_reasoning(self): assert "reasoning_content" not in result +class TestStreamingParseFallback: + """Verify streaming parse failures preserve visible output text.""" + + def test_plain_fallback_text_preserves_spacing(self) -> None: + adapter = HarmonyAdapter(harmony_input=False, harmony_output=False) + + assert adapter._strip_harmony_stream_control_text(" next token") == " next token" + + def test_parse_failure_falls_back_to_decoded_content_and_resets_state( + self) -> None: + adapter = HarmonyAdapter(harmony_input=False, harmony_output=False) + request_id = "test-streaming-parse-fallback" + stream_state = adapter.create_stream_state( + request_id=request_id, available_tools=None, tool_choice=None) + stream_state.process_token_batch = Mock( + side_effect=ValueError("bad harmony state")) + + with patch.object( + adapter, + "_safe_decode_utf8", + return_value= + "<|start|>assistant<|channel|>final<|message|>Hello<|return|>" + ): + deltas = adapter.stateful_stream_harmony_tokens_to_openai_deltas( + request_id=request_id, + tokens=[1, 2, 3], + available_tools=None, + tool_choice=None, + ) + + assert deltas == [{"content": "Hello"}] + assert adapter.get_stream_state(request_id) is not stream_state + + def test_parse_failure_drops_control_only_fallback_and_resets_state( + self) -> None: + adapter = HarmonyAdapter(harmony_input=False, harmony_output=False) + request_id = "test-streaming-control-only-fallback" + stream_state = adapter.create_stream_state( + request_id=request_id, available_tools=None, tool_choice=None) + stream_state.process_token_batch = Mock( + side_effect=ValueError("bad harmony state")) + + with patch.object( + adapter, + "_safe_decode_utf8", + return_value= + "<|start|>assistant<|channel|>final<|message|><|return|>" + ): + deltas = adapter.stateful_stream_harmony_tokens_to_openai_deltas( + request_id=request_id, + tokens=[1, 2, 3], + available_tools=None, + tool_choice=None, + ) + + assert deltas == [] + assert adapter.get_stream_state(request_id) is not stream_state + + def test_streaming_response_emits_fallback_content(self) -> None: + adapter = HarmonyAdapter(harmony_input=False, harmony_output=False) + request_id = "test-streaming-response-fallback" + stream_state = adapter.create_stream_state( + request_id=request_id, available_tools=None, tool_choice=None) + stream_state.process_token_batch = Mock( + side_effect=ValueError("bad harmony state")) + + with patch.object( + adapter, + "_safe_decode_utf8", + return_value= + "<|start|>assistant<|channel|>final<|message|>Hello<|return|>" + ): + responses, should_stop = adapter.create_openai_streaming_response( + request_id=request_id, + tokens=[1, 2, 3], + available_tools=None, + model_name="test-model", + tool_choice=None, + ) + + assert not should_stop + assert len(responses) == 1 + data = json.loads(responses[0].replace("data: ", "").strip()) + assert data["choices"][0]["delta"]["content"] == "Hello" + + class TestRemainingTokensOnDone: """Verify handle_streaming_response processes leftover tokens on done=True.""" From 6dfc75481d2ad022bc72c3c41920b0a878aa864f Mon Sep 17 00:00:00 2001 From: Simeng Liu Date: Fri, 1 May 2026 12:59:14 -0700 Subject: [PATCH 5/5] [None][fix] Align Harmony empty tools routing Match the router Harmony tokenization path to chat_harmony when a request explicitly provides an empty tools list, and include pre-commit formatting updates for the changed branch files. Signed-off-by: Simeng Liu --- .../batch_manager/kvCacheManagerTest.cpp | 4 +- tensorrt_llm/serve/harmony_adapter.py | 13 ++-- tensorrt_llm/serve/router.py | 3 +- tests/unittest/disaggregated/test_router.py | 59 +++++++++++++++---- .../llmapi/apps/test_harmony_parsing.py | 45 +++++++------- 5 files changed, 77 insertions(+), 47 deletions(-) diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 9da9b7cb923d..dd4360cea888 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -3414,8 +3414,8 @@ TEST_F(KVCacheManagerTest, SchedulingRemoveSequenceIgnoresAlreadyRemovedSequence inputTokens->push_back(token); } - KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, - maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, + KVCacheManager kvCacheManager(numLayers, numKvHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, + maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, false); kvCacheManager.allocatePools(false); diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index cfd971787df0..f8472d3b5994 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -533,8 +533,7 @@ def _strip_harmony_stream_control_text(self, text: str) -> str: # Remove Harmony header-only lines while preserving real content lines. cleaned = re.sub( r"(?m)^\s*(assistant|analysis|commentary|final|system|developer|user)(?:\s+to=[^\n]*)?\s*$\n?", - "", - cleaned) + "", cleaned) cleaned = re.sub(r"(?m)^\s*(json|code)\s*$\n?", "", cleaned) cleaned = re.sub(r"\n{3,}", "\n\n", cleaned) cleaned = cleaned.strip("\n") @@ -577,12 +576,9 @@ def _streaming_error_state_summary( } def _handle_streaming_parse_error( - self, - request_id: str, - tokens: list[int], + self, request_id: str, tokens: list[int], available_tools: list[dict[str, Any]] | None, - tool_choice: str | None, - parse_error: Exception, + tool_choice: str | None, parse_error: Exception, stream_state: HarmonyStreamState | None) -> str: """Log, decode, and reset parser state after a streaming Harmony parse error.""" token_sample = { @@ -597,7 +593,8 @@ def _handle_streaming_parse_error( f"fallback_chars={len(fallback_text)}; " f"token_sample={token_sample}; " f"state={self._streaming_error_state_summary(stream_state)}") - logger.debug(f"Problematic streaming tokens for request {request_id}: {tokens}") + logger.debug( + f"Problematic streaming tokens for request {request_id}: {tokens}") # A StreamableParser can remain poisoned after a parse exception. Reset it so # later batches still have a chance to parse normally. diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 22fe0cc5023f..6b48ffc8957d 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -810,9 +810,10 @@ def _tokenize_harmony_chat( self, request: ChatCompletionRequest) -> list[list[int]]: from tensorrt_llm.serve import harmony_adapter + tools = None if not request.tools else self._tool_dicts(request) result = harmony_adapter.get_harmony_adapter().openai_to_harmony_tokens( request.messages, - self._tool_dicts(request), + tools, reasoning_effort=harmony_adapter.maybe_transform_reasoning_effort( request.reasoning_effort), tool_choice=getattr(request, "tool_choice", None), diff --git a/tests/unittest/disaggregated/test_router.py b/tests/unittest/disaggregated/test_router.py index 457ac3677d1c..d2472be62bf6 100644 --- a/tests/unittest/disaggregated/test_router.py +++ b/tests/unittest/disaggregated/test_router.py @@ -414,13 +414,12 @@ async def test_kv_cache_aware_router_short_prompt_without_blocks_uses_load( @pytest.mark.asyncio async def test_kv_cache_aware_router_tie_breaks_changing_candidate_sets( ) -> None: - router = KvCacheAwareRouter(server_role=None, - servers=[ - "server1", "server2", "server3", "server4" - ], - use_tokens=False, - max_batch_size=10, - tokens_per_block=32) + router = KvCacheAwareRouter( + server_role=None, + servers=["server1", "server2", "server3", "server4"], + use_tokens=False, + max_batch_size=10, + tokens_per_block=32) requests = [ CompletionRequest(model="TinyLlama", prompt=[[index] * 65]) @@ -549,8 +548,9 @@ def test_block_key_hasher_matches_bound_cpp_hasher_with_cache_salt() -> None: assert block_key_hasher(first_block, cache_salt_id=salt_id) == salted_first_hash - salted_second_hash = BlockKeyHasher.hash_token_ids( - second_block, salted_first_hash, salt_id) + salted_second_hash = BlockKeyHasher.hash_token_ids(second_block, + salted_first_hash, + salt_id) assert block_key_hasher(second_block, salted_first_hash, salt_id) == salted_second_hash assert salted_second_hash == BlockKeyHasher.hash_token_ids( @@ -1100,8 +1100,7 @@ def _mock_tokenizer(token_ids=None): @pytest.mark.parametrize("router_class", [KvCacheAwareRouter, ConversationRouter]) -def test_tokenize_forwards_tools_and_chat_template_kwargs( - router_class) -> None: +def test_tokenize_forwards_tools_and_chat_template_kwargs(router_class) -> None: """Regression test for PR #13232. ``BlockHashMixin._tokenize`` must forward the request's ``tools`` (as a @@ -1257,6 +1256,44 @@ def test_gpt_oss_tokenize_uses_harmony_tokens_for_router_hashes() -> None: assert block_hashes == expected_hashes +def test_gpt_oss_harmony_empty_tools_matches_chat_harmony_path() -> None: + router = KvCacheAwareRouter(server_role=None, + servers=["server1"], + use_tokens=False, + max_batch_size=32, + tokens_per_block=32) + + harmony_tokens = [100, 101, 102, 103, 104] + harmony = mock.MagicMock() + harmony.openai_to_harmony_tokens.return_value = harmony_tokens + + with mock.patch("tensorrt_llm.serve.harmony_adapter.get_harmony_adapter", + return_value=harmony), mock.patch( + "tensorrt_llm.serve.harmony_adapter." + "maybe_transform_reasoning_effort", + return_value="medium"): + req = ChatCompletionRequest( + model="openai/gpt-oss-20b", + messages=[{ + "role": "user", + "content": "what's the weather in Paris?" + }], + tools=[], + tool_choice="auto", + reasoning_effort="medium", + ) + token_lists = router._tokenize(req) + + harmony.openai_to_harmony_tokens.assert_called_once() + assert token_lists == [harmony_tokens] + + call_args = harmony.openai_to_harmony_tokens.call_args + assert call_args.args[0] == req.messages + assert call_args.args[1] is None + assert call_args.kwargs["reasoning_effort"] == "medium" + assert call_args.kwargs["tool_choice"] == "auto" + + def test_use_harmony_flag_for_alias_model() -> None: router = KvCacheAwareRouter(server_role=None, servers=["server1"], diff --git a/tests/unittest/llmapi/apps/test_harmony_parsing.py b/tests/unittest/llmapi/apps/test_harmony_parsing.py index f19df40c6896..f2edc84dce74 100644 --- a/tests/unittest/llmapi/apps/test_harmony_parsing.py +++ b/tests/unittest/llmapi/apps/test_harmony_parsing.py @@ -625,20 +625,18 @@ def test_plain_fallback_text_preserves_spacing(self) -> None: assert adapter._strip_harmony_stream_control_text(" next token") == " next token" - def test_parse_failure_falls_back_to_decoded_content_and_resets_state( - self) -> None: + def test_parse_failure_falls_back_to_decoded_content_and_resets_state(self) -> None: adapter = HarmonyAdapter(harmony_input=False, harmony_output=False) request_id = "test-streaming-parse-fallback" stream_state = adapter.create_stream_state( - request_id=request_id, available_tools=None, tool_choice=None) - stream_state.process_token_batch = Mock( - side_effect=ValueError("bad harmony state")) + request_id=request_id, available_tools=None, tool_choice=None + ) + stream_state.process_token_batch = Mock(side_effect=ValueError("bad harmony state")) with patch.object( - adapter, - "_safe_decode_utf8", - return_value= - "<|start|>assistant<|channel|>final<|message|>Hello<|return|>" + adapter, + "_safe_decode_utf8", + return_value="<|start|>assistant<|channel|>final<|message|>Hello<|return|>", ): deltas = adapter.stateful_stream_harmony_tokens_to_openai_deltas( request_id=request_id, @@ -650,20 +648,18 @@ def test_parse_failure_falls_back_to_decoded_content_and_resets_state( assert deltas == [{"content": "Hello"}] assert adapter.get_stream_state(request_id) is not stream_state - def test_parse_failure_drops_control_only_fallback_and_resets_state( - self) -> None: + def test_parse_failure_drops_control_only_fallback_and_resets_state(self) -> None: adapter = HarmonyAdapter(harmony_input=False, harmony_output=False) request_id = "test-streaming-control-only-fallback" stream_state = adapter.create_stream_state( - request_id=request_id, available_tools=None, tool_choice=None) - stream_state.process_token_batch = Mock( - side_effect=ValueError("bad harmony state")) + request_id=request_id, available_tools=None, tool_choice=None + ) + stream_state.process_token_batch = Mock(side_effect=ValueError("bad harmony state")) with patch.object( - adapter, - "_safe_decode_utf8", - return_value= - "<|start|>assistant<|channel|>final<|message|><|return|>" + adapter, + "_safe_decode_utf8", + return_value="<|start|>assistant<|channel|>final<|message|><|return|>", ): deltas = adapter.stateful_stream_harmony_tokens_to_openai_deltas( request_id=request_id, @@ -679,15 +675,14 @@ def test_streaming_response_emits_fallback_content(self) -> None: adapter = HarmonyAdapter(harmony_input=False, harmony_output=False) request_id = "test-streaming-response-fallback" stream_state = adapter.create_stream_state( - request_id=request_id, available_tools=None, tool_choice=None) - stream_state.process_token_batch = Mock( - side_effect=ValueError("bad harmony state")) + request_id=request_id, available_tools=None, tool_choice=None + ) + stream_state.process_token_batch = Mock(side_effect=ValueError("bad harmony state")) with patch.object( - adapter, - "_safe_decode_utf8", - return_value= - "<|start|>assistant<|channel|>final<|message|>Hello<|return|>" + adapter, + "_safe_decode_utf8", + return_value="<|start|>assistant<|channel|>final<|message|>Hello<|return|>", ): responses, should_stop = adapter.create_openai_streaming_response( request_id=request_id,