Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def cross_compile_for_windows(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -174,6 +176,8 @@ def cross_compile_for_windows(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -339,6 +343,8 @@ def cross_compile_for_windows(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -451,6 +457,8 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -547,6 +555,8 @@ def compile(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -755,6 +765,8 @@ def compile(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -1169,6 +1181,8 @@ def convert_exported_program_to_serialized_trt_engine(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -1246,6 +1260,8 @@ def convert_exported_program_to_serialized_trt_engine(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -1420,6 +1436,8 @@ def convert_exported_program_to_serialized_trt_engine(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
DYNAMICALLY_ALLOCATE_RESOURCES = False
DECOMPOSE_ATTENTION = False
ATTN_BIAS_IS_CAUSAL = True
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy"
CUDA_GRAPH_STRATEGY = "disabled"

if platform.system() == "Linux":
import pwd
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
AUTOCAST_MAX_OUTPUT_THRESHOLD,
CACHE_BUILT_ENGINES,
CPU_MEMORY_BUDGET,
CUDA_GRAPH_STRATEGY,
DECOMPOSE_ATTENTION,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
DYNAMICALLY_ALLOCATE_RESOURCES,
ENABLE_AUTOCAST,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
Expand Down Expand Up @@ -100,6 +102,8 @@ class CompilationSettings:
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning).
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled".
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
Expand Down Expand Up @@ -154,6 +158,10 @@ class CompilationSettings:
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
runtime_cache_path: str = RUNTIME_CACHE_PATH
dynamic_shapes_kernel_specialization_strategy: str = (
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY
)
cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
Expand Down
48 changes: 48 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,53 @@ def __del__(self) -> None:
def set_use_output_allocator(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable

def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None:
"""Verify all TRT submodules are monolithically capturable on RTX.

For whole-graph CUDA graph mode with mixed TRT + PyTorch ops,
all TRT engines must be safe for manual stream capture. If any
engine has lazy kernel specialization or non-capturable conditions,
raises RuntimeError.
"""
from torch_tensorrt._features import ENABLED_FEATURES

if not ENABLED_FEATURES.tensorrt_rtx:
return # non-RTX: no check needed
from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
PythonTorchTensorRTModule,
)

for name, mod in self.compiled_module.named_modules():
if isinstance(mod, PythonTorchTensorRTModule):
if not mod._is_monolithic_capturable(stream):
raise RuntimeError(
f"CUDA graph capture failed: TRT submodule "
f"'{name}' is not monolithically capturable "
f"(lazy kernel specialization or non-capturable "
f"stream). Whole-graph CUDA graph mode with mixed "
f"TRT + PyTorch ops requires all TRT engines to be "
f"capturable. Consider using "
f"cuda_graph_strategy='whole_graph_capture' with "
f"set_cudagraphs_mode(True) instead of "
f"enable_cudagraphs()."
)
# Ensure RTX-native is DISABLED so TRT engines do not
# interfere with the outer monolithic capture
if mod._rtx_native_cudagraphs:
from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
_get_cuda_graph_strategy,
)

mod.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
"disabled"
)
mod.context = mod._create_context()
mod._rtx_native_cudagraphs = False
logger.info(
f"Disabled RTX-native CUDA graphs for '{name}' "
f"(using outer monolithic capture instead)"
)

def forward(
self, *args: Any, **kwargs: Any
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -183,6 +230,7 @@ def forward(

with torch.cuda.stream(self._engine_stream):
if need_cudagraphs_record:
self._check_monolithic_capturability(self._engine_stream)
self.cudagraph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
self._output_buffers = self.compiled_module(*args, **kwargs)
Expand Down
Loading
Loading