From 0b3b1934ae924db510c70d434ecbdf9526fa29da Mon Sep 17 00:00:00 2001
From: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
Date: Thu, 30 Apr 2026 02:10:39 -0700
Subject: [PATCH 1/3] [None][test] add Nemotron Ultra V3 AutoDeploy accuracy
test
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
---
.../model_registry/configs/ultra_v3.yaml | 57 ++++++++++++++++
.../defs/accuracy/references/gsm8k.yaml | 4 ++
.../defs/accuracy/references/mmlu.yaml | 4 ++
.../defs/accuracy/test_llm_api_autodeploy.py | 66 +++++++++++++++++++
.../test_lists/test-db/l0_dgx_b200.yml | 8 +++
tests/test_common/llm_data.py | 1 +
6 files changed, 140 insertions(+)
create mode 100644 examples/auto_deploy/model_registry/configs/ultra_v3.yaml
diff --git a/examples/auto_deploy/model_registry/configs/ultra_v3.yaml b/examples/auto_deploy/model_registry/configs/ultra_v3.yaml
new file mode 100644
index 000000000000..9c8a1bad1f1f
--- /dev/null
+++ b/examples/auto_deploy/model_registry/configs/ultra_v3.yaml
@@ -0,0 +1,57 @@
+runtime: trtllm
+compile_backend: torch-cudagraph
+max_batch_size: 16
+max_seq_len: 12288 # tuned for 0.80 free_gpu_memory_fraction with 8-way TP
+enable_chunked_prefill: true
+attn_backend: trtllm
+model_factory: AutoModelForCausalLM
+skip_loading_weights: false
+cuda_graph_config:
+ batch_sizes: [1, 2, 4, 8, 16]
+kv_cache_config:
+ # tunable mamba cache dtype
+ # --> use float32 for accuracy and default (auto) for speed
+ mamba_ssm_cache_dtype: float32
+ free_gpu_memory_fraction: 0.80
+transforms:
+ detect_sharding:
+ allreduce_strategy: SYMM_MEM
+ # NOTE: add 'tp' to sharding dims only for high-throughput runs
+ # For low-latency, keep mamba and attention replicated
+ sharding_dims: ['ep', 'bmm', 'tp']
+ # NOTE: sharding_source applies only to TP sharding
+ sharding_source: ['manual']
+ manual_config:
+ head_dim: 128
+ tp_plan:
+ # mamba SSM layer
+ "in_proj": "mamba"
+ "out_proj": "rowwise"
+ # attention layer
+ "q_proj": "colwise"
+ "k_proj": "colwise"
+ "v_proj": "colwise"
+ "o_proj": "rowwise"
+ # moe layer: SHARED experts
+ "up_proj": "colwise"
+ "down_proj": "rowwise"
+ # MoLE: latent projections: simple shard
+ "fc1_latent_proj": "gather"
+ "fc2_latent_proj": "gather"
+ multi_stream_moe:
+ stage: compile
+ enabled: true
+ gather_logits_before_lm_head:
+ # TODO: fix https://github.com/NVIDIA/TensorRT-LLM/issues/9878 to enable by default
+ enabled: true
+ fuse_mamba_a_log:
+ stage: post_load_fusion
+ enabled: true
+ insert_cached_ssm_attention:
+ backend: flashinfer_ssm
+ fuse_fp8_moe:
+ allow_different_input_scales: true
+ fuse_nvfp4_moe:
+ allow_different_input_scales: true
+ load_weights:
+ disable_preload: false
diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml
index ad38c829cd19..4086aef881c2 100644
--- a/tests/integration/defs/accuracy/references/gsm8k.yaml
+++ b/tests/integration/defs/accuracy/references/gsm8k.yaml
@@ -424,6 +424,10 @@ nvidia/Nemotron-Super-V3:
kv_cache_quant_algo: FP8
spec_dec_algo: MTP
accuracy: 91.88
+nvidia/Nemotron-Ultra-V3:
+ - quant_algo: NVFP4
+ kv_cache_quant_algo: FP8
+ accuracy: 91.797
nvidia/Nemotron-3-Nano:
- accuracy: 69.37
- quant_algo: FP8
diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml
index 6df8325a28a8..ce06d462974b 100644
--- a/tests/integration/defs/accuracy/references/mmlu.yaml
+++ b/tests/integration/defs/accuracy/references/mmlu.yaml
@@ -433,6 +433,10 @@ nvidia/Nemotron-Super-V3:
kv_cache_quant_algo: FP8
spec_dec_algo: MTP
accuracy: 85.13
+nvidia/Nemotron-Ultra-V3:
+ - quant_algo: NVFP4
+ kv_cache_quant_algo: FP8
+ accuracy: 85.70
nvidia/Nemotron-3-Nano:
- accuracy: 73.85
- quant_algo: FP8
diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py
index ce19f0392715..094e9f3fa99f 100644
--- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py
+++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py
@@ -750,6 +750,72 @@ def test_mtp(self, world_size, attn_backend):
print_memory_usage("after evaluation")
+class TestNemotronUltraV3(LlmapiAccuracyTestHarness):
+ MODEL_NAME = "nvidia/Nemotron-Ultra-V3"
+ CONFIG_YAML = str(
+ Path(get_llm_root()) / "examples" / "auto_deploy" / "model_registry" /
+ "configs" / "ultra_v3.yaml")
+ MODEL_PATHS = {
+ "nvfp4": hf_id_to_local_model_dir("nvidia/Nemotron-Ultra-V3-NVFP4"),
+ }
+
+ def get_default_sampling_params(self):
+ # Use end_id=None to allow framework to read tokenizer's EOS tokens [2, 11]
+ # and enable task-specific stop sequences (critical for GSM8K)
+ return SamplingParams(end_id=None,
+ pad_id=None,
+ n=1,
+ use_beam_search=False)
+
+ @pytest.mark.parametrize("attn_backend", ["flashinfer", "trtllm"])
+ @pytest.mark.parametrize("enable_attention_dp", [False, True],
+ ids=["attn_dp_off", "attn_dp_on"])
+ @pytest.mark.parametrize("world_size", [4, 8])
+ @pytest.mark.parametrize("model_id", ["nvfp4"])
+ def test_accuracy(self, model_id, world_size, enable_attention_dp,
+ attn_backend):
+ if get_device_count() < world_size:
+ pytest.skip(f"Not enough devices for world_size={world_size}")
+
+ model_path = self.MODEL_PATHS[model_id]
+ kwargs = {}
+ kwargs["attn_backend"] = attn_backend
+ kwargs.setdefault("transforms", {}).setdefault(
+ "detect_sharding", {})["enable_attention_dp"] = enable_attention_dp
+
+ print_memory_usage("test start")
+ with AutoDeployLLM(model=model_path,
+ tokenizer=model_path,
+ world_size=world_size,
+ yaml_extra=[self.CONFIG_YAML],
+ trust_remote_code=True,
+ **kwargs) as llm:
+ _set_quant_config(llm, model_id)
+ print_memory_usage("after engine build")
+
+ sampling_params = self.get_default_sampling_params()
+ task = MMLU(self.MODEL_NAME)
+ task.evaluate(llm, sampling_params=sampling_params)
+
+ # Ultra V3 uses extended thinking: enable_thinking=True so the model
+ # can use ... CoT before the #### answer.
+ # Increase max_tokens to 1024 to allow the full thinking chain to
+ # complete before the "#### N" answer token -- 256 is too short.
+ sampling_params.max_tokens = 1024
+ task = GSM8K(self.MODEL_NAME)
+ task.NUM_SAMPLES = 128
+ task.evaluate(llm,
+ sampling_params=sampling_params,
+ extra_evaluator_kwargs={
+ "apply_chat_template": True,
+ "chat_template_kwargs": {
+ "enable_thinking": True
+ },
+ })
+
+ print_memory_usage("after evaluation")
+
+
class TestGLM4Flash(LlmapiAccuracyTestHarness):
"""Accuracy regression tests for GLM-4.7-Flash variants"""
diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml
index c7c64ad95111..0bc6f9646463 100644
--- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml
+++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml
@@ -345,6 +345,7 @@ l0_dgx_b200:
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_accuracy[nvfp4-4-attn_dp_on-trtllm]
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[ws4_180gb-flashinfer]
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[ws4_180gb-trtllm]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_off-flashinfer]
- accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[nvidia_Llama-3.1-8B-Instruct-NVFP4-True]
# ------------- AutoDeploy Perf Sanity ---------------
- perf/test_perf_sanity.py::test_e2e[aggr_upload-super_ad_blackwell-super_ad_ws4_1k1k] TIMEOUT (120)
@@ -363,6 +364,9 @@ l0_dgx_b200:
backend: autodeploy
orchestrator: mpi
tests:
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_on-flashinfer]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_off-trtllm]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_on-trtllm]
# ------------- AutoDeploy Perf Sanity ---------------
- perf/test_perf_sanity.py::test_e2e[aggr_upload-super_ad_blackwell-super_ad_ws4_1k1k] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[aggr_upload-super_mtp_ad_blackwell-super_mtp_ad_ws4_1k1k] TIMEOUT (120)
@@ -385,3 +389,7 @@ l0_dgx_b200:
tests:
- accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[deepseek-ai_DeepSeek-R1-0528-True]
- accuracy/test_llm_api_autodeploy.py::TestQwen3_5_397B_MoE::test_nvfp4[8]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_off-flashinfer]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_on-flashinfer]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_off-trtllm]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_on-trtllm]
diff --git a/tests/test_common/llm_data.py b/tests/test_common/llm_data.py
index a6339814813c..8fa89555dab6 100644
--- a/tests/test_common/llm_data.py
+++ b/tests/test_common/llm_data.py
@@ -68,6 +68,7 @@
"google/gemma-4-26B-A4B-it": "gemma/gemma-4-26B-A4B-it",
"Qwen/Qwen3.5-35B-A3B": "Qwen3.5-35B-A3B",
"MiniMaxAI/MiniMax-M2": "MiniMax-M2",
+ "nvidia/Nemotron-Ultra-V3-NVFP4": "nemotron-ultra-sample-ckpt-old-format-sft_nvfp4_aggressive_03_04_26_nvfp4",
}
From d62d4562435ed59aa92347f23be2cdf328b4553d Mon Sep 17 00:00:00 2001
From: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
Date: Sat, 2 May 2026 23:44:31 -0700
Subject: [PATCH 2/3] remove flavors, skip test for pre blackwell
Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
---
.../defs/accuracy/test_llm_api_autodeploy.py | 25 +++++++------------
.../test_lists/test-db/l0_dgx_b200.yml | 11 +++-----
2 files changed, 12 insertions(+), 24 deletions(-)
diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py
index 094e9f3fa99f..3c8ebce6febe 100644
--- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py
+++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py
@@ -767,29 +767,22 @@ def get_default_sampling_params(self):
n=1,
use_beam_search=False)
- @pytest.mark.parametrize("attn_backend", ["flashinfer", "trtllm"])
- @pytest.mark.parametrize("enable_attention_dp", [False, True],
- ids=["attn_dp_off", "attn_dp_on"])
+ @skip_pre_blackwell
@pytest.mark.parametrize("world_size", [4, 8])
@pytest.mark.parametrize("model_id", ["nvfp4"])
- def test_accuracy(self, model_id, world_size, enable_attention_dp,
- attn_backend):
+ def test_accuracy(self, model_id, world_size):
if get_device_count() < world_size:
pytest.skip(f"Not enough devices for world_size={world_size}")
model_path = self.MODEL_PATHS[model_id]
- kwargs = {}
- kwargs["attn_backend"] = attn_backend
- kwargs.setdefault("transforms", {}).setdefault(
- "detect_sharding", {})["enable_attention_dp"] = enable_attention_dp
-
print_memory_usage("test start")
- with AutoDeployLLM(model=model_path,
- tokenizer=model_path,
- world_size=world_size,
- yaml_extra=[self.CONFIG_YAML],
- trust_remote_code=True,
- **kwargs) as llm:
+ with AutoDeployLLM(
+ model=model_path,
+ tokenizer=model_path,
+ world_size=world_size,
+ yaml_extra=[self.CONFIG_YAML],
+ trust_remote_code=True,
+ ) as llm:
_set_quant_config(llm, model_id)
print_memory_usage("after engine build")
diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml
index 0bc6f9646463..ab3f8c18d5d8 100644
--- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml
+++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml
@@ -345,7 +345,7 @@ l0_dgx_b200:
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_accuracy[nvfp4-4-attn_dp_on-trtllm]
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[ws4_180gb-flashinfer]
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_mtp[ws4_180gb-trtllm]
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_off-flashinfer]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4]
- accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[nvidia_Llama-3.1-8B-Instruct-NVFP4-True]
# ------------- AutoDeploy Perf Sanity ---------------
- perf/test_perf_sanity.py::test_e2e[aggr_upload-super_ad_blackwell-super_ad_ws4_1k1k] TIMEOUT (120)
@@ -364,9 +364,7 @@ l0_dgx_b200:
backend: autodeploy
orchestrator: mpi
tests:
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_on-flashinfer]
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_off-trtllm]
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4-attn_dp_on-trtllm]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-4]
# ------------- AutoDeploy Perf Sanity ---------------
- perf/test_perf_sanity.py::test_e2e[aggr_upload-super_ad_blackwell-super_ad_ws4_1k1k] TIMEOUT (120)
- perf/test_perf_sanity.py::test_e2e[aggr_upload-super_mtp_ad_blackwell-super_mtp_ad_ws4_1k1k] TIMEOUT (120)
@@ -389,7 +387,4 @@ l0_dgx_b200:
tests:
- accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[deepseek-ai_DeepSeek-R1-0528-True]
- accuracy/test_llm_api_autodeploy.py::TestQwen3_5_397B_MoE::test_nvfp4[8]
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_off-flashinfer]
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_on-flashinfer]
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_off-trtllm]
- - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8-attn_dp_on-trtllm]
+ - accuracy/test_llm_api_autodeploy.py::TestNemotronUltraV3::test_accuracy[nvfp4-8]
From 312bdb83bea4a9b52682653372678ba97418252e Mon Sep 17 00:00:00 2001
From: Tal Cherckez
Date: Mon, 11 May 2026 03:00:42 -0700
Subject: [PATCH 3/3] fix: make FlashInfer Mamba decode inputs contiguous
Signed-off-by: Tal Cherckez
---
.../auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py
index e6bab2d5807e..7232afd17c6a 100644
--- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py
+++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py
@@ -126,6 +126,11 @@ def _flashinfer_cached_ssm(
import flashinfer
+ # FlashInfer needs contiguous x/B/C with 128-byte alignment.
+ x_decode = x_decode.contiguous()
+ B_decode = B_decode.contiguous()
+ C_decode = C_decode.contiguous()
+
slot_idx_decode_i32 = slot_idx_decode.to(torch.int32)
y_decode = flashinfer.mamba.selective_state_update(
ssm_state_cache,