diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 67d4a418e86c..be5a65c0e45a 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -80,6 +80,12 @@ def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: U class OpenAIDisaggServer: + _CONVERSATION_ID_HEADERS = ( + "x-session-id", + "x-correlation-id", + "x-session-affinity", + "x-multi-turn-session-id", + ) def __init__(self, config: DisaggServerConfig, @@ -164,15 +170,16 @@ def register_routes(self): @staticmethod def _extract_conversation_id(req: UCompletionRequest, raw_req: Request): - """Populate conversation_id from the X-Correlation-ID header. + """Populate conversation_id from supported session headers. When not already set in the request body, copies the header value into ``disaggregated_params.conversation_id``. - aiperf sends multi-turn session IDs via the ``X-Correlation-ID`` - header (see aiperf ``base_transports.build_headers``). We mirror - that convention so the ConversationRouter can provide session - affinity without requiring clients to set the body field. + Supported headers are checked in priority order: ``X-Session-ID``, + ``X-Correlation-ID``, ``x-session-affinity``, and + ``x-multi-turn-session-id``. We mirror these conventions so the + ConversationRouter can provide session affinity without requiring + clients to set the body field. When ``disaggregated_params`` is ``None`` (standard OpenAI requests without disagg fields), a minimal instance is created @@ -180,9 +187,14 @@ def _extract_conversation_id(req: UCompletionRequest, raw_req: Request): ``disaggregated_params`` in ``_get_ctx_request`` / ``_get_gen_request`` before forwarding to workers. """ - header_conv_id = raw_req.headers.get("x-correlation-id") - if header_conv_id is None: + header_conv_id = None + for header_name in OpenAIDisaggServer._CONVERSATION_ID_HEADERS: + header_conv_id = raw_req.headers.get(header_name) + if header_conv_id is not None and header_conv_id.strip(): + break + else: return + if req.disaggregated_params is None: req.disaggregated_params = DisaggregatedParams( request_type="context_only", diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 9285526830dd..09fb9cf06f46 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -784,6 +784,7 @@ accuracy/test_llm_api_pytorch_multimodal.py::TestQwen3VL::test_auto_dtype[forced accuracy/test_llm_api_pytorch_multimodal.py::TestQwen3VL_MOE::test_auto_dtype accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype accuracy/test_llm_api_pytorch_ray.py::TestLlama3_1_8BInstruct::test_pp2_ray +unittest/disaggregated/test_openai_disagg_server.py disaggregated/test_auto_scaling.py::test_disagg_server_restart[etcd-round_robin] disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin] disaggregated/test_auto_scaling.py::test_minimal_instances[etcd-round_robin] diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index b64d418aacfd..4839e481ef57 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -45,6 +45,7 @@ l0_a10: - unittest/others/test_tracing.py - unittest/disaggregated/test_disagg_openai_client.py - unittest/disaggregated/test_disagg_utils.py + - unittest/disaggregated/test_openai_disagg_server.py - unittest/disaggregated/test_openai_disagg_service.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py diff --git a/tests/unittest/disaggregated/test_openai_disagg_server.py b/tests/unittest/disaggregated/test_openai_disagg_server.py new file mode 100644 index 000000000000..298bfcc93a3f --- /dev/null +++ b/tests/unittest/disaggregated/test_openai_disagg_server.py @@ -0,0 +1,114 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import SimpleNamespace + +from starlette.datastructures import Headers + +from tensorrt_llm.serve.openai_disagg_server import OpenAIDisaggServer +from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams + + +def _raw_request(headers: dict[str, str]): + return SimpleNamespace(headers=Headers(headers=headers)) + + +def test_extract_conversation_id_from_headers(): + cases = [ + ({"X-Session-ID": "session-id"}, "session-id"), + ({"X-Correlation-ID": "correlation-id"}, "correlation-id"), + ({"x-session-affinity": "session-affinity"}, "session-affinity"), + ({"x-multi-turn-session-id": "multi-turn-session-id"}, "multi-turn-session-id"), + ( + { + "X-Correlation-ID": "correlation-id", + "X-Session-ID": "session-id", + "x-session-affinity": "session-affinity", + "x-multi-turn-session-id": "multi-turn-session-id", + }, + "session-id", + ), + ( + { + "x-session-affinity": "session-affinity", + "x-multi-turn-session-id": "multi-turn-session-id", + }, + "session-affinity", + ), + ( + { + "X-Session-ID": "", + "X-Correlation-ID": "correlation-id", + }, + "correlation-id", + ), + ] + + for headers, expected_conversation_id in cases: + request = CompletionRequest(model="test-model", prompt="hello") + + OpenAIDisaggServer._extract_conversation_id(request, _raw_request(headers)) + + assert request.disaggregated_params is not None + assert request.disaggregated_params.conversation_id == expected_conversation_id + + +def test_extract_conversation_id_ignores_empty_headers(): + request = CompletionRequest(model="test-model", prompt="hello") + + OpenAIDisaggServer._extract_conversation_id( + request, + _raw_request( + { + "X-Session-ID": "", + "X-Correlation-ID": " ", + "x-session-affinity": "", + "x-multi-turn-session-id": " ", + } + ), + ) + + assert request.disaggregated_params is None + + +def test_extract_conversation_id_preserves_body_conversation_id(): + request = CompletionRequest( + model="test-model", + prompt="hello", + disaggregated_params=DisaggregatedParams( + request_type="context_only", + conversation_id="body-id", + ), + ) + + OpenAIDisaggServer._extract_conversation_id( + request, + _raw_request({"X-Session-ID": "header-id"}), + ) + + assert request.disaggregated_params.conversation_id == "body-id" + + +def test_extract_conversation_id_populates_existing_disaggregated_params(): + request = CompletionRequest( + model="test-model", + prompt="hello", + disaggregated_params=DisaggregatedParams(request_type="context_only"), + ) + + OpenAIDisaggServer._extract_conversation_id( + request, + _raw_request({"x-multi-turn-session-id": "multi-turn-session-id"}), + ) + + assert request.disaggregated_params.conversation_id == "multi-turn-session-id"