Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions tensorrt_llm/serve/openai_disagg_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: U


class OpenAIDisaggServer:
_CONVERSATION_ID_HEADERS = (
"x-session-id",
"x-correlation-id",
"x-session-affinity",
"x-multi-turn-session-id",
)

def __init__(self,
config: DisaggServerConfig,
Expand Down Expand Up @@ -164,25 +170,31 @@ def register_routes(self):

@staticmethod
def _extract_conversation_id(req: UCompletionRequest, raw_req: Request):
"""Populate conversation_id from the X-Correlation-ID header.
"""Populate conversation_id from supported session headers.

When not already set in the request body, copies the header value
into ``disaggregated_params.conversation_id``.

aiperf sends multi-turn session IDs via the ``X-Correlation-ID``
header (see aiperf ``base_transports.build_headers``). We mirror
that convention so the ConversationRouter can provide session
affinity without requiring clients to set the body field.
Supported headers are checked in priority order: ``X-Session-ID``,
``X-Correlation-ID``, ``x-session-affinity``, and
``x-multi-turn-session-id``. We mirror these conventions so the
ConversationRouter can provide session affinity without requiring
clients to set the body field.

When ``disaggregated_params`` is ``None`` (standard OpenAI
requests without disagg fields), a minimal instance is created
to carry the conversation_id. The service layer always rebuilds
``disaggregated_params`` in ``_get_ctx_request`` /
``_get_gen_request`` before forwarding to workers.
"""
header_conv_id = raw_req.headers.get("x-correlation-id")
if header_conv_id is None:
header_conv_id = None
for header_name in OpenAIDisaggServer._CONVERSATION_ID_HEADERS:
header_conv_id = raw_req.headers.get(header_name)
if header_conv_id is not None and header_conv_id.strip():
break
else:
return
Comment thread
reasonsolo marked this conversation as resolved.

if req.disaggregated_params is None:
req.disaggregated_params = DisaggregatedParams(
request_type="context_only",
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ accuracy/test_llm_api_pytorch_multimodal.py::TestQwen3VL::test_auto_dtype[forced
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen3VL_MOE::test_auto_dtype
accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype
accuracy/test_llm_api_pytorch_ray.py::TestLlama3_1_8BInstruct::test_pp2_ray
unittest/disaggregated/test_openai_disagg_server.py
disaggregated/test_auto_scaling.py::test_disagg_server_restart[etcd-round_robin]
disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin]
disaggregated/test_auto_scaling.py::test_minimal_instances[etcd-round_robin]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ l0_a10:
- unittest/others/test_tracing.py
- unittest/disaggregated/test_disagg_openai_client.py
- unittest/disaggregated/test_disagg_utils.py
- unittest/disaggregated/test_openai_disagg_server.py
- unittest/disaggregated/test_openai_disagg_service.py
- unittest/disaggregated/test_router.py
- unittest/disaggregated/test_remoteDictionary.py
Expand Down
114 changes: 114 additions & 0 deletions tests/unittest/disaggregated/test_openai_disagg_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2026, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from types import SimpleNamespace

from starlette.datastructures import Headers

from tensorrt_llm.serve.openai_disagg_server import OpenAIDisaggServer
from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams


def _raw_request(headers: dict[str, str]):
return SimpleNamespace(headers=Headers(headers=headers))


def test_extract_conversation_id_from_headers():
cases = [
({"X-Session-ID": "session-id"}, "session-id"),
({"X-Correlation-ID": "correlation-id"}, "correlation-id"),
({"x-session-affinity": "session-affinity"}, "session-affinity"),
({"x-multi-turn-session-id": "multi-turn-session-id"}, "multi-turn-session-id"),
(
{
"X-Correlation-ID": "correlation-id",
"X-Session-ID": "session-id",
"x-session-affinity": "session-affinity",
"x-multi-turn-session-id": "multi-turn-session-id",
},
"session-id",
),
(
{
"x-session-affinity": "session-affinity",
"x-multi-turn-session-id": "multi-turn-session-id",
},
"session-affinity",
),
(
{
"X-Session-ID": "",
"X-Correlation-ID": "correlation-id",
},
"correlation-id",
),
]

for headers, expected_conversation_id in cases:
request = CompletionRequest(model="test-model", prompt="hello")

OpenAIDisaggServer._extract_conversation_id(request, _raw_request(headers))

assert request.disaggregated_params is not None
assert request.disaggregated_params.conversation_id == expected_conversation_id


def test_extract_conversation_id_ignores_empty_headers():
request = CompletionRequest(model="test-model", prompt="hello")

OpenAIDisaggServer._extract_conversation_id(
request,
_raw_request(
{
"X-Session-ID": "",
"X-Correlation-ID": " ",
"x-session-affinity": "",
"x-multi-turn-session-id": " ",
}
),
)

assert request.disaggregated_params is None


def test_extract_conversation_id_preserves_body_conversation_id():
request = CompletionRequest(
model="test-model",
prompt="hello",
disaggregated_params=DisaggregatedParams(
request_type="context_only",
conversation_id="body-id",
),
)

OpenAIDisaggServer._extract_conversation_id(
request,
_raw_request({"X-Session-ID": "header-id"}),
)

assert request.disaggregated_params.conversation_id == "body-id"


def test_extract_conversation_id_populates_existing_disaggregated_params():
request = CompletionRequest(
model="test-model",
prompt="hello",
disaggregated_params=DisaggregatedParams(request_type="context_only"),
)

OpenAIDisaggServer._extract_conversation_id(
request,
_raw_request({"x-multi-turn-session-id": "multi-turn-session-id"}),
)

assert request.disaggregated_params.conversation_id == "multi-turn-session-id"
Loading