Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def cross_compile_for_windows(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -174,6 +175,7 @@ def cross_compile_for_windows(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -339,6 +341,7 @@ def cross_compile_for_windows(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -451,6 +454,7 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -547,6 +551,7 @@ def compile(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a warning or check in case user configured dynamic_shapes_kernel_specialization_strategy in TensorRT

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good suggestion Lan, I have a followup task to emit user warnings for

  1. timing cache used in TRT-RTX
  2. runtime cache used in standard TRT
  3. dynamic shape strategy used in standard TRT
  4. cudagraphs flag used in standard TRT
    so that its easier to review the change/behavior. I will put it in then

lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -755,6 +760,7 @@ def compile(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -1169,6 +1175,7 @@ def convert_exported_program_to_serialized_trt_engine(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -1246,6 +1253,7 @@ def convert_exported_program_to_serialized_trt_engine(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -1420,6 +1428,7 @@ def convert_exported_program_to_serialized_trt_engine(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
DYNAMICALLY_ALLOCATE_RESOURCES = False
DECOMPOSE_ATTENTION = False
ATTN_BIAS_IS_CAUSAL = True
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy"

if platform.system() == "Linux":
import pwd
Expand Down
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
DYNAMICALLY_ALLOCATE_RESOURCES,
ENABLE_AUTOCAST,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
Expand Down Expand Up @@ -100,6 +101,7 @@ class CompilationSettings:
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning).
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy".
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
Expand Down Expand Up @@ -154,6 +156,9 @@ class CompilationSettings:
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
runtime_cache_path: str = RUNTIME_CACHE_PATH
dynamic_shapes_kernel_specialization_strategy: str = (
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY
)
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
logger = logging.getLogger(__name__)


def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any:
"""Map strategy string to TRT enum. Only called on RTX builds."""
return {
"lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY,
"eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER,
"none": trt.DynamicShapesKernelSpecializationStrategy.NONE,
}.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY)


class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc]
def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None:
trt.IOutputAllocator.__init__(self)
Expand Down Expand Up @@ -345,6 +354,14 @@ def _setup_runtime_config(self) -> None:
self.runtime_config.set_execution_context_allocation_strategy(
trt.ExecutionContextAllocationStrategy.STATIC
)
self.runtime_config.dynamic_shapes_kernel_specialization_strategy = (
_get_dynamic_shapes_kernel_strategy(
self.settings.dynamic_shapes_kernel_specialization_strategy
)
)
logger.info(
f"Dynamic shapes kernel specialization strategy: {self.settings.dynamic_shapes_kernel_specialization_strategy}"
)
self.runtime_cache = self.runtime_config.create_runtime_cache()
self._load_runtime_cache()
self.runtime_config.set_runtime_cache(self.runtime_cache)
Expand Down
133 changes: 133 additions & 0 deletions tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import importlib
import unittest

import torch
import torch_tensorrt as torchtrt
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity


@unittest.skipIf(
not ENABLED_FEATURES.tensorrt_rtx,
"Dynamic shapes kernel specialization strategy requires TensorRT-RTX",
)
@unittest.skipIf(
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
class TestDynamicShapesKernelStrategyModels(TestCase):
"""End-to-end model tests with different kernel specialization strategies."""

def tearDown(self):
torch._dynamo.reset()

def _compile_and_verify(self, model, strategy):
input_tensor = torch.randn(4, 3, 224, 224).cuda()
compiled = torchtrt.compile(
model,
ir="dynamo",
inputs=[
torchtrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32,
)
],
enabled_precisions={torch.float32},
use_python_runtime=True,
min_block_size=1,
dynamic_shapes_kernel_specialization_strategy=strategy,
Comment thread
tp5uiuc marked this conversation as resolved.
)
ref_output = model(input_tensor)
trt_output = compiled(input_tensor)
cos_sim = cosine_similarity(ref_output, trt_output)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"Cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} "
f"with strategy={strategy}",
)

def test_resnet18_lazy_strategy(self):
import torchvision.models as models

model = models.resnet18(pretrained=True).eval().cuda()
self._compile_and_verify(model, "lazy")

def test_resnet18_eager_strategy(self):
import torchvision.models as models

model = models.resnet18(pretrained=True).eval().cuda()
self._compile_and_verify(model, "eager")

def test_resnet18_none_strategy(self):
import torchvision.models as models

model = models.resnet18(pretrained=True).eval().cuda()
self._compile_and_verify(model, "none")


@unittest.skipIf(
not ENABLED_FEATURES.tensorrt_rtx,
"Dynamic shapes kernel specialization strategy requires TensorRT-RTX",
)
class TestDynamicShapesKernelStrategyDynamic(TestCase):
"""Tests kernel specialization strategies with dynamic input shapes."""

def tearDown(self):
torch._dynamo.reset()

def _test_dynamic_batch_with_strategy(self, strategy):
class ConvModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, padding=1)
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(self.conv(x))

model = ConvModel().eval().cuda()

compiled = torchtrt.compile(
model,
ir="dynamo",
inputs=[
torchtrt.Input(
min_shape=(1, 3, 32, 32),
opt_shape=(4, 3, 32, 32),
max_shape=(8, 3, 32, 32),
dtype=torch.float32,
)
],
enabled_precisions={torch.float32},
use_python_runtime=True,
min_block_size=1,
dynamic_shapes_kernel_specialization_strategy=strategy,
)

for batch_size in (1, 4, 8):
with self.subTest(batch_size=batch_size, strategy=strategy):
input_tensor = torch.randn(batch_size, 3, 32, 32).cuda()
ref_output = model(input_tensor)
trt_output = compiled(input_tensor)
cos_sim = cosine_similarity(ref_output, trt_output)
self.assertTrue(
cos_sim > COSINE_THRESHOLD,
f"BS={batch_size}, strategy={strategy}: cosine similarity "
f"{cos_sim} below threshold {COSINE_THRESHOLD}",
)

def test_dynamic_batch_lazy(self):
self._test_dynamic_batch_with_strategy("lazy")

def test_dynamic_batch_eager(self):
self._test_dynamic_batch_with_strategy("eager")

def test_dynamic_batch_none(self):
self._test_dynamic_batch_with_strategy("none")


if __name__ == "__main__":
run_tests()
Loading
Loading