From fe174e7639f7632a9be1857d433bcb49c2b36303 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Thu, 16 Apr 2026 00:01:44 -0700 Subject: [PATCH 1/2] fix(test): enable TRT-RTX refit and engine cache tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now that #4181 removed the RTX-specific batch norm workaround that bypassed constant folding, the refit bug (#3752) is resolved — eps constants are no longer created as separate CONSTANT layers on RTX. Remove the RTX skip decorators from: - test_dynamo_compile_with_refittable_weight_stripped_engine - test_dynamo_compile_with_custom_engine_cache - test_dynamo_compile_change_input_shape Keep the RTX skip on test_caching_small_model, which fails a timing assertion (cached compilation is slower than uncached on RTX). Update the skip message to reflect the actual reason. Fix import ordering in test_weight_stripped_engine.py: tensorrt must be imported after torch_tensorrt so the tensorrt_rtx module alias is resolved correctly. Fixes #3752 --- tests/py/dynamo/models/test_engine_cache.py | 13 +------------ .../py/dynamo/models/test_weight_stripped_engine.py | 8 ++------ 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index f17c375489..460100184a 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -271,11 +271,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): @unittest.skipIf( not importlib.util.find_spec("torchvision"), "torchvision not installed" ) - @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", - ) def test_dynamo_compile_with_custom_engine_cache(self): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -347,11 +342,6 @@ def test_dynamo_compile_with_custom_engine_cache(self): @unittest.skipIf( not importlib.util.find_spec("torchvision"), "torchvision not installed" ) - @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", - ) def test_dynamo_compile_change_input_shape(self): """Runs compilation 3 times, the cache should miss each time""" model = models.resnet18(pretrained=True).eval().to("cuda") @@ -673,8 +663,7 @@ def forward(self, c, d): ) @unittest.skipIf( torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", + "Engine caching compilation time assertion is unreliable with TensorRT-RTX", ) def test_caching_small_model(self): from torch_tensorrt.dynamo._refit import refit_module_weights diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 6bf1b58f71..57ad38fd08 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -4,7 +4,6 @@ import shutil import unittest -import tensorrt as trt import torch import torch_tensorrt as torch_trt from torch.testing._internal.common_utils import TestCase @@ -13,6 +12,8 @@ from torch_tensorrt.dynamo._refit import refit_module_weights from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +import tensorrt as trt # isort: skip # must import after torch_tensorrt to resolve tensorrt_rtx alias + assertions = unittest.TestCase() if importlib.util.find_spec("torchvision"): @@ -277,11 +278,6 @@ def test_engine_caching_saves_weight_stripped_engine(self): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) - @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", - ) def test_dynamo_compile_with_refittable_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) From f3e7ccd8e0030989aa9031fead918b5d67c48e53 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 17 Apr 2026 14:30:24 -0700 Subject: [PATCH 2/2] skip : weight stripped engine timing test Signed-off-by: tejaswinp --- tests/py/dynamo/models/test_weight_stripped_engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 57ad38fd08..dc2b01ea65 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -278,6 +278,10 @@ def test_engine_caching_saves_weight_stripped_engine(self): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) + @unittest.skipIf( + torch_trt.ENABLED_FEATURES.tensorrt_rtx, + "Engine caching compilation time assertion is unreliable with TensorRT-RTX", + ) def test_dynamo_compile_with_refittable_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)