diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index f17c375489..460100184a 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -271,11 +271,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): @unittest.skipIf( not importlib.util.find_spec("torchvision"), "torchvision not installed" ) - @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", - ) def test_dynamo_compile_with_custom_engine_cache(self): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -347,11 +342,6 @@ def test_dynamo_compile_with_custom_engine_cache(self): @unittest.skipIf( not importlib.util.find_spec("torchvision"), "torchvision not installed" ) - @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", - ) def test_dynamo_compile_change_input_shape(self): """Runs compilation 3 times, the cache should miss each time""" model = models.resnet18(pretrained=True).eval().to("cuda") @@ -673,8 +663,7 @@ def forward(self, c, d): ) @unittest.skipIf( torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", + "Engine caching compilation time assertion is unreliable with TensorRT-RTX", ) def test_caching_small_model(self): from torch_tensorrt.dynamo._refit import refit_module_weights diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 6bf1b58f71..dc2b01ea65 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -4,7 +4,6 @@ import shutil import unittest -import tensorrt as trt import torch import torch_tensorrt as torch_trt from torch.testing._internal.common_utils import TestCase @@ -13,6 +12,8 @@ from torch_tensorrt.dynamo._refit import refit_module_weights from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity +import tensorrt as trt # isort: skip # must import after torch_tensorrt to resolve tensorrt_rtx alias + assertions = unittest.TestCase() if importlib.util.find_spec("torchvision"): @@ -278,9 +279,8 @@ def test_engine_caching_saves_weight_stripped_engine(self): "torchvision is not installed", ) @unittest.skipIf( - torch_trt.ENABLED_FEATURES.tensorrt_rtx, - # TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752 - "There is bug in refit, so we skip the test for now", + torch_trt.ENABLED_FEATURES.tensorrt_rtx, + "Engine caching compilation time assertion is unreliable with TensorRT-RTX", ) def test_dynamo_compile_with_refittable_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda")