From b96440ef3894eb7b278eb9e94fa4eb8b6b7275e2 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 21 Apr 2026 02:20:06 -0700 Subject: [PATCH 1/9] feat(runtime): introduce IRuntimeConfig scaffolding and bump ABI to v9 Lay the shared infrastructure used by three upcoming TensorRT-RTX-only runtime features (runtime cache, dynamic shapes kernel specialization strategy, native CUDA graph strategy) in the C++ runtime path. Core changes - Bump ABI_VERSION from "8" to "9" and add three new SerializedInfoIndex entries (RUNTIME_CACHE_PATH_IDX, DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, CUDA_GRAPH_STRATEGY_IDX). One bump covers all three feature fields. - Add an IRuntimeConfig + IRuntimeCache shared_ptr pair to TRTEngine behind TRT_MAJOR_RTX, plus three plain string/int fields that remain serializable on non-RTX builds so the ABI is stable across both. - Extract a private recreate_execution_context() helper that is the single site where exec_ctx is built. On RTX builds it creates (once) the IRuntimeConfig, invokes per-feature appliers, and then creates the execution context via createExecutionContext(IRuntimeConfig*). Replaces four prior direct createExecutionContext call sites in the constructor, disable_profiling, set_device_memory_budget, and set_resource_allocation_strategy so each automatically inherits the runtime-config path on RTX. - Declare apply_runtime_cache / apply_dynamic_shapes_kernel_strategy / apply_cuda_graph_strategy as private RTX-only helpers with empty bodies; follow-up commits fill these in per feature. The empty stubs keep this commit behavior-neutral. - Extend TRTEngine::serialize, the deserialization constructor, the __obj_flatten__ tuple, and to_str so the new fields round-trip. - Expose RUNTIME_CACHE_PATH_IDX, DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, and CUDA_GRAPH_STRATEGY_IDX via torch.ops.tensorrt. Python side - Add dynamic_shapes_kernel_specialization_strategy ("lazy" default) and cuda_graph_strategy ("disabled" default) to _defaults.py, CompilationSettings, and the three compile() entry points. - Thread runtime_cache_path, dynamic_shapes_kernel_specialization_ strategy, and cuda_graph_strategy through _TorchTensorRTModule._ pack_engine_info with string-to-int maps so the C++ engine sees validated integer codes (0/1/2 for strategies) and raises ValueError for unknown strings. No behavior change yet: the RTX appliers are empty and all new strategy defaults select the prior code paths. --- core/runtime/TRTEngine.cpp | 113 ++++++++++++++---- core/runtime/TRTEngine.h | 41 ++++++- core/runtime/register_jit_hooks.cpp | 3 + core/runtime/runtime.h | 5 +- py/torch_tensorrt/dynamo/_compiler.py | 8 ++ py/torch_tensorrt/dynamo/_defaults.py | 2 + py/torch_tensorrt/dynamo/_settings.py | 7 ++ .../dynamo/runtime/_TorchTensorRTModule.py | 49 +++++++- 8 files changed, 197 insertions(+), 31 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d29daa112b..eadd0398ce 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include "NvInfer.h" @@ -62,7 +64,10 @@ TRTEngine::TRTEngine( bool hardware_compatible, bool requires_output_allocator, const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) + const ResourceAllocationStrategy resource_allocation_strategy, + const std::string& runtime_cache_path, + int dynamic_shapes_kernel_strategy, + int cuda_graph_strategy) : TRTEngine( "deserialized_trt", serialized_engine, @@ -73,7 +78,10 @@ TRTEngine::TRTEngine( hardware_compatible, requires_output_allocator, serialized_metadata, - resource_allocation_strategy) {} + resource_allocation_strategy, + runtime_cache_path, + dynamic_shapes_kernel_strategy, + cuda_graph_strategy) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -88,7 +96,10 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic)) {} + : ResourceAllocationStrategy::kStatic), + serialized_info[RUNTIME_CACHE_PATH_IDX], + std::stoi(serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), + std::stoi(serialized_info[CUDA_GRAPH_STRATEGY_IDX])) {} TRTEngine::TRTEngine( const std::string& mod_name, @@ -100,7 +111,20 @@ TRTEngine::TRTEngine( bool hardware_compatible, bool requires_output_allocator, const std::string& serialized_metadata, - const ResourceAllocationStrategy resource_allocation_strategy) { + const ResourceAllocationStrategy resource_allocation_strategy, + const std::string& runtime_cache_path, + int dynamic_shapes_kernel_strategy, + int cuda_graph_strategy) { + this->runtime_cache_path = runtime_cache_path; + TORCHTRT_CHECK( + dynamic_shapes_kernel_strategy >= 0 && dynamic_shapes_kernel_strategy <= 2, + "Invalid dynamic_shapes_kernel_strategy: " << dynamic_shapes_kernel_strategy + << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); + this->dynamic_shapes_kernel_strategy = dynamic_shapes_kernel_strategy; + TORCHTRT_CHECK( + cuda_graph_strategy >= 0 && cuda_graph_strategy <= 1, + "Invalid cuda_graph_strategy: " << cuda_graph_strategy << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); + this->cuda_graph_strategy = cuda_graph_strategy; TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -134,13 +158,7 @@ TRTEngine::TRTEngine( LOG_DEBUG( "Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static")); - if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); + recreate_execution_context(); // Pre-allocate placeholder for empty tensors (TensorRT requires non-null addresses) cudaMalloc(&empty_tensor_placeholder, 1); @@ -278,8 +296,7 @@ void TRTEngine::disable_profiling() { torch::cuda::synchronize(device_info.id); profile_execution = false; trt_engine_profiler.reset(); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context"); + recreate_execution_context(); } void TRTEngine::dump_engine_layer_info_to_file(const std::string& path) { @@ -376,10 +393,7 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { trt_engine_profiler.reset(); } bool result = cuda_engine->setWeightStreamingBudgetV2(budget); - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - TORCHTRT_CHECK( - (exec_ctx.get() != nullptr), - "Unable to recreate TensorRT execution context after setting new device memory budget"); + recreate_execution_context(); if (profile_execution) { enable_profiling(); } @@ -428,6 +442,11 @@ std::string TRTEngine::to_str() const { ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; + ss << " Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; + ss << " Dynamic Shapes Kernel Strategy: " << dynamic_shapes_kernel_strategy + << " (0=lazy, 1=eager, 2=none)" << std::endl; + ss << " CUDA Graph Strategy: " << cuda_graph_strategy + << " (0=disabled, 1=whole_graph_capture)" << std::endl; // clang-format on return ss.str(); } @@ -472,7 +491,10 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), - std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])); + std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), + std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), + std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), + std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX])); } std::vector TRTEngine::serialize() { @@ -497,6 +519,9 @@ std::vector TRTEngine::serialize() { serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; + serialized_info[RUNTIME_CACHE_PATH_IDX] = this->runtime_cache_path; + serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string(this->dynamic_shapes_kernel_strategy); + serialized_info[CUDA_GRAPH_STRATEGY_IDX] = std::to_string(this->cuda_graph_strategy); return serialized_info; } @@ -508,17 +533,53 @@ void TRTEngine::reset_captured_graph() { void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) { if (new_strategy != this->resource_allocation_strategy) { this->resource_allocation_strategy = new_strategy; - if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) { - LOG_DEBUG("Setting resource allocation strategy to dynamic"); - this->exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - LOG_DEBUG("Setting resource allocation strategy to static"); - this->exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } + LOG_DEBUG( + "Setting resource allocation strategy to " + << (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ? "dynamic" + : "static")); + recreate_execution_context(); } } +void TRTEngine::recreate_execution_context() { +#ifdef TRT_MAJOR_RTX + if (!runtime_config) { + runtime_config = make_trt(cuda_engine->createRuntimeConfig()); + TORCHTRT_CHECK(runtime_config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); + apply_runtime_cache(); + apply_dynamic_shapes_kernel_strategy(); + apply_cuda_graph_strategy(); + } + runtime_config->setExecutionContextAllocationStrategy( + resource_allocation_strategy == ResourceAllocationStrategy::kDynamic + ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED + : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC); + exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_config.get())); +#else + if (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { + exec_ctx = + make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + } else { + exec_ctx = make_trt(cuda_engine->createExecutionContext()); + } +#endif + TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); +} + +#ifdef TRT_MAJOR_RTX +void TRTEngine::apply_runtime_cache() { + // Body added in a follow-up commit that wires the TRT-RTX runtime cache. +} + +void TRTEngine::apply_dynamic_shapes_kernel_strategy() { + // Body added in a follow-up commit that wires the dynamic shapes kernel specialization strategy. +} + +void TRTEngine::apply_cuda_graph_strategy() { + // Body added in a follow-up commit that wires the TRT-RTX native CUDA graph strategy. +} +#endif + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 363631863f..51712f7d28 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -30,7 +30,10 @@ using FlattenedState = std::tuple< std::tuple, // requires_output_allocator std::tuple, // serialized metadata std::tuple, // Platform - std::tuple>; // Resource Allocation Strategy + std::tuple, // Resource Allocation Strategy + std::tuple, // Runtime Cache Path (TRT-RTX) + std::tuple, // Dynamic Shapes Kernel Specialization Strategy (TRT-RTX) + std::tuple>; // CUDA Graph Strategy (TRT-RTX) struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -134,7 +137,10 @@ struct TRTEngine : torch::CustomClassHolder { bool requires_output_allocator = false, const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + const std::string& runtime_cache_path = "", + int dynamic_shapes_kernel_strategy = 0, + int cuda_graph_strategy = 0); TRTEngine(std::vector serialized_info); @@ -149,7 +155,10 @@ struct TRTEngine : torch::CustomClassHolder { bool requires_output_allocator = false, const std::string& serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = - TRTEngine::ResourceAllocationStrategy::kStatic); + TRTEngine::ResourceAllocationStrategy::kStatic, + const std::string& runtime_cache_path = "", + int dynamic_shapes_kernel_strategy = 0, + int cuda_graph_strategy = 0); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; @@ -217,6 +226,32 @@ struct TRTEngine : torch::CustomClassHolder { ResourceAllocationStrategy resource_allocation_strategy = kStatic; void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); + + // TRT-RTX runtime config state. The plain fields are stored unconditionally so that + // serialization remains ABI-stable on non-RTX builds; the IRuntimeConfig / IRuntimeCache + // handles themselves only exist on RTX. + std::string runtime_cache_path = ""; + int dynamic_shapes_kernel_strategy = 0; // 0=lazy, 1=eager, 2=none + int cuda_graph_strategy = 0; // 0=disabled, 1=whole_graph_capture + +#ifdef TRT_MAJOR_RTX + std::shared_ptr runtime_config; + std::shared_ptr runtime_cache; +#endif + + private: + // Single entry point that (re)creates exec_ctx. On RTX builds this also creates / reuses + // the IRuntimeConfig and applies all runtime config settings. + void recreate_execution_context(); + +#ifdef TRT_MAJOR_RTX + // Per-feature appliers invoked the first time recreate_execution_context() runs. Bodies + // are provided in follow-up commits that introduce each feature; keeping the declarations + // here lets the scaffolding land without behavior changes. + void apply_runtime_cache(); + void apply_dynamic_shapes_kernel_strategy(); + void apply_cuda_graph_strategy(); +#endif }; } // namespace runtime diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index e8f6217a21..20d7580b4c 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -150,6 +150,9 @@ TORCH_LIBRARY(tensorrt, m) { m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); + m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); + m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); + m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index d8f71683d3..7e7a374460 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,7 +16,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "8"; +const std::string ABI_VERSION = "9"; extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { @@ -39,6 +39,9 @@ typedef enum { TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, + RUNTIME_CACHE_PATH_IDX, + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, + CUDA_GRAPH_STRATEGY_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d04c294ad9..9e1f8f9c85 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -113,6 +113,8 @@ def cross_compile_for_windows( enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, + dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -356,6 +358,8 @@ def cross_compile_for_windows( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, + "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } @@ -487,6 +491,8 @@ def compile( cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, + dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -785,6 +791,8 @@ def compile( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, + "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 8998479a63..5db1042183 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -70,6 +70,8 @@ ENABLE_RESOURCE_PARTITIONING = False CPU_MEMORY_BUDGET = None DYNAMICALLY_ALLOCATE_RESOURCES = False +DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 595f9dcb55..95d0cd88bd 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, + CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -124,6 +125,8 @@ class CompilationSettings: autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines + dynamic_shapes_kernel_specialization_strategy (str): TensorRT-RTX dynamic shapes kernel specialization strategy: "lazy" (default, compile specialized kernels in background and use fallbacks until ready), "eager" (compile specialized kernels synchronously, blocking first inference), or "none" (always use fallback kernels). Not used for standard TensorRT. + cuda_graph_strategy (str): TensorRT-RTX CUDA graph strategy: "disabled" (default) or "whole_graph_capture" (let TensorRT-RTX manage CUDA graph capture/replay internally). When set and combined with `torch_tensorrt.runtime.set_cudagraphs_mode(True)` on RTX, overrides manual capture. Not used for standard TensorRT. decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. """ @@ -189,6 +192,10 @@ class CompilationSettings: enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING cpu_memory_budget: Optional[int] = CPU_MEMORY_BUDGET dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES + dynamic_shapes_kernel_specialization_strategy: str = ( + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY + ) + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY decompose_attention: bool = DECOMPOSE_ATTENTION attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d77c0bf39f..e8edb4163d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -35,6 +35,10 @@ SERIALIZED_METADATA_IDX = -1 # Not implemented TARGET_PLATFORM_IDX = -1 # Not implemented REQUIRES_OUTPUT_ALLOCATOR_IDX = -1 # Not implemented +RESOURCE_ALLOCATION_STRATEGY_IDX = -1 # Not implemented +RUNTIME_CACHE_PATH_IDX = -1 # Not implemented +DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = -1 # Not implemented +CUDA_GRAPH_STRATEGY_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented if ENABLED_FEATURES.torch_tensorrt_runtime: @@ -53,7 +57,22 @@ RESOURCE_ALLOCATION_STRATEGY_IDX = ( torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() ) # 10 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11 + RUNTIME_CACHE_PATH_IDX = torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX() # 11 + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( + torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX() + ) # 12 + CUDA_GRAPH_STRATEGY_IDX = torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX() # 13 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 14 + +_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { + "lazy": 0, + "eager": 1, + "none": 2, +} +_CUDA_GRAPH_STRATEGY_MAP: Dict[str, int] = { + "disabled": 0, + "whole_graph_capture": 1, +} @for_all_methods(needs_torch_tensorrt_runtime) @@ -145,6 +164,11 @@ def __init__( self.engine = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources + self.runtime_cache_path = settings.runtime_cache_path + self.dynamic_shapes_kernel_specialization_strategy = ( + settings.dynamic_shapes_kernel_specialization_strategy + ) + self.cuda_graph_strategy = settings.cuda_graph_strategy self.symbolic_shape_expressions = symbolic_shape_expressions if ( @@ -203,6 +227,29 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) + engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + f"Invalid dynamic_shapes_kernel_specialization_strategy " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " + f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" + ) + engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( + _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ + self.dynamic_shapes_kernel_specialization_strategy + ] + ) + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " + f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" + ) + engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( + _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] + ) return engine_info From 01b9f386fa76cc0f364c67b8eba5bed052bdd784 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 21 Apr 2026 02:24:05 -0700 Subject: [PATCH 2/9] feat(runtime): add runtime cache to C++ runtime for TensorRT-RTX Implement TensorRT-RTX runtime cache persistence in the C++ runtime path (TorchTensorRTModule / TRTEngine). Mirrors the Python-runtime feature landed in pytorch/TensorRT#4180. What - apply_runtime_cache() (no-op stub from the prior commit) now creates an IRuntimeCache from the IRuntimeConfig, loads any existing cache file from the configured path, and attaches the cache to the config via IRuntimeConfig::setRuntimeCache (taken by const reference). - load_runtime_cache() reads the cache under an advisory shared lock (flock LOCK_SH) on POSIX. Concurrent readers coexist; transient failures downgrade to warnings so inference never blocks on cache IO. - save_runtime_cache() writes the serialized cache atomically via tmp-file + rename under an exclusive lock (flock LOCK_EX). The write path creates intermediate directories as needed. On Windows the save falls back to a best-effort write without advisory locking and emits a warning; LockFileEx support is a follow-up. - ~TRTEngine() now invokes save_runtime_cache() before tearing down the cache, config, and execution context so JIT compilation results survive process exits. Why - TensorRT-RTX JIT-compiles specialized kernels at inference time. The runtime cache lets those compilations persist across runs and across processes, which was measured at ~8x warm-vs-cold speedup in the Python-runtime implementation. - Without this commit, users relying on the C++ runtime (TorchScript deployments, use_python_runtime=False) would have no way to retain JIT work and would pay the cold-start cost on every process start. Tests - tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py exercises the C++ runtime path (use_python_runtime=False) with cache save on destructor, directory creation, warm-cache roundtrip correctness via cosine-similarity, and ABI/index registration. --- core/runtime/TRTEngine.cpp | 118 +++++++++++++- core/runtime/TRTEngine.h | 5 + .../runtime/test_000_runtime_cache_cpp.py | 146 ++++++++++++++++++ 3 files changed, 268 insertions(+), 1 deletion(-) create mode 100644 tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index eadd0398ce..c8e7fb0e26 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -12,6 +12,12 @@ #include "core/util/prelude.h" #include "torch/torch.h" +#if defined(TRT_MAJOR_RTX) && !defined(_WIN32) +#include +#include +#include +#endif + namespace torch_tensorrt { namespace core { namespace runtime { @@ -283,6 +289,11 @@ TRTEngine::TRTEngine( TRTEngine::~TRTEngine() { torch::cuda::synchronize(device_info.id); +#ifdef TRT_MAJOR_RTX + save_runtime_cache(); + runtime_cache.reset(); + runtime_config.reset(); +#endif trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -568,7 +579,23 @@ void TRTEngine::recreate_execution_context() { #ifdef TRT_MAJOR_RTX void TRTEngine::apply_runtime_cache() { - // Body added in a follow-up commit that wires the TRT-RTX runtime cache. + if (runtime_cache_path.empty()) { + LOG_DEBUG("Runtime cache disabled (no path configured)."); + return; + } + runtime_cache = make_trt(runtime_config->createRuntimeCache()); + if (runtime_cache.get() == nullptr) { + LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); + return; + } + load_runtime_cache(); + bool ok = runtime_config->setRuntimeCache(*runtime_cache); + if (!ok) { + LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); + runtime_cache.reset(); + return; + } + LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); } void TRTEngine::apply_dynamic_shapes_kernel_strategy() { @@ -578,7 +605,96 @@ void TRTEngine::apply_dynamic_shapes_kernel_strategy() { void TRTEngine::apply_cuda_graph_strategy() { // Body added in a follow-up commit that wires the TRT-RTX native CUDA graph strategy. } + +void TRTEngine::load_runtime_cache() { + if (runtime_cache == nullptr || runtime_cache_path.empty()) { + return; + } + if (!std::filesystem::exists(runtime_cache_path)) { + LOG_DEBUG("No existing runtime cache at " << runtime_cache_path); + return; + } +#ifndef _WIN32 + int fd = ::open(runtime_cache_path.c_str(), O_RDONLY); + if (fd < 0) { + LOG_WARNING("Failed to open runtime cache for reading: " << runtime_cache_path); + return; + } + if (::flock(fd, LOCK_SH) != 0) { + LOG_WARNING("Failed to acquire shared lock on runtime cache; skipping load."); + ::close(fd); + return; + } +#endif + try { + std::ifstream f(runtime_cache_path, std::ios::binary); + std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + if (!buf.empty()) { + bool ok = runtime_cache->deserialize(buf.data(), buf.size()); + if (ok) { + LOG_INFO("Loaded runtime cache from " << runtime_cache_path << " (" << buf.size() << " bytes)"); + } else { + LOG_WARNING("runtime_cache->deserialize returned false for " << runtime_cache_path); + } + } + } catch (const std::exception& e) { + LOG_WARNING("Failed to load runtime cache: " << e.what()); + } +#ifndef _WIN32 + ::flock(fd, LOCK_UN); + ::close(fd); +#endif +} + +void TRTEngine::save_runtime_cache() { + if (runtime_cache == nullptr || runtime_cache_path.empty()) { + return; + } + auto host_mem = make_trt(runtime_cache->serialize()); + if (host_mem.get() == nullptr || host_mem->size() == 0) { + return; + } + try { + std::filesystem::path path(runtime_cache_path); + if (path.has_parent_path()) { + std::filesystem::create_directories(path.parent_path()); + } + std::filesystem::path tmp_path = path; + tmp_path += ".tmp"; + +#ifndef _WIN32 + int fd = ::open(tmp_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (fd < 0) { + LOG_WARNING("Failed to open runtime cache tmp file for writing: " << tmp_path.string()); + return; + } + if (::flock(fd, LOCK_EX) != 0) { + LOG_WARNING("Failed to acquire exclusive lock on runtime cache tmp file; skipping save."); + ::close(fd); + return; + } + ssize_t written = ::write(fd, host_mem->data(), host_mem->size()); + ::flock(fd, LOCK_UN); + ::close(fd); + if (written != static_cast(host_mem->size())) { + LOG_WARNING("Short write when saving runtime cache to " << tmp_path.string()); + return; + } +#else + // Windows: best-effort write without a cross-process lock. Follow-up: LockFileEx. + { + std::ofstream out(tmp_path, std::ios::binary); + out.write(reinterpret_cast(host_mem->data()), host_mem->size()); + } + LOG_WARNING("Runtime cache save on Windows runs without advisory locking; concurrent writers may race."); #endif + std::filesystem::rename(tmp_path, path); + LOG_INFO("Saved runtime cache to " << runtime_cache_path << " (" << host_mem->size() << " bytes)"); + } catch (const std::exception& e) { + LOG_WARNING("Failed to save runtime cache: " << e.what()); + } +} +#endif // TRT_MAJOR_RTX } // namespace runtime } // namespace core diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 51712f7d28..a2d75f8630 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -251,6 +251,11 @@ struct TRTEngine : torch::CustomClassHolder { void apply_runtime_cache(); void apply_dynamic_shapes_kernel_strategy(); void apply_cuda_graph_strategy(); + + // Runtime cache persistence (RTX-only). Load is invoked from apply_runtime_cache(); save + // is invoked from the destructor before exec_ctx / runtime_config tear down. + void load_runtime_cache(); + void save_runtime_cache(); #endif }; diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py b/tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py new file mode 100644 index 0000000000..a7a62ef131 --- /dev/null +++ b/tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py @@ -0,0 +1,146 @@ +import gc +import os +import shutil +import tempfile +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv(x)) + + +def _fresh_model_and_inputs(seed=0): + """Create a deterministic SimpleModel + input tensor pair.""" + torch.manual_seed(seed) + return SimpleModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] + + +def _compile_cpp(model, inputs, runtime_cache_path=None): + """Compile the given model through the C++ runtime path.""" + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": False, + "min_block_size": 1, + } + if runtime_cache_path is not None: + kwargs["runtime_cache_path"] = runtime_cache_path + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +class TestRuntimeCacheCppPersistence(TestCase): + """Exercise C++-runtime runtime cache load/save against disk.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + + def test_cache_saved_on_del(self): + model, inputs = _fresh_model_and_inputs() + compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + _ = compiled(*[inp.clone() for inp in inputs]) + self.assertFalse( + os.path.isfile(self.cache_path), + "Cache should not exist before module cleanup", + ) + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(self.cache_path), + "Cache file should be created after module cleanup", + ) + + def test_cache_file_nonempty(self): + model, inputs = _fresh_model_and_inputs() + compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + _ = compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + self.assertGreater( + os.path.getsize(self.cache_path), + 0, + "Cache file should have nonzero size", + ) + + def test_cache_roundtrip(self): + """Compile, infer, save. Then recompile same model+cache and verify correctness.""" + model, inputs = _fresh_model_and_inputs() + with torch.no_grad(): + ref_output = model(*inputs) + + compiled1 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + out1 = compiled1(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out1), + COSINE_THRESHOLD, + "First compiled output should match eager", + ) + del compiled1 + gc.collect() + self.assertTrue(os.path.isfile(self.cache_path)) + + compiled2 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + out2 = compiled2(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out2), + COSINE_THRESHOLD, + "Second compiled output (warm cache) should still match eager", + ) + + def test_save_creates_directory(self): + nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") + model, inputs = _fresh_model_and_inputs() + compiled = _compile_cpp(model, inputs, runtime_cache_path=nested_path) + _ = compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(nested_path), + "Save should create intermediate directories", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestCppSerializationIndices(TestCase): + """Verify the new C++ serialization indices are registered by the runtime.""" + + def test_new_indices_registered(self): + self.assertEqual(int(torch.ops.tensorrt.ABI_VERSION()), 9) + self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), 14) + self.assertEqual(int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), 11) + self.assertEqual( + int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), 12 + ) + self.assertEqual(int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), 13) + + +if __name__ == "__main__": + run_tests() From 481455f024fdb48261831c98bb468045d1d5b31d Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 21 Apr 2026 02:25:55 -0700 Subject: [PATCH 3/9] feat(runtime): add dynamic shapes kernel specialization strategy to C++ runtime Wire the dynamic_shapes_kernel_specialization_strategy compile setting into the C++ runtime path on TensorRT-RTX by filling in the apply_dynamic_shapes_kernel_strategy() body introduced in the scaffolding commit. What - apply_dynamic_shapes_kernel_strategy() now calls IRuntimeConfig::setDynamicShapesKernelSpecializationStrategy with the integer code (0=lazy, 1=eager, 2=none) that was validated at engine construction. - The setting is applied once when the IRuntimeConfig is first built inside recreate_execution_context(); the value is serialized with the engine so deserialized modules restore the same strategy. Why - "lazy" (the default) compiles specialized kernels in the background and uses fallbacks until they are ready - good for latency of the first call but hands-off for steady-state throughput. - "eager" compiles the specialized kernel synchronously on first use, blocking inference but eliminating the fallback phase. - "none" disables kernel specialization entirely and always uses the generic fallback. Useful in combination with outer CUDA graph capture where a stable set of kernels is required. Tests - tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py validates the setting default, the full {lazy, eager, none} matrix through the C++ runtime (use_python_runtime=False), dynamic shape traversal under "eager", and ValueError rejection of unknown strategy names at engine-packing time. --- core/runtime/TRTEngine.cpp | 4 +- ...test_000_dynamic_shapes_kernel_strategy.py | 133 ++++++++++++++++++ 2 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index c8e7fb0e26..d695667511 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -599,7 +599,9 @@ void TRTEngine::apply_runtime_cache() { } void TRTEngine::apply_dynamic_shapes_kernel_strategy() { - // Body added in a follow-up commit that wires the dynamic shapes kernel specialization strategy. + runtime_config->setDynamicShapesKernelSpecializationStrategy( + static_cast(dynamic_shapes_kernel_strategy)); + LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << dynamic_shapes_kernel_strategy); } void TRTEngine::apply_cuda_graph_strategy() { diff --git a/tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py new file mode 100644 index 0000000000..6761ca1c65 --- /dev/null +++ b/tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py @@ -0,0 +1,133 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._defaults import ( + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, +) +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class DynamicConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) + self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv2(torch.relu(self.conv1(x)))) + + +def _compile_cpp(strategy): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + + +class TestDynamicShapesKernelStrategySettings(TestCase): + """Setting-level validation that runs on every build (RTX and non-RTX).""" + + def test_default_value(self): + settings = CompilationSettings() + self.assertEqual( + settings.dynamic_shapes_kernel_specialization_strategy, + DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + ) + + def test_settable_values(self): + for value in ("lazy", "eager", "none"): + settings = CompilationSettings( + dynamic_shapes_kernel_specialization_strategy=value + ) + self.assertEqual( + settings.dynamic_shapes_kernel_specialization_strategy, value + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel strategy is a TensorRT-RTX feature", +) +class TestDynamicShapesKernelStrategyCpp(TestCase): + """End-to-end: compile + infer through the C++ runtime with each strategy.""" + + def test_lazy(self): + compiled = _compile_cpp("lazy") + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_eager(self): + compiled = _compile_cpp("eager") + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_none(self): + compiled = _compile_cpp("none") + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_dynamic_shape_with_eager(self): + """Exercise shape changes under eager kernel specialization.""" + compiled = _compile_cpp("eager") + for batch in (1, 2, 3, 4): + x = torch.randn(batch, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (batch, 8, 16, 16)) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestDynamicShapesKernelStrategyInvalidValue(TestCase): + """Invalid strategy names are rejected at engine-packing time.""" + + def test_invalid_strategy_raises(self): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", + ) + + +if __name__ == "__main__": + run_tests() From 2b630e8973b32768be9d6012160b7bcb3fcf1cf0 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 21 Apr 2026 02:29:22 -0700 Subject: [PATCH 4/9] feat(runtime): add TensorRT-RTX native CUDA graph strategy to C++ runtime Wire cuda_graph_strategy into the C++ runtime and make the execute_engine CUDA graph path TensorRT-RTX-aware. Fills in the apply_cuda_graph_strategy stub and adds coexistence handling for outer whole-graph capture. What - apply_cuda_graph_strategy() now calls IRuntimeConfig::setCudaGraphStrategy with either kDISABLED (default) or kWHOLE_GRAPH_CAPTURE. On RTX this hands capture/replay off to the TRT-RTX runtime, avoiding the lazy-kernel and dynamic-shape hazards of wrapping enqueueV3 in at::cuda::CUDAGraph. - is_monolithic_capturable(stream) returns whether an engine can safely be captured by an outer torch.cuda.CUDAGraph: RTX builds check IExecutionContext::isStreamCapturable and require a non-lazy kernel strategy; non-RTX builds always return true. - disable_rtx_native_cudagraphs() is a one-shot switch that turns off the engine internal capture and recreates the execution context so that outer stream captures contain the kernel launches directly. - execute_engine.cpp now computes effective_cudagraphs. On RTX, if a cuda_graph_strategy is set or SUBGRAPH cudagraphs is enabled, it bypasses the manual at::cuda::CUDAGraph path (the TRT-RTX runtime handles that inside enqueueV3). It also polls cudaStreamIsCapturing on the engine stream and, if an outer capture is already running, invokes disable_rtx_native_cudagraphs() so the outer capture proceeds without collision. Why - On TRT-RTX, the manual at::cuda::CUDAGraph wrapper around enqueueV3 can freeze fallback kernels in the captured graph (kLAZY specialisation would swap them later), and fails outright when the engine needs runtime allocation, DDS, control flow, or weight streaming. - Letting the TRT-RTX runtime own capture fixes both problems, and the outer-capture detection keeps the feature compatible with the existing CudaGraphsTorchTensorRTModule whole-graph wrapper without requiring it to know anything about RTX internals. Tests - tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py validates the setting default, both {disabled, whole_graph_capture} through the C++ runtime, the RTX-native override when set_cudagraphs_mode(True) is combined with a strategy, repeated inference correctness, and ValueError rejection of unknown strategy names. --- core/runtime/TRTEngine.cpp | 34 ++++- core/runtime/TRTEngine.h | 12 ++ core/runtime/execute_engine.cpp | 33 ++++- .../runtime/test_000_cuda_graph_strategy.py | 116 ++++++++++++++++++ 4 files changed, 188 insertions(+), 7 deletions(-) create mode 100644 tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d695667511..b6f3ce6e81 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -552,6 +552,33 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt } } +bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { +#if defined(TRT_MAJOR_RTX) && defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION) + // "lazy" strategy (0) swaps specialized kernels in mid-run, which would invalidate a + // captured graph. Any other strategy (eager/none) combined with a capturable stream is + // safe for outer monolithic capture. + return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != 0; +#else + (void)stream; + return true; +#endif +} + +void TRTEngine::disable_rtx_native_cudagraphs() { +#ifdef TRT_MAJOR_RTX + if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == 0) { + return; + } + LOG_WARNING( + "Outer CUDA stream capture detected; disabling TRT-RTX native CUDA graph strategy on engine " + << name << " for the remainder of its lifetime."); + cuda_graph_strategy = 0; + apply_cuda_graph_strategy(); + recreate_execution_context(); + rtx_native_cudagraphs_disabled = true; +#endif +} + void TRTEngine::recreate_execution_context() { #ifdef TRT_MAJOR_RTX if (!runtime_config) { @@ -605,7 +632,12 @@ void TRTEngine::apply_dynamic_shapes_kernel_strategy() { } void TRTEngine::apply_cuda_graph_strategy() { - // Body added in a follow-up commit that wires the TRT-RTX native CUDA graph strategy. + bool ok = runtime_config->setCudaGraphStrategy( + cuda_graph_strategy == 1 ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE + : nvinfer1::CudaGraphStrategy::kDISABLED); + if (!ok) { + LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); + } } void TRTEngine::load_runtime_cache() { diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index a2d75f8630..7e6d357ec7 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -233,12 +233,24 @@ struct TRTEngine : torch::CustomClassHolder { std::string runtime_cache_path = ""; int dynamic_shapes_kernel_strategy = 0; // 0=lazy, 1=eager, 2=none int cuda_graph_strategy = 0; // 0=disabled, 1=whole_graph_capture + // One-shot flag: set the first time execute_engine detects an outer stream capture around + // this engine, at which point its TRT-RTX native CUDA graph capture is turned off so the + // two do not fight. The flag stays set for the remainder of the engine's lifetime. + bool rtx_native_cudagraphs_disabled = false; #ifdef TRT_MAJOR_RTX std::shared_ptr runtime_config; std::shared_ptr runtime_cache; #endif + // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph + // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. + bool is_monolithic_capturable(cudaStream_t stream) const; + + // Disable TRT-RTX native CUDA graph capture on this engine (one-shot, invoked when an + // outer stream capture is detected around execute_engine). No-op on non-RTX. + void disable_rtx_native_cudagraphs(); + private: // Single entry point that (re)creates exec_ctx. On RTX builds this also creates / reuses // the IRuntimeConfig and applies all runtime config settings. diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 553469392b..0d0e6c1035 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -217,11 +217,29 @@ std::vector execute_engine(std::vector inputs, c10::intr auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); + // effective_cudagraphs controls the manual at::cuda::CUDAGraph path below. On TRT-RTX + // builds we bypass that path whenever the engine has a cuda_graph_strategy set or the + // outer runtime has requested subgraph cudagraphs - the TRT-RTX runtime handles capture + // and replay internally inside enqueueV3. If an outer stream capture is already in + // progress (e.g. the caller wraps this module in CudaGraphsTorchTensorRTModule for + // whole-graph capture), RTX-native capture would conflict, so we disable it one-shot. + bool effective_cudagraphs = cudagraphs_enabled; +#ifdef TRT_MAJOR_RTX + if (compiled_engine->cuda_graph_strategy != 0 || cudagraphs_enabled) { + effective_cudagraphs = false; + cudaStreamCaptureStatus capture_status; + cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status); + if (capture_status != cudaStreamCaptureStatusNone) { + compiled_engine->disable_rtx_native_cudagraphs(); + } + } +#endif + bool shape_changed = _validate_shapes(inputs, compiled_engine); // Whether cudagraphs needs to record the graph on this pass auto result = compiled_engine->runtime_states.set_runtime_states( - cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); + effective_cudagraphs, compiled_engine->use_pre_allocated_outputs, shape_changed); bool need_cudagraphs_record = std::get<0>(result); bool can_use_pre_allocated_outputs = std::get<1>(result); @@ -244,7 +262,8 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->input_profile_path); } - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues); + setup_input_tensors( + inputs, compiled_engine, effective_cudagraphs, need_cudagraphs_record, inputShapeTensorValues); // Check if input shapes can be inferred. int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; std::vector names(io_size); @@ -276,7 +295,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); } - if (cudagraphs_enabled) { + if (effective_cudagraphs) { TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress( name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), @@ -316,8 +335,10 @@ std::vector execute_engine(std::vector inputs, c10::intr caller_exec_complete.record(compiled_engine->caller_stream); caller_exec_complete.block(compiled_engine->engine_stream); - if (!cudagraphs_enabled) { - // Direct execution uses the caller buffers directly + if (!effective_cudagraphs) { + // Direct execution uses the caller buffers directly. On TRT-RTX with a + // cuda_graph_strategy set, the engine captures/replays internally during + // this enqueueV3 call. compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); } else { if (need_cudagraphs_record) { @@ -350,7 +371,7 @@ std::vector execute_engine(std::vector inputs, c10::intr trt_exec_complete.record(compiled_engine->engine_stream); trt_exec_complete.block(compiled_engine->caller_stream); - if (cudagraphs_enabled) { + if (effective_cudagraphs) { // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { outputs[o].copy_(compiled_engine->output_buffers[o], false); diff --git a/tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py new file mode 100644 index 0000000000..8a2968b0d8 --- /dev/null +++ b/tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py @@ -0,0 +1,116 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._defaults import CUDA_GRAPH_STRATEGY +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class CudaGraphModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv(x)) + + +def _compile_cpp(strategy): + model = CudaGraphModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + cuda_graph_strategy=strategy, + ) + torch._dynamo.reset() + return compiled, inputs + + +class TestCudaGraphStrategySettings(TestCase): + """Setting-level validation that runs on every build (RTX and non-RTX).""" + + def test_default_value(self): + settings = CompilationSettings() + self.assertEqual(settings.cuda_graph_strategy, CUDA_GRAPH_STRATEGY) + + def test_settable_values(self): + for value in ("disabled", "whole_graph_capture"): + settings = CompilationSettings(cuda_graph_strategy=value) + self.assertEqual(settings.cuda_graph_strategy, value) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy is a TensorRT-RTX feature", +) +class TestCudaGraphStrategyCpp(TestCase): + """End-to-end: compile + infer through the C++ runtime with each strategy.""" + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_disabled(self): + compiled, inputs = _compile_cpp("disabled") + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_whole_graph_capture(self): + compiled, inputs = _compile_cpp("whole_graph_capture") + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_whole_graph_capture_with_subgraph_cudagraphs(self): + """Subgraph cudagraph mode + RTX strategy: RTX-native should take over without errors.""" + compiled, inputs = _compile_cpp("whole_graph_capture") + torchtrt.runtime.set_cudagraphs_mode(True) + y = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_repeated_inference(self): + """Repeated inference exercises the RTX-native capture/replay path.""" + compiled, inputs = _compile_cpp("whole_graph_capture") + ref = compiled(*[inp.clone() for inp in inputs]) + for _ in range(4): + out = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(out.shape, ref.shape) + self.assertTrue(torch.isfinite(out).all().item()) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestCudaGraphStrategyInvalidValue(TestCase): + """Invalid strategy names are rejected at engine-packing time.""" + + def test_invalid_strategy_raises(self): + model = CudaGraphModel().eval().cuda() + inputs = [torch.randn(2, 3, 16, 16).cuda()] + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + cuda_graph_strategy="not_a_real_strategy", + ) + + +if __name__ == "__main__": + run_tests() From 54f9ccda97e076c8574ce382beb310db639c2141 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 22 Apr 2026 12:41:12 -0700 Subject: [PATCH 5/9] refactor(runtime): extract TRTRuntimeConfig, address PR review Address the structural PR feedback by extracting TensorRT-RTX-specific IRuntimeConfig state into its own type and collapsing the per-feature appliers that previously scattered `#ifdef TRT_MAJOR_RTX` through TRTEngine. What - New core/runtime/TRTRuntimeConfig.{h,cpp} owns the IRuntimeConfig shared_ptr plus (on TRT-RTX) the IRuntimeCache, runtime-cache path, dynamic shapes kernel strategy, CUDA graph strategy, and the rtx_native_cudagraphs_disabled one-shot flag. All per-feature appliers live there as public members and are no-ops on non-RTX builds, keeping the only `#ifdef TRT_MAJOR_RTX` scatter contained in this new file. - Strategy fields are now strongly-typed enums (`DynamicShapesKernelStrategy`, `CudaGraphStrategyOption`) with matching `to_string`/`to_int` helpers, validated at engine construction via `to_dynamic_shapes_kernel_strategy` / `to_cuda_ graph_strategy_option` rather than raw int ranges. - `TRTEngine::recreate_execution_context` is now backend-agnostic: it calls `runtime_cfg.ensure_initialized`, applies the allocation strategy, and creates the execution context via `createExecutionContext(IRuntimeConfig*)`. Both standard TensorRT and TRT-RTX go through this uniform path; only the three RTX-only setters (`setRuntimeCache`, `setDynamicShapesKernel SpecializationStrategy`, `setCudaGraphStrategy`) stay behind an `#ifdef TRT_MAJOR_RTX` guard inside the struct. - `~TRTEngine` now wraps cleanup in try/catch and delegates cache persistence to `TRTRuntimeConfig::save_runtime_cache_nothrow`, so stack unwinding can no longer propagate a cache-save failure out of the destructor. - `save_runtime_cache_nothrow` uses `std::filesystem` + atomic `tmp+rename` only; file locking is out of scope for this PR and will be introduced in a follow-up once we pick a portable mechanism. - `is_monolithic_capturable` asserts `exec_ctx` is non-null; the three RTX-only appliers `TORCHTRT_ASSERT` that `config` is live before dereferencing. - `disable_rtx_native_cudagraphs` persists the runtime cache before flipping the strategy so any kernels compiled under the internal capture survive to the next reload. - `TRTEngine::to_str` now emits human-readable strategy names (via `to_string(enum)`) instead of integer codes. - New serialization indices (`RUNTIME_CACHE_PATH_IDX`, `DYNAMIC_ SHAPES_KERNEL_STRATEGY_IDX`, `CUDA_GRAPH_STRATEGY_IDX`) are now `#ifdef TRT_MAJOR_RTX`-gated in runtime.h, register_jit_hooks.cpp, the FlattenedState tuple, the serialize/deserialize constructors, and `__obj_flatten__`. Standard TRT builds keep `SERIALIZATION_LEN == 11` so engines serialized there do not carry RTX-only slots. - Python `_TorchTensorRTModule` reads the RTX-only index accessors and writes the RTX-only engine-info slots only when `ENABLED_FEATURES.tensorrt_rtx` is true. Standard TRT users see no new behavior at runtime. - Deduplicated `_compiler.py` arguments after rebase on upstream main where PR #4184 had already added `dynamic_shapes_kernel_specialization_strategy`. Kept one copy of each arg; `cuda_graph_strategy` is threaded through all three compile() entry points. Build + tests - RTX build on A100 / L40S: libtorchtrt.so and libtorchtrt_ runtime.so link clean, no `#ifdef` diagnostics. Pre-commit checks pass (clang-format, black, isort, ruff, mypy, typos, buildifier). - All 35 runtime-cache/strategy tests pass; regression across test_000_runtime_cache.py (Python runtime), test_002_cudagraphs_ cpp.py, test_005_dynamic_allocation.py is green. Addresses review comments on PR #4202: - Guarding of new IDX entries and Python accessors on TRT_MAJOR_RTX / ENABLED_FEATURES.tensorrt_rtx. - Encapsulation of RTX-specific state in a dedicated type with enumerated strategies and transparent standard-TRT/RTX behavior. - Destructor exception safety. - Unification of the execution-context creation path via IRuntimeConfig. - Removal of file locking for runtime-cache persistence. - Debug asserts before dereferencing the live IRuntimeConfig. - Human-readable to_str output. - save_runtime_cache invoked from disable_rtx_native_cudagraphs. --- core/runtime/BUILD | 6 +- core/runtime/TRTEngine.cpp | 239 +++--------------- core/runtime/TRTEngine.h | 57 ++--- core/runtime/TRTRuntimeConfig.cpp | 222 ++++++++++++++++ core/runtime/TRTRuntimeConfig.h | 93 +++++++ core/runtime/execute_engine.cpp | 15 +- core/runtime/register_jit_hooks.cpp | 2 + core/runtime/runtime.h | 2 + py/torch_tensorrt/dynamo/_compiler.py | 6 +- py/torch_tensorrt/dynamo/_defaults.py | 1 - .../dynamo/runtime/_TorchTensorRTModule.py | 69 ++--- ...egy.py => test_001_cuda_graph_strategy.py} | 0 12 files changed, 434 insertions(+), 278 deletions(-) create mode 100644 core/runtime/TRTRuntimeConfig.cpp create mode 100644 core/runtime/TRTRuntimeConfig.h rename tests/py/dynamo/runtime/{test_000_cuda_graph_strategy.py => test_001_cuda_graph_strategy.py} (100%) diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 19260149ae..61fcd7a283 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -1,6 +1,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_pkg//:pkg.bzl", "pkg_tar") load("@rules_pkg//pkg:mappings.bzl", "pkg_files") + package(default_visibility = ["//visibility:public"]) config_setting( @@ -66,6 +67,7 @@ cc_library( "RTDevice.cpp", "TRTEngine.cpp", "TRTEngineProfiler.cpp", + "TRTRuntimeConfig.cpp", "execute_engine.cpp", "register_jit_hooks.cpp", "runtime.cpp", @@ -75,6 +77,7 @@ cc_library( "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", "runtime.h", ], linkopts = [ @@ -107,6 +110,7 @@ filegroup( "RTDevice.h", "TRTEngine.h", "TRTEngineProfiler.h", + "TRTRuntimeConfig.h", "runtime.h", ], visibility = ["//visibility:public"], @@ -121,6 +125,6 @@ pkg_tar( pkg_files( name = "include_pkg_files", srcs = [":include_files"], - visibility = ["//visibility:public"], prefix = "include/torch_tensorrt/core/runtime/", + visibility = ["//visibility:public"], ) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index b6f3ce6e81..66181a3c40 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include "NvInfer.h" @@ -12,12 +11,6 @@ #include "core/util/prelude.h" #include "torch/torch.h" -#if defined(TRT_MAJOR_RTX) && !defined(_WIN32) -#include -#include -#include -#endif - namespace torch_tensorrt { namespace core { namespace runtime { @@ -102,10 +95,15 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic), + : ResourceAllocationStrategy::kStatic) +#ifdef TRT_MAJOR_RTX + , serialized_info[RUNTIME_CACHE_PATH_IDX], std::stoi(serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), - std::stoi(serialized_info[CUDA_GRAPH_STRATEGY_IDX])) {} + std::stoi(serialized_info[CUDA_GRAPH_STRATEGY_IDX]) +#endif + ) { +} TRTEngine::TRTEngine( const std::string& mod_name, @@ -121,16 +119,9 @@ TRTEngine::TRTEngine( const std::string& runtime_cache_path, int dynamic_shapes_kernel_strategy, int cuda_graph_strategy) { - this->runtime_cache_path = runtime_cache_path; - TORCHTRT_CHECK( - dynamic_shapes_kernel_strategy >= 0 && dynamic_shapes_kernel_strategy <= 2, - "Invalid dynamic_shapes_kernel_strategy: " << dynamic_shapes_kernel_strategy - << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); - this->dynamic_shapes_kernel_strategy = dynamic_shapes_kernel_strategy; - TORCHTRT_CHECK( - cuda_graph_strategy >= 0 && cuda_graph_strategy <= 1, - "Invalid cuda_graph_strategy: " << cuda_graph_strategy << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); - this->cuda_graph_strategy = cuda_graph_strategy; + runtime_cfg.runtime_cache_path = runtime_cache_path; + runtime_cfg.dynamic_shapes_kernel_strategy = to_dynamic_shapes_kernel_strategy(dynamic_shapes_kernel_strategy); + runtime_cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(cuda_graph_strategy); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -288,12 +279,13 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { - torch::cuda::synchronize(device_info.id); -#ifdef TRT_MAJOR_RTX - save_runtime_cache(); - runtime_cache.reset(); - runtime_config.reset(); -#endif + // Destructors must not throw; `save_runtime_cache_nothrow` is itself no-throw but we + // wrap it defensively to keep stack unwinding safe in all circumstances. + try { + torch::cuda::synchronize(device_info.id); + runtime_cfg.save_runtime_cache_nothrow(); + } catch (...) { + } trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -453,12 +445,8 @@ std::string TRTEngine::to_str() const { ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; - ss << " Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; - ss << " Dynamic Shapes Kernel Strategy: " << dynamic_shapes_kernel_strategy - << " (0=lazy, 1=eager, 2=none)" << std::endl; - ss << " CUDA Graph Strategy: " << cuda_graph_strategy - << " (0=disabled, 1=whole_graph_capture)" << std::endl; // clang-format on + runtime_cfg.write_to_str(ss); return ss.str(); } @@ -502,10 +490,14 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]), - std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]), + std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]) +#ifdef TRT_MAJOR_RTX + , std::tuple("runtime_cache_path", serialized_info[RUNTIME_CACHE_PATH_IDX]), std::tuple("dynamic_shapes_kernel_strategy", serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), - std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX])); + std::tuple("cuda_graph_strategy", serialized_info[CUDA_GRAPH_STRATEGY_IDX]) +#endif + ); } std::vector TRTEngine::serialize() { @@ -530,9 +522,12 @@ std::vector TRTEngine::serialize() { serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; - serialized_info[RUNTIME_CACHE_PATH_IDX] = this->runtime_cache_path; - serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string(this->dynamic_shapes_kernel_strategy); - serialized_info[CUDA_GRAPH_STRATEGY_IDX] = std::to_string(this->cuda_graph_strategy); +#ifdef TRT_MAJOR_RTX + serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; + serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = + std::to_string(static_cast(runtime_cfg.dynamic_shapes_kernel_strategy)); + serialized_info[CUDA_GRAPH_STRATEGY_IDX] = std::to_string(static_cast(runtime_cfg.cuda_graph_strategy)); +#endif return serialized_info; } @@ -553,183 +548,29 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt } bool TRTEngine::is_monolithic_capturable(cudaStream_t stream) const { -#if defined(TRT_MAJOR_RTX) && defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION) - // "lazy" strategy (0) swaps specialized kernels in mid-run, which would invalidate a - // captured graph. Any other strategy (eager/none) combined with a capturable stream is - // safe for outer monolithic capture. - return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != 0; -#else - (void)stream; - return true; -#endif + return runtime_cfg.is_monolithic_capturable(exec_ctx.get(), stream); } void TRTEngine::disable_rtx_native_cudagraphs() { -#ifdef TRT_MAJOR_RTX - if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == 0) { - return; + bool was_disabled = runtime_cfg.rtx_native_cudagraphs_disabled; + runtime_cfg.disable_rtx_native_cudagraphs(name); + if (!was_disabled && runtime_cfg.rtx_native_cudagraphs_disabled) { + // The CUDA graph strategy on the IRuntimeConfig has been flipped; rebuild exec_ctx + // so the new strategy takes effect for subsequent enqueueV3 calls. + recreate_execution_context(); } - LOG_WARNING( - "Outer CUDA stream capture detected; disabling TRT-RTX native CUDA graph strategy on engine " - << name << " for the remainder of its lifetime."); - cuda_graph_strategy = 0; - apply_cuda_graph_strategy(); - recreate_execution_context(); - rtx_native_cudagraphs_disabled = true; -#endif } void TRTEngine::recreate_execution_context() { -#ifdef TRT_MAJOR_RTX - if (!runtime_config) { - runtime_config = make_trt(cuda_engine->createRuntimeConfig()); - TORCHTRT_CHECK(runtime_config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); - apply_runtime_cache(); - apply_dynamic_shapes_kernel_strategy(); - apply_cuda_graph_strategy(); - } - runtime_config->setExecutionContextAllocationStrategy( + runtime_cfg.ensure_initialized(cuda_engine.get()); + runtime_cfg.set_execution_context_allocation_strategy( resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED : nvinfer1::ExecutionContextAllocationStrategy::kSTATIC); - exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_config.get())); -#else - if (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) { - exec_ctx = - make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - exec_ctx = make_trt(cuda_engine->createExecutionContext()); - } -#endif + exec_ctx = make_trt(cuda_engine->createExecutionContext(runtime_cfg.config.get())); TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Unable to (re)create TensorRT execution context"); } -#ifdef TRT_MAJOR_RTX -void TRTEngine::apply_runtime_cache() { - if (runtime_cache_path.empty()) { - LOG_DEBUG("Runtime cache disabled (no path configured)."); - return; - } - runtime_cache = make_trt(runtime_config->createRuntimeCache()); - if (runtime_cache.get() == nullptr) { - LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); - return; - } - load_runtime_cache(); - bool ok = runtime_config->setRuntimeCache(*runtime_cache); - if (!ok) { - LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); - runtime_cache.reset(); - return; - } - LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); -} - -void TRTEngine::apply_dynamic_shapes_kernel_strategy() { - runtime_config->setDynamicShapesKernelSpecializationStrategy( - static_cast(dynamic_shapes_kernel_strategy)); - LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << dynamic_shapes_kernel_strategy); -} - -void TRTEngine::apply_cuda_graph_strategy() { - bool ok = runtime_config->setCudaGraphStrategy( - cuda_graph_strategy == 1 ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE - : nvinfer1::CudaGraphStrategy::kDISABLED); - if (!ok) { - LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); - } -} - -void TRTEngine::load_runtime_cache() { - if (runtime_cache == nullptr || runtime_cache_path.empty()) { - return; - } - if (!std::filesystem::exists(runtime_cache_path)) { - LOG_DEBUG("No existing runtime cache at " << runtime_cache_path); - return; - } -#ifndef _WIN32 - int fd = ::open(runtime_cache_path.c_str(), O_RDONLY); - if (fd < 0) { - LOG_WARNING("Failed to open runtime cache for reading: " << runtime_cache_path); - return; - } - if (::flock(fd, LOCK_SH) != 0) { - LOG_WARNING("Failed to acquire shared lock on runtime cache; skipping load."); - ::close(fd); - return; - } -#endif - try { - std::ifstream f(runtime_cache_path, std::ios::binary); - std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); - if (!buf.empty()) { - bool ok = runtime_cache->deserialize(buf.data(), buf.size()); - if (ok) { - LOG_INFO("Loaded runtime cache from " << runtime_cache_path << " (" << buf.size() << " bytes)"); - } else { - LOG_WARNING("runtime_cache->deserialize returned false for " << runtime_cache_path); - } - } - } catch (const std::exception& e) { - LOG_WARNING("Failed to load runtime cache: " << e.what()); - } -#ifndef _WIN32 - ::flock(fd, LOCK_UN); - ::close(fd); -#endif -} - -void TRTEngine::save_runtime_cache() { - if (runtime_cache == nullptr || runtime_cache_path.empty()) { - return; - } - auto host_mem = make_trt(runtime_cache->serialize()); - if (host_mem.get() == nullptr || host_mem->size() == 0) { - return; - } - try { - std::filesystem::path path(runtime_cache_path); - if (path.has_parent_path()) { - std::filesystem::create_directories(path.parent_path()); - } - std::filesystem::path tmp_path = path; - tmp_path += ".tmp"; - -#ifndef _WIN32 - int fd = ::open(tmp_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); - if (fd < 0) { - LOG_WARNING("Failed to open runtime cache tmp file for writing: " << tmp_path.string()); - return; - } - if (::flock(fd, LOCK_EX) != 0) { - LOG_WARNING("Failed to acquire exclusive lock on runtime cache tmp file; skipping save."); - ::close(fd); - return; - } - ssize_t written = ::write(fd, host_mem->data(), host_mem->size()); - ::flock(fd, LOCK_UN); - ::close(fd); - if (written != static_cast(host_mem->size())) { - LOG_WARNING("Short write when saving runtime cache to " << tmp_path.string()); - return; - } -#else - // Windows: best-effort write without a cross-process lock. Follow-up: LockFileEx. - { - std::ofstream out(tmp_path, std::ios::binary); - out.write(reinterpret_cast(host_mem->data()), host_mem->size()); - } - LOG_WARNING("Runtime cache save on Windows runs without advisory locking; concurrent writers may race."); -#endif - std::filesystem::rename(tmp_path, path); - LOG_INFO("Saved runtime cache to " << runtime_cache_path << " (" << host_mem->size() << " bytes)"); - } catch (const std::exception& e) { - LOG_WARNING("Failed to save runtime cache: " << e.what()); - } -} -#endif // TRT_MAJOR_RTX - } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 7e6d357ec7..5daa53081f 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -13,12 +13,14 @@ #include "torch/custom_class.h" #include "core/runtime/TRTEngineProfiler.h" +#include "core/runtime/TRTRuntimeConfig.h" #include "core/util/prelude.h" namespace torch_tensorrt { namespace core { namespace runtime { +#ifdef TRT_MAJOR_RTX using FlattenedState = std::tuple< std::tuple, // ABI_VERSION std::tuple, // name @@ -34,6 +36,20 @@ using FlattenedState = std::tuple< std::tuple, // Runtime Cache Path (TRT-RTX) std::tuple, // Dynamic Shapes Kernel Specialization Strategy (TRT-RTX) std::tuple>; // CUDA Graph Strategy (TRT-RTX) +#else +using FlattenedState = std::tuple< + std::tuple, // ABI_VERSION + std::tuple, // name + std::tuple, // device + std::tuple, // engine + std::tuple, // input binding names + std::tuple, // output binding names + std::tuple, // HW compatibility + std::tuple, // requires_output_allocator + std::tuple, // serialized metadata + std::tuple, // Platform + std::tuple>; // Resource Allocation Strategy +#endif struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -227,48 +243,23 @@ struct TRTEngine : torch::CustomClassHolder { void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy); ResourceAllocationStrategy get_resource_allocation_strategy(); - // TRT-RTX runtime config state. The plain fields are stored unconditionally so that - // serialization remains ABI-stable on non-RTX builds; the IRuntimeConfig / IRuntimeCache - // handles themselves only exist on RTX. - std::string runtime_cache_path = ""; - int dynamic_shapes_kernel_strategy = 0; // 0=lazy, 1=eager, 2=none - int cuda_graph_strategy = 0; // 0=disabled, 1=whole_graph_capture - // One-shot flag: set the first time execute_engine detects an outer stream capture around - // this engine, at which point its TRT-RTX native CUDA graph capture is turned off so the - // two do not fight. The flag stays set for the remainder of the engine's lifetime. - bool rtx_native_cudagraphs_disabled = false; - -#ifdef TRT_MAJOR_RTX - std::shared_ptr runtime_config; - std::shared_ptr runtime_cache; -#endif + // All TensorRT-RTX-specific IRuntimeConfig state lives here. On non-RTX builds this + // still owns a shared IRuntimeConfig (so the execution-context allocation strategy is + // applied via the uniform code path) but the RTX-only setters become no-ops. + TRTRuntimeConfig runtime_cfg; // Monolithic-capturability check used when this engine is wrapped by an outer whole-graph // capture (e.g. CudaGraphsTorchTensorRTModule). Non-RTX builds always return true. bool is_monolithic_capturable(cudaStream_t stream) const; - // Disable TRT-RTX native CUDA graph capture on this engine (one-shot, invoked when an - // outer stream capture is detected around execute_engine). No-op on non-RTX. + // Disable TensorRT-RTX native CUDA graph capture on this engine (one-shot, invoked when + // an outer stream capture is detected around execute_engine). No-op on non-RTX. void disable_rtx_native_cudagraphs(); private: - // Single entry point that (re)creates exec_ctx. On RTX builds this also creates / reuses - // the IRuntimeConfig and applies all runtime config settings. + // Single entry point that (re)creates exec_ctx. Also creates (once) the IRuntimeConfig + // owned by runtime_cfg and applies all runtime config settings. void recreate_execution_context(); - -#ifdef TRT_MAJOR_RTX - // Per-feature appliers invoked the first time recreate_execution_context() runs. Bodies - // are provided in follow-up commits that introduce each feature; keeping the declarations - // here lets the scaffolding land without behavior changes. - void apply_runtime_cache(); - void apply_dynamic_shapes_kernel_strategy(); - void apply_cuda_graph_strategy(); - - // Runtime cache persistence (RTX-only). Load is invoked from apply_runtime_cache(); save - // is invoked from the destructor before exec_ctx / runtime_config tear down. - void load_runtime_cache(); - void save_runtime_cache(); -#endif }; } // namespace runtime diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp new file mode 100644 index 0000000000..443ee75ea5 --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -0,0 +1,222 @@ +#include "core/runtime/TRTRuntimeConfig.h" + +#include +#include +#include +#include + +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +std::string to_string(DynamicShapesKernelStrategy s) { + switch (s) { + case DynamicShapesKernelStrategy::kLazy: + return "lazy"; + case DynamicShapesKernelStrategy::kEager: + return "eager"; + case DynamicShapesKernelStrategy::kNone: + return "none"; + } + return "unknown"; +} + +std::string to_string(CudaGraphStrategyOption s) { + switch (s) { + case CudaGraphStrategyOption::kDisabled: + return "disabled"; + case CudaGraphStrategyOption::kWholeGraphCapture: + return "whole_graph_capture"; + } + return "unknown"; +} + +DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy(int v) { + TORCHTRT_CHECK( + v >= 0 && v <= 2, + "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); + return static_cast(v); +} + +CudaGraphStrategyOption to_cuda_graph_strategy_option(int v) { + TORCHTRT_CHECK( + v >= 0 && v <= 1, + "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); + return static_cast(v); +} + +void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { + if (config) { + return; + } + TORCHTRT_CHECK(cuda_engine != nullptr, "Cannot initialize TRTRuntimeConfig without a live ICudaEngine"); + config = make_trt(cuda_engine->createRuntimeConfig()); + TORCHTRT_CHECK(config.get() != nullptr, "Unable to create TensorRT IRuntimeConfig"); + +#ifdef TRT_MAJOR_RTX + // Runtime cache -- TRT-RTX only. + if (!runtime_cache_path.empty()) { + runtime_cache = make_trt(config->createRuntimeCache()); + if (runtime_cache.get() == nullptr) { + LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); + } else { + load_runtime_cache_nothrow(); + bool ok = config->setRuntimeCache(*runtime_cache); + if (!ok) { + LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); + runtime_cache.reset(); + } else { + LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); + } + } + } else { + LOG_DEBUG("Runtime cache disabled (no path configured)."); + } + + // Dynamic shapes kernel specialization strategy -- TRT-RTX only. + config->setDynamicShapesKernelSpecializationStrategy( + static_cast(dynamic_shapes_kernel_strategy)); + LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); + + // CUDA graph strategy -- TRT-RTX only. + bool ok = config->setCudaGraphStrategy( + cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture + ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE + : nvinfer1::CudaGraphStrategy::kDISABLED); + if (!ok) { + LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); + } +#endif +} + +void TRTRuntimeConfig::set_execution_context_allocation_strategy( + nvinfer1::ExecutionContextAllocationStrategy strategy) { + TORCHTRT_ASSERT(config, "TRTRuntimeConfig::config must be initialized before setting allocation strategy"); + config->setExecutionContextAllocationStrategy(strategy); +} + +bool TRTRuntimeConfig::uses_internal_capture(bool cudagraphs_enabled) const { +#ifdef TRT_MAJOR_RTX + // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled + // strategy is set, or when subgraph cudagraphs are enabled globally. In both cases the + // caller should skip its manual at::cuda::CUDAGraph wrapper because TRT-RTX's internal + // capture would collide with it. + return cuda_graph_strategy != CudaGraphStrategyOption::kDisabled || cudagraphs_enabled; +#else + (void)cudagraphs_enabled; + return false; +#endif +} + +void TRTRuntimeConfig::disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept { +#ifdef TRT_MAJOR_RTX + if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == CudaGraphStrategyOption::kDisabled) { + return; + } + LOG_WARNING( + "Outer CUDA stream capture detected; disabling TensorRT-RTX native CUDA graph strategy on engine " + << engine_name << " for the remainder of its lifetime."); + // Persist any kernels the engine-internal capture has compiled so far; the outer + // capture will run without them otherwise, and we want future reloads to reuse them. + save_runtime_cache_nothrow(); + cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + if (config) { + bool ok = config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED); + if (!ok) { + LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); + } + } + rtx_native_cudagraphs_disabled = true; +#else + (void)engine_name; +#endif +} + +bool TRTRuntimeConfig::is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const { +#if defined(TRT_MAJOR_RTX) && defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION) + TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); + // "lazy" kernel specialization swaps specialized kernels in mid-run, which invalidates + // captured graphs. Other strategies (eager/none) are safe when the context reports the + // stream capturable. + return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != DynamicShapesKernelStrategy::kLazy; +#else + // isStreamCapturable is declared inside `#if ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION` + // in the TensorRT-RTX header; conservatively assume the engine is capturable when that + // feature flag is not enabled at compile time. + (void)exec_ctx; + (void)stream; + return true; +#endif +} + +void TRTRuntimeConfig::load_runtime_cache_nothrow() noexcept { +#ifdef TRT_MAJOR_RTX + TORCHTRT_ASSERT(runtime_cache, "load_runtime_cache_nothrow requires runtime_cache to be initialized"); + if (runtime_cache_path.empty()) { + return; + } + try { + if (!std::filesystem::exists(runtime_cache_path)) { + LOG_DEBUG("No existing runtime cache at " << runtime_cache_path); + return; + } + std::ifstream f(runtime_cache_path, std::ios::binary); + std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + if (buf.empty()) { + return; + } + bool ok = runtime_cache->deserialize(buf.data(), buf.size()); + if (ok) { + LOG_INFO("Loaded runtime cache from " << runtime_cache_path << " (" << buf.size() << " bytes)"); + } else { + LOG_WARNING("runtime_cache->deserialize returned false for " << runtime_cache_path); + } + } catch (const std::exception& e) { + LOG_WARNING("Failed to load runtime cache: " << e.what()); + } catch (...) { + LOG_WARNING("Failed to load runtime cache (unknown exception)."); + } +#endif +} + +void TRTRuntimeConfig::save_runtime_cache_nothrow() noexcept { +#ifdef TRT_MAJOR_RTX + if (!runtime_cache || runtime_cache_path.empty()) { + return; + } + try { + auto host_mem = make_trt(runtime_cache->serialize()); + if (!host_mem || host_mem->size() == 0) { + return; + } + std::filesystem::path path(runtime_cache_path); + if (path.has_parent_path()) { + std::filesystem::create_directories(path.parent_path()); + } + std::filesystem::path tmp_path = path; + tmp_path += ".tmp"; + { + std::ofstream out(tmp_path, std::ios::binary); + out.write(reinterpret_cast(host_mem->data()), host_mem->size()); + } + std::filesystem::rename(tmp_path, path); + LOG_INFO("Saved runtime cache to " << runtime_cache_path << " (" << host_mem->size() << " bytes)"); + } catch (const std::exception& e) { + LOG_WARNING("Failed to save runtime cache: " << e.what()); + } catch (...) { + LOG_WARNING("Failed to save runtime cache (unknown exception)."); + } +#endif +} + +void TRTRuntimeConfig::write_to_str(std::ostream& os) const { + os << " Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; + os << " Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; + os << " CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; +} + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h new file mode 100644 index 0000000000..9f22045be4 --- /dev/null +++ b/core/runtime/TRTRuntimeConfig.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "NvInfer.h" + +namespace torch_tensorrt { +namespace core { +namespace runtime { + +// TensorRT-RTX-only configuration for how shape-specialized kernels are compiled. +enum class DynamicShapesKernelStrategy : int32_t { + kLazy = 0, + kEager = 1, + kNone = 2, +}; + +// TensorRT-RTX-only configuration for how CUDA graph capture/replay is handled. +enum class CudaGraphStrategyOption : int32_t { + kDisabled = 0, + kWholeGraphCapture = 1, +}; + +std::string to_string(DynamicShapesKernelStrategy s); +std::string to_string(CudaGraphStrategyOption s); +DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy(int v); +CudaGraphStrategyOption to_cuda_graph_strategy_option(int v); + +// Encapsulates the nvinfer1::IRuntimeConfig owned by a TRTEngine along with the +// TensorRT-RTX-specific state (runtime cache, dynamic shapes kernel strategy, native +// CUDA graph strategy). All `#ifdef TRT_MAJOR_RTX` guards live in this file and its +// implementation so callers can treat this struct uniformly between RTX and standard +// TensorRT builds. +struct TRTRuntimeConfig { + // Settings - typically populated from engine deserialization before `ensure_initialized`. + std::string runtime_cache_path = ""; + DynamicShapesKernelStrategy dynamic_shapes_kernel_strategy = DynamicShapesKernelStrategy::kLazy; + CudaGraphStrategyOption cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; + + // One-shot: set to true once an outer stream capture has been detected and the + // engine-internal CUDA graph strategy has been disabled for the remainder of the + // owning engine's lifetime. + bool rtx_native_cudagraphs_disabled = false; + + // Live resources. The IRuntimeConfig is lazy-constructed on first `ensure_initialized`. + std::shared_ptr config; +#ifdef TRT_MAJOR_RTX + std::shared_ptr runtime_cache; +#endif + + // Construct the IRuntimeConfig once and apply all TRT-RTX-specific settings. Safe to + // call multiple times; only the first call initializes and applies the RTX-only + // setters. On subsequent calls this is a no-op. + void ensure_initialized(nvinfer1::ICudaEngine* cuda_engine); + + // Apply (or re-apply) the execution context allocation strategy on the IRuntimeConfig. + // Available on both standard TensorRT and TensorRT-RTX via IRuntimeConfig. + void set_execution_context_allocation_strategy(nvinfer1::ExecutionContextAllocationStrategy strategy); + + // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the + // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always + // false on non-RTX builds. + bool uses_internal_capture(bool cudagraphs_enabled) const; + + // One-shot: disable engine-internal CUDA graph capture. Invoked when an outer stream + // capture is detected around execute_engine, so the outer capture can contain the + // kernel launches directly. Saves the runtime cache before recreating the context so + // compiled kernels from the present run are preserved for future reloads. + void disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept; + + // Whether the execution context is safe to include in an outer monolithic capture. + // Non-RTX builds always return true. + bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; + + // Load the runtime cache from disk using std::filesystem. No-throw: errors log and + // return. Invoked internally from `ensure_initialized` when a cache path is set. + void load_runtime_cache_nothrow() noexcept; + + // Save the runtime cache to disk using std::filesystem (tmp + rename). No-throw: + // errors log and return, so it is safe to call from a destructor. + void save_runtime_cache_nothrow() noexcept; + + // Append a human-readable summary to a TRTEngine::to_str stream. + void write_to_str(std::ostream& os) const; +}; + +} // namespace runtime +} // namespace core +} // namespace torch_tensorrt diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 0d0e6c1035..2a71b7ebd3 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -218,14 +218,14 @@ std::vector execute_engine(std::vector inputs, c10::intr auto run_standard_execution = [&]() { bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); // effective_cudagraphs controls the manual at::cuda::CUDAGraph path below. On TRT-RTX - // builds we bypass that path whenever the engine has a cuda_graph_strategy set or the - // outer runtime has requested subgraph cudagraphs - the TRT-RTX runtime handles capture - // and replay internally inside enqueueV3. If an outer stream capture is already in - // progress (e.g. the caller wraps this module in CudaGraphsTorchTensorRTModule for - // whole-graph capture), RTX-native capture would conflict, so we disable it one-shot. + // builds the engine-internal runtime owns capture/replay inside enqueueV3 whenever the + // engine has a cuda_graph_strategy set or subgraph cudagraphs are enabled; the struct + // reports that via `uses_internal_capture` so the caller skips its manual wrapper. If + // an outer stream capture is already in progress (e.g. the caller wraps this module in + // CudaGraphsTorchTensorRTModule for whole-graph capture), engine-internal capture would + // collide, so we disable it one-shot here. bool effective_cudagraphs = cudagraphs_enabled; -#ifdef TRT_MAJOR_RTX - if (compiled_engine->cuda_graph_strategy != 0 || cudagraphs_enabled) { + if (compiled_engine->runtime_cfg.uses_internal_capture(cudagraphs_enabled)) { effective_cudagraphs = false; cudaStreamCaptureStatus capture_status; cudaStreamIsCapturing(compiled_engine->engine_stream.stream(), &capture_status); @@ -233,7 +233,6 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->disable_rtx_native_cudagraphs(); } } -#endif bool shape_changed = _validate_shapes(inputs, compiled_engine); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 20d7580b4c..ad49890307 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -150,9 +150,11 @@ TORCH_LIBRARY(tensorrt, m) { m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; }); +#ifdef TRT_MAJOR_RTX m.def("RUNTIME_CACHE_PATH_IDX", []() -> int64_t { return RUNTIME_CACHE_PATH_IDX; }); m.def("DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX", []() -> int64_t { return DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX; }); m.def("CUDA_GRAPH_STRATEGY_IDX", []() -> int64_t { return CUDA_GRAPH_STRATEGY_IDX; }); +#endif m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); return it->second; diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 7e7a374460..70c8aa8119 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -39,9 +39,11 @@ typedef enum { TARGET_PLATFORM_IDX, REQUIRES_OUTPUT_ALLOCATOR_IDX, RESOURCE_ALLOCATION_STRATEGY_IDX, +#ifdef TRT_MAJOR_RTX RUNTIME_CACHE_PATH_IDX, DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX, CUDA_GRAPH_STRATEGY_IDX, +#endif SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9e1f8f9c85..a50d55469f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -113,7 +113,6 @@ def cross_compile_for_windows( enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, @@ -358,7 +357,6 @@ def cross_compile_for_windows( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, @@ -491,7 +489,6 @@ def compile( cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, - dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, @@ -791,7 +788,6 @@ def compile( "enable_resource_partitioning": enable_resource_partitioning, "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, - "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, @@ -1200,6 +1196,7 @@ def convert_exported_program_to_serialized_trt_engine( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, @@ -1451,6 +1448,7 @@ def convert_exported_program_to_serialized_trt_engine( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "cuda_graph_strategy": cuda_graph_strategy, "decompose_attention": decompose_attention, "attn_bias_is_causal": attn_bias_is_causal, } diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 5db1042183..a929b5ea1d 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -74,7 +74,6 @@ CUDA_GRAPH_STRATEGY = "disabled" DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True -DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index e8edb4163d..0713b24f6c 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -57,12 +57,15 @@ RESOURCE_ALLOCATION_STRATEGY_IDX = ( torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX() ) # 10 - RUNTIME_CACHE_PATH_IDX = torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX() # 11 - DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( - torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX() - ) # 12 - CUDA_GRAPH_STRATEGY_IDX = torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX() # 13 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 14 + if ENABLED_FEATURES.tensorrt_rtx: + RUNTIME_CACHE_PATH_IDX = torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX() # 11 + DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX = ( + torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX() + ) # 12 + CUDA_GRAPH_STRATEGY_IDX = torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX() # 13 + SERIALIZATION_LEN = ( + torch.ops.tensorrt.SERIALIZATION_LEN() + ) # 14 (RTX) / 11 (standard) _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP: Dict[str, int] = { "lazy": 0, @@ -164,11 +167,12 @@ def __init__( self.engine = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources - self.runtime_cache_path = settings.runtime_cache_path - self.dynamic_shapes_kernel_specialization_strategy = ( - settings.dynamic_shapes_kernel_specialization_strategy - ) - self.cuda_graph_strategy = settings.cuda_graph_strategy + if ENABLED_FEATURES.tensorrt_rtx: + self.runtime_cache_path = settings.runtime_cache_path + self.dynamic_shapes_kernel_specialization_strategy = ( + settings.dynamic_shapes_kernel_specialization_strategy + ) + self.cuda_graph_strategy = settings.cuda_graph_strategy self.symbolic_shape_expressions = symbolic_shape_expressions if ( @@ -227,29 +231,30 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) - engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" - if ( - self.dynamic_shapes_kernel_specialization_strategy - not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP - ): - raise ValueError( - f"Invalid dynamic_shapes_kernel_specialization_strategy " - f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " - f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" - ) - engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( - _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ + if ENABLED_FEATURES.tensorrt_rtx: + engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" + if ( self.dynamic_shapes_kernel_specialization_strategy - ] - ) - if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: - raise ValueError( - f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " - f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + f"Invalid dynamic_shapes_kernel_specialization_strategy " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " + f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" + ) + engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( + _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ + self.dynamic_shapes_kernel_specialization_strategy + ] + ) + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " + f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" + ) + engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( + _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] ) - engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( - _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] - ) return engine_info diff --git a/tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py similarity index 100% rename from tests/py/dynamo/runtime/test_000_cuda_graph_strategy.py rename to tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py From 1fa8c82ee84601b997310ee28454ef30000ff6ec Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 22 Apr 2026 12:43:08 -0700 Subject: [PATCH 6/9] test: consolidate C++ runtime tests, add model-level coverage Address PR review comments that asked the new C++ runtime tests be folded into existing feature-level files rather than shipped as parallel `*_cpp.py` files. What - Merge `test_000_runtime_cache_cpp.py` into the existing `test_000_runtime_cache.py`. The file already covered the Python runtime path; two new classes (`TestRuntimeCacheCppPersistence`, `TestCppSerializationIndices`) cover the C++ runtime path via `use_python_runtime=False`, and the serialization-index assertions. Skip on non-RTX builds. - Fold the C++ runtime cases for dynamic shapes kernel specialization strategy into `test_001_dynamic_shapes_kernel_ strategy.py` (introduced upstream in PR #4184). Two new classes (`TestDynamicShapesKernelStrategyCpp`, `TestDynamicShapesKernel StrategyCppInvalidValue`) exercise lazy/eager/none end-to-end and reject invalid strategy names. The pre-existing Python runtime tests remain untouched. - Rename `test_000_cuda_graph_strategy.py` to `test_001_cuda_graph_ strategy.py` to match the `test_001_*` convention used for L1 RTX-only features. When upstream lands the Python runtime counterpart (PR #4187), both sets fold into the same file. - Add model-level tests: `test_runtime_cache_models.py` gains a `TestRuntimeCacheCppModels` class exercising ResNet18 through the C++ runtime with warm-cache roundtrip. `test_dynamic_shapes_ kernel_strategy_models.py` gains `TestDynamicShapesKernelStrategy CppModels` covering lazy/eager/none on ResNet18 via the C++ runtime. Verified - 35 passed / 3 skipped in the runtime/ tests (merged file plus test_001 strategy files). - No regression in test_002_cudagraphs_cpp.py (8 passed) or test_005_dynamic_allocation.py (1 passed). Addresses PR #4202 review comments asking for test file merges and the addition of model-level runtime_cache_models.py / dynamic_shapes_kernel_strategy_models.py coverage. --- ...t_dynamic_shapes_kernel_strategy_models.py | 64 ++++++++ .../models/test_runtime_cache_models.py | 91 +++++++++++ ...test_000_dynamic_shapes_kernel_strategy.py | 133 ---------------- .../dynamo/runtime/test_000_runtime_cache.py | 136 ++++++++++++++++ .../runtime/test_000_runtime_cache_cpp.py | 146 ------------------ ...test_001_dynamic_shapes_kernel_strategy.py | 99 ++++++++++++ 6 files changed, 390 insertions(+), 279 deletions(-) delete mode 100644 tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py delete mode 100644 tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py diff --git a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py index badfff81ea..fd3b9ee93d 100644 --- a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py +++ b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py @@ -129,5 +129,69 @@ def test_dynamic_batch_none(self): self._test_dynamic_batch_with_strategy("none") +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", +) +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +class TestDynamicShapesKernelStrategyCppModels(TestCase): + """End-to-end model tests with each strategy exercised through the C++ runtime.""" + + def tearDown(self): + torch._dynamo.reset() + + def _compile_and_verify_cpp(self, model, strategy): + input_tensor = torch.randn(4, 3, 224, 224).cuda() + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"C++ runtime cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} " + f"with strategy={strategy}", + ) + + def test_resnet18_lazy_strategy_cpp(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + self._compile_and_verify_cpp(model, "lazy") + + def test_resnet18_eager_strategy_cpp(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + self._compile_and_verify_cpp(model, "eager") + + def test_resnet18_none_strategy_cpp(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + self._compile_and_verify_cpp(model, "none") + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/models/test_runtime_cache_models.py b/tests/py/dynamo/models/test_runtime_cache_models.py index aecb2fbaa3..7ffae1f5ad 100644 --- a/tests/py/dynamo/models/test_runtime_cache_models.py +++ b/tests/py/dynamo/models/test_runtime_cache_models.py @@ -325,5 +325,96 @@ def forward(self, x): self.assertTrue(True, "Timing test completed (informational)") +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +class TestRuntimeCacheCppModels(TestCase): + """End-to-end model tests with runtime cache exercised through the C++ runtime.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + torch._dynamo.reset() + + def test_resnet18_with_runtime_cache_cpp(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + runtime_cache_path=self.cache_path, + ) + + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"ResNet18 C++ runtime cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", + ) + + # Verify the runtime cache is persisted on engine destruction. + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(self.cache_path), + "Runtime cache should be saved after ResNet18 C++-runtime inference", + ) + + def test_resnet18_cache_reuse_cpp(self): + """Warm-cache second compile should match eager output.""" + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + compile_kwargs = { + "ir": "dynamo", + "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + "enabled_precisions": {torch.float32}, + "use_python_runtime": False, + "min_block_size": 1, + "runtime_cache_path": self.cache_path, + } + + compiled1 = torchtrt.compile(model, **compile_kwargs) + out1 = compiled1(input_tensor) + self.assertTrue( + cosine_similarity(ref_output, out1) > COSINE_THRESHOLD, + "First ResNet18 C++-runtime output should match eager", + ) + del compiled1 + gc.collect() + self.assertTrue(os.path.isfile(self.cache_path)) + + compiled2 = torchtrt.compile(model, **compile_kwargs) + out2 = compiled2(input_tensor) + self.assertTrue( + cosine_similarity(ref_output, out2) > COSINE_THRESHOLD, + "Second ResNet18 C++-runtime output (warm cache) should match eager", + ) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py deleted file mode 100644 index 6761ca1c65..0000000000 --- a/tests/py/dynamo/runtime/test_000_dynamic_shapes_kernel_strategy.py +++ /dev/null @@ -1,133 +0,0 @@ -import unittest - -import torch -import torch_tensorrt as torchtrt -from torch.testing._internal.common_utils import TestCase, run_tests -from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo._defaults import ( - DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, -) -from torch_tensorrt.dynamo._settings import CompilationSettings - - -class DynamicConvModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) - self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) - - def forward(self, x): - return torch.relu(self.conv2(torch.relu(self.conv1(x)))) - - -def _compile_cpp(strategy): - model = DynamicConvModel().eval().cuda() - inp = torchtrt.Input( - min_shape=(1, 3, 16, 16), - opt_shape=(2, 3, 16, 16), - max_shape=(4, 3, 16, 16), - dtype=torch.float32, - ) - compiled = torchtrt.compile( - model, - ir="dynamo", - inputs=[inp], - enabled_precisions={torch.float32}, - use_python_runtime=False, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, - ) - torch._dynamo.reset() - return compiled - - -class TestDynamicShapesKernelStrategySettings(TestCase): - """Setting-level validation that runs on every build (RTX and non-RTX).""" - - def test_default_value(self): - settings = CompilationSettings() - self.assertEqual( - settings.dynamic_shapes_kernel_specialization_strategy, - DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, - ) - - def test_settable_values(self): - for value in ("lazy", "eager", "none"): - settings = CompilationSettings( - dynamic_shapes_kernel_specialization_strategy=value - ) - self.assertEqual( - settings.dynamic_shapes_kernel_specialization_strategy, value - ) - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Dynamic shapes kernel strategy is a TensorRT-RTX feature", -) -class TestDynamicShapesKernelStrategyCpp(TestCase): - """End-to-end: compile + infer through the C++ runtime with each strategy.""" - - def test_lazy(self): - compiled = _compile_cpp("lazy") - x = torch.randn(2, 3, 16, 16, device="cuda") - y = compiled(x) - self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) - self.assertTrue(torch.isfinite(y).all().item()) - - def test_eager(self): - compiled = _compile_cpp("eager") - x = torch.randn(2, 3, 16, 16, device="cuda") - y = compiled(x) - self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) - self.assertTrue(torch.isfinite(y).all().item()) - - def test_none(self): - compiled = _compile_cpp("none") - x = torch.randn(2, 3, 16, 16, device="cuda") - y = compiled(x) - self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) - self.assertTrue(torch.isfinite(y).all().item()) - - def test_dynamic_shape_with_eager(self): - """Exercise shape changes under eager kernel specialization.""" - compiled = _compile_cpp("eager") - for batch in (1, 2, 3, 4): - x = torch.randn(batch, 3, 16, 16, device="cuda") - y = compiled(x) - self.assertEqual(tuple(y.shape), (batch, 8, 16, 16)) - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -class TestDynamicShapesKernelStrategyInvalidValue(TestCase): - """Invalid strategy names are rejected at engine-packing time.""" - - def test_invalid_strategy_raises(self): - model = DynamicConvModel().eval().cuda() - inp = torchtrt.Input( - min_shape=(1, 3, 16, 16), - opt_shape=(2, 3, 16, 16), - max_shape=(4, 3, 16, 16), - dtype=torch.float32, - ) - with self.assertRaises((ValueError, RuntimeError)): - torchtrt.compile( - model, - ir="dynamo", - inputs=[inp], - enabled_precisions={torch.float32}, - use_python_runtime=False, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", - ) - - -if __name__ == "__main__": - run_tests() diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index bad67db24c..fc7be8a979 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -11,6 +11,7 @@ from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity class SimpleModel(torch.nn.Module): @@ -283,5 +284,140 @@ def test_timing_cache_still_created(self): ) +class CppSimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv(x)) + + +def _fresh_cpp_model_and_inputs(seed=0): + """Create a deterministic CppSimpleModel + input tensor pair for C++-runtime tests.""" + torch.manual_seed(seed) + return CppSimpleModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] + + +def _compile_cpp(model, inputs, runtime_cache_path=None): + """Compile the given model through the C++ runtime path (use_python_runtime=False).""" + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": False, + "min_block_size": 1, + } + if runtime_cache_path is not None: + kwargs["runtime_cache_path"] = runtime_cache_path + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +class TestRuntimeCacheCppPersistence(TestCase): + """Exercise the C++-runtime code path: load on engine setup, save on destructor.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + + def test_cache_saved_on_del(self): + model, inputs = _fresh_cpp_model_and_inputs() + compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + _ = compiled(*[inp.clone() for inp in inputs]) + self.assertFalse( + os.path.isfile(self.cache_path), + "Cache should not exist before module cleanup", + ) + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(self.cache_path), + "Cache file should be created after module cleanup", + ) + + def test_cache_file_nonempty(self): + model, inputs = _fresh_cpp_model_and_inputs() + compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + _ = compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + self.assertGreater( + os.path.getsize(self.cache_path), + 0, + "Cache file should have nonzero size", + ) + + def test_cache_roundtrip(self): + """Compile, infer, save. Then recompile same model+cache and verify correctness.""" + model, inputs = _fresh_cpp_model_and_inputs() + with torch.no_grad(): + ref_output = model(*inputs) + + compiled1 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + out1 = compiled1(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out1), + COSINE_THRESHOLD, + "First compiled output should match eager", + ) + del compiled1 + gc.collect() + self.assertTrue(os.path.isfile(self.cache_path)) + + compiled2 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) + out2 = compiled2(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out2), + COSINE_THRESHOLD, + "Second compiled output (warm cache) should still match eager", + ) + + def test_save_creates_directory(self): + nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") + model, inputs = _fresh_cpp_model_and_inputs() + compiled = _compile_cpp(model, inputs, runtime_cache_path=nested_path) + _ = compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(nested_path), + "Save should create intermediate directories", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "New serialization indices are registered only on TensorRT-RTX builds", +) +class TestCppSerializationIndices(TestCase): + """Verify the new RTX-only C++ serialization indices are registered by the runtime.""" + + def test_new_indices_registered(self): + self.assertEqual(int(torch.ops.tensorrt.ABI_VERSION()), 9) + self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), 14) + self.assertEqual(int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), 11) + self.assertEqual( + int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), 12 + ) + self.assertEqual(int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), 13) + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py b/tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py deleted file mode 100644 index a7a62ef131..0000000000 --- a/tests/py/dynamo/runtime/test_000_runtime_cache_cpp.py +++ /dev/null @@ -1,146 +0,0 @@ -import gc -import os -import shutil -import tempfile -import unittest - -import torch -import torch_tensorrt as torchtrt -from torch.testing._internal.common_utils import TestCase, run_tests -from torch_tensorrt._features import ENABLED_FEATURES -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity - - -class SimpleModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) - - def forward(self, x): - return torch.relu(self.conv(x)) - - -def _fresh_model_and_inputs(seed=0): - """Create a deterministic SimpleModel + input tensor pair.""" - torch.manual_seed(seed) - return SimpleModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] - - -def _compile_cpp(model, inputs, runtime_cache_path=None): - """Compile the given model through the C++ runtime path.""" - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "enabled_precisions": {torch.float32}, - "use_python_runtime": False, - "min_block_size": 1, - } - if runtime_cache_path is not None: - kwargs["runtime_cache_path"] = runtime_cache_path - compiled = torchtrt.compile(model, **kwargs) - torch._dynamo.reset() - return compiled - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Runtime cache is only available with TensorRT-RTX", -) -class TestRuntimeCacheCppPersistence(TestCase): - """Exercise C++-runtime runtime cache load/save against disk.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) - - def test_cache_saved_on_del(self): - model, inputs = _fresh_model_and_inputs() - compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertFalse( - os.path.isfile(self.cache_path), - "Cache should not exist before module cleanup", - ) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(self.cache_path), - "Cache file should be created after module cleanup", - ) - - def test_cache_file_nonempty(self): - model, inputs = _fresh_model_and_inputs() - compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertGreater( - os.path.getsize(self.cache_path), - 0, - "Cache file should have nonzero size", - ) - - def test_cache_roundtrip(self): - """Compile, infer, save. Then recompile same model+cache and verify correctness.""" - model, inputs = _fresh_model_and_inputs() - with torch.no_grad(): - ref_output = model(*inputs) - - compiled1 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - out1 = compiled1(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out1), - COSINE_THRESHOLD, - "First compiled output should match eager", - ) - del compiled1 - gc.collect() - self.assertTrue(os.path.isfile(self.cache_path)) - - compiled2 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - out2 = compiled2(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out2), - COSINE_THRESHOLD, - "Second compiled output (warm cache) should still match eager", - ) - - def test_save_creates_directory(self): - nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - model, inputs = _fresh_model_and_inputs() - compiled = _compile_cpp(model, inputs, runtime_cache_path=nested_path) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(nested_path), - "Save should create intermediate directories", - ) - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -class TestCppSerializationIndices(TestCase): - """Verify the new C++ serialization indices are registered by the runtime.""" - - def test_new_indices_registered(self): - self.assertEqual(int(torch.ops.tensorrt.ABI_VERSION()), 9) - self.assertEqual(int(torch.ops.tensorrt.SERIALIZATION_LEN()), 14) - self.assertEqual(int(torch.ops.tensorrt.RUNTIME_CACHE_PATH_IDX()), 11) - self.assertEqual( - int(torch.ops.tensorrt.DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX()), 12 - ) - self.assertEqual(int(torch.ops.tensorrt.CUDA_GRAPH_STRATEGY_IDX()), 13) - - -if __name__ == "__main__": - run_tests() diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 8c0a12cbdf..598efa71cc 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -142,5 +142,104 @@ def test_setting_ignored_on_non_rtx(self): self.assertEqual(output.shape, (2, 3)) +class DynamicConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) + self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv2(torch.relu(self.conv1(x)))) + + +def _compile_cpp(strategy): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Dynamic shapes kernel strategy is a TensorRT-RTX feature", +) +class TestDynamicShapesKernelStrategyCpp(TestCase): + """End-to-end: compile + infer through the C++ runtime with each strategy.""" + + def test_lazy(self): + compiled = _compile_cpp("lazy") + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_eager(self): + compiled = _compile_cpp("eager") + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_none(self): + compiled = _compile_cpp("none") + x = torch.randn(2, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) + self.assertTrue(torch.isfinite(y).all().item()) + + def test_dynamic_shape_with_eager(self): + """Exercise shape changes under eager kernel specialization.""" + compiled = _compile_cpp("eager") + for batch in (1, 2, 3, 4): + x = torch.randn(batch, 3, 16, 16, device="cuda") + y = compiled(x) + self.assertEqual(tuple(y.shape), (batch, 8, 16, 16)) + + +@unittest.skipIf( + not ENABLED_FEATURES.torch_tensorrt_runtime, + "C++ runtime is not available", +) +class TestDynamicShapesKernelStrategyCppInvalidValue(TestCase): + """Invalid strategy names are rejected at engine-packing time on the C++ runtime path.""" + + def test_invalid_strategy_raises(self): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + with self.assertRaises((ValueError, RuntimeError)): + torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy="not_a_real_strategy", + ) + + if __name__ == "__main__": run_tests() From a4989c760a308eb434c6deb04e09bacc87c18701 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 22 Apr 2026 14:14:14 -0700 Subject: [PATCH 7/9] refactor(runtime): second-round review refinements on TRTRuntimeConfig Follow-up to 54f9ccda9 / 1fa8c82ee addressing the second batch of PR #4202 review feedback. Pure refactor with no user-visible behavior change; all tests green on A100 (35 passed / 3 skipped + 9 regression passed). TRTEngine - Constructor signature simplified: three separate `runtime_cache_path` / `dynamic_shapes_kernel_strategy` / `cuda_graph_strategy` parameters collapsed into a single `TRTRuntimeConfig runtime_cfg` sink parameter. The forwarding ctor std::moves it into the primary ctor, which std::moves it into the member. - String sink parameters (mod_name, serialized_engine, serialized_ metadata) taken by value and moved into members / slugify. - Deserialization constructor routes through the new free function make_runtime_config_from_serialized, which internalizes the TRT_MAJOR_RTX-gated index reads so the constructor itself stays unguarded. - FlattenedState uses a single TRTRTX_FLATTENED_STATE_EXTRAS macro for the three RTX-only tuple entries instead of duplicating the first eleven entries across two branches. - Destructor restored to the pre-refactor structure: torch::cuda:: synchronize runs outside a try block and runtime_cfg.save_runtime_ cache (now noexcept by signature) is called directly. Exception safety is guaranteed by the member's type, not by a defensive try/catch. - __obj_flatten__ and serialize cast enum values via std::underlying_type_t<...> instead of int so serialization stays in lockstep with any future underlying-type change on the enums. TRTRuntimeConfig - Conversion helpers take std::underlying_type_t (the declared 32-bit integer type) instead of raw int. Callers at serialization boundaries explicitly std::stoi / static_cast into the right type. - [[nodiscard]] added to to_string, to_dynamic_shapes_kernel_strategy, to_cuda_graph_strategy_option, uses_internal_capture, is_monolithic_ capturable, to_str, and make_runtime_config_from_serialized. - to_string default cases now TORCHTRT_CHECK(false, ...) with the unexpected integer value; std::unreachable is C++23. - set_execution_context_allocation_strategy is now const. - Cache I/O split into two layers: - Free functions load_runtime_cache(path, cache) and save_runtime_cache(path, cache) perform the raw std::filesystem I/O and use TORCHTRT_CHECK on failure -- exception-propagating, easier to test in isolation. - Member TRTRuntimeConfig::save_runtime_cache() is a noexcept wrapper that calls the free function and swallows exceptions via try/catch -- safe from a destructor. The _nothrow suffix is dropped from the member name (the signature now carries that contract). - write_to_str(ostream&) replaced by two functions: a const-correct to_str() -> std::string, and a free operator<<(ostream&, const TRTRuntimeConfig&) that wraps it with "Runtime cfg { ... }" delimiters. TRTEngine::to_str streams the config via the free operator. Python - _settings.py: removed a duplicated dynamic_shapes_kernel_ specialization_strategy field and its duplicated docstring left over from the upstream rebase of PR #4184 into our changes. Covers review comments 3126538200, 3126541782, 3126547529, 3126549147, 3126682329, 3126683329, 3126693226, 3126715369, 3126725953, 3126736626, 3126738422, 3126745230, 3126747553, 3126749405, 3126764831, 3126772536, 3126786564, 3126803652, 3126816780, 3126818065, 3126818561, 3126819429, 3126823781, 3126840987, 3126846827. --- core/runtime/TRTEngine.cpp | 65 +++++------- core/runtime/TRTEngine.h | 47 ++++----- core/runtime/TRTRuntimeConfig.cpp | 142 ++++++++++++++++---------- core/runtime/TRTRuntimeConfig.h | 52 +++++++--- py/torch_tensorrt/dynamo/_settings.py | 4 - 5 files changed, 167 insertions(+), 143 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 66181a3c40..2d32e19c87 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -55,32 +55,28 @@ void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims } TRTEngine::TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, + std::string serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, - const std::string& runtime_cache_path, - int dynamic_shapes_kernel_strategy, - int cuda_graph_strategy) + TRTRuntimeConfig runtime_cfg) : TRTEngine( "deserialized_trt", - serialized_engine, + std::move(serialized_engine), cuda_device, _in_binding_names, _out_binding_names, target_platform, hardware_compatible, requires_output_allocator, - serialized_metadata, + std::move(serialized_metadata), resource_allocation_strategy, - runtime_cache_path, - dynamic_shapes_kernel_strategy, - cuda_graph_strategy) {} + std::move(runtime_cfg)) {} TRTEngine::TRTEngine(std::vector serialized_info) : TRTEngine( @@ -95,33 +91,22 @@ TRTEngine::TRTEngine(std::vector serialized_info) serialized_info[SERIALIZED_METADATA_IDX], (static_cast(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic - : ResourceAllocationStrategy::kStatic) -#ifdef TRT_MAJOR_RTX - , - serialized_info[RUNTIME_CACHE_PATH_IDX], - std::stoi(serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX]), - std::stoi(serialized_info[CUDA_GRAPH_STRATEGY_IDX]) -#endif - ) { -} + : ResourceAllocationStrategy::kStatic), + make_runtime_config_from_serialized(serialized_info)) {} TRTEngine::TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& _in_binding_names, const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, bool requires_output_allocator, - const std::string& serialized_metadata, + std::string serialized_metadata, const ResourceAllocationStrategy resource_allocation_strategy, - const std::string& runtime_cache_path, - int dynamic_shapes_kernel_strategy, - int cuda_graph_strategy) { - runtime_cfg.runtime_cache_path = runtime_cache_path; - runtime_cfg.dynamic_shapes_kernel_strategy = to_dynamic_shapes_kernel_strategy(dynamic_shapes_kernel_strategy); - runtime_cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(cuda_graph_strategy); + TRTRuntimeConfig runtime_cfg) { + this->runtime_cfg = std::move(runtime_cfg); TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), "This engine was not built to run on this platform (built for: " << target_platform << ", current platform: " @@ -132,7 +117,7 @@ TRTEngine::TRTEngine( auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible); TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); - this->serialized_metadata = serialized_metadata; + this->serialized_metadata = std::move(serialized_metadata); this->requires_output_allocator = requires_output_allocator; device_info = most_compatible_device.value(); multi_gpu_device_check(); @@ -140,7 +125,7 @@ TRTEngine::TRTEngine( rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger())); - name = slugify(mod_name); + name = slugify(std::move(mod_name)); cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); @@ -279,13 +264,10 @@ TRTEngine::TRTEngine( } TRTEngine::~TRTEngine() { - // Destructors must not throw; `save_runtime_cache_nothrow` is itself no-throw but we - // wrap it defensively to keep stack unwinding safe in all circumstances. - try { - torch::cuda::synchronize(device_info.id); - runtime_cfg.save_runtime_cache_nothrow(); - } catch (...) { - } + torch::cuda::synchronize(device_info.id); + // Marked noexcept by the type system, so safe to invoke from a destructor without + // explicit try/catch; any I/O error is logged internally. + runtime_cfg.save_runtime_cache(); trt_engine_profiler.reset(); exec_ctx.reset(); cuda_engine.reset(); @@ -445,8 +427,8 @@ std::string TRTEngine::to_str() const { ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; + ss << runtime_cfg; // clang-format on - runtime_cfg.write_to_str(ss); return ss.str(); } @@ -524,9 +506,10 @@ std::vector TRTEngine::serialize() { this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0"; #ifdef TRT_MAJOR_RTX serialized_info[RUNTIME_CACHE_PATH_IDX] = runtime_cfg.runtime_cache_path; - serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = - std::to_string(static_cast(runtime_cfg.dynamic_shapes_kernel_strategy)); - serialized_info[CUDA_GRAPH_STRATEGY_IDX] = std::to_string(static_cast(runtime_cfg.cuda_graph_strategy)); + serialized_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = std::to_string( + static_cast>(runtime_cfg.dynamic_shapes_kernel_strategy)); + serialized_info[CUDA_GRAPH_STRATEGY_IDX] = + std::to_string(static_cast>(runtime_cfg.cuda_graph_strategy)); #endif return serialized_info; diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 5daa53081f..d3ef259e32 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -21,22 +21,17 @@ namespace core { namespace runtime { #ifdef TRT_MAJOR_RTX -using FlattenedState = std::tuple< - std::tuple, // ABI_VERSION - std::tuple, // name - std::tuple, // device - std::tuple, // engine - std::tuple, // input binding names - std::tuple, // output binding names - std::tuple, // HW compatibility - std::tuple, // requires_output_allocator - std::tuple, // serialized metadata - std::tuple, // Platform - std::tuple, // Resource Allocation Strategy - std::tuple, // Runtime Cache Path (TRT-RTX) - std::tuple, // Dynamic Shapes Kernel Specialization Strategy (TRT-RTX) - std::tuple>; // CUDA Graph Strategy (TRT-RTX) +// Extra FlattenedState entries for TensorRT-RTX-only fields. Leading comma so this +// macro can be dropped directly into the std::tuple parameter pack after the final +// shared entry without duplicating the per-entry type in both branches. +#define TRTRTX_FLATTENED_STATE_EXTRAS \ + , std::tuple /* Runtime Cache Path */ \ + , std::tuple /* Dynamic Shapes Kernel Strategy */ \ + , std::tuple /* CUDA Graph Strategy */ #else +#define TRTRTX_FLATTENED_STATE_EXTRAS +#endif + using FlattenedState = std::tuple< std::tuple, // ABI_VERSION std::tuple, // name @@ -48,8 +43,8 @@ using FlattenedState = std::tuple< std::tuple, // requires_output_allocator std::tuple, // serialized metadata std::tuple, // Platform - std::tuple>; // Resource Allocation Strategy -#endif + std::tuple /* Resource Allocation Strategy */ + TRTRTX_FLATTENED_STATE_EXTRAS>; struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -144,37 +139,33 @@ struct TRTEngine : torch::CustomClassHolder { ~TRTEngine(); TRTEngine( - const std::string& serialized_engine, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, - const std::string& runtime_cache_path = "", - int dynamic_shapes_kernel_strategy = 0, - int cuda_graph_strategy = 0); + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); TRTEngine(std::vector serialized_info); TRTEngine( - const std::string& mod_name, - const std::string& serialized_engine, + std::string mod_name, + std::string serialized_engine, const RTDevice& cuda_device, const std::vector& in_binding_names, const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, bool requires_output_allocator = false, - const std::string& serialized_metadata = "", + std::string serialized_metadata = "", const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy = TRTEngine::ResourceAllocationStrategy::kStatic, - const std::string& runtime_cache_path = "", - int dynamic_shapes_kernel_strategy = 0, - int cuda_graph_strategy = 0); + TRTRuntimeConfig runtime_cfg = TRTRuntimeConfig{}); TRTEngine& operator=(const TRTEngine& other); std::string to_str() const; diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index 443ee75ea5..61c1307fac 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -2,9 +2,11 @@ #include #include +#include #include #include +#include "core/runtime/runtime.h" #include "core/util/prelude.h" namespace torch_tensorrt { @@ -20,7 +22,10 @@ std::string to_string(DynamicShapesKernelStrategy s) { case DynamicShapesKernelStrategy::kNone: return "none"; } - return "unknown"; + TORCHTRT_CHECK( + false, + "Unexpected DynamicShapesKernelStrategy value: " + << static_cast>(s)); } std::string to_string(CudaGraphStrategyOption s) { @@ -30,17 +35,19 @@ std::string to_string(CudaGraphStrategyOption s) { case CudaGraphStrategyOption::kWholeGraphCapture: return "whole_graph_capture"; } - return "unknown"; + TORCHTRT_CHECK( + false, + "Unexpected CudaGraphStrategyOption value: " << static_cast>(s)); } -DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy(int v) { +DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy(std::underlying_type_t v) { TORCHTRT_CHECK( v >= 0 && v <= 2, "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); return static_cast(v); } -CudaGraphStrategyOption to_cuda_graph_strategy_option(int v) { +CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { TORCHTRT_CHECK( v >= 0 && v <= 1, "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); @@ -62,7 +69,11 @@ void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { if (runtime_cache.get() == nullptr) { LOG_WARNING("Failed to create TensorRT IRuntimeCache; runtime cache will be skipped."); } else { - load_runtime_cache_nothrow(); + try { + load_runtime_cache(runtime_cache_path, runtime_cache.get()); + } catch (const std::exception& e) { + LOG_WARNING("Failed to load runtime cache from " << runtime_cache_path << ": " << e.what()); + } bool ok = config->setRuntimeCache(*runtime_cache); if (!ok) { LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); @@ -92,7 +103,7 @@ void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { } void TRTRuntimeConfig::set_execution_context_allocation_strategy( - nvinfer1::ExecutionContextAllocationStrategy strategy) { + nvinfer1::ExecutionContextAllocationStrategy strategy) const { TORCHTRT_ASSERT(config, "TRTRuntimeConfig::config must be initialized before setting allocation strategy"); config->setExecutionContextAllocationStrategy(strategy); } @@ -120,7 +131,7 @@ void TRTRuntimeConfig::disable_rtx_native_cudagraphs(const std::string& engine_n << engine_name << " for the remainder of its lifetime."); // Persist any kernels the engine-internal capture has compiled so far; the outer // capture will run without them otherwise, and we want future reloads to reuse them. - save_runtime_cache_nothrow(); + save_runtime_cache(); cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; if (config) { bool ok = config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED); @@ -151,70 +162,93 @@ bool TRTRuntimeConfig::is_monolithic_capturable(nvinfer1::IExecutionContext* exe #endif } -void TRTRuntimeConfig::load_runtime_cache_nothrow() noexcept { +void TRTRuntimeConfig::save_runtime_cache() noexcept { #ifdef TRT_MAJOR_RTX - TORCHTRT_ASSERT(runtime_cache, "load_runtime_cache_nothrow requires runtime_cache to be initialized"); - if (runtime_cache_path.empty()) { + if (!runtime_cache || runtime_cache_path.empty()) { return; } try { - if (!std::filesystem::exists(runtime_cache_path)) { - LOG_DEBUG("No existing runtime cache at " << runtime_cache_path); - return; - } - std::ifstream f(runtime_cache_path, std::ios::binary); - std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); - if (buf.empty()) { - return; - } - bool ok = runtime_cache->deserialize(buf.data(), buf.size()); - if (ok) { - LOG_INFO("Loaded runtime cache from " << runtime_cache_path << " (" << buf.size() << " bytes)"); - } else { - LOG_WARNING("runtime_cache->deserialize returned false for " << runtime_cache_path); - } + runtime::save_runtime_cache(runtime_cache_path, runtime_cache.get()); } catch (const std::exception& e) { - LOG_WARNING("Failed to load runtime cache: " << e.what()); + LOG_WARNING("Failed to save runtime cache to " << runtime_cache_path << ": " << e.what()); } catch (...) { - LOG_WARNING("Failed to load runtime cache (unknown exception)."); + LOG_WARNING("Failed to save runtime cache (unknown exception)."); } #endif } -void TRTRuntimeConfig::save_runtime_cache_nothrow() noexcept { +std::string TRTRuntimeConfig::to_str() const { + std::ostringstream os; + os << "Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; + os << "Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; + os << "CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; + return os.str(); +} + +void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { #ifdef TRT_MAJOR_RTX - if (!runtime_cache || runtime_cache_path.empty()) { + TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); + if (!std::filesystem::exists(path)) { + LOG_DEBUG("No existing runtime cache at " << path); return; } - try { - auto host_mem = make_trt(runtime_cache->serialize()); - if (!host_mem || host_mem->size() == 0) { - return; - } - std::filesystem::path path(runtime_cache_path); - if (path.has_parent_path()) { - std::filesystem::create_directories(path.parent_path()); - } - std::filesystem::path tmp_path = path; - tmp_path += ".tmp"; - { - std::ofstream out(tmp_path, std::ios::binary); - out.write(reinterpret_cast(host_mem->data()), host_mem->size()); - } - std::filesystem::rename(tmp_path, path); - LOG_INFO("Saved runtime cache to " << runtime_cache_path << " (" << host_mem->size() << " bytes)"); - } catch (const std::exception& e) { - LOG_WARNING("Failed to save runtime cache: " << e.what()); - } catch (...) { - LOG_WARNING("Failed to save runtime cache (unknown exception)."); + std::ifstream f(path, std::ios::binary); + std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + if (buf.empty()) { + return; + } + bool ok = cache->deserialize(buf.data(), buf.size()); + TORCHTRT_CHECK(ok, "IRuntimeCache::deserialize returned false for " << path); + LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); +#else + (void)path; + (void)cache; +#endif +} + +void save_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { +#ifdef TRT_MAJOR_RTX + TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); + auto host_mem = make_trt(cache->serialize()); + if (!host_mem || host_mem->size() == 0) { + return; } + std::filesystem::path fs_path(path); + if (fs_path.has_parent_path()) { + std::filesystem::create_directories(fs_path.parent_path()); + } + std::filesystem::path tmp_path = fs_path; + tmp_path += ".tmp"; + { + std::ofstream out(tmp_path, std::ios::binary); + out.write(reinterpret_cast(host_mem->data()), host_mem->size()); + } + std::filesystem::rename(tmp_path, fs_path); + LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); +#else + (void)path; + (void)cache; +#endif +} + +TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info) { + TRTRuntimeConfig cfg; +#ifdef TRT_MAJOR_RTX + cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; + cfg.dynamic_shapes_kernel_strategy = + to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); + cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); +#else + (void)info; #endif + return cfg; } -void TRTRuntimeConfig::write_to_str(std::ostream& os) const { - os << " Runtime Cache Path: " << (runtime_cache_path.empty() ? "" : runtime_cache_path) << std::endl; - os << " Dynamic Shapes Kernel Strategy: " << to_string(dynamic_shapes_kernel_strategy) << std::endl; - os << " CUDA Graph Strategy: " << to_string(cuda_graph_strategy) << std::endl; +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg) { + os << "Runtime cfg {" << std::endl; + os << cfg.to_str(); + os << "}" << std::endl; + return os; } } // namespace runtime diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h index 9f22045be4..13d6c87f85 100644 --- a/core/runtime/TRTRuntimeConfig.h +++ b/core/runtime/TRTRuntimeConfig.h @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include "NvInfer.h" @@ -25,10 +27,14 @@ enum class CudaGraphStrategyOption : int32_t { kWholeGraphCapture = 1, }; -std::string to_string(DynamicShapesKernelStrategy s); -std::string to_string(CudaGraphStrategyOption s); -DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy(int v); -CudaGraphStrategyOption to_cuda_graph_strategy_option(int v); +// Conversion helpers. Signatures use the enum's underlying type (int32_t) rather than +// raw `int` so call sites pass validated strategy codes directly without implicit +// narrowing. +[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s); +[[nodiscard]] std::string to_string(CudaGraphStrategyOption s); +[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( + std::underlying_type_t v); +[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v); // Encapsulates the nvinfer1::IRuntimeConfig owned by a TRTEngine along with the // TensorRT-RTX-specific state (runtime cache, dynamic shapes kernel strategy, native @@ -59,12 +65,12 @@ struct TRTRuntimeConfig { // Apply (or re-apply) the execution context allocation strategy on the IRuntimeConfig. // Available on both standard TensorRT and TensorRT-RTX via IRuntimeConfig. - void set_execution_context_allocation_strategy(nvinfer1::ExecutionContextAllocationStrategy strategy); + void set_execution_context_allocation_strategy(nvinfer1::ExecutionContextAllocationStrategy strategy) const; // Returns true if the TensorRT-RTX runtime owns capture/replay for this engine so the // caller should bypass its own at::cuda::CUDAGraph capture around enqueueV3. Always // false on non-RTX builds. - bool uses_internal_capture(bool cudagraphs_enabled) const; + [[nodiscard]] bool uses_internal_capture(bool cudagraphs_enabled) const; // One-shot: disable engine-internal CUDA graph capture. Invoked when an outer stream // capture is detected around execute_engine, so the outer capture can contain the @@ -74,20 +80,34 @@ struct TRTRuntimeConfig { // Whether the execution context is safe to include in an outer monolithic capture. // Non-RTX builds always return true. - bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; + [[nodiscard]] bool is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const; - // Load the runtime cache from disk using std::filesystem. No-throw: errors log and - // return. Invoked internally from `ensure_initialized` when a cache path is set. - void load_runtime_cache_nothrow() noexcept; + // Save the runtime cache to disk. Signature is `noexcept` so this is safe from a + // destructor. The underlying file I/O is performed by free functions declared below + // (non-noexcept, exception-leaky for easier testing); this member wraps them and + // swallows any exceptions. + void save_runtime_cache() noexcept; - // Save the runtime cache to disk using std::filesystem (tmp + rename). No-throw: - // errors log and return, so it is safe to call from a destructor. - void save_runtime_cache_nothrow() noexcept; - - // Append a human-readable summary to a TRTEngine::to_str stream. - void write_to_str(std::ostream& os) const; + // Returns a human-readable summary of the runtime config. + [[nodiscard]] std::string to_str() const; }; +// Free-function I/O helpers. Declared outside TRTRuntimeConfig so they can be tested +// independently of a live TRTEngine and without the noexcept suppression of the member +// wrappers. +// +// These perform raw file I/O and may throw on failure; the member wrappers +// (`save_runtime_cache`, `ensure_initialized`'s load step) catch and log instead. +void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache); +void save_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache); + +// Construct a TRTRuntimeConfig from a flattened serialization vector. Reads the +// RTX-only indices only on RTX builds; standard TRT builds return a default-initialized +// struct. +[[nodiscard]] TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info); + +std::ostream& operator<<(std::ostream& os, const TRTRuntimeConfig& cfg); + } // namespace runtime } // namespace core } // namespace torch_tensorrt diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 95d0cd88bd..eb4f4e07e7 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -125,7 +125,6 @@ class CompilationSettings: autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines - dynamic_shapes_kernel_specialization_strategy (str): TensorRT-RTX dynamic shapes kernel specialization strategy: "lazy" (default, compile specialized kernels in background and use fallbacks until ready), "eager" (compile specialized kernels synchronously, blocking first inference), or "none" (always use fallback kernels). Not used for standard TensorRT. cuda_graph_strategy (str): TensorRT-RTX CUDA graph strategy: "disabled" (default) or "whole_graph_capture" (let TensorRT-RTX manage CUDA graph capture/replay internally). When set and combined with `torch_tensorrt.runtime.set_cudagraphs_mode(True)` on RTX, overrides manual capture. Not used for standard TensorRT. decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. @@ -192,9 +191,6 @@ class CompilationSettings: enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING cpu_memory_budget: Optional[int] = CPU_MEMORY_BUDGET dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES - dynamic_shapes_kernel_specialization_strategy: str = ( - DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY - ) cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY decompose_attention: bool = DECOMPOSE_ATTENTION attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL From 612556ba0fb307783e9229b5f4434fcb334cc90b Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Wed, 22 Apr 2026 22:04:01 -0700 Subject: [PATCH 8/9] refactor(runtime): third-round review polish + cross-backend verification Follow-up to a4989c760 addressing the second batch of comments on PR #4202 plus verification that the non-RTX (standard TensorRT) build path still compiles and tests correctly skip RTX-only suites. Reviewer feedback - FlattenedState: the TRTRTX_FLATTENED_STATE_EXTRAS macro is inlined directly into the tuple parameter pack with a nested `#ifdef TRT_MAJOR_RTX`; no preprocessor macro is introduced, per the reviewer's "Inline and fix" note. - TRTEngine::to_str now calls `runtime_cfg.to_str()` directly rather than relying on the free `operator<<` framing; keeps the engine's existing two-space indentation consistent. - TRTRuntimeConfig free-function I/O helpers (`load_runtime_cache`, `save_runtime_cache`) moved to an anonymous namespace inside TRTRuntimeConfig.cpp and removed from the public header; the member wrapper `TRTRuntimeConfig::save_runtime_cache()` stays in the header (noexcept, catches exceptions from the raw helper). Renamed the internal free save helper to `save_runtime_cache_impl` to avoid clashing with the member of the same name. - Enum conversion helpers `to_string(...)` / `to_dynamic_shapes_kernel_strategy` / `to_cuda_graph_strategy_option` moved to anonymous namespace in the cpp; nothing outside this translation unit needs them now that TRTEngine holds a TRTRuntimeConfig directly. - Replaced `(void)param;` suppression pattern with `TORCHTRT_UNUSED` on the parameter declaration in five places. - Removed the nested `defined(ENABLE_FEATURE_DISABLE_RUNTIME_ ALLOCATION)` guard on `isStreamCapturable`. Instead, the Bazel rule for `//core/runtime:runtime` now sets `ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION` as a local_define for the `:rtx_win` and `:rtx_x86_64` configs so the RTX header's feature gate is always on when we're building for RTX, matching the reviewer's invariant. Cross-backend - Python `_TorchTensorRTModule._pack_engine_info` now always validates `dynamic_shapes_kernel_specialization_strategy` and `cuda_graph_strategy` against the allowed name lists, regardless of whether the build is RTX or standard TRT. The engine-info serialization slots are only written on RTX, but the validation runs universally so typos surface early on any backend. Build + test - RTX A100: 35 passed / 3 skipped on new + merged suites; 9 passed regression (test_002_cudagraphs_cpp.py + test_005_dynamic_ allocation.py). Wheel `torch_tensorrt_rtx-2.12.0.dev0+a4989c760`. - Standard TRT A100: wheel `torch_tensorrt-2.12.0.dev0+a4989c760` builds clean without `--use-rtx`. Import smoke shows `tensorrt_rtx=False`, `SERIALIZATION_LEN=11`. 7 passed / 31 skipped (all skips with clean "Runtime cache is only available with TensorRT-RTX" / "CUDA graph strategy is a TensorRT-RTX feature" messages); 9 regression passed. Covers review comments 3126975981, 3127004055, 3127028393, 3127038410, 3127076231, and 3127100282. --- core/runtime/BUILD | 7 + core/runtime/TRTEngine.cpp | 2 +- core/runtime/TRTEngine.h | 22 ++- core/runtime/TRTRuntimeConfig.cpp | 127 +++++++++--------- core/runtime/TRTRuntimeConfig.h | 18 --- .../dynamo/runtime/_TorchTensorRTModule.py | 49 ++++--- 6 files changed, 105 insertions(+), 120 deletions(-) diff --git a/core/runtime/BUILD b/core/runtime/BUILD index 61fcd7a283..796b0d3c2d 100644 --- a/core/runtime/BUILD +++ b/core/runtime/BUILD @@ -83,6 +83,13 @@ cc_library( linkopts = [ "-lstdc++fs", ], + local_defines = select({ + # TensorRT-RTX builds: opt into feature-gated APIs that the runtime layer + # depends on (e.g. IExecutionContext::isStreamCapturable). + ":rtx_win": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + ":rtx_x86_64": ["ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION"], + "//conditions:default": [], + }), deps = [ "//core/plugins:torch_tensorrt_plugins", "//core/util:prelude", diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 2d32e19c87..2ba42ed954 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -427,7 +427,7 @@ std::string TRTEngine::to_str() const { ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; ss << " Target Platform: " << target_platform << std::endl; ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl; - ss << runtime_cfg; + ss << runtime_cfg.to_str(); // clang-format on return ss.str(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index d3ef259e32..6ad5b2a3f2 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -20,18 +20,6 @@ namespace torch_tensorrt { namespace core { namespace runtime { -#ifdef TRT_MAJOR_RTX -// Extra FlattenedState entries for TensorRT-RTX-only fields. Leading comma so this -// macro can be dropped directly into the std::tuple parameter pack after the final -// shared entry without duplicating the per-entry type in both branches. -#define TRTRTX_FLATTENED_STATE_EXTRAS \ - , std::tuple /* Runtime Cache Path */ \ - , std::tuple /* Dynamic Shapes Kernel Strategy */ \ - , std::tuple /* CUDA Graph Strategy */ -#else -#define TRTRTX_FLATTENED_STATE_EXTRAS -#endif - using FlattenedState = std::tuple< std::tuple, // ABI_VERSION std::tuple, // name @@ -43,8 +31,14 @@ using FlattenedState = std::tuple< std::tuple, // requires_output_allocator std::tuple, // serialized metadata std::tuple, // Platform - std::tuple /* Resource Allocation Strategy */ - TRTRTX_FLATTENED_STATE_EXTRAS>; + std::tuple // Resource Allocation Strategy +#ifdef TRT_MAJOR_RTX + , + std::tuple, // Runtime Cache Path (TRT-RTX) + std::tuple, // Dynamic Shapes Kernel Strategy (TRT-RTX) + std::tuple // CUDA Graph Strategy (TRT-RTX) +#endif + >; struct TorchTRTRuntimeStates { // Indicates whether CUDAGraphs were enabled in the previous execute_engine diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index 61c1307fac..cdf8b63f09 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -13,7 +13,12 @@ namespace torch_tensorrt { namespace core { namespace runtime { -std::string to_string(DynamicShapesKernelStrategy s) { +// File-local helpers. Kept out of the header because they are only used by this +// translation unit -- TRTEngine now consumes a TRTRuntimeConfig directly and does not +// need the enum conversion helpers. +namespace { + +[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s) { switch (s) { case DynamicShapesKernelStrategy::kLazy: return "lazy"; @@ -28,7 +33,7 @@ std::string to_string(DynamicShapesKernelStrategy s) { << static_cast>(s)); } -std::string to_string(CudaGraphStrategyOption s) { +[[nodiscard]] std::string to_string(CudaGraphStrategyOption s) { switch (s) { case CudaGraphStrategyOption::kDisabled: return "disabled"; @@ -40,20 +45,64 @@ std::string to_string(CudaGraphStrategyOption s) { "Unexpected CudaGraphStrategyOption value: " << static_cast>(s)); } -DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy(std::underlying_type_t v) { +[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( + std::underlying_type_t v) { TORCHTRT_CHECK( v >= 0 && v <= 2, "Invalid dynamic shapes kernel strategy value: " << v << ". Expected 0 (lazy), 1 (eager), or 2 (none)."); return static_cast(v); } -CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { +[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v) { TORCHTRT_CHECK( v >= 0 && v <= 1, "Invalid CUDA graph strategy value: " << v << ". Expected 0 (disabled) or 1 (whole_graph_capture)."); return static_cast(v); } +#ifdef TRT_MAJOR_RTX +// Raw cache I/O helpers. Exception-propagating; the caller wraps in try/catch at the +// TRTRuntimeConfig member level. Kept file-local because the IRuntimeCache type is +// itself TensorRT-RTX-only and tests reach this path through the member wrappers. +void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); + if (!std::filesystem::exists(path)) { + LOG_DEBUG("No existing runtime cache at " << path); + return; + } + std::ifstream f(path, std::ios::binary); + std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + if (buf.empty()) { + return; + } + bool ok = cache->deserialize(buf.data(), buf.size()); + TORCHTRT_CHECK(ok, "IRuntimeCache::deserialize returned false for " << path); + LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); +} + +void save_runtime_cache_impl(const std::string& path, nvinfer1::IRuntimeCache* cache) { + TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); + auto host_mem = make_trt(cache->serialize()); + if (!host_mem || host_mem->size() == 0) { + return; + } + std::filesystem::path fs_path(path); + if (fs_path.has_parent_path()) { + std::filesystem::create_directories(fs_path.parent_path()); + } + std::filesystem::path tmp_path = fs_path; + tmp_path += ".tmp"; + { + std::ofstream out(tmp_path, std::ios::binary); + out.write(reinterpret_cast(host_mem->data()), host_mem->size()); + } + std::filesystem::rename(tmp_path, fs_path); + LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); +} +#endif // TRT_MAJOR_RTX + +} // namespace + void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { if (config) { return; @@ -108,7 +157,7 @@ void TRTRuntimeConfig::set_execution_context_allocation_strategy( config->setExecutionContextAllocationStrategy(strategy); } -bool TRTRuntimeConfig::uses_internal_capture(bool cudagraphs_enabled) const { +bool TRTRuntimeConfig::uses_internal_capture(TORCHTRT_UNUSED bool cudagraphs_enabled) const { #ifdef TRT_MAJOR_RTX // On TRT-RTX the internal runtime handles capture/replay whenever a non-disabled // strategy is set, or when subgraph cudagraphs are enabled globally. In both cases the @@ -116,12 +165,11 @@ bool TRTRuntimeConfig::uses_internal_capture(bool cudagraphs_enabled) const { // capture would collide with it. return cuda_graph_strategy != CudaGraphStrategyOption::kDisabled || cudagraphs_enabled; #else - (void)cudagraphs_enabled; return false; #endif } -void TRTRuntimeConfig::disable_rtx_native_cudagraphs(const std::string& engine_name) noexcept { +void TRTRuntimeConfig::disable_rtx_native_cudagraphs(TORCHTRT_UNUSED const std::string& engine_name) noexcept { #ifdef TRT_MAJOR_RTX if (rtx_native_cudagraphs_disabled || cuda_graph_strategy == CudaGraphStrategyOption::kDisabled) { return; @@ -140,24 +188,19 @@ void TRTRuntimeConfig::disable_rtx_native_cudagraphs(const std::string& engine_n } } rtx_native_cudagraphs_disabled = true; -#else - (void)engine_name; #endif } -bool TRTRuntimeConfig::is_monolithic_capturable(nvinfer1::IExecutionContext* exec_ctx, cudaStream_t stream) const { -#if defined(TRT_MAJOR_RTX) && defined(ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION) +bool TRTRuntimeConfig::is_monolithic_capturable( + TORCHTRT_UNUSED nvinfer1::IExecutionContext* exec_ctx, + TORCHTRT_UNUSED cudaStream_t stream) const { +#ifdef TRT_MAJOR_RTX TORCHTRT_ASSERT(exec_ctx != nullptr, "is_monolithic_capturable requires a live IExecutionContext"); // "lazy" kernel specialization swaps specialized kernels in mid-run, which invalidates // captured graphs. Other strategies (eager/none) are safe when the context reports the // stream capturable. return exec_ctx->isStreamCapturable(stream) && dynamic_shapes_kernel_strategy != DynamicShapesKernelStrategy::kLazy; #else - // isStreamCapturable is declared inside `#if ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION` - // in the TensorRT-RTX header; conservatively assume the engine is capturable when that - // feature flag is not enabled at compile time. - (void)exec_ctx; - (void)stream; return true; #endif } @@ -168,7 +211,7 @@ void TRTRuntimeConfig::save_runtime_cache() noexcept { return; } try { - runtime::save_runtime_cache(runtime_cache_path, runtime_cache.get()); + save_runtime_cache_impl(runtime_cache_path, runtime_cache.get()); } catch (const std::exception& e) { LOG_WARNING("Failed to save runtime cache to " << runtime_cache_path << ": " << e.what()); } catch (...) { @@ -185,61 +228,13 @@ std::string TRTRuntimeConfig::to_str() const { return os.str(); } -void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { -#ifdef TRT_MAJOR_RTX - TORCHTRT_CHECK(cache != nullptr, "load_runtime_cache requires a non-null IRuntimeCache"); - if (!std::filesystem::exists(path)) { - LOG_DEBUG("No existing runtime cache at " << path); - return; - } - std::ifstream f(path, std::ios::binary); - std::vector buf((std::istreambuf_iterator(f)), std::istreambuf_iterator()); - if (buf.empty()) { - return; - } - bool ok = cache->deserialize(buf.data(), buf.size()); - TORCHTRT_CHECK(ok, "IRuntimeCache::deserialize returned false for " << path); - LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); -#else - (void)path; - (void)cache; -#endif -} - -void save_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) { -#ifdef TRT_MAJOR_RTX - TORCHTRT_CHECK(cache != nullptr, "save_runtime_cache requires a non-null IRuntimeCache"); - auto host_mem = make_trt(cache->serialize()); - if (!host_mem || host_mem->size() == 0) { - return; - } - std::filesystem::path fs_path(path); - if (fs_path.has_parent_path()) { - std::filesystem::create_directories(fs_path.parent_path()); - } - std::filesystem::path tmp_path = fs_path; - tmp_path += ".tmp"; - { - std::ofstream out(tmp_path, std::ios::binary); - out.write(reinterpret_cast(host_mem->data()), host_mem->size()); - } - std::filesystem::rename(tmp_path, fs_path); - LOG_INFO("Saved runtime cache to " << path << " (" << host_mem->size() << " bytes)"); -#else - (void)path; - (void)cache; -#endif -} - -TRTRuntimeConfig make_runtime_config_from_serialized(const std::vector& info) { +TRTRuntimeConfig make_runtime_config_from_serialized(TORCHTRT_UNUSED const std::vector& info) { TRTRuntimeConfig cfg; #ifdef TRT_MAJOR_RTX cfg.runtime_cache_path = info[RUNTIME_CACHE_PATH_IDX]; cfg.dynamic_shapes_kernel_strategy = to_dynamic_shapes_kernel_strategy(std::stoi(info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX])); cfg.cuda_graph_strategy = to_cuda_graph_strategy_option(std::stoi(info[CUDA_GRAPH_STRATEGY_IDX])); -#else - (void)info; #endif return cfg; } diff --git a/core/runtime/TRTRuntimeConfig.h b/core/runtime/TRTRuntimeConfig.h index 13d6c87f85..e964706c2e 100644 --- a/core/runtime/TRTRuntimeConfig.h +++ b/core/runtime/TRTRuntimeConfig.h @@ -27,15 +27,6 @@ enum class CudaGraphStrategyOption : int32_t { kWholeGraphCapture = 1, }; -// Conversion helpers. Signatures use the enum's underlying type (int32_t) rather than -// raw `int` so call sites pass validated strategy codes directly without implicit -// narrowing. -[[nodiscard]] std::string to_string(DynamicShapesKernelStrategy s); -[[nodiscard]] std::string to_string(CudaGraphStrategyOption s); -[[nodiscard]] DynamicShapesKernelStrategy to_dynamic_shapes_kernel_strategy( - std::underlying_type_t v); -[[nodiscard]] CudaGraphStrategyOption to_cuda_graph_strategy_option(std::underlying_type_t v); - // Encapsulates the nvinfer1::IRuntimeConfig owned by a TRTEngine along with the // TensorRT-RTX-specific state (runtime cache, dynamic shapes kernel strategy, native // CUDA graph strategy). All `#ifdef TRT_MAJOR_RTX` guards live in this file and its @@ -92,15 +83,6 @@ struct TRTRuntimeConfig { [[nodiscard]] std::string to_str() const; }; -// Free-function I/O helpers. Declared outside TRTRuntimeConfig so they can be tested -// independently of a live TRTEngine and without the noexcept suppression of the member -// wrappers. -// -// These perform raw file I/O and may throw on failure; the member wrappers -// (`save_runtime_cache`, `ensure_initialized`'s load step) catch and log instead. -void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache); -void save_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache); - // Construct a TRTRuntimeConfig from a flattened serialization vector. Reads the // RTX-only indices only on RTX builds; standard TRT builds return a default-initialized // struct. diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 0713b24f6c..79c14ddb9d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -167,12 +167,14 @@ def __init__( self.engine = None self.requires_output_allocator = requires_output_allocator self.dynamically_allocate_resources = settings.dynamically_allocate_resources - if ENABLED_FEATURES.tensorrt_rtx: - self.runtime_cache_path = settings.runtime_cache_path - self.dynamic_shapes_kernel_specialization_strategy = ( - settings.dynamic_shapes_kernel_specialization_strategy - ) - self.cuda_graph_strategy = settings.cuda_graph_strategy + # TensorRT-RTX-only runtime config mirror. The engine-info serialization slots + # only exist on RTX builds (see below), but we validate the strategy names on + # every build so typos are caught regardless of backend. + self.runtime_cache_path = settings.runtime_cache_path + self.dynamic_shapes_kernel_specialization_strategy = ( + settings.dynamic_shapes_kernel_specialization_strategy + ) + self.cuda_graph_strategy = settings.cuda_graph_strategy self.symbolic_shape_expressions = symbolic_shape_expressions if ( @@ -231,27 +233,32 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str( int(self.dynamically_allocate_resources) ) - if ENABLED_FEATURES.tensorrt_rtx: + # Validate TensorRT-RTX strategy names on every build so typos are caught + # regardless of backend. The engine-info slots themselves only exist on RTX + # builds and are written below, but the validation is cheap and catches user + # errors early. + if ENABLED_FEATURES.tensorrt_rtx and self.runtime_cache_path is not None: engine_info[RUNTIME_CACHE_PATH_IDX] = self.runtime_cache_path or "" - if ( - self.dynamic_shapes_kernel_specialization_strategy - not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP - ): - raise ValueError( - f"Invalid dynamic_shapes_kernel_specialization_strategy " - f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " - f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" - ) + if ( + self.dynamic_shapes_kernel_specialization_strategy + not in _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP + ): + raise ValueError( + f"Invalid dynamic_shapes_kernel_specialization_strategy " + f"{self.dynamic_shapes_kernel_specialization_strategy!r}; expected one of " + f"{list(_DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP.keys())}" + ) + if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: + raise ValueError( + f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " + f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" + ) + if ENABLED_FEATURES.tensorrt_rtx: engine_info[DYNAMIC_SHAPES_KERNEL_STRATEGY_IDX] = str( _DYNAMIC_SHAPES_KERNEL_STRATEGY_MAP[ self.dynamic_shapes_kernel_specialization_strategy ] ) - if self.cuda_graph_strategy not in _CUDA_GRAPH_STRATEGY_MAP: - raise ValueError( - f"Invalid cuda_graph_strategy {self.cuda_graph_strategy!r}; expected one of " - f"{list(_CUDA_GRAPH_STRATEGY_MAP.keys())}" - ) engine_info[CUDA_GRAPH_STRATEGY_IDX] = str( _CUDA_GRAPH_STRATEGY_MAP[self.cuda_graph_strategy] ) From e8521233c139ca21e67b4884ed2e4aee840120c8 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Thu, 23 Apr 2026 02:16:58 -0700 Subject: [PATCH 9/9] refactor(runtime): fourth-round review polish + test deduplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to 612556ba0 addressing the latest batch of comments on pytorch/TensorRT PR #4202. Two categories of changes: Reviewer-suggested C++ simplifications (TRTRuntimeConfig.cpp) - load_runtime_cache: inlined the deserialize() call directly into TORCHTRT_CHECK instead of going through an intermediate bool. - ensure_initialized / setRuntimeCache: flipped the if/else so the success branch comes first and the warning + reset lands in the else, matching the reviewer's diff suggestion. - ensure_initialized / setCudaGraphStrategy: inlined the call into the if-condition and dropped the intermediate `bool ok` local. - disable_rtx_native_cudagraphs: same shape fix for the disable-path setCudaGraphStrategy call. Runtime cache durability (TRTEngine.cpp) - recreate_execution_context now flushes the runtime cache before rebuilding the IExecutionContext. The destructor already saves at teardown, but recreate can happen mid-lifetime around profiling toggles and allocator changes; without flushing there, a process kill between an allocator flip and teardown would lose any kernels compiled during the previous context. No-op on standard TensorRT and when no cache path is configured. Test deduplication (tests/py/dynamo/**/test_*{runtime_cache,dynamic_ shapes_kernel_strategy}*.py) Reviewer asked to stop copy-pasting bodies between the Python- and C++-runtime test classes. The persistence, model, and dynamic-shape suites now share one parameterized body that runs on both runtimes: - test_000_runtime_cache.py: TestRuntimeCachePersistence holds the single body; parameterized.expand(_RUNTIMES) fans out over ("python", True) and ("cpp", False). The CppPersistence class, its helpers, and CppSimpleModel are gone; a shared ConvModel with seeded init drives both paths. The C++ parameter skips itself via self.skipTest when torch_tensorrt_runtime is off. - test_001_dynamic_shapes_kernel_strategy.py: the lazy/eager/none test trio in TestDynamicShapesKernelStrategyCpp collapses into a single parameterized test_strategy_inference. Same parameter sweep on TestDynamicShapesKernelStrategySetup.test_strategy_ applied. - test_runtime_cache_models.py: TestRuntimeCacheModels, TestRuntimeCacheDynamicShapes, and TestRuntimeCachePerformance are parameterized over (runtime, use_python_runtime); the Cpp* sibling class is removed. - test_dynamic_shapes_kernel_strategy_models.py: one parameter product (strategy × runtime) drives both the resnet18 and dynamic-batch tests; the Cpp* sibling class is removed. Net: ~200 fewer lines of test code, same coverage, plus symmetry between Python- and C++-runtime test execution. Build + verification - RTX A100 (ipp1-2162, cuda13.0 dev container), wheel torch_tensorrt_rtx-2.12.0.dev0+612556ba0. - runtime/test_000_runtime_cache.py + runtime/test_001_dynamic_shapes_kernel_strategy.py + runtime/test_001_cuda_graph_strategy.py: 36 passed / 3 skipped (up from 35 pre-dedup — the param expansion picks up one extra per-runtime variant on the strategy applied test). - runtime/test_005_dynamic_allocation.py + runtime/test_002_cudagraphs_cpp.py: 9 passed (regression clean). - Model-level subset (resnet18 + dynamic-batch sweep across both runtimes and all three strategies): 10 passed. - Dedicated C++-runtime verification script confirms that use_python_runtime=False produces TorchTensorRTModule (not PythonTorchTensorRTModule), and that the runtime cache is populated and flushed through the C++ path (file size > 0 on engine destruction). Covers review comments 3128480385, 3128493651, 3128747920, 3128754155, 3128759096, and 3128764510. --- core/runtime/TRTEngine.cpp | 6 + core/runtime/TRTRuntimeConfig.cpp | 26 +- ...t_dynamic_shapes_kernel_strategy_models.py | 183 +++++-------- .../models/test_runtime_cache_models.py | 241 +++++----------- .../dynamo/runtime/test_000_runtime_cache.py | 258 +++++++----------- ...test_001_dynamic_shapes_kernel_strategy.py | 115 ++++---- 6 files changed, 305 insertions(+), 524 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 2ba42ed954..51efa2388a 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -545,6 +545,12 @@ void TRTEngine::disable_rtx_native_cudagraphs() { } void TRTEngine::recreate_execution_context() { + // Flush any kernels the previous execution context may have compiled into the + // runtime cache before creating the replacement. The destructor also saves, but + // doing it here guards against losing compiled kernels across profiling toggles, + // allocator changes, or process kills that happen between allocator changes and + // teardown. No-op on standard TensorRT or when no cache path is configured. + runtime_cfg.save_runtime_cache(); runtime_cfg.ensure_initialized(cuda_engine.get()); runtime_cfg.set_execution_context_allocation_strategy( resource_allocation_strategy == ResourceAllocationStrategy::kDynamic diff --git a/core/runtime/TRTRuntimeConfig.cpp b/core/runtime/TRTRuntimeConfig.cpp index cdf8b63f09..0804a0a7fa 100644 --- a/core/runtime/TRTRuntimeConfig.cpp +++ b/core/runtime/TRTRuntimeConfig.cpp @@ -75,8 +75,7 @@ void load_runtime_cache(const std::string& path, nvinfer1::IRuntimeCache* cache) if (buf.empty()) { return; } - bool ok = cache->deserialize(buf.data(), buf.size()); - TORCHTRT_CHECK(ok, "IRuntimeCache::deserialize returned false for " << path); + TORCHTRT_CHECK(cache->deserialize(buf.data(), buf.size()), "IRuntimeCache::deserialize returned false for " << path); LOG_INFO("Loaded runtime cache from " << path << " (" << buf.size() << " bytes)"); } @@ -123,12 +122,11 @@ void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { } catch (const std::exception& e) { LOG_WARNING("Failed to load runtime cache from " << runtime_cache_path << ": " << e.what()); } - bool ok = config->setRuntimeCache(*runtime_cache); - if (!ok) { + if (config->setRuntimeCache(*runtime_cache)) { + LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); + } else { LOG_WARNING("Failed to attach runtime cache to IRuntimeConfig; cache will be unused."); runtime_cache.reset(); - } else { - LOG_DEBUG("TensorRT-RTX runtime cache configured at " << runtime_cache_path); } } } else { @@ -141,11 +139,10 @@ void TRTRuntimeConfig::ensure_initialized(nvinfer1::ICudaEngine* cuda_engine) { LOG_DEBUG("Dynamic shapes kernel specialization strategy set to " << to_string(dynamic_shapes_kernel_strategy)); // CUDA graph strategy -- TRT-RTX only. - bool ok = config->setCudaGraphStrategy( - cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture - ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE - : nvinfer1::CudaGraphStrategy::kDISABLED); - if (!ok) { + if (!config->setCudaGraphStrategy( + cuda_graph_strategy == CudaGraphStrategyOption::kWholeGraphCapture + ? nvinfer1::CudaGraphStrategy::kWHOLE_GRAPH_CAPTURE + : nvinfer1::CudaGraphStrategy::kDISABLED)) { LOG_WARNING("Failed to set CUDA graph strategy; continuing with default."); } #endif @@ -181,11 +178,8 @@ void TRTRuntimeConfig::disable_rtx_native_cudagraphs(TORCHTRT_UNUSED const std:: // capture will run without them otherwise, and we want future reloads to reuse them. save_runtime_cache(); cuda_graph_strategy = CudaGraphStrategyOption::kDisabled; - if (config) { - bool ok = config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED); - if (!ok) { - LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); - } + if (config && !config->setCudaGraphStrategy(nvinfer1::CudaGraphStrategy::kDISABLED)) { + LOG_WARNING("Failed to update CUDA graph strategy on IRuntimeConfig after disable."); } rtx_native_cudagraphs_disabled = true; #endif diff --git a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py index fd3b9ee93d..e3d438c3ae 100644 --- a/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py +++ b/tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py @@ -3,10 +3,43 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +# Combinations of (strategy, runtime_name, use_python_runtime). Tests use parameterized +# so the strategy sweep runs on both runtimes with a single test body. +_STRATEGY_RUNTIMES = [ + ("lazy_python", "lazy", True), + ("eager_python", "eager", True), + ("none_python", "none", True), + ("lazy_cpp", "lazy", False), + ("eager_cpp", "eager", False), + ("none_cpp", "none", False), +] + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + + +def _compile_with_strategy( + model, inputs, *, use_python_runtime, strategy, enabled_precisions +): + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions=enabled_precisions, + use_python_runtime=use_python_runtime, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, @@ -17,17 +50,18 @@ "torchvision is not installed", ) class TestDynamicShapesKernelStrategyModels(TestCase): - """End-to-end model tests with different kernel specialization strategies.""" + """End-to-end model tests with each strategy across both runtimes.""" - def tearDown(self): - torch._dynamo.reset() + @parameterized.expand(_STRATEGY_RUNTIMES) + def test_resnet18_strategy(self, _name, strategy, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + import torchvision.models as models - def _compile_and_verify(self, model, strategy): + model = models.resnet18(pretrained=True).eval().cuda() input_tensor = torch.randn(4, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile_with_strategy( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 224, 224), opt_shape=(4, 3, 224, 224), @@ -35,10 +69,9 @@ def _compile_and_verify(self, model, strategy): dtype=torch.float32, ) ], + use_python_runtime=use_python_runtime, + strategy=strategy, enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, ) ref_output = model(input_tensor) trt_output = compiled(input_tensor) @@ -46,39 +79,21 @@ def _compile_and_verify(self, model, strategy): self.assertTrue( cos_sim > COSINE_THRESHOLD, f"Cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} " - f"with strategy={strategy}", + f"(strategy={strategy}, python_runtime={use_python_runtime})", ) - def test_resnet18_lazy_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "lazy") - - def test_resnet18_eager_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "eager") - - def test_resnet18_none_strategy(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify(model, "none") - @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", ) class TestDynamicShapesKernelStrategyDynamic(TestCase): - """Tests kernel specialization strategies with dynamic input shapes.""" + """Tests kernel specialization strategies with dynamic input shapes, both runtimes.""" - def tearDown(self): - torch._dynamo.reset() + @parameterized.expand(_STRATEGY_RUNTIMES) + def test_dynamic_batch_with_strategy(self, _name, strategy, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) - def _test_dynamic_batch_with_strategy(self, strategy): class ConvModel(torch.nn.Module): def __init__(self): super().__init__() @@ -90,10 +105,9 @@ def forward(self, x): model = ConvModel().eval().cuda() - compiled = torchtrt.compile( + compiled = _compile_with_strategy( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 32, 32), opt_shape=(4, 3, 32, 32), @@ -101,96 +115,21 @@ def forward(self, x): dtype=torch.float32, ) ], + use_python_runtime=use_python_runtime, + strategy=strategy, enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, ) for batch_size in (1, 4, 8): - with self.subTest(batch_size=batch_size, strategy=strategy): - input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() - ref_output = model(input_tensor) - trt_output = compiled(input_tensor) - cos_sim = cosine_similarity(ref_output, trt_output) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - f"BS={batch_size}, strategy={strategy}: cosine similarity " - f"{cos_sim} below threshold {COSINE_THRESHOLD}", - ) - - def test_dynamic_batch_lazy(self): - self._test_dynamic_batch_with_strategy("lazy") - - def test_dynamic_batch_eager(self): - self._test_dynamic_batch_with_strategy("eager") - - def test_dynamic_batch_none(self): - self._test_dynamic_batch_with_strategy("none") - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", -) -@unittest.skipIf( - not importlib.util.find_spec("torchvision"), - "torchvision is not installed", -) -class TestDynamicShapesKernelStrategyCppModels(TestCase): - """End-to-end model tests with each strategy exercised through the C++ runtime.""" - - def tearDown(self): - torch._dynamo.reset() - - def _compile_and_verify_cpp(self, model, strategy): - input_tensor = torch.randn(4, 3, 224, 224).cuda() - compiled = torchtrt.compile( - model, - ir="dynamo", - inputs=[ - torchtrt.Input( - min_shape=(1, 3, 224, 224), - opt_shape=(4, 3, 224, 224), - max_shape=(8, 3, 224, 224), - dtype=torch.float32, - ) - ], - enabled_precisions={torch.float32}, - use_python_runtime=False, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, - ) - ref_output = model(input_tensor) - trt_output = compiled(input_tensor) - cos_sim = cosine_similarity(ref_output, trt_output) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - f"C++ runtime cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} " - f"with strategy={strategy}", - ) - - def test_resnet18_lazy_strategy_cpp(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify_cpp(model, "lazy") - - def test_resnet18_eager_strategy_cpp(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify_cpp(model, "eager") - - def test_resnet18_none_strategy_cpp(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - self._compile_and_verify_cpp(model, "none") + input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"BS={batch_size}, strategy={strategy}, python_runtime={use_python_runtime}: " + f"cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", + ) if __name__ == "__main__": diff --git a/tests/py/dynamo/models/test_runtime_cache_models.py b/tests/py/dynamo/models/test_runtime_cache_models.py index 7ffae1f5ad..55b11b623e 100644 --- a/tests/py/dynamo/models/test_runtime_cache_models.py +++ b/tests/py/dynamo/models/test_runtime_cache_models.py @@ -8,10 +8,32 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +# Parameterize end-to-end cache tests over both runtime paths. The C++ variant is +# skipped inside the test body when the C++ runtime is not available. +_RUNTIMES = [("python", True), ("cpp", False)] + + +def _compile(model, inputs, *, use_python_runtime, runtime_cache_path): + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": use_python_runtime, + "min_block_size": 1, + "runtime_cache_path": runtime_cache_path, + } + return torchtrt.compile(model, **kwargs) + + +def _skip_if_cpp_unavailable(testcase, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + testcase.skipTest("C++ runtime is not available") + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, @@ -22,7 +44,7 @@ "torchvision is not installed", ) class TestRuntimeCacheModels(TestCase): - """End-to-end model tests with runtime cache enabled.""" + """End-to-end model tests with runtime cache enabled — both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -32,19 +54,18 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_resnet18_with_runtime_cache(self): + @parameterized.expand(_RUNTIMES) + def test_resnet18_with_runtime_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() input_tensor = torch.randn(1, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, + [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) @@ -57,7 +78,6 @@ def test_resnet18_with_runtime_cache(self): f"ResNet18 cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", ) - # Verify runtime cache is saved on cleanup del compiled gc.collect() self.assertTrue( @@ -65,8 +85,10 @@ def test_resnet18_with_runtime_cache(self): "Runtime cache should be saved after ResNet18 inference", ) - def test_resnet18_cache_reuse(self): - """Compile + infer twice with same cache path. Second run should load cached data.""" + @parameterized.expand(_RUNTIMES) + def test_resnet18_cache_reuse(self, _name, use_python_runtime): + """Compile + infer twice with same cache path. Second run loads cached data.""" + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.resnet18(pretrained=True).eval().cuda() @@ -74,16 +96,13 @@ def test_resnet18_cache_reuse(self): ref_output = model(input_tensor) compile_kwargs = { - "ir": "dynamo", "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } # First compilation — cold cache - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) _ = compiled1(input_tensor) del compiled1 gc.collect() @@ -92,7 +111,7 @@ def test_resnet18_cache_reuse(self): cache_size_1 = os.path.getsize(self.cache_path) # Second compilation — warm cache - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) output2 = compiled2(input_tensor) cos_sim = cosine_similarity(ref_output, output2) @@ -104,23 +123,21 @@ def test_resnet18_cache_reuse(self): del compiled2 gc.collect() cache_size_2 = os.path.getsize(self.cache_path) - # Cache should exist and be non-empty after both runs self.assertGreater(cache_size_1, 0) self.assertGreater(cache_size_2, 0) - def test_mobilenet_v2_with_runtime_cache(self): + @parameterized.expand(_RUNTIMES) + def test_mobilenet_v2_with_runtime_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) import torchvision.models as models model = models.mobilenet_v2(pretrained=True).eval().cuda() input_tensor = torch.randn(1, 3, 224, 224).cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, + [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) @@ -143,7 +160,7 @@ def test_mobilenet_v2_with_runtime_cache(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheDynamicShapes(TestCase): - """Tests runtime cache with dynamic input shapes.""" + """Tests runtime cache with dynamic input shapes, exercised on both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -153,7 +170,10 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_dynamic_batch_with_cache(self): + @parameterized.expand(_RUNTIMES) + def test_dynamic_batch_with_cache(self, _name, use_python_runtime): + _skip_if_cpp_unavailable(self, use_python_runtime) + class ConvModel(torch.nn.Module): def __init__(self): super().__init__() @@ -165,10 +185,9 @@ def forward(self, x): model = ConvModel().eval().cuda() - compiled = torchtrt.compile( + compiled = _compile( model, - ir="dynamo", - inputs=[ + [ torchtrt.Input( min_shape=(1, 3, 32, 32), opt_shape=(4, 3, 32, 32), @@ -176,39 +195,28 @@ def forward(self, x): dtype=torch.float32, ) ], - enabled_precisions={torch.float32}, - use_python_runtime=True, - min_block_size=1, + use_python_runtime=use_python_runtime, runtime_cache_path=self.cache_path, ) - # Test with batch size 1 - input_bs1 = torch.randn(1, 3, 32, 32).cuda() - ref_bs1 = model(input_bs1) - out_bs1 = compiled(input_bs1) - cos_sim_1 = cosine_similarity(ref_bs1, out_bs1) - self.assertTrue( - cos_sim_1 > COSINE_THRESHOLD, - f"BS=1 cosine similarity {cos_sim_1} below threshold", - ) + for batch_size in (1, 4): + input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() + ref_output = model(input_tensor) + out = compiled(input_tensor) + cos_sim = cosine_similarity(ref_output, out) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"BS={batch_size} cosine similarity {cos_sim} below threshold", + ) - # Test with batch size 4 - input_bs4 = torch.randn(4, 3, 32, 32).cuda() - ref_bs4 = model(input_bs4) - out_bs4 = compiled(input_bs4) - cos_sim_4 = cosine_similarity(ref_bs4, out_bs4) - self.assertTrue( - cos_sim_4 > COSINE_THRESHOLD, - f"BS=4 cosine similarity {cos_sim_4} below threshold", - ) - - # Verify cache is saved del compiled gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - def test_cache_valid_across_shapes(self): + @parameterized.expand(_RUNTIMES) + def test_cache_valid_across_shapes(self, _name, use_python_runtime): """Save cache from one shape, load and verify it works with another shape in range.""" + _skip_if_cpp_unavailable(self, use_python_runtime) class SimpleConv(torch.nn.Module): def __init__(self): @@ -221,7 +229,6 @@ def forward(self, x): model = SimpleConv().eval().cuda() compile_kwargs = { - "ir": "dynamo", "inputs": [ torchtrt.Input( min_shape=(1, 3, 16, 16), @@ -230,14 +237,12 @@ def forward(self, x): dtype=torch.float32, ) ], - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } # First run with batch=2 — saves cache - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) input_bs2 = torch.randn(2, 3, 16, 16).cuda() _ = compiled1(input_bs2) del compiled1 @@ -246,7 +251,7 @@ def forward(self, x): self.assertTrue(os.path.isfile(self.cache_path)) # Second run with batch=3 — loads same cache - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) input_bs3 = torch.randn(3, 3, 16, 16).cuda() ref_bs3 = model(input_bs3) out_bs3 = compiled2(input_bs3) @@ -273,8 +278,10 @@ def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) torch._dynamo.reset() - def test_warmup_timing(self): - """Measure cold vs warm cache inference time. Informational only — no strict pass/fail.""" + @parameterized.expand(_RUNTIMES) + def test_warmup_timing(self, _name, use_python_runtime): + """Measure cold vs warm cache inference time. Informational — no strict assertion.""" + _skip_if_cpp_unavailable(self, use_python_runtime) class MLP(torch.nn.Module): def __init__(self): @@ -290,16 +297,12 @@ def forward(self, x): input_tensor = torch.randn(16, 256).cuda() compile_kwargs = { - "ir": "dynamo", "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - "enabled_precisions": {torch.float32}, - "use_python_runtime": True, - "min_block_size": 1, + "use_python_runtime": use_python_runtime, "runtime_cache_path": self.cache_path, } - # Cold cache compilation + inference - compiled1 = torchtrt.compile(model, **compile_kwargs) + compiled1 = _compile(model, **compile_kwargs) torch.cuda.synchronize() start = time.perf_counter() _ = compiled1(input_tensor) @@ -309,112 +312,18 @@ def forward(self, x): gc.collect() torch._dynamo.reset() - # Warm cache compilation + inference - compiled2 = torchtrt.compile(model, **compile_kwargs) + compiled2 = _compile(model, **compile_kwargs) torch.cuda.synchronize() start = time.perf_counter() _ = compiled2(input_tensor) torch.cuda.synchronize() warm_time = time.perf_counter() - start - print(f"\n Cold cache first inference: {cold_time*1000:.1f}ms") - print(f" Warm cache first inference: {warm_time*1000:.1f}ms") - print(f" Speedup: {cold_time/warm_time:.2f}x") - - # No strict assertion — just log for visibility + print(f"\n [{_name}] Cold cache first inference: {cold_time*1000:.1f}ms") + print(f" [{_name}] Warm cache first inference: {warm_time*1000:.1f}ms") + print(f" [{_name}] Speedup: {cold_time/warm_time:.2f}x") self.assertTrue(True, "Timing test completed (informational)") -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Runtime cache is only available with TensorRT-RTX", -) -@unittest.skipIf( - not importlib.util.find_spec("torchvision"), - "torchvision is not installed", -) -class TestRuntimeCacheCppModels(TestCase): - """End-to-end model tests with runtime cache exercised through the C++ runtime.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) - torch._dynamo.reset() - - def test_resnet18_with_runtime_cache_cpp(self): - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - input_tensor = torch.randn(1, 3, 224, 224).cuda() - - compiled = torchtrt.compile( - model, - ir="dynamo", - inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - enabled_precisions={torch.float32}, - use_python_runtime=False, - min_block_size=1, - runtime_cache_path=self.cache_path, - ) - - ref_output = model(input_tensor) - trt_output = compiled(input_tensor) - - cos_sim = cosine_similarity(ref_output, trt_output) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - f"ResNet18 C++ runtime cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", - ) - - # Verify the runtime cache is persisted on engine destruction. - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(self.cache_path), - "Runtime cache should be saved after ResNet18 C++-runtime inference", - ) - - def test_resnet18_cache_reuse_cpp(self): - """Warm-cache second compile should match eager output.""" - import torchvision.models as models - - model = models.resnet18(pretrained=True).eval().cuda() - input_tensor = torch.randn(1, 3, 224, 224).cuda() - ref_output = model(input_tensor) - - compile_kwargs = { - "ir": "dynamo", - "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], - "enabled_precisions": {torch.float32}, - "use_python_runtime": False, - "min_block_size": 1, - "runtime_cache_path": self.cache_path, - } - - compiled1 = torchtrt.compile(model, **compile_kwargs) - out1 = compiled1(input_tensor) - self.assertTrue( - cosine_similarity(ref_output, out1) > COSINE_THRESHOLD, - "First ResNet18 C++-runtime output should match eager", - ) - del compiled1 - gc.collect() - self.assertTrue(os.path.isfile(self.cache_path)) - - compiled2 = torchtrt.compile(model, **compile_kwargs) - out2 = compiled2(input_tensor) - self.assertTrue( - cosine_similarity(ref_output, out2) > COSINE_THRESHOLD, - "Second ResNet18 C++-runtime output (warm cache) should match eager", - ) - - if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index fc7be8a979..dc23847870 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -1,5 +1,4 @@ import gc -import logging import os import shutil import tempfile @@ -7,10 +6,10 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH -from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity @@ -19,31 +18,50 @@ def forward(self, x): return torch.relu(x) + 1.0 -class TwoLayerModel(torch.nn.Module): +class ConvModel(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(8, 8) + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) def forward(self, x): - return torch.relu(self.linear(x)) + return torch.relu(self.conv(x)) -def _compile_simple(runtime_cache_path=None): - """Helper: compile SimpleModel with Python runtime, return (compiled_module, inputs).""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] +def _fresh_conv_model_and_inputs(seed=0): + """Deterministic ConvModel + input pair for end-to-end cache tests on either runtime.""" + torch.manual_seed(seed) + return ConvModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] + + +def _compile(model, inputs, *, use_python_runtime, runtime_cache_path=None): + """Compile `model` through either runtime. Returns the compiled module.""" kwargs = { "ir": "dynamo", "inputs": inputs, "enabled_precisions": {torch.float32}, - "use_python_runtime": True, + "use_python_runtime": use_python_runtime, "min_block_size": 1, } if runtime_cache_path is not None: kwargs["runtime_cache_path"] = runtime_cache_path compiled = torchtrt.compile(model, **kwargs) torch._dynamo.reset() - return compiled, inputs + return compiled + + +def _compile_simple(runtime_cache_path=None): + """Compile the SimpleModel on the Python runtime (used by Python-only setup tests).""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + return ( + _compile( + model, + inputs, + use_python_runtime=True, + runtime_cache_path=runtime_cache_path, + ), + inputs, + ) def _find_python_trt_module(compiled): @@ -52,18 +70,23 @@ def _find_python_trt_module(compiled): PythonTorchTensorRTModule, ) - for name, mod in compiled.named_modules(): + for _name, mod in compiled.named_modules(): if isinstance(mod, PythonTorchTensorRTModule): return mod return None +# Parameterize end-to-end cache persistence tests over both runtime paths. The C++ +# variant is skipped inside the test body when the C++ runtime is not available. +_RUNTIMES = [("python", True), ("cpp", False)] + + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheSetup(TestCase): - """Tests that runtime config and cache are correctly created for RTX.""" + """Python-runtime-only setup checks: the compiled module exposes a live runtime cache.""" def test_runtime_config_created(self): compiled, _ = _compile_simple() @@ -78,7 +101,6 @@ def test_context_created_successfully(self): compiled, inputs = _compile_simple() mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod.context, "execution context should be created") - # Verify inference works output = compiled(*[inp.clone() for inp in inputs]) self.assertEqual(output.shape, inputs[0].shape) @@ -103,7 +125,7 @@ def test_runtime_cache_path_custom(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCachePersistence(TestCase): - """Tests that runtime cache is correctly saved to and loaded from disk.""" + """Load-on-setup / save-on-destructor contract, exercised on both runtimes.""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -112,9 +134,20 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.cache_dir, ignore_errors=True) - def test_cache_saved_on_del(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) - # Run inference to populate the cache + def _skip_if_cpp_unavailable(self, use_python_runtime): + if not use_python_runtime and not ENABLED_FEATURES.torch_tensorrt_runtime: + self.skipTest("C++ runtime is not available") + + @parameterized.expand(_RUNTIMES) + def test_cache_saved_on_del(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) self.assertFalse( os.path.isfile(self.cache_path), @@ -127,8 +160,16 @@ def test_cache_saved_on_del(self): "Cache file should be created after module cleanup", ) - def test_cache_file_nonempty(self): - compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) + @parameterized.expand(_RUNTIMES) + def test_cache_file_nonempty(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -138,30 +179,54 @@ def test_cache_file_nonempty(self): "Cache file should have nonzero size", ) - def test_cache_roundtrip(self): - """Compile, infer, save. Then compile again with same cache path and verify correctness.""" - model = SimpleModel().eval().cuda() - inputs = [torch.randn(2, 3).cuda()] - ref_output = model(*inputs) + @parameterized.expand(_RUNTIMES) + def test_cache_roundtrip(self, _name, use_python_runtime): + """Populate + save, then recompile and confirm correctness against eager output.""" + self._skip_if_cpp_unavailable(use_python_runtime) + model, inputs = _fresh_conv_model_and_inputs() + with torch.no_grad(): + ref_output = model(*inputs) - # First compilation — populates and saves cache - compiled1, _ = _compile_simple(runtime_cache_path=self.cache_path) - _ = compiled1(*[inp.clone() for inp in inputs]) + compiled1 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out1 = compiled1(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out1), + COSINE_THRESHOLD, + "First compiled output should match eager", + ) del compiled1 gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - # Second compilation — should load cached data - compiled2, _ = _compile_simple(runtime_cache_path=self.cache_path) - output = compiled2(*[inp.clone() for inp in inputs]) - max_diff = float(torch.max(torch.abs(ref_output - output))) - self.assertAlmostEqual( - max_diff, 0, places=3, msg="Output mismatch after cache roundtrip" + compiled2 = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=self.cache_path, + ) + out2 = compiled2(*[inp.clone() for inp in inputs]) + self.assertGreater( + cosine_similarity(ref_output, out2), + COSINE_THRESHOLD, + "Second compiled output (warm cache) should still match eager", ) - def test_save_creates_directory(self): + @parameterized.expand(_RUNTIMES) + def test_save_creates_directory(self, _name, use_python_runtime): + self._skip_if_cpp_unavailable(use_python_runtime) nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - compiled, inputs = _compile_simple(runtime_cache_path=nested_path) + model, inputs = _fresh_conv_model_and_inputs() + compiled = _compile( + model, + inputs, + use_python_runtime=use_python_runtime, + runtime_cache_path=nested_path, + ) _ = compiled(*[inp.clone() for inp in inputs]) del compiled gc.collect() @@ -176,7 +241,7 @@ def test_save_creates_directory(self): "Runtime cache is only available with TensorRT-RTX", ) class TestRuntimeCacheConcurrency(TestCase): - """Tests that file locking works for concurrent access.""" + """Tests that file locking works for concurrent access (Python runtime only).""" def setUp(self): self.cache_dir = tempfile.mkdtemp() @@ -192,7 +257,6 @@ def test_filelock_works(self): del compiled gc.collect() self.assertTrue(os.path.isfile(self.cache_path)) - # Verify we can acquire a lock on the same path (no deadlock) from filelock import FileLock lock = FileLock(self.cache_path + ".lock") @@ -202,14 +266,12 @@ def test_filelock_works(self): def test_sequential_save_load(self): """Two modules saving and loading from the same path should not corrupt data.""" - # First module saves compiled1, inputs = _compile_simple(runtime_cache_path=self.cache_path) _ = compiled1(*[inp.clone() for inp in inputs]) del compiled1 gc.collect() size1 = os.path.getsize(self.cache_path) - # Second module saves (overwrites) compiled2, inputs = _compile_simple(runtime_cache_path=self.cache_path) _ = compiled2(*[inp.clone() for inp in inputs]) del compiled2 @@ -228,7 +290,6 @@ class TestTimingCacheSkipped(TestCase): """Tests that timing cache is correctly skipped for RTX builds.""" def setUp(self): - # Clean up any pre-existing timing cache if os.path.isfile(TIMING_CACHE_PATH): os.remove(TIMING_CACHE_PATH) @@ -273,7 +334,6 @@ def test_no_runtime_config_for_standard_trt(self): ) def test_timing_cache_still_created(self): - # Clean up any pre-existing timing cache if os.path.isfile(TIMING_CACHE_PATH): os.remove(TIMING_CACHE_PATH) compiled, inputs = _compile_simple() @@ -284,120 +344,6 @@ def test_timing_cache_still_created(self): ) -class CppSimpleModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) - - def forward(self, x): - return torch.relu(self.conv(x)) - - -def _fresh_cpp_model_and_inputs(seed=0): - """Create a deterministic CppSimpleModel + input tensor pair for C++-runtime tests.""" - torch.manual_seed(seed) - return CppSimpleModel().eval().cuda(), [torch.randn(2, 3, 16, 16).cuda()] - - -def _compile_cpp(model, inputs, runtime_cache_path=None): - """Compile the given model through the C++ runtime path (use_python_runtime=False).""" - kwargs = { - "ir": "dynamo", - "inputs": inputs, - "enabled_precisions": {torch.float32}, - "use_python_runtime": False, - "min_block_size": 1, - } - if runtime_cache_path is not None: - kwargs["runtime_cache_path"] = runtime_cache_path - compiled = torchtrt.compile(model, **kwargs) - torch._dynamo.reset() - return compiled - - -@unittest.skipIf( - not ENABLED_FEATURES.torch_tensorrt_runtime, - "C++ runtime is not available", -) -@unittest.skipIf( - not ENABLED_FEATURES.tensorrt_rtx, - "Runtime cache is only available with TensorRT-RTX", -) -class TestRuntimeCacheCppPersistence(TestCase): - """Exercise the C++-runtime code path: load on engine setup, save on destructor.""" - - def setUp(self): - self.cache_dir = tempfile.mkdtemp() - self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") - - def tearDown(self): - shutil.rmtree(self.cache_dir, ignore_errors=True) - - def test_cache_saved_on_del(self): - model, inputs = _fresh_cpp_model_and_inputs() - compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - _ = compiled(*[inp.clone() for inp in inputs]) - self.assertFalse( - os.path.isfile(self.cache_path), - "Cache should not exist before module cleanup", - ) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(self.cache_path), - "Cache file should be created after module cleanup", - ) - - def test_cache_file_nonempty(self): - model, inputs = _fresh_cpp_model_and_inputs() - compiled = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertGreater( - os.path.getsize(self.cache_path), - 0, - "Cache file should have nonzero size", - ) - - def test_cache_roundtrip(self): - """Compile, infer, save. Then recompile same model+cache and verify correctness.""" - model, inputs = _fresh_cpp_model_and_inputs() - with torch.no_grad(): - ref_output = model(*inputs) - - compiled1 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - out1 = compiled1(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out1), - COSINE_THRESHOLD, - "First compiled output should match eager", - ) - del compiled1 - gc.collect() - self.assertTrue(os.path.isfile(self.cache_path)) - - compiled2 = _compile_cpp(model, inputs, runtime_cache_path=self.cache_path) - out2 = compiled2(*[inp.clone() for inp in inputs]) - self.assertGreater( - cosine_similarity(ref_output, out2), - COSINE_THRESHOLD, - "Second compiled output (warm cache) should still match eager", - ) - - def test_save_creates_directory(self): - nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") - model, inputs = _fresh_cpp_model_and_inputs() - compiled = _compile_cpp(model, inputs, runtime_cache_path=nested_path) - _ = compiled(*[inp.clone() for inp in inputs]) - del compiled - gc.collect() - self.assertTrue( - os.path.isfile(nested_path), - "Save should create intermediate directories", - ) - - @unittest.skipIf( not ENABLED_FEATURES.torch_tensorrt_runtime, "C++ runtime is not available", diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 598efa71cc..d514be86d1 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -2,16 +2,29 @@ import torch import torch_tensorrt as torchtrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._settings import CompilationSettings +_STRATEGIES = [("lazy",), ("eager",), ("none",)] + class SimpleModel(torch.nn.Module): def forward(self, x): return torch.relu(x) + 1.0 +class DynamicConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) + self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) + + def forward(self, x): + return torch.relu(self.conv2(torch.relu(self.conv1(x)))) + + def _compile_simple(**extra_kwargs): """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" model = SimpleModel().eval().cuda() @@ -36,13 +49,34 @@ def _compile_simple(**extra_kwargs): return compiled +def _compile_cpp(strategy): + model = DynamicConvModel().eval().cuda() + inp = torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[inp], + enabled_precisions={torch.float32}, + use_python_runtime=False, + min_block_size=1, + dynamic_shapes_kernel_specialization_strategy=strategy, + ) + torch._dynamo.reset() + return compiled + + def _find_python_trt_module(compiled): """Walk the compiled graph module to find PythonTorchTensorRTModule instances.""" from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( PythonTorchTensorRTModule, ) - for name, mod in compiled.named_modules(): + for _name, mod in compiled.named_modules(): if isinstance(mod, PythonTorchTensorRTModule): return mod return None @@ -55,6 +89,12 @@ def _find_python_trt_module(compiled): class TestDynamicShapesKernelStrategySetup(TestCase): """Tests that the dynamic shapes kernel specialization strategy is correctly applied.""" + _EXPECTED_ENUM = { + "lazy": "LAZY", + "eager": "EAGER", + "none": "NONE", + } + def test_default_strategy_is_lazy(self): import tensorrt as trt @@ -67,28 +107,21 @@ def test_default_strategy_is_lazy(self): trt.DynamicShapesKernelSpecializationStrategy.LAZY, ) - def test_eager_strategy(self): + @parameterized.expand(_STRATEGIES) + def test_strategy_applied(self, strategy): import tensorrt as trt compiled = _compile_simple( - dynamic_shapes_kernel_specialization_strategy="eager" - ) - mod = _find_python_trt_module(compiled) - self.assertIsNotNone(mod) - self.assertEqual( - mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.EAGER, + dynamic_shapes_kernel_specialization_strategy=strategy ) - - def test_none_strategy(self): - import tensorrt as trt - - compiled = _compile_simple(dynamic_shapes_kernel_specialization_strategy="none") mod = _find_python_trt_module(compiled) self.assertIsNotNone(mod) self.assertEqual( mod.runtime_config.dynamic_shapes_kernel_specialization_strategy, - trt.DynamicShapesKernelSpecializationStrategy.NONE, + getattr( + trt.DynamicShapesKernelSpecializationStrategy, + self._EXPECTED_ENUM[strategy], + ), ) def test_context_created_with_each_strategy(self): @@ -101,7 +134,6 @@ def test_context_created_with_each_strategy(self): self.assertIsNotNone( mod.context, f"Execution context should be created for {strategy}" ) - # Test inference with multiple dynamic batch sizes for bs in (1, 2, 4): output = compiled(torch.randn(bs, 3).cuda()) self.assertEqual(output.shape, (bs, 3)) @@ -137,42 +169,10 @@ def test_setting_ignored_on_non_rtx(self): mod.runtime_config, "runtime_config should be None for standard TRT", ) - # Inference should still work output = compiled(torch.randn(2, 3).cuda()) self.assertEqual(output.shape, (2, 3)) -class DynamicConvModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1) - self.conv2 = torch.nn.Conv2d(16, 8, 3, padding=1) - - def forward(self, x): - return torch.relu(self.conv2(torch.relu(self.conv1(x)))) - - -def _compile_cpp(strategy): - model = DynamicConvModel().eval().cuda() - inp = torchtrt.Input( - min_shape=(1, 3, 16, 16), - opt_shape=(2, 3, 16, 16), - max_shape=(4, 3, 16, 16), - dtype=torch.float32, - ) - compiled = torchtrt.compile( - model, - ir="dynamo", - inputs=[inp], - enabled_precisions={torch.float32}, - use_python_runtime=False, - min_block_size=1, - dynamic_shapes_kernel_specialization_strategy=strategy, - ) - torch._dynamo.reset() - return compiled - - @unittest.skipIf( not ENABLED_FEATURES.torch_tensorrt_runtime, "C++ runtime is not available", @@ -184,22 +184,9 @@ def _compile_cpp(strategy): class TestDynamicShapesKernelStrategyCpp(TestCase): """End-to-end: compile + infer through the C++ runtime with each strategy.""" - def test_lazy(self): - compiled = _compile_cpp("lazy") - x = torch.randn(2, 3, 16, 16, device="cuda") - y = compiled(x) - self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) - self.assertTrue(torch.isfinite(y).all().item()) - - def test_eager(self): - compiled = _compile_cpp("eager") - x = torch.randn(2, 3, 16, 16, device="cuda") - y = compiled(x) - self.assertEqual(tuple(y.shape), (2, 8, 16, 16)) - self.assertTrue(torch.isfinite(y).all().item()) - - def test_none(self): - compiled = _compile_cpp("none") + @parameterized.expand(_STRATEGIES) + def test_strategy_inference(self, strategy): + compiled = _compile_cpp(strategy) x = torch.randn(2, 3, 16, 16, device="cuda") y = compiled(x) self.assertEqual(tuple(y.shape), (2, 8, 16, 16))