-
Notifications
You must be signed in to change notification settings - Fork 394
feat: add dynamic shapes kernel specialization strategy for TRT-RTX #4184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
lanluo-nvidia
merged 2 commits into
pytorch:main
from
tp5uiuc:feat/trtrtx-dynamic-shapes-strategy
Apr 21, 2026
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
tests/py/dynamo/models/test_dynamic_shapes_kernel_strategy_models.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| import importlib | ||
| import unittest | ||
|
|
||
| import torch | ||
| import torch_tensorrt as torchtrt | ||
| from torch.testing._internal.common_utils import TestCase, run_tests | ||
| from torch_tensorrt._features import ENABLED_FEATURES | ||
| from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not ENABLED_FEATURES.tensorrt_rtx, | ||
| "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", | ||
| ) | ||
| @unittest.skipIf( | ||
| not importlib.util.find_spec("torchvision"), | ||
| "torchvision is not installed", | ||
| ) | ||
| class TestDynamicShapesKernelStrategyModels(TestCase): | ||
| """End-to-end model tests with different kernel specialization strategies.""" | ||
|
|
||
| def tearDown(self): | ||
| torch._dynamo.reset() | ||
|
|
||
| def _compile_and_verify(self, model, strategy): | ||
| input_tensor = torch.randn(4, 3, 224, 224).cuda() | ||
| compiled = torchtrt.compile( | ||
| model, | ||
| ir="dynamo", | ||
| inputs=[ | ||
| torchtrt.Input( | ||
| min_shape=(1, 3, 224, 224), | ||
| opt_shape=(4, 3, 224, 224), | ||
| max_shape=(8, 3, 224, 224), | ||
| dtype=torch.float32, | ||
| ) | ||
| ], | ||
| enabled_precisions={torch.float32}, | ||
| use_python_runtime=True, | ||
| min_block_size=1, | ||
| dynamic_shapes_kernel_specialization_strategy=strategy, | ||
|
tp5uiuc marked this conversation as resolved.
|
||
| ) | ||
| ref_output = model(input_tensor) | ||
| trt_output = compiled(input_tensor) | ||
| cos_sim = cosine_similarity(ref_output, trt_output) | ||
| self.assertTrue( | ||
| cos_sim > COSINE_THRESHOLD, | ||
| f"Cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD} " | ||
| f"with strategy={strategy}", | ||
| ) | ||
|
|
||
| def test_resnet18_lazy_strategy(self): | ||
| import torchvision.models as models | ||
|
|
||
| model = models.resnet18(pretrained=True).eval().cuda() | ||
| self._compile_and_verify(model, "lazy") | ||
|
|
||
| def test_resnet18_eager_strategy(self): | ||
| import torchvision.models as models | ||
|
|
||
| model = models.resnet18(pretrained=True).eval().cuda() | ||
| self._compile_and_verify(model, "eager") | ||
|
|
||
| def test_resnet18_none_strategy(self): | ||
| import torchvision.models as models | ||
|
|
||
| model = models.resnet18(pretrained=True).eval().cuda() | ||
| self._compile_and_verify(model, "none") | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not ENABLED_FEATURES.tensorrt_rtx, | ||
| "Dynamic shapes kernel specialization strategy requires TensorRT-RTX", | ||
| ) | ||
| class TestDynamicShapesKernelStrategyDynamic(TestCase): | ||
| """Tests kernel specialization strategies with dynamic input shapes.""" | ||
|
|
||
| def tearDown(self): | ||
| torch._dynamo.reset() | ||
|
|
||
| def _test_dynamic_batch_with_strategy(self, strategy): | ||
| class ConvModel(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) | ||
| self.relu = torch.nn.ReLU() | ||
|
|
||
| def forward(self, x): | ||
| return self.relu(self.conv(x)) | ||
|
|
||
| model = ConvModel().eval().cuda() | ||
|
|
||
| compiled = torchtrt.compile( | ||
| model, | ||
| ir="dynamo", | ||
| inputs=[ | ||
| torchtrt.Input( | ||
| min_shape=(1, 3, 32, 32), | ||
| opt_shape=(4, 3, 32, 32), | ||
| max_shape=(8, 3, 32, 32), | ||
| dtype=torch.float32, | ||
| ) | ||
| ], | ||
| enabled_precisions={torch.float32}, | ||
| use_python_runtime=True, | ||
| min_block_size=1, | ||
| dynamic_shapes_kernel_specialization_strategy=strategy, | ||
| ) | ||
|
|
||
| for batch_size in (1, 4, 8): | ||
| with self.subTest(batch_size=batch_size, strategy=strategy): | ||
| input_tensor = torch.randn(batch_size, 3, 32, 32).cuda() | ||
| ref_output = model(input_tensor) | ||
| trt_output = compiled(input_tensor) | ||
| cos_sim = cosine_similarity(ref_output, trt_output) | ||
| self.assertTrue( | ||
| cos_sim > COSINE_THRESHOLD, | ||
| f"BS={batch_size}, strategy={strategy}: cosine similarity " | ||
| f"{cos_sim} below threshold {COSINE_THRESHOLD}", | ||
| ) | ||
|
|
||
| def test_dynamic_batch_lazy(self): | ||
| self._test_dynamic_batch_with_strategy("lazy") | ||
|
|
||
| def test_dynamic_batch_eager(self): | ||
| self._test_dynamic_batch_with_strategy("eager") | ||
|
|
||
| def test_dynamic_batch_none(self): | ||
| self._test_dynamic_batch_with_strategy("none") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a warning or check in case user configured dynamic_shapes_kernel_specialization_strategy in TensorRT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good suggestion Lan, I have a followup task to emit user warnings for
so that its easier to review the change/behavior. I will put it in then