Skip to content

Conversation

@DoeringChristian
Copy link
Contributor

@DoeringChristian DoeringChristian commented Nov 24, 2025

This PR makes advanced tensor indexing compatible with PyTorch and NumPy.

When indexing a tensor with an index array, we previously indexed all combinations of the arrays. In the following example, we are indexing the element at [0, 0], [0, 2], [2, 0], [2, 2].

In [1]: import drjit as dr
In [2]: from drjit.auto import TensorXf, Float, UInt32
In [3]: x = TensorXf(dr.arange(Float, 16), (4, 4))
In [4]: x
Out[4]: 
[[0, 1, 2, 3],
 [4, 5, 6, 7],
 [8, 9, 10, 11],
 [12, 13, 14, 15]]
In [5]: x[UInt32(0, 2), UInt32(0, 2)]
Out[5]: 
[[0, 2],
 [8, 10]]

In contrast, PyTorch and NumPy handle this case differently. If arrays are used to index into tensors, these frameworks only return the entries at the diagonal coordinates [0, 0] and [2, 2]. This PR changes the behavior of the indexing function to match that of PyTorch.

In [1]: import drjit as dr
In [2]: from drjit.auto import TensorXf, Float, UInt32
In [3]: x = TensorXf(dr.arange(Float, 16), (4, 4))
In [4]: x
Out[4]: 
[[0, 1, 2, 3],
 [4, 5, 6, 7],
 [8, 9, 10, 11],
 [12, 13, 14, 15]]
In [5]: x[UInt32(0, 2), UInt32(0, 2)]
Out[5]: [0, 10]

The PR also adds tests that ensure compatibility with the PyTorch indexing mechanism.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants