diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index f6226385de..37a7c42252 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -197,8 +197,13 @@ def slice_scatter_decomposition( dim_size = input_tensor.shape[dim] device_input_tensor = input_tensor.device - start = 0 if start is None else start # Ensure start is int - start = get_positive_dim(start, input_tensor.shape[dim]) + if start is None: + start = 0 + elif isinstance(start, int): + start = get_positive_dim(start, dim_size) + elif isinstance(start, torch.SymInt): + if start < 0: + start = start + dim_size if end is None: # Ensure end is int end = dim_size end = (