Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
)
def test_dynamo_compile_with_custom_engine_cache(self):
model = models.resnet18(pretrained=True).eval().to("cuda")

Expand Down Expand Up @@ -347,11 +342,6 @@ def test_dynamo_compile_with_custom_engine_cache(self):
@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
)
def test_dynamo_compile_change_input_shape(self):
"""Runs compilation 3 times, the cache should miss each time"""
model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -673,8 +663,7 @@ def forward(self, c, d):
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
Comment thread
lanluo-nvidia marked this conversation as resolved.
)
def test_caching_small_model(self):
from torch_tensorrt.dynamo._refit import refit_module_weights
Expand Down
8 changes: 4 additions & 4 deletions tests/py/dynamo/models/test_weight_stripped_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import shutil
import unittest

import tensorrt as trt
import torch
import torch_tensorrt as torch_trt
from torch.testing._internal.common_utils import TestCase
Expand All @@ -13,6 +12,8 @@
from torch_tensorrt.dynamo._refit import refit_module_weights
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

import tensorrt as trt # isort: skip # must import after torch_tensorrt to resolve tensorrt_rtx alias

assertions = unittest.TestCase()

if importlib.util.find_spec("torchvision"):
Expand Down Expand Up @@ -278,9 +279,8 @@ def test_engine_caching_saves_weight_stripped_engine(self):
"torchvision is not installed",
)
@unittest.skipIf(
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
# TODO: need to fix this https://github.com/pytorch/TensorRT/issues/3752
"There is bug in refit, so we skip the test for now",
torch_trt.ENABLED_FEATURES.tensorrt_rtx,
"Engine caching compilation time assertion is unreliable with TensorRT-RTX",
)
def test_dynamo_compile_with_refittable_weight_stripped_engine(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down
Loading