From 20d5e023d47bb746beb7c719177bb4e6254e8536 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Sun, 12 Apr 2026 13:34:18 -0700 Subject: [PATCH 1/3] feat: add dynamic shapes kernel specialization strategy for TRT-RTX Expose IRuntimeConfig.setDynamicShapesKernelSpecializationStrategy() through the Torch-TensorRT Python API. Users can now control how shape-specialized kernels are compiled at runtime for dynamic shapes on TensorRT-RTX via the new `dynamic_shapes_kernel_specialization_strategy` compilation setting ("lazy", "eager", or "none"). Co-Authored-By: Claude Opus 4.6 (1M context) --- py/torch_tensorrt/dynamo/_compiler.py | 9 ++ py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 5 + .../runtime/_PythonTorchTensorRTModule.py | 17 +++ ...t_dynamic_shapes_kernel_strategy_models.py | 128 ++++++++++++++++ ...test_001_dynamic_shapes_kernel_strategy.py | 139 ++++++++++++++++++ 6 files changed, 299 insertions(+) create mode 100644 tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py create mode 100644 tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9d5b5eaad8..d04c294ad9 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -93,6 +93,7 @@ def cross_compile_for_windows( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, + dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -174,6 +175,7 @@ def cross_compile_for_windows( hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. + dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -339,6 +341,7 @@ def cross_compile_for_windows( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, + "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -451,6 +454,7 @@ def compile( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, + dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -547,6 +551,7 @@ def compile( hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. + dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -755,6 +760,7 @@ def compile( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, + "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1169,6 +1175,7 @@ def convert_exported_program_to_serialized_trt_engine( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, + dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -1246,6 +1253,7 @@ def convert_exported_program_to_serialized_trt_engine( hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. + dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1420,6 +1428,7 @@ def convert_exported_program_to_serialized_trt_engine( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, + "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 3525080e8c..8998479a63 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -72,6 +72,7 @@ DYNAMICALLY_ALLOCATE_RESOURCES = False DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True +DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 8ef3d5d2e6..595f9dcb55 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -23,6 +23,7 @@ DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, DYNAMICALLY_ALLOCATE_RESOURCES, ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, @@ -100,6 +101,7 @@ class CompilationSettings: hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. + dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy". cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. @@ -154,6 +156,9 @@ class CompilationSettings: hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH runtime_cache_path: str = RUNTIME_CACHE_PATH + dynamic_shapes_kernel_specialization_strategy: str = ( + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY + ) lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 9d122446fe..e7cc9a10d9 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -27,6 +27,15 @@ logger = logging.getLogger(__name__) +def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: + """Map strategy string to TRT enum. Only called on RTX builds.""" + return { + "lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY, + "eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER, + "none": trt.DynamicShapesKernelSpecializationStrategy.NONE, + }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) + + class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: trt.IOutputAllocator.__init__(self) @@ -345,6 +354,14 @@ def _setup_runtime_config(self) -> None: self.runtime_config.set_execution_context_allocation_strategy( trt.ExecutionContextAllocationStrategy.STATIC ) + self.runtime_config.dynamic_shapes_kernel_specialization_strategy = ( + _get_dynamic_shapes_kernel_strategy( + self.settings.dynamic_shapes_kernel_specialization_strategy + ) + ) + logger.info( + f"Dynamic shapes kernel specialization strategy: {self.settings.dynamic_shapes_kernel_specialization_strategy}" + ) self.runtime_cache = self.runtime_config.create_runtime_cache() self._load_runtime_cache() self.runtime_config.set_runtime_cache(self.runtime_cache) diff --git a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py new file mode 100644 index 0000000000..f02d47a285 --- /dev/null +++ b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py @@ -0,0 +1,128 @@ +import importlib +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", +) +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +class TestDynamicShapesKernelStrategyModels(TestCase): + """End-to-end model tests with different kernel specialization strategies.""" + + def tearDown(self): + torch._dynamo.reset() + + def _compile_and_verify(self, model, input_tensor, strategy): + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} " + f"with strategy={strategy}", + ) + + def test_resnet18_lazy_strategy(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + self._compile_and_verify(model, input_tensor, "lazy") + + def test_resnet18_eager_strategy(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + self._compile_and_verify(model, input_tensor, "eager") + + def test_resnet18_none_strategy(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + self._compile_and_verify(model, input_tensor, "none") + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", +) +class TestDynamicShapesKernelStrategyDynamic(TestCase): + """Tests kernel specialization strategies with dynamic input shapes.""" + + def tearDown(self): + torch._dynamo.reset() + + def _test_dynamic_batch_with_strategy(self, strategy): + class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + model = ConvModel().eval().cuda() + + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ], + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + + for batch_size in (1, 4, 8): + with self.subTest(batch_size=batch_size, strategy=strategy): + input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"BS={batch_size}, strategy={strategy}: cosine similarity " + f"{cos_sim} below threshold {COSINE_THRESHOLD}", + ) + + def test_dynamic_batch_lazy(self): + self._test_dynamic_batch_with_strategy("lazy") + + def test_dynamic_batch_eager(self): + self._test_dynamic_batch_with_strategy("eager") + + def test_dynamic_batch_none(self): + self._test_dynamic_batch_with_strategy("none") + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py new file mode 100644 index 0000000000..68a1920ca4 --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -0,0 +1,139 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +def _compile_simple(**extra_kwargs): + """Helper: compile SimpleModel with Python runtime, return (compiled_module, inputs).""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + } + kwargs.update(extra_kwargs) + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled, inputs + + +def _find_python_trt_module(compiled): + """Walk the compiled graph module to find PythonTorchTensorRTModule instances.""" + from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + PythonTorchTensorRTModule, + ) + + for name, mod in compiled.named_modules(): + if isinstance(mod, PythonTorchTensorRTModule): + return mod + return None + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", +) +class TestDynamicShapesKernelStrategySetup(TestCase): + """Tests that the dynamic shapes kernel specialization strategy is correctly applied.""" + + def test_default_strategy_is_lazy(self): + import tensorrt as trt + + compiled, _ = _compile_simple() + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod, "No PythonTorchTensorRTModule found") + self.assertIsNotNone(mod.runtime_config, "runtime_config should be set for RTX") + self.assertEqual( + mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, + trt.DynamicShapesKernelSpecializationStrategy.LAZY, + ) + + def test_eager_strategy(self): + import tensorrt as trt + + compiled, _ = _compile_simple( + dynamic_shapes_kernel_specialization_strategy="eager" + ) + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + self.assertEqual( + mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, + trt.DynamicShapesKernelSpecializationStrategy.EAGER, + ) + + def test_none_strategy(self): + import tensorrt as trt + + compiled, _ = _compile_simple( + dynamic_shapes_kernel_specialization_strategy="none" + ) + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + self.assertEqual( + mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, + trt.DynamicShapesKernelSpecializationStrategy.NONE, + ) + + def test_context_created_with_each_strategy(self): + for strategy in ("lazy", "eager", "none"): + with self.subTest(strategy=strategy): + compiled, inputs = _compile_simple( + dynamic_shapes_kernel_specialization_strategy=strategy + ) + mod = _find_python_trt_module(compiled) + self.assertIsNotNone( + mod.context, f"Execution context should be created for {strategy}" + ) + output = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(output.shape, inputs[0].shape) + + def test_setting_in_compilation_settings(self): + for strategy in ("lazy", "eager", "none"): + settings = CompilationSettings( + dynamic_shapes_kernel_specialization_strategy=strategy + ) + self.assertEqual( + settings.dynamic_shapes_kernel_specialization_strategy, strategy + ) + + def test_default_compilation_settings(self): + settings = CompilationSettings() + self.assertEqual(settings.dynamic_shapes_kernel_specialization_strategy, "lazy") + + +@unittest.skipIf( + ENABLED_FEATURES.tensorrt_rtx, + "This test verifies standard TRT behavior (non-RTX)", +) +class TestDynamicShapesKernelStrategyNonRTX(TestCase): + """Tests that the setting is ignored on non-RTX builds.""" + + def test_setting_ignored_on_non_rtx(self): + compiled, inputs = _compile_simple( + dynamic_shapes_kernel_specialization_strategy="eager" + ) + mod = _find_python_trt_module(compiled) + if mod is not None: + self.assertIsNone( + mod.runtime_config, + "runtime_config should be None for standard TRT", + ) + # Inference should still work + output = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(output.shape, inputs[0].shape) + + +if __name__ == "__main__": + run_tests() From d7619caf0cc22723830cf1c7d55ebc964ffc9fae Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Sun, 12 Apr 2026 14:45:02 -0700 Subject: [PATCH 2/3] test: use dynamic shape inputs in kernel strategy tests Address review feedback: compile with torchtrt.Input min/opt/max ranges so dynamic shapes are actually exercised. Co-Authored-By: Claude Opus 4.6 (1M context) --- ...t_dynamic_shapes_kernel_strategy_models.py | 21 ++++++----- ...test_001_dynamic_shapes_kernel_strategy.py | 35 +++++++++++-------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py index f02d47a285..badfff81ea 100644 --- a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py +++ b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py @@ -22,11 +22,19 @@ class TestDynamicShapesKernelStrategyModels(TestCase): def tearDown(self): torch._dynamo.reset() - def _compile_and_verify(self, model, input_tensor, strategy): + def _compile_and_verify(self, model, strategy): + input_tensor = torch.randn(4, 3, 224, 224).cuda() compiled = torchtrt.compile( model, ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + inputs=[ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ], enabled_precisions={torch.float32}, use_python_runtime=True, min_block_size=1, @@ -45,22 +53,19 @@ def test_resnet18_lazy_strategy(self): import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() - input_tensor = torch.randn(1, 3, 224, 224).cuda() - self._compile_and_verify(model, input_tensor, "lazy") + self._compile_and_verify(model, "lazy") def test_resnet18_eager_strategy(self): import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() - input_tensor = torch.randn(1, 3, 224, 224).cuda() - self._compile_and_verify(model, input_tensor, "eager") + self._compile_and_verify(model, "eager") def test_resnet18_none_strategy(self): import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() - input_tensor = torch.randn(1, 3, 224, 224).cuda() - self._compile_and_verify(model, input_tensor, "none") + self._compile_and_verify(model, "none") @unittest.skipIf( diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 68a1920ca4..8c0a12cbdf 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -13,9 +13,16 @@ def forward(self, x): def _compile_simple(**extra_kwargs): - """Helper: compile SimpleModel with Python runtime, return (compiled_module, inputs).""" + """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] + inputs = [ + torchtrt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(4, 3), + dtype=torch.float32, + ) + ] kwargs = { "ir": "dynamo", "inputs": inputs, @@ -26,7 +33,7 @@ def _compile_simple(**extra_kwargs): kwargs.update(extra_kwargs) compiled = torchtrt.compile(model, **kwargs) torch._dynamo.reset() - return compiled, inputs + return compiled def _find_python_trt_module(compiled): @@ -51,7 +58,7 @@ class TestDynamicShapesKernelStrategySetup(TestCase): def test_default_strategy_is_lazy(self): import tensorrt as trt - compiled, _ = _compile_simple() + compiled = _compile_simple() mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod, "No PythonTorchTensorRTModule found") self.assertIsNotNone(mod.runtime_config, "runtime_config should be set for RTX") @@ -63,7 +70,7 @@ def test_default_strategy_is_lazy(self): def test_eager_strategy(self): import tensorrt as trt - compiled, _ = _compile_simple( + compiled = _compile_simple( dynamic_shapes_kernel_specialization_strategy="eager" ) mod = _find_python_trt_module(compiled) @@ -76,9 +83,7 @@ def test_eager_strategy(self): def test_none_strategy(self): import tensorrt as trt - compiled, _ = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="none" - ) + compiled = _compile_simple(dynamic_shapes_kernel_specialization_strategy="none") mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod) self.assertEqual( @@ -89,15 +94,17 @@ def test_none_strategy(self): def test_context_created_with_each_strategy(self): for strategy in ("lazy", "eager", "none"): with self.subTest(strategy=strategy): - compiled, inputs = _compile_simple( + compiled = _compile_simple( dynamic_shapes_kernel_specialization_strategy=strategy ) mod = _find_python_trt_module(compiled) self.assertIsNotNone( mod.context, f"Execution context should be created for {strategy}" ) - output = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(output.shape, inputs[0].shape) + # Test inference with multiple dynamic batch sizes + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) def test_setting_in_compilation_settings(self): for strategy in ("lazy", "eager", "none"): @@ -121,7 +128,7 @@ class TestDynamicShapesKernelStrategyNonRTX(TestCase): """Tests that the setting is ignored on non-RTX builds.""" def test_setting_ignored_on_non_rtx(self): - compiled, inputs = _compile_simple( + compiled = _compile_simple( dynamic_shapes_kernel_specialization_strategy="eager" ) mod = _find_python_trt_module(compiled) @@ -131,8 +138,8 @@ def test_setting_ignored_on_non_rtx(self): "runtime_config should be None for standard TRT", ) # Inference should still work - output = compiled(*[inp.clone() for inp in inputs]) - self.assertEqual(output.shape, inputs[0].shape) + output = compiled(torch.randn(2, 3).cuda()) + self.assertEqual(output.shape, (2, 3)) if __name__ == "__main__": From 52d9ba5d295abeeccc5e29b30565f746ccca2b94 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 14 Apr 2026 04:06:53 -0700 Subject: [PATCH 3/3] feat: add TRT-RTX native CUDA graph support Add cuda_graph_strategy compilation setting and automatic RTX-native CUDA graph integration for the Python runtime path. Key changes: - New cuda_graph_strategy setting ("disabled" / "whole_graph_capture") on CompilationSettings, mapped to trt.CudaGraphStrategy on IRuntimeConfig (same pattern as dynamic_shapes_kernel_specialization) - In SUBGRAPH cudagraph mode on RTX, always use RTX-native CUDA graphs (manual torch.cuda.CUDAGraph capture is not safe due to lazy kernel specialization and potential runtime allocation) - _is_monolithic_capturable() check using context.is_stream_capturable() and strategy != "lazy" for WHOLE_GRAPH mode safety validation - _enable_rtx_native_cudagraphs() for runtime context recreation - _check_monolithic_capturability() in CudaGraphsTorchTensorRTModule for mixed TRT + PyTorch graph validation - Comprehensive unit tests covering all code paths Co-Authored-By: Claude Opus 4.6 (1M context) --- py/torch_tensorrt/dynamo/_compiler.py | 9 + py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 3 + .../runtime/_CudaGraphsTorchTensorRTModule.py | 48 +++ .../runtime/_PythonTorchTensorRTModule.py | 80 +++- .../models/test_cuda_graph_strategy_models.py | 186 +++++++++ .../runtime/test_001_cuda_graph_strategy.py | 354 ++++++++++++++++++ 7 files changed, 674 insertions(+), 7 deletions(-) create mode 100644 tests/py/dynamo/models/test_cuda_graph_strategy_models.py create mode 100644 tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d04c294ad9..4e2a5402c2 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -94,6 +94,7 @@ def cross_compile_for_windows( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -176,6 +177,7 @@ def cross_compile_for_windows( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -342,6 +344,7 @@ def cross_compile_for_windows( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -455,6 +458,7 @@ def compile( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -552,6 +556,7 @@ def compile( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -761,6 +766,7 @@ def compile( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1176,6 +1182,7 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -1254,6 +1261,7 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1429,6 +1437,7 @@ def convert_exported_program_to_serialized_trt_engine( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 8998479a63..4ad05da6ce 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -73,6 +73,7 @@ DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 595f9dcb55..94515a720b 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, + CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -102,6 +103,7 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled". cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. @@ -159,6 +161,7 @@ class CompilationSettings: dynamic_shapes_kernel_specialization_strategy: str = ( DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY ) + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 9e54fbac3d..a166ab859e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -114,6 +114,53 @@ def __del__(self) -> None: def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable + def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: + """Verify all TRT submodules are monolithically capturable on RTX. + + For whole-graph CUDA graph mode with mixed TRT + PyTorch ops, + all TRT engines must be safe for manual stream capture. If any + engine has lazy kernel specialization or non-capturable conditions, + raises RuntimeError. + """ + from torch_tensorrt._features import ENABLED_FEATURES + + if not ENABLED_FEATURES.tensorrt_rtx: + return # non-RTX: no check needed + from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + PythonTorchTensorRTModule, + ) + + for name, mod in self.compiled_module.named_modules(): + if isinstance(mod, PythonTorchTensorRTModule): + if not mod._is_monolithic_capturable(stream): + raise RuntimeError( + f"CUDA graph capture failed: TRT submodule " + f"'{name}' is not monolithically capturable " + f"(lazy kernel specialization or non-capturable " + f"stream). Whole-graph CUDA graph mode with mixed " + f"TRT + PyTorch ops requires all TRT engines to be " + f"capturable. Consider using " + f"cuda_graph_strategy='whole_graph_capture' with " + f"set_cudagraphs_mode(True) instead of " + f"enable_cudagraphs()." + ) + # Ensure RTX-native is DISABLED so TRT engines do not + # interfere with the outer monolithic capture + if mod._rtx_native_cudagraphs: + from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + _get_cuda_graph_strategy, + ) + + mod.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + "disabled" + ) + mod.context = mod._create_context() + mod._rtx_native_cudagraphs = False + logger.info( + f"Disabled RTX-native CUDA graphs for '{name}' " + f"(using outer monolithic capture instead)" + ) + def forward( self, *args: Any, **kwargs: Any ) -> torch.Tensor | Tuple[torch.Tensor, ...]: @@ -183,6 +230,7 @@ def forward( with torch.cuda.stream(self._engine_stream): if need_cudagraphs_record: + self._check_monolithic_capturability(self._engine_stream) self.cudagraph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): self._output_buffers = self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e7cc9a10d9..d1e13f3cd2 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -36,6 +36,14 @@ def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) +def _get_cuda_graph_strategy(strategy_str: str) -> Any: + """Map strategy string to TRT CudaGraphStrategy enum. Only called on RTX builds.""" + return { + "disabled": trt.CudaGraphStrategy.DISABLED, + "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + }.get(strategy_str, trt.CudaGraphStrategy.DISABLED) + + class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: trt.IOutputAllocator.__init__(self) @@ -241,6 +249,7 @@ def __init__( self.runtime_config: Any = None self.runtime_cache: Any = None self.runtime_cache_path = settings.runtime_cache_path + self._rtx_native_cudagraphs = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -309,6 +318,10 @@ def setup_engine(self) -> None: if ENABLED_FEATURES.tensorrt_rtx: self._setup_runtime_config() + self._rtx_native_cudagraphs = ( + ENABLED_FEATURES.tensorrt_rtx + and self.settings.cuda_graph_strategy != "disabled" + ) self.context = self._create_context() assert self.context is not None, "Failed to create execution context" @@ -336,7 +349,10 @@ def setup_engine(self) -> None: if self.requires_output_allocator: self.create_output_allocator() - if torch_tensorrt.runtime.get_cudagraphs_mode(): + if ( + torch_tensorrt.runtime.get_cudagraphs_mode() + and not self._rtx_native_cudagraphs + ): self.cudagraph = torch.cuda.CUDAGraph() self.is_shape_inference_io = { @@ -362,6 +378,10 @@ def _setup_runtime_config(self) -> None: logger.info( f"Dynamic shapes kernel specialization strategy: {self.settings.dynamic_shapes_kernel_specialization_strategy}" ) + self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + self.settings.cuda_graph_strategy + ) + logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}") self.runtime_cache = self.runtime_config.create_runtime_cache() self._load_runtime_cache() self.runtime_config.set_runtime_cache(self.runtime_cache) @@ -466,6 +486,32 @@ def _reset_captured_graph(self) -> None: self.cudagraph.reset() self.cudagraph = None + def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: + """Check if manual torch.cuda.CUDAGraph capture is safe for this engine. + + Returns False on RTX if the engine has conditions that prevent + manual stream capture (runtime allocation, DDS, lazy kernels). + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return True # non-RTX: assume capturable (existing behavior) + # Check 1: TRT-RTX stream capturability (runtime allocation, DDS, etc.) + if not self.context.is_stream_capturable(stream.cuda_stream): + return False + # Check 2: Lazy kernel specialization would invalidate captured graph + if self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy": + return False + return True + + def _enable_rtx_native_cudagraphs(self) -> None: + """Switch to RTX-native CUDA graphs by recreating the execution context.""" + if self.runtime_config is not None: + self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + "whole_graph_capture" + ) + self.context = self._create_context() + self._rtx_native_cudagraphs = True + logger.info("Switched to TRT-RTX native CUDA graphs") + def __del__(self) -> None: self._save_runtime_cache() self._reset_captured_graph() @@ -559,13 +605,32 @@ def create_output_allocator(self) -> None: def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: + # On RTX + SUBGRAPH cudagraphs: always use RTX-native CUDA graphs. + # Manual torch.cuda.CUDAGraph capture is not safe on TRT-RTX because + # lazy kernel specialization can invalidate captured graphs and + # runtime allocation can prevent stream capture. + if ENABLED_FEATURES.tensorrt_rtx and self.cudagraphs_enabled: + if not self._rtx_native_cudagraphs: + logger.warning( + "Manual CUDA graph capture is not guaranteed to work " + "on TRT-RTX (lazy kernel specialization or " + "non-capturable stream). Switching to TRT-RTX native " + "CUDA graphs. Set cuda_graph_strategy=" + '"whole_graph_capture" at compile time to avoid ' + "this warning." + ) + self._enable_rtx_native_cudagraphs() + + effective_cudagraphs = ( + self.cudagraphs_enabled and not self._rtx_native_cudagraphs + ) shape_changed = self.validate_input_shapes(contiguous_inputs) ( need_cudagraphs_record, can_use_pre_allocated_outputs, need_cudagraphs_reset, ) = self.runtime_states.set_runtime_states( - self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed + effective_cudagraphs, self.use_pre_allocated_outputs, shape_changed ) if need_cudagraphs_reset: @@ -587,7 +652,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." self.setup_input_tensors( - contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record + contiguous_inputs, effective_cudagraphs, need_cudagraphs_record ) if shape_changed: @@ -623,7 +688,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if need_cudagraphs_record: self._output_buffers[o] = outputs[o].clone() - if self.cudagraphs_enabled: + if effective_cudagraphs: self.context.set_tensor_address( output_name, self._output_buffers[o].data_ptr() ) @@ -649,7 +714,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): - if self.cudagraphs_enabled: + if effective_cudagraphs: if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() @@ -683,7 +748,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ): self.pre_allocated_outputs = self.create_output_tensors() - if self.cudagraphs_enabled: + if effective_cudagraphs: for idx, o in enumerate(outputs): o.copy_(self._output_buffers[idx]) @@ -840,7 +905,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: return run_output_allocator() else: logger.debug( - f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}." + f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}" + + (" (RTX native)" if self._rtx_native_cudagraphs else "") ) return run_standard_execution() diff --git a/tests/py/dynamo/models/test_cuda_graph_strategy_models.py b/tests/py/dynamo/models/test_cuda_graph_strategy_models.py new file mode 100644 index 0000000000..bce596d15f --- /dev/null +++ b/tests/py/dynamo/models/test_cuda_graph_strategy_models.py @@ -0,0 +1,186 @@ +import unittest + +import torch +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES + + +class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) + + def forward(self, x): + return F.relu(self.conv(x)) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy models require TensorRT-RTX", +) +class TestCudaGraphStrategyModels(TestCase): + """End-to-end model tests with cuda_graph_strategy.""" + + def _check_cosine_similarity(self, output, ref_output, threshold=0.99): + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > threshold, + f"Cosine similarity {cos_sim.item():.4f} below threshold {threshold}", + ) + + def test_resnet18_whole_graph_capture(self): + try: + from torchvision.models import resnet18 + except ImportError: + self.skipTest("torchvision not available") + + model = resnet18(weights=None).eval().cuda() + input_tensor = torch.randn(4, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + output = compiled(input_tensor) + self._check_cosine_similarity(output, ref_output) + + def test_resnet18_disabled_strategy(self): + try: + from torchvision.models import resnet18 + except ImportError: + self.skipTest("torchvision not available") + + model = resnet18(weights=None).eval().cuda() + input_tensor = torch.randn(4, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="disabled", + ) + torch._dynamo.reset() + + output = compiled(input_tensor) + self._check_cosine_similarity(output, ref_output) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy models require TensorRT-RTX", +) +class TestCudaGraphStrategyDynamic(TestCase): + """Tests with dynamic batch sizes and cudagraph mode integration.""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_dynamic_batch_whole_graph_capture(self): + model = ConvModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + for bs in (1, 4, 8): + input_tensor = torch.randn(bs, 3, 32, 32).cuda() + ref_output = model(input_tensor) + output = compiled(input_tensor) + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > 0.99, + f"Batch size {bs}: cosine similarity {cos_sim.item():.4f} too low", + ) + + def test_dynamic_batch_with_subgraph_cudagraphs(self): + model = ConvModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + torchtrt.runtime.set_cudagraphs_mode(True) + + for bs in (1, 4, 8): + input_tensor = torch.randn(bs, 3, 32, 32).cuda() + ref_output = model(input_tensor) + output = compiled(input_tensor) + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > 0.99, + f"Batch size {bs}: cosine similarity {cos_sim.item():.4f} too low", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py new file mode 100644 index 0000000000..554b7f6b2f --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -0,0 +1,354 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +def _compile_simple(**extra_kwargs): + """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" + model = SimpleModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(4, 3), + dtype=torch.float32, + ) + ] + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + } + kwargs.update(extra_kwargs) + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled + + +def _find_python_trt_module(compiled): + """Walk the compiled graph module to find PythonTorchTensorRTModule instances.""" + from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + PythonTorchTensorRTModule, + ) + + for name, mod in compiled.named_modules(): + if isinstance(mod, PythonTorchTensorRTModule): + return mod + return None + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy requires TensorRT-RTX", +) +class TestCudaGraphStrategySetup(TestCase): + """Tests that cuda_graph_strategy is correctly applied on TRT-RTX.""" + + def test_default_strategy_is_disabled(self): + import tensorrt as trt + + compiled = _compile_simple() + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod, "No PythonTorchTensorRTModule found") + self.assertIsNotNone(mod.runtime_config, "runtime_config should be set for RTX") + self.assertEqual( + mod.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.DISABLED, + ) + + def test_whole_graph_capture_strategy(self): + import tensorrt as trt + + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + self.assertEqual( + mod.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + ) + + def test_rtx_native_flag_set(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + self.assertTrue(mod._rtx_native_cudagraphs) + + def test_rtx_native_flag_disabled(self): + compiled = _compile_simple(cuda_graph_strategy="disabled") + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + self.assertFalse(mod._rtx_native_cudagraphs) + + def test_inference_with_each_strategy(self): + for strategy in ("disabled", "whole_graph_capture"): + with self.subTest(strategy=strategy): + compiled = _compile_simple(cuda_graph_strategy=strategy) + mod = _find_python_trt_module(compiled) + self.assertIsNotNone( + mod.context, + f"Execution context should be created for {strategy}", + ) + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_setting_in_compilation_settings(self): + for strategy in ("disabled", "whole_graph_capture"): + settings = CompilationSettings(cuda_graph_strategy=strategy) + self.assertEqual(settings.cuda_graph_strategy, strategy) + + def test_default_compilation_settings(self): + settings = CompilationSettings() + self.assertEqual(settings.cuda_graph_strategy, "disabled") + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy integration requires TensorRT-RTX", +) +class TestCudaGraphStrategyWithSubgraphCudagraphs(TestCase): + """Tests integration with set_cudagraphs_mode().""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_rtx_native_bypasses_manual_capture(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + + torchtrt.runtime.set_cudagraphs_mode(True) + + # Run inference a few times to ensure capture would have happened + for _ in range(3): + compiled(torch.randn(2, 3).cuda()) + + # Manual cudagraph should NOT have been recorded (RTX handles it natively) + self.assertFalse( + hasattr(mod, "cudagraph") + and isinstance(mod.cudagraph, torch.cuda.CUDAGraph), + "Manual CUDA graph should not be recorded when RTX native is active", + ) + + def test_subgraph_mode_always_uses_rtx_native(self): + """Even with cuda_graph_strategy=disabled, SUBGRAPH mode on RTX + should override to RTX-native because manual capture is not safe.""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + # Initially, _rtx_native_cudagraphs is False (disabled strategy) + self.assertFalse(mod._rtx_native_cudagraphs) + + torchtrt.runtime.set_cudagraphs_mode(True) + + # Run inference — should trigger override to RTX-native + for _ in range(3): + compiled(torch.randn(2, 3).cuda()) + + # Should have been overridden to RTX-native + self.assertTrue( + mod._rtx_native_cudagraphs, + "RTX-native should be enabled automatically in SUBGRAPH mode", + ) + # Manual cudagraph should NOT have been recorded + self.assertFalse( + hasattr(mod, "cudagraph") + and isinstance(mod.cudagraph, torch.cuda.CUDAGraph), + "Manual CUDA graph should not be recorded on RTX", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Monolithic capturability tests require TensorRT-RTX", +) +class TestMonolithicCapturability(TestCase): + """Tests for _is_monolithic_capturable() and related logic.""" + + def test_lazy_strategy_not_monolithic_capturable(self): + """Lazy kernel specialization makes monolithic capture unsafe.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="lazy", + ) + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + stream = torch.cuda.Stream() + self.assertFalse(mod._is_monolithic_capturable(stream)) + + def test_eager_strategy_monolithic_capturable(self): + """Eager strategy with capturable stream should be monolithic capturable.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="eager", + ) + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + stream = torch.cuda.Stream() + # is_stream_capturable depends on engine properties. + # With eager strategy, the strategy check passes. + if mod.context.is_stream_capturable(stream.cuda_stream): + self.assertTrue(mod._is_monolithic_capturable(stream)) + + def test_none_strategy_monolithic_capturable(self): + """None strategy (always fallback) should be monolithic capturable.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="none", + ) + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + stream = torch.cuda.Stream() + if mod.context.is_stream_capturable(stream.cuda_stream): + self.assertTrue(mod._is_monolithic_capturable(stream)) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Context recreation tests require TensorRT-RTX", +) +class TestContextRecreation(TestCase): + """Tests for _enable_rtx_native_cudagraphs() context recreation.""" + + def test_enable_rtx_native_recreates_context(self): + """Calling _enable_rtx_native_cudagraphs recreates the execution context.""" + import tensorrt as trt + + compiled = _compile_simple(cuda_graph_strategy="disabled") + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + self.assertFalse(mod._rtx_native_cudagraphs) + + old_context_id = id(mod.context) + mod._enable_rtx_native_cudagraphs() + + self.assertTrue(mod._rtx_native_cudagraphs) + self.assertNotEqual( + id(mod.context), + old_context_id, + "Context should be recreated", + ) + self.assertEqual( + mod.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + ) + + def test_explicit_whole_graph_capture_no_override_needed(self): + """With explicit whole_graph_capture, SUBGRAPH mode should not + need to override (already RTX-native).""" + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod) + self.assertTrue(mod._rtx_native_cudagraphs) + + old_context_id = id(mod.context) + + torchtrt.runtime.set_cudagraphs_mode(True) + compiled(torch.randn(2, 3).cuda()) + torchtrt.runtime.set_cudagraphs_mode(False) + + # Context should NOT have been recreated (was already RTX-native) + self.assertEqual( + id(mod.context), + old_context_id, + "Context should not be recreated if already RTX-native", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Cudagraph mode toggle tests require TensorRT-RTX", +) +class TestCudagraphModeToggle(TestCase): + """Tests for toggling cudagraph mode with RTX-native.""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_cudagraphs_off_after_rtx_native_override(self): + """After RTX-native override, disabling cudagraphs should still + produce correct results (RTX-native continues transparently).""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + + torchtrt.runtime.set_cudagraphs_mode(True) + compiled(torch.randn(2, 3).cuda()) # triggers override + + torchtrt.runtime.set_cudagraphs_mode(False) + + # Should still work -- RTX-native is transparent + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_no_cudagraphs_with_whole_graph_capture(self): + """With cuda_graph_strategy='whole_graph_capture' but no + set_cudagraphs_mode, RTX-native runs transparently.""" + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + mod = _find_python_trt_module(compiled) + self.assertTrue(mod._rtx_native_cudagraphs) + + # No set_cudagraphs_mode(True) -- RTX-native still active transparently + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_toggle_on_off_on(self): + """Toggle cudagraphs on -> off -> on, verify correctness each time.""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + inp = torch.randn(2, 3).cuda() + + # Phase 1: on + torchtrt.runtime.set_cudagraphs_mode(True) + out1 = compiled(inp) + self.assertEqual(out1.shape, (2, 3)) + + # Phase 2: off + torchtrt.runtime.set_cudagraphs_mode(False) + out2 = compiled(inp) + self.assertEqual(out2.shape, (2, 3)) + + # Phase 3: on again + torchtrt.runtime.set_cudagraphs_mode(True) + out3 = compiled(inp) + self.assertEqual(out3.shape, (2, 3)) + + +@unittest.skipIf( + ENABLED_FEATURES.tensorrt_rtx, + "This test verifies standard TRT behavior (non-RTX)", +) +class TestCudaGraphStrategyNonRTX(TestCase): + """Tests that the setting is ignored on non-RTX builds.""" + + def test_setting_ignored_on_non_rtx(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + mod = _find_python_trt_module(compiled) + if mod is not None: + self.assertIsNone( + mod.runtime_config, + "runtime_config should be None for standard TRT", + ) + self.assertFalse(mod._rtx_native_cudagraphs) + output = compiled(torch.randn(2, 3).cuda()) + self.assertEqual(output.shape, (2, 3)) + + +if __name__ == "__main__": + run_tests()