Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 251 additions & 38 deletions src/python/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,29 @@
#include "base.h"
#include "slice.h"
#include "meta.h"
#include <algorithm>
#include <vector>

/// Holds metadata about slicing component
struct Component {
enum Type { None, Integer, Slice, Advanced, Ellipsis };

Type type;
Py_ssize_t start, step, slice_size, size;
nb::object object;

Component(Py_ssize_t start, Py_ssize_t step, Py_ssize_t slice_size,
// Constructor for None indices
Component(Type t)
: type(t), start(0), step(1), slice_size(1), size(1) { }

// Constructor for integer and slice indices
Component(Type t, Py_ssize_t start, Py_ssize_t step, Py_ssize_t slice_size,
Py_ssize_t size)
: start(start), step(step), slice_size(slice_size), size(size) { }
: type(t), start(start), step(step), slice_size(slice_size), size(size) { }

Component(nb::handle h, Py_ssize_t slice_size, Py_ssize_t size)
: start(0), step(1), slice_size(slice_size), size(size),
// Constructor for advanced indices (array indexing)
Component(Type t, nb::handle h, Py_ssize_t slice_size, Py_ssize_t size)
: type(t), start(0), step(1), slice_size(slice_size), size(size),
object(nb::borrow(h)) { }
};

Expand Down Expand Up @@ -72,11 +82,15 @@ slice_index(const nb::type_object_t<ArrayBase> &dtype,
indices_len = nb::len(indices);

std::vector<Component> components;
components.reserve(shape_len);
components.reserve(indices_len); // May include None indices

// First pass: parse indices
nb::list basic_shapes; // Shapes from basic indexing (slices)
size_t advanced_size = 0; // Size of advanced index arrays (all must be same)

for (nb::handle h : indices) {
if (h.is_none()) {
shape_out.append(1);
components.emplace_back(Component::None);
continue;
}

Expand All @@ -96,15 +110,14 @@ slice_index(const nb::type_object_t<ArrayBase> &dtype,
"bounds for axis %zu with size %zd.",
v, components.size(), size);

components.emplace_back(v, 1, 1, size);
components.emplace_back(Component::Integer, v, 1, 1, size);
continue;
} else if (tp.is(&PySlice_Type)) {
Py_ssize_t start, stop, step;
size_t slice_length;
nb::detail::slice_compute(h.ptr(), size, start, stop, step, slice_length);
components.emplace_back(start, step, (Py_ssize_t) slice_length, size);
shape_out.append(slice_length);
size_out *= slice_length;
components.emplace_back(Component::Slice, start, step, (Py_ssize_t) slice_length, size);
basic_shapes.append(slice_length);
continue;
} else if (is_drjit_type(tp)) {
const ArraySupplement *s2 = &supp(tp);
Expand Down Expand Up @@ -138,9 +151,21 @@ slice_index(const nb::type_object_t<ArrayBase> &dtype,
if (!o.type().is(dtype))
o = dtype(o);

components.emplace_back(o, slice_size, size);
shape_out.append(slice_size);
size_out *= slice_size;
components.emplace_back(Component::Advanced, o, slice_size, size);

// Track the maximum size for broadcasting
// PyTorch/NumPy broadcast all advanced indices to the same shape
if (advanced_size == 0) {
advanced_size = slice_size;
} else if (slice_size != 1 && advanced_size != 1 && advanced_size != slice_size) {
// Broadcasting rules: sizes must be 1 or equal
nb::raise("drjit.slice_index(): advanced index arrays with shapes %zu and %zu "
"cannot be broadcast together.", advanced_size, slice_size);
} else if (slice_size > advanced_size) {
// Update to the larger size (broadcasting smaller arrays to match)
advanced_size = slice_size;
}

continue;
}
} else if (tp.is(&PyEllipsis_Type)) {
Expand All @@ -151,9 +176,8 @@ slice_index(const nb::type_object_t<ArrayBase> &dtype,
if (shape_offset >= shape_len)
nb::detail::fail("slice_index(): internal error.");
size = nb::cast<Py_ssize_t>(shape[shape_offset++]);
components.emplace_back(0, 1, size, size);
shape_out.append(size);
size_out *= size;
components.emplace_back(Component::Slice, 0, 1, size, size);
basic_shapes.append(size);
}
continue;
}
Expand All @@ -167,43 +191,232 @@ slice_index(const nb::type_object_t<ArrayBase> &dtype,
// Implicit ellipsis at the end
while (shape_offset != shape_len) {
Py_ssize_t size = nb::cast<Py_ssize_t>(shape[shape_offset++]);
components.emplace_back(0, 1, size, size);
shape_out.append(size);
size_out *= size;
components.emplace_back(Component::Slice, 0, 1, size, size);
basic_shapes.append(size);
}

// Build output shape following PyTorch/NumPy advanced indexing rules:
// - None indices create new dimensions of size 1 at their positions
// - Integer indices reduce dimensions (don't appear in output)
// - Advanced indices: if consecutive, stay in place; if non-consecutive, move to front
shape_out.clear();

// Check if there are advanced indices and if they're consecutive
int first_adv = -1, last_adv = -1;
for (size_t i = 0; i < components.size(); ++i) {
if (components[i].type == Component::Advanced) {
if (first_adv == -1) first_adv = i;
last_adv = i;
}
}

bool has_advanced = (first_adv != -1);
bool consecutive = true;
if (has_advanced) {
for (int i = first_adv; i <= last_adv; ++i) {
if (components[i].type == Component::None) continue; // None doesn't break consecutiveness
if (components[i].type != Component::Advanced) {
consecutive = false;
break;
}
}
}

// Build output shape based on index arrangement
if (has_advanced && consecutive) {
// Advanced indices are consecutive: replace all with a single dimension
bool advanced_added = false;
for (const auto &comp : components) {
if (comp.type == Component::None) {
shape_out.append(1);
} else if (comp.type == Component::Slice) {
shape_out.append(comp.slice_size);
} else if (comp.type == Component::Advanced) {
if (!advanced_added) {
// All consecutive advanced indices produce a single dimension
shape_out.append(advanced_size);
advanced_added = true;
}
// Subsequent advanced indices don't add dimensions
}
// Integer indices don't contribute
}
} else if (has_advanced && !consecutive) {
// Advanced indices are non-consecutive: move to front
shape_out.append(advanced_size);
for (const auto &comp : components) {
if (comp.type == Component::None) {
shape_out.append(1);
} else if (comp.type == Component::Slice) {
shape_out.append(comp.slice_size);
}
// Integer and Advanced (already added) don't contribute here
}
} else {
// No advanced indexing: process each index type in order
for (const auto &comp : components) {
if (comp.type == Component::None) {
shape_out.append(1);
} else if (comp.type == Component::Slice) {
shape_out.append(comp.slice_size);
}
// Integer indices don't contribute to shape
}
}

// Calculate total size from the actual output shape
size_out = 1;
for (nb::handle h : shape_out)
size_out *= nb::cast<size_t>(h);

nb::object index = arange(dtype, 0, size_out, 1),
index_out;

nb::object active = nb::borrow(Py_True);
if (size_out) {
size_out = 1;
// Unified algorithm that handles both basic and advanced indexing
index_out = dtype(0);

// Calculate the stride multiplier for the input tensor dimensions
// Skip None components as they don't correspond to input dimensions
size_t input_stride = 1;
std::vector<size_t> input_strides;
for (auto it = components.rbegin(); it != components.rend(); ++it) {
const Component &c = *it;
nb::object index_next, index_rem;

if (it + 1 != components.rend()) {
index_next = index.floor_div(dtype(c.slice_size));
index_rem = fma(index_next, dtype(uint32_t(-c.slice_size)), index);
if (it->type == Component::None) {
input_strides.push_back(0); // Placeholder for None
} else {
input_strides.push_back(input_stride);
input_stride *= it->size;
}
}
std::reverse(input_strides.begin(), input_strides.end());

// Decompose output index according to output shape
nb::object remaining = index;
std::vector<nb::object> output_dim_indices;

// Decompose based on actual output shape (in reverse order)
for (size_t i = nb::len(shape_out); i > 0; --i) {
size_t dim_size = nb::cast<size_t>(shape_out[i - 1]);
nb::object dim_idx;
if (i > 1) {
nb::object quotient = remaining.floor_div(dtype(dim_size));
dim_idx = remaining - quotient * dtype(dim_size);
remaining = quotient;
} else {
index_rem = index;
dim_idx = remaining;
}
output_dim_indices.insert(output_dim_indices.begin(), dim_idx);
}

nb::object index_val;
if (!c.object.is_valid())
index_val = fma(index_rem, dtype(uint32_t(c.step * size_out)),
dtype(uint32_t(c.start * size_out)));
else
index_val = gather(dtype, c.object, index_rem, active,
ReduceMode::Auto) *
dtype(uint32_t(size_out));
// Check if there are advanced indices and if they're consecutive
int first_adv = -1, last_adv = -1;
for (size_t i = 0; i < components.size(); ++i) {
if (components[i].type == Component::Advanced) {
if (first_adv == -1) first_adv = i;
last_adv = i;
}
}

bool has_advanced = (first_adv != -1);
bool consecutive = true;
if (has_advanced) {
for (int i = first_adv; i <= last_adv; ++i) {
if (components[i].type == Component::None) continue;
if (components[i].type != Component::Advanced) {
consecutive = false;
break;
}
}
}

// Extract advanced_idx and basic indices from output_dim_indices
nb::object advanced_idx = dtype(0);
std::vector<nb::object> basic_dim_indices;
size_t output_idx = 0;
bool advanced_found = false;

if (has_advanced && consecutive) {
// Advanced indices are consecutive: they stay in their natural position
for (const auto &comp : components) {
if (comp.type == Component::None) {
output_idx++;
} else if (comp.type == Component::Advanced) {
if (!advanced_found) {
advanced_idx = output_dim_indices[output_idx];
advanced_found = true;
}
output_idx++;
} else if (comp.type == Component::Slice) {
basic_dim_indices.push_back(output_dim_indices[output_idx]);
output_idx++;
}
}
} else if (has_advanced && !consecutive) {
// Advanced indices are non-consecutive: they're moved to the front
advanced_idx = output_dim_indices[0];
output_idx = 1;
for (const auto &comp : components) {
if (comp.type == Component::None) {
if (output_idx < output_dim_indices.size()) {
output_idx++;
}
} else if (comp.type == Component::Slice) {
if (output_idx < output_dim_indices.size()) {
basic_dim_indices.push_back(output_dim_indices[output_idx]);
output_idx++;
}
}
}
} else {
// No advanced indexing: just map output dimensions to input
for (const auto &comp : components) {
if (comp.type == Component::None) {
output_idx++;
} else if (comp.type == Component::Slice) {
if (output_idx < output_dim_indices.size()) {
basic_dim_indices.push_back(output_dim_indices[output_idx]);
output_idx++;
}
}
}
}

// Map output indices back to input dimensions
size_t basic_idx_counter = 0;
for (size_t i = 0; i < components.size(); ++i) {
const Component &c = components[i];

// Skip None indices as they don't correspond to input dimensions
if (c.type == Component::None)
continue;

index_out += index_val;
nb::object dim_index;

if (c.type == Component::Advanced) {
// Advanced index: use the advanced_idx to gather from the index array
// Handle broadcasting: if the index array has size 1, broadcast it
if (c.slice_size == 1) {
dim_index = gather(dtype, c.object, dtype(0), active, ReduceMode::Auto);
} else {
dim_index = gather(dtype, c.object, advanced_idx, active, ReduceMode::Auto);
}
} else if (c.type == Component::Integer) {
// Integer index
dim_index = dtype(c.start);
} else if (c.type == Component::Slice) {
// Basic slice: get the dimension index and apply slice transformation
if (basic_idx_counter < basic_dim_indices.size()) {
dim_index = basic_dim_indices[basic_idx_counter];
dim_index = fma(dim_index, dtype(uint32_t(c.step)), dtype(uint32_t(c.start)));
basic_idx_counter++;
} else {
dim_index = dtype(c.start);
}
}

index = std::move(index_next);
size_out *= c.size;
// Add contribution to output index
index_out += dim_index * dtype(uint32_t(input_strides[i]));
}
} else {
index_out = dtype();
Expand Down
7 changes: 4 additions & 3 deletions tests/test_freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3427,11 +3427,12 @@ def func(x: mod.TensorXf, row: mod.UInt32, col: mod.UInt32):

frozen = dr.freeze(func, auto_opaque=auto_opaque)

for i in range(3):
for i in range(4):
shape = ((i + 5), 10)
x = mod.TensorXf(dr.arange(mod.Float, dr.prod(shape)), shape=shape)
row = dr.arange(mod.UInt32, i + 4)
col = dr.arange(mod.UInt32, 3) + 1
# Both row and col must have the same length for advanced indexing
row = dr.arange(mod.UInt32, i+2)
col = dr.arange(mod.UInt32, i+2) + 1

res = frozen(x, row, col)
ref = func(x, row, col)
Expand Down
Loading