diff --git a/src/python/slice.cpp b/src/python/slice.cpp index fc309e07d..178e5f088 100644 --- a/src/python/slice.cpp +++ b/src/python/slice.cpp @@ -17,19 +17,29 @@ #include "base.h" #include "slice.h" #include "meta.h" +#include #include /// Holds metadata about slicing component struct Component { + enum Type { None, Integer, Slice, Advanced, Ellipsis }; + + Type type; Py_ssize_t start, step, slice_size, size; nb::object object; - Component(Py_ssize_t start, Py_ssize_t step, Py_ssize_t slice_size, + // Constructor for None indices + Component(Type t) + : type(t), start(0), step(1), slice_size(1), size(1) { } + + // Constructor for integer and slice indices + Component(Type t, Py_ssize_t start, Py_ssize_t step, Py_ssize_t slice_size, Py_ssize_t size) - : start(start), step(step), slice_size(slice_size), size(size) { } + : type(t), start(start), step(step), slice_size(slice_size), size(size) { } - Component(nb::handle h, Py_ssize_t slice_size, Py_ssize_t size) - : start(0), step(1), slice_size(slice_size), size(size), + // Constructor for advanced indices (array indexing) + Component(Type t, nb::handle h, Py_ssize_t slice_size, Py_ssize_t size) + : type(t), start(0), step(1), slice_size(slice_size), size(size), object(nb::borrow(h)) { } }; @@ -72,11 +82,15 @@ slice_index(const nb::type_object_t &dtype, indices_len = nb::len(indices); std::vector components; - components.reserve(shape_len); + components.reserve(indices_len); // May include None indices + + // First pass: parse indices + nb::list basic_shapes; // Shapes from basic indexing (slices) + size_t advanced_size = 0; // Size of advanced index arrays (all must be same) for (nb::handle h : indices) { if (h.is_none()) { - shape_out.append(1); + components.emplace_back(Component::None); continue; } @@ -96,15 +110,14 @@ slice_index(const nb::type_object_t &dtype, "bounds for axis %zu with size %zd.", v, components.size(), size); - components.emplace_back(v, 1, 1, size); + components.emplace_back(Component::Integer, v, 1, 1, size); continue; } else if (tp.is(&PySlice_Type)) { Py_ssize_t start, stop, step; size_t slice_length; nb::detail::slice_compute(h.ptr(), size, start, stop, step, slice_length); - components.emplace_back(start, step, (Py_ssize_t) slice_length, size); - shape_out.append(slice_length); - size_out *= slice_length; + components.emplace_back(Component::Slice, start, step, (Py_ssize_t) slice_length, size); + basic_shapes.append(slice_length); continue; } else if (is_drjit_type(tp)) { const ArraySupplement *s2 = &supp(tp); @@ -138,9 +151,21 @@ slice_index(const nb::type_object_t &dtype, if (!o.type().is(dtype)) o = dtype(o); - components.emplace_back(o, slice_size, size); - shape_out.append(slice_size); - size_out *= slice_size; + components.emplace_back(Component::Advanced, o, slice_size, size); + + // Track the maximum size for broadcasting + // PyTorch/NumPy broadcast all advanced indices to the same shape + if (advanced_size == 0) { + advanced_size = slice_size; + } else if (slice_size != 1 && advanced_size != 1 && advanced_size != slice_size) { + // Broadcasting rules: sizes must be 1 or equal + nb::raise("drjit.slice_index(): advanced index arrays with shapes %zu and %zu " + "cannot be broadcast together.", advanced_size, slice_size); + } else if (slice_size > advanced_size) { + // Update to the larger size (broadcasting smaller arrays to match) + advanced_size = slice_size; + } + continue; } } else if (tp.is(&PyEllipsis_Type)) { @@ -151,9 +176,8 @@ slice_index(const nb::type_object_t &dtype, if (shape_offset >= shape_len) nb::detail::fail("slice_index(): internal error."); size = nb::cast(shape[shape_offset++]); - components.emplace_back(0, 1, size, size); - shape_out.append(size); - size_out *= size; + components.emplace_back(Component::Slice, 0, 1, size, size); + basic_shapes.append(size); } continue; } @@ -167,43 +191,232 @@ slice_index(const nb::type_object_t &dtype, // Implicit ellipsis at the end while (shape_offset != shape_len) { Py_ssize_t size = nb::cast(shape[shape_offset++]); - components.emplace_back(0, 1, size, size); - shape_out.append(size); - size_out *= size; + components.emplace_back(Component::Slice, 0, 1, size, size); + basic_shapes.append(size); + } + + // Build output shape following PyTorch/NumPy advanced indexing rules: + // - None indices create new dimensions of size 1 at their positions + // - Integer indices reduce dimensions (don't appear in output) + // - Advanced indices: if consecutive, stay in place; if non-consecutive, move to front + shape_out.clear(); + + // Check if there are advanced indices and if they're consecutive + int first_adv = -1, last_adv = -1; + for (size_t i = 0; i < components.size(); ++i) { + if (components[i].type == Component::Advanced) { + if (first_adv == -1) first_adv = i; + last_adv = i; + } } + bool has_advanced = (first_adv != -1); + bool consecutive = true; + if (has_advanced) { + for (int i = first_adv; i <= last_adv; ++i) { + if (components[i].type == Component::None) continue; // None doesn't break consecutiveness + if (components[i].type != Component::Advanced) { + consecutive = false; + break; + } + } + } + + // Build output shape based on index arrangement + if (has_advanced && consecutive) { + // Advanced indices are consecutive: replace all with a single dimension + bool advanced_added = false; + for (const auto &comp : components) { + if (comp.type == Component::None) { + shape_out.append(1); + } else if (comp.type == Component::Slice) { + shape_out.append(comp.slice_size); + } else if (comp.type == Component::Advanced) { + if (!advanced_added) { + // All consecutive advanced indices produce a single dimension + shape_out.append(advanced_size); + advanced_added = true; + } + // Subsequent advanced indices don't add dimensions + } + // Integer indices don't contribute + } + } else if (has_advanced && !consecutive) { + // Advanced indices are non-consecutive: move to front + shape_out.append(advanced_size); + for (const auto &comp : components) { + if (comp.type == Component::None) { + shape_out.append(1); + } else if (comp.type == Component::Slice) { + shape_out.append(comp.slice_size); + } + // Integer and Advanced (already added) don't contribute here + } + } else { + // No advanced indexing: process each index type in order + for (const auto &comp : components) { + if (comp.type == Component::None) { + shape_out.append(1); + } else if (comp.type == Component::Slice) { + shape_out.append(comp.slice_size); + } + // Integer indices don't contribute to shape + } + } + + // Calculate total size from the actual output shape + size_out = 1; + for (nb::handle h : shape_out) + size_out *= nb::cast(h); + nb::object index = arange(dtype, 0, size_out, 1), index_out; nb::object active = nb::borrow(Py_True); if (size_out) { - size_out = 1; + // Unified algorithm that handles both basic and advanced indexing index_out = dtype(0); + // Calculate the stride multiplier for the input tensor dimensions + // Skip None components as they don't correspond to input dimensions + size_t input_stride = 1; + std::vector input_strides; for (auto it = components.rbegin(); it != components.rend(); ++it) { - const Component &c = *it; - nb::object index_next, index_rem; - - if (it + 1 != components.rend()) { - index_next = index.floor_div(dtype(c.slice_size)); - index_rem = fma(index_next, dtype(uint32_t(-c.slice_size)), index); + if (it->type == Component::None) { + input_strides.push_back(0); // Placeholder for None + } else { + input_strides.push_back(input_stride); + input_stride *= it->size; + } + } + std::reverse(input_strides.begin(), input_strides.end()); + + // Decompose output index according to output shape + nb::object remaining = index; + std::vector output_dim_indices; + + // Decompose based on actual output shape (in reverse order) + for (size_t i = nb::len(shape_out); i > 0; --i) { + size_t dim_size = nb::cast(shape_out[i - 1]); + nb::object dim_idx; + if (i > 1) { + nb::object quotient = remaining.floor_div(dtype(dim_size)); + dim_idx = remaining - quotient * dtype(dim_size); + remaining = quotient; } else { - index_rem = index; + dim_idx = remaining; } + output_dim_indices.insert(output_dim_indices.begin(), dim_idx); + } - nb::object index_val; - if (!c.object.is_valid()) - index_val = fma(index_rem, dtype(uint32_t(c.step * size_out)), - dtype(uint32_t(c.start * size_out))); - else - index_val = gather(dtype, c.object, index_rem, active, - ReduceMode::Auto) * - dtype(uint32_t(size_out)); + // Check if there are advanced indices and if they're consecutive + int first_adv = -1, last_adv = -1; + for (size_t i = 0; i < components.size(); ++i) { + if (components[i].type == Component::Advanced) { + if (first_adv == -1) first_adv = i; + last_adv = i; + } + } + + bool has_advanced = (first_adv != -1); + bool consecutive = true; + if (has_advanced) { + for (int i = first_adv; i <= last_adv; ++i) { + if (components[i].type == Component::None) continue; + if (components[i].type != Component::Advanced) { + consecutive = false; + break; + } + } + } + + // Extract advanced_idx and basic indices from output_dim_indices + nb::object advanced_idx = dtype(0); + std::vector basic_dim_indices; + size_t output_idx = 0; + bool advanced_found = false; + + if (has_advanced && consecutive) { + // Advanced indices are consecutive: they stay in their natural position + for (const auto &comp : components) { + if (comp.type == Component::None) { + output_idx++; + } else if (comp.type == Component::Advanced) { + if (!advanced_found) { + advanced_idx = output_dim_indices[output_idx]; + advanced_found = true; + } + output_idx++; + } else if (comp.type == Component::Slice) { + basic_dim_indices.push_back(output_dim_indices[output_idx]); + output_idx++; + } + } + } else if (has_advanced && !consecutive) { + // Advanced indices are non-consecutive: they're moved to the front + advanced_idx = output_dim_indices[0]; + output_idx = 1; + for (const auto &comp : components) { + if (comp.type == Component::None) { + if (output_idx < output_dim_indices.size()) { + output_idx++; + } + } else if (comp.type == Component::Slice) { + if (output_idx < output_dim_indices.size()) { + basic_dim_indices.push_back(output_dim_indices[output_idx]); + output_idx++; + } + } + } + } else { + // No advanced indexing: just map output dimensions to input + for (const auto &comp : components) { + if (comp.type == Component::None) { + output_idx++; + } else if (comp.type == Component::Slice) { + if (output_idx < output_dim_indices.size()) { + basic_dim_indices.push_back(output_dim_indices[output_idx]); + output_idx++; + } + } + } + } + + // Map output indices back to input dimensions + size_t basic_idx_counter = 0; + for (size_t i = 0; i < components.size(); ++i) { + const Component &c = components[i]; + + // Skip None indices as they don't correspond to input dimensions + if (c.type == Component::None) + continue; - index_out += index_val; + nb::object dim_index; + + if (c.type == Component::Advanced) { + // Advanced index: use the advanced_idx to gather from the index array + // Handle broadcasting: if the index array has size 1, broadcast it + if (c.slice_size == 1) { + dim_index = gather(dtype, c.object, dtype(0), active, ReduceMode::Auto); + } else { + dim_index = gather(dtype, c.object, advanced_idx, active, ReduceMode::Auto); + } + } else if (c.type == Component::Integer) { + // Integer index + dim_index = dtype(c.start); + } else if (c.type == Component::Slice) { + // Basic slice: get the dimension index and apply slice transformation + if (basic_idx_counter < basic_dim_indices.size()) { + dim_index = basic_dim_indices[basic_idx_counter]; + dim_index = fma(dim_index, dtype(uint32_t(c.step)), dtype(uint32_t(c.start))); + basic_idx_counter++; + } else { + dim_index = dtype(c.start); + } + } - index = std::move(index_next); - size_out *= c.size; + // Add contribution to output index + index_out += dim_index * dtype(uint32_t(input_strides[i])); } } else { index_out = dtype(); diff --git a/tests/test_freeze.py b/tests/test_freeze.py index 225944ea7..98ba33a92 100644 --- a/tests/test_freeze.py +++ b/tests/test_freeze.py @@ -3427,11 +3427,12 @@ def func(x: mod.TensorXf, row: mod.UInt32, col: mod.UInt32): frozen = dr.freeze(func, auto_opaque=auto_opaque) - for i in range(3): + for i in range(4): shape = ((i + 5), 10) x = mod.TensorXf(dr.arange(mod.Float, dr.prod(shape)), shape=shape) - row = dr.arange(mod.UInt32, i + 4) - col = dr.arange(mod.UInt32, 3) + 1 + # Both row and col must have the same length for advanced indexing + row = dr.arange(mod.UInt32, i+2) + col = dr.arange(mod.UInt32, i+2) + 1 res = frozen(x, row, col) ref = func(x, row, col) diff --git a/tests/test_pytorch_indexing.py b/tests/test_pytorch_indexing.py new file mode 100644 index 000000000..06e0876e0 --- /dev/null +++ b/tests/test_pytorch_indexing.py @@ -0,0 +1,529 @@ +""" +Test PyTorch-compatible tensor indexing behavior. + +This test suite ensures that Dr.Jit tensor indexing is strictly compatible +with PyTorch, particularly for the critical requirement that integer indexing +returns 0-D tensors (not Python scalars). +""" + +import drjit as dr +import pytest +import sys + +# Optional PyTorch dependency +try: + import torch + import numpy as np + + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + + +def skip_if_no_torch(): + """Skip test if PyTorch is not available.""" + if not HAS_TORCH: + pytest.skip("PyTorch not available") + + +# Helper functions for conversion +def drjit_to_torch(drjit_tensor): + """Convert Dr.Jit tensor to PyTorch tensor.""" + if HAS_TORCH: + np_array = ( + drjit_tensor.numpy() + if hasattr(drjit_tensor, "numpy") + else np.array(drjit_tensor) + ) + return torch.from_numpy(np_array).float() + return None + + +def assert_shape_equal(dr_tensor, pt_tensor, msg=""): + """Assert that Dr.Jit and PyTorch tensors have equal shapes.""" + dr_shape = dr_tensor.shape + pt_shape = tuple(pt_tensor.shape) + assert ( + dr_shape == pt_shape + ), f"{msg}\nDr.Jit shape: {dr_shape}, PyTorch shape: {pt_shape}" + + +def assert_values_equal(dr_tensor, pt_tensor, rtol=1e-5, atol=1e-7, msg=""): + """Assert that Dr.Jit and PyTorch tensors have equal values.""" + if HAS_TORCH: + dr_np = ( + dr_tensor.numpy() if hasattr(dr_tensor, "numpy") else np.array(dr_tensor) + ) + pt_np = pt_tensor.detach().cpu().numpy() + np.testing.assert_allclose(dr_np, pt_np, rtol=rtol, atol=atol, err_msg=msg) + + +# ============================================================================= +# Basic Integer Indexing Tests +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test01_single_int_index_1d_returns_0d(t): + """Test that single integer index on 1D tensor returns 0-D tensor.""" + skip_if_no_torch() + + # Create test data + data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + # Test positive index + dr_result = dr_tensor[5] + pt_result = pt_tensor[5] + + # CRITICAL: Should return 0-D tensor, not scalar + assert dr_result.ndim == 0, f"Expected ndim=0, got {dr_result.ndim}" + assert dr_result.shape == (), f"Expected shape=(), got {dr_result.shape}" + assert pt_result.ndim == 0, "PyTorch should also return 0-D tensor" + + assert_shape_equal(dr_result, pt_result, "Single int index shape mismatch") + assert_values_equal(dr_result, pt_result, msg="Single int index value mismatch") + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test02_negative_int_index_1d_returns_0d(t): + """Test that negative integer index on 1D tensor returns 0-D tensor.""" + skip_if_no_torch() + + data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + # Test negative indices + for idx in [-1, -5, -10]: + dr_result = dr_tensor[idx] + pt_result = pt_tensor[idx] + + assert ( + dr_result.ndim == 0 + ), f"Index {idx}: Expected ndim=0, got {dr_result.ndim}" + assert ( + dr_result.shape == () + ), f"Index {idx}: Expected shape=(), got {dr_result.shape}" + assert_shape_equal(dr_result, pt_result, f"Negative index {idx} shape mismatch") + assert_values_equal( + dr_result, pt_result, msg=f"Negative index {idx} value mismatch" + ) + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test03_multi_int_index_2d_returns_0d(t): + """Test that multiple integer indices on 2D tensor return 0-D tensor.""" + skip_if_no_torch() + + data = list(range(20)) + dr_tensor = t(data, shape=(4, 5)) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(4, 5) + + # Test various index combinations + test_cases = [(0, 0), (2, 3), (3, 4), (-1, -1), (-2, 3)] + + for i, j in test_cases: + dr_result = dr_tensor[i, j] + pt_result = pt_tensor[i, j] + + assert ( + dr_result.ndim == 0 + ), f"Index ({i}, {j}): Expected ndim=0, got {dr_result.ndim}" + assert ( + dr_result.shape == () + ), f"Index ({i}, {j}): Expected shape=(), got {dr_result.shape}" + assert_shape_equal(dr_result, pt_result, f"Index ({i}, {j}) shape mismatch") + assert_values_equal( + dr_result, pt_result, msg=f"Index ({i}, {j}) value mismatch" + ) + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test04_single_int_index_2d_reduces_dim(t): + """Test that single integer index on 2D tensor reduces dimension.""" + skip_if_no_torch() + + data = list(range(20)) + dr_tensor = t(data, shape=(4, 5)) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(4, 5) + + # Single index should return 1D tensor + for idx in [0, 2, -1]: + dr_result = dr_tensor[idx] + pt_result = pt_tensor[idx] + + assert ( + dr_result.ndim == 1 + ), f"Index {idx}: Expected ndim=1, got {dr_result.ndim}" + assert dr_result.shape == ( + 5, + ), f"Index {idx}: Expected shape=(5,), got {dr_result.shape}" + assert_shape_equal(dr_result, pt_result, f"Single index {idx} shape mismatch") + assert_values_equal( + dr_result, pt_result, msg=f"Single index {idx} value mismatch" + ) + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test05_multi_int_index_3d_returns_0d(t): + """Test that full indexing on 3D tensor returns 0-D tensor.""" + skip_if_no_torch() + + data = list(range(60)) + dr_tensor = t(data, shape=(3, 4, 5)) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(3, 4, 5) + + test_cases = [(0, 0, 0), (1, 2, 3), (2, 3, 4), (-1, -1, -1)] + + for i, j, k in test_cases: + dr_result = dr_tensor[i, j, k] + pt_result = pt_tensor[i, j, k] + + assert ( + dr_result.ndim == 0 + ), f"Index ({i}, {j}, {k}): Expected ndim=0, got {dr_result.ndim}" + assert ( + dr_result.shape == () + ), f"Index ({i}, {j}, {k}): Expected shape=(), got {dr_result.shape}" + assert_shape_equal( + dr_result, pt_result, f"Index ({i}, {j}, {k}) shape mismatch" + ) + assert_values_equal( + dr_result, pt_result, msg=f"Index ({i}, {j}, {k}) value mismatch" + ) + + +# ============================================================================= +# Slicing Tests +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test06_slice_1d(t): + """Test basic slicing on 1D tensor.""" + skip_if_no_torch() + + data = list(range(10)) + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + test_slices = [ + slice(2, 7), # [2:7] + slice(None, 5), # [:5] + slice(3, None), # [3:] + slice(None, None, 2), # [::2] + ] + + # PyTorch doesn't support negative step slices, so test Dr.Jit independently + drjit_only_slices = [ + slice(8, 2, -1), # [8:2:-1] + slice(None, None, -1), # [::-1] + ] + + for s in test_slices: + dr_result = dr_tensor[s] + pt_result = pt_tensor[s] + + assert_shape_equal(dr_result, pt_result, f"Slice {s} shape mismatch") + assert_values_equal(dr_result, pt_result, msg=f"Slice {s} value mismatch") + + # Test Dr.Jit-only slices (negative steps) + for s in drjit_only_slices: + dr_result = dr_tensor[s] + # Just verify it doesn't crash and returns reasonable shape + assert dr_result.ndim == 1, f"Slice {s} should preserve 1D" + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test07_slice_2d(t): + """Test slicing on 2D tensor.""" + skip_if_no_torch() + + data = list(range(20)) + dr_tensor = t(data, shape=(4, 5)) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(4, 5) + + test_cases = [ + (slice(1, 3), slice(None)), # [1:3, :] + (slice(None), slice(2, 4)), # [:, 2:4] + (slice(1, 3), slice(2, 4)), # [1:3, 2:4] + ] + + # PyTorch doesn't support negative steps + drjit_only_cases = [ + (slice(None, None, -1), slice(None)), # [::-1, :] + ] + + for idx in test_cases: + dr_result = dr_tensor[idx] + pt_result = pt_tensor[idx] + + assert_shape_equal(dr_result, pt_result, f"Slice {idx} shape mismatch") + assert_values_equal(dr_result, pt_result, msg=f"Slice {idx} value mismatch") + + # Test Dr.Jit-only cases (negative steps) + for idx in drjit_only_cases: + dr_result = dr_tensor[idx] + # Just verify it doesn't crash and returns reasonable shape + assert dr_result.ndim == 2, f"Slice {idx} should preserve 2D" + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test08_mixed_int_slice(t): + """Test mixing integer and slice indices.""" + skip_if_no_torch() + + data = list(range(20)) + dr_tensor = t(data, shape=(4, 5)) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(4, 5) + + test_cases = [ + (0, slice(None)), # [0, :] + (slice(None), 0), # [:, 0] + (2, slice(1, 4)), # [2, 1:4] + (slice(1, 3), 2), # [1:3, 2] + ] + + for idx in test_cases: + dr_result = dr_tensor[idx] + pt_result = pt_tensor[idx] + + assert_shape_equal(dr_result, pt_result, f"Mixed index {idx} shape mismatch") + assert_values_equal( + dr_result, pt_result, msg=f"Mixed index {idx} value mismatch" + ) + + +# ============================================================================= +# Ellipsis Tests +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test09_ellipsis(t): + """Test ellipsis (...) indexing.""" + skip_if_no_torch() + + data = list(range(60)) + dr_tensor = t(data, shape=(3, 4, 5)) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(3, 4, 5) + + test_cases = [ + (Ellipsis,), # [...] + (Ellipsis, 0), # [..., 0] + (0, Ellipsis), # [0, ...] + (1, Ellipsis, 2), # [1, ..., 2] + ] + + for idx in test_cases: + dr_result = dr_tensor[idx] + pt_result = pt_tensor[idx] + + assert_shape_equal(dr_result, pt_result, f"Ellipsis {idx} shape mismatch") + assert_values_equal(dr_result, pt_result, msg=f"Ellipsis {idx} value mismatch") + + +# ============================================================================= +# None/newaxis Tests +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test10_newaxis(t): + """Test None/newaxis indexing.""" + skip_if_no_torch() + + data = list(range(10)) + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + test_cases = [ + (None, slice(None)), # [None, :] + (slice(None), None), # [:, None] + (None, slice(None), None), # [None, :, None] + ] + + for idx in test_cases: + dr_result = dr_tensor[idx] + pt_result = pt_tensor[idx] + + assert_shape_equal(dr_result, pt_result, f"Newaxis {idx} shape mismatch") + assert_values_equal(dr_result, pt_result, msg=f"Newaxis {idx} value mismatch") + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test11_empty_slice(t): + """Test slicing that produces empty tensor.""" + skip_if_no_torch() + + data = list(range(10)) + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + dr_result = dr_tensor[5:5] + pt_result = pt_tensor[5:5] + + assert_shape_equal(dr_result, pt_result, "Empty slice shape mismatch") + assert dr_result.shape == (0,), f"Expected shape=(0,), got {dr_result.shape}" + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test12_single_element_tensor(t): + """Test indexing single element tensor.""" + skip_if_no_torch() + + dr_tensor = t([42.0]) + pt_tensor = torch.tensor([42.0], dtype=torch.float32) + + dr_result = dr_tensor[0] + pt_result = pt_tensor[0] + + assert dr_result.ndim == 0, f"Expected ndim=0, got {dr_result.ndim}" + assert dr_result.shape == (), f"Expected shape=(), got {dr_result.shape}" + assert_shape_equal(dr_result, pt_result, "Single element shape mismatch") + assert_values_equal(dr_result, pt_result, msg="Single element value mismatch") + + +# ============================================================================= +# Array Indexing Tests (using Dr.Jit integer arrays) +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test13_array_index_1d(t): + """Test array indexing on 1D tensor.""" + skip_if_no_torch() + + data = list(range(10)) + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + # Create index array + indices = [0, 2, 4, 6, 8] + index_type = dr.uint32_array_t(dr.array_t(t)) + dr_indices = index_type(indices) + pt_indices = torch.tensor(indices, dtype=torch.long) + + dr_result = dr_tensor[dr_indices] + pt_result = pt_tensor[pt_indices] + + assert_shape_equal(dr_result, pt_result, "Array index shape mismatch") + assert_values_equal(dr_result, pt_result, msg="Array index value mismatch") + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test14_array_index_2d(t): + """Test array indexing on 2D tensor (first dimension).""" + skip_if_no_torch() + + data = list(range(20)) + dr_tensor = t(data, shape=(4, 5)) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(4, 5) + + # Create index array + indices = [0, 2, 3] + index_type = dr.uint32_array_t(dr.array_t(t)) + dr_indices = index_type(indices) + pt_indices = torch.tensor(indices, dtype=torch.long) + + dr_result = dr_tensor[dr_indices] + pt_result = pt_tensor[pt_indices] + + assert_shape_equal(dr_result, pt_result, "Array index 2D shape mismatch") + assert dr_result.shape == (3, 5), f"Expected shape=(3, 5), got {dr_result.shape}" + assert_values_equal(dr_result, pt_result, msg="Array index 2D value mismatch") + + +# ============================================================================= +# Assignment Tests +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test15_setitem_single_element(t): + """Test assigning to single element.""" + skip_if_no_torch() + + data = list(range(10)) + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + dr_tensor[5] = 100.0 + pt_tensor[5] = 100.0 + + assert_values_equal(dr_tensor, pt_tensor, msg="Single element assignment mismatch") + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test16_setitem_slice(t): + """Test assigning to slice.""" + skip_if_no_torch() + + data = list(range(10)) + dr_tensor = t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32) + + dr_tensor[2:7] = 100.0 + pt_tensor[2:7] = 100.0 + + assert_values_equal(dr_tensor, pt_tensor, msg="Slice assignment mismatch") + + +# ============================================================================= +# Comprehensive Compatibility Test +# ============================================================================= + + +@pytest.test_arrays("is_tensor, float32, is_jit") +def test17_comprehensive_indexing_compatibility(t): + """Comprehensive test covering multiple indexing scenarios.""" + skip_if_no_torch() + + # Test with different tensor shapes + test_configs = [ + (10,), # 1D + (4, 5), # 2D + (2, 3, 4), # 3D + ] + + for shape in test_configs: + size = 1 + for dim in shape: + size *= dim + + data = list(range(size)) + dr_tensor = t(data, shape=shape) if len(shape) > 1 else t(data) + pt_tensor = torch.tensor(data, dtype=torch.float32).reshape(shape) + + # Test 1: Verify shapes match + assert_shape_equal(dr_tensor, pt_tensor, f"Initial shape mismatch for {shape}") + + # Test 2: Full slice should preserve shape + dr_result = dr_tensor[...] + pt_result = pt_tensor[...] + assert_shape_equal(dr_result, pt_result, f"Ellipsis shape mismatch for {shape}") + + # Test 3: Integer index on first dimension + dr_result = dr_tensor[0] + pt_result = pt_tensor[0] + assert_shape_equal( + dr_result, pt_result, f"First dim index shape mismatch for {shape}" + ) + + # Test 4: If multidimensional, test full integer indexing + if len(shape) > 1: + full_idx = tuple([0] * len(shape)) + dr_result = dr_tensor[full_idx] + pt_result = pt_tensor[full_idx] + assert dr_result.ndim == 0, f"Full index should return 0-D for {shape}" + assert_shape_equal( + dr_result, pt_result, f"Full index shape mismatch for {shape}" + ) + diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 4f57f3bf4..212b4cea1 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -63,7 +63,7 @@ def check(shape, indices, shape_out, index_out): check(shape=(3, 7), indices=(t(0), slice(0, 7, 3)), shape_out=(1, 3), index_out=t(0, 3, 6)) check(shape=(3, 7), indices=(t(0), t(0, 3, 6)), - shape_out=(1, 3), index_out=t(0, 3, 6)) + shape_out=(3,), index_out=t(0, 3, 6)) # Consecutive advanced indices → single dimension check(shape=(3, 7), indices=(2, slice(None, None, None)), shape_out=(7,), index_out=t(14, 15, 16, 17, 18, 19, 20)) check(shape=(3, 7), indices=(slice(None, None, None), 2), @@ -686,3 +686,120 @@ def test23_item_array(t): with pytest.raises(RuntimeError, match='can only convert arrays of length 1'): t([]).item() + + +@pytest.test_arrays('is_tensor, -bool') +def test24_pytorch_compat_scalar_indexing(t): + """ + Test PyTorch-compatible indexing behavior: integer indexing should + return 0-D tensors, not Python scalars. + + This is critical for PyTorch compatibility, as PyTorch always returns + tensors (even 0-D) from indexing operations, unlike NumPy which returns + Python scalars. + """ + # Test 1: Single integer index on 1D tensor returns 0-D tensor + v = t([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + result = v[5] + assert result.ndim == 0, f"Single int index should return 0-D tensor, got ndim={result.ndim}" + assert result.shape == (), f"Single int index should return shape=(), got {result.shape}" + + # Test 2: Negative index also returns 0-D tensor + result = v[-1] + assert result.ndim == 0, f"Negative int index should return 0-D tensor, got ndim={result.ndim}" + assert result.shape == (), f"Negative int index should return shape=(), got {result.shape}" + + # Test 3: Multiple integer indices on 2D tensor return 0-D tensor + v2 = t(list(range(20)), shape=(4, 5)) + result = v2[2, 3] + assert result.ndim == 0, f"Multi-int index should return 0-D tensor, got ndim={result.ndim}" + assert result.shape == (), f"Multi-int index should return shape=(), got {result.shape}" + + # Test 4: Single index on 2D tensor reduces dimension (returns 1D) + result = v2[2] + assert result.ndim == 1, f"Single index on 2D should return 1-D tensor, got ndim={result.ndim}" + assert result.shape == (5,), f"Single index on 2D should return shape=(5,), got {result.shape}" + + # Test 5: Full indexing on 3D tensor returns 0-D tensor + v3 = t(list(range(60)), shape=(3, 4, 5)) + result = v3[1, 2, 3] + assert result.ndim == 0, f"Full 3D index should return 0-D tensor, got ndim={result.ndim}" + assert result.shape == (), f"Full 3D index should return shape=(), got {result.shape}" + + # Test 6: Verify the 0-D tensor contains the correct value + # For consistency check, compare with item() method + v_simple = t([42]) + result_0d = v_simple[0] + assert result_0d.ndim == 0 + # The 0-D tensor should have the same value as item() + assert result_0d.item() == 42 + + +@pytest.test_arrays("float32, jit") +def test24_multidim_scalar(t): + pytest.importorskip("torch") + + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + TensorXf = mod.TensorXf + Float = mod.Float + + shape = (10, 10, 10) + rng = dr.rng() + + x = rng.random(Float, dr.prod(shape)) + x = TensorXf(x, shape) + x_torch = x.torch() + + ref = x_torch[:, 1, :] + res = x[:, 1, :] + + assert dr.allclose(res, ref) + + +@pytest.test_arrays("float32, jit") +def test25_multidim_advanced(t): + pytest.importorskip("torch") + + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + TensorXf = mod.TensorXf + Float = mod.Float + + shape = (10, 10, 10) + rng = dr.rng() + index = dr.arange(UInt32, 5) + 1 + + x = rng.random(Float, dr.prod(shape)) + x = TensorXf(x, shape) + x_torch = x.torch() + + index_torch = index.torch().long() + ref = x_torch[index_torch, :, index_torch] + res = x[index, :, index] + + assert dr.allclose(res, ref) + + +@pytest.test_arrays("float32, jit") +def test26_4d_advanced(t): + pytest.importorskip("torch") + + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + TensorXf = mod.TensorXf + Float = mod.Float + + shape = (10, 10, 10, 10) + rng = dr.rng() + index = dr.arange(UInt32, 5) + 1 + + x = rng.random(Float, dr.prod(shape)) + x = TensorXf(x, shape) + x_torch = x.torch() + + index_torch = index.torch().long() + ref = x_torch[index_torch, :, index_torch, :] + res = x[index, :, index, :] + + assert dr.allclose(res, ref)