diff --git a/docs/cpp.rst b/docs/cpp.rst index 9079640d2..2a7e879b0 100644 --- a/docs/cpp.rst +++ b/docs/cpp.rst @@ -494,11 +494,11 @@ trees: template void visit_jit_pairs(T &v0, T &v1) { if constexpr (dr::is_jit_v && dr::depth_v == 1) { - /// Do something with 'v0' and 'v1' + // Do something with 'v0' and 'v1' } else if constexpr (dr::is_traversable_v) { - /// Recurse and try again if the object is traversable + // Recurse and try again if the object is traversable dr::traverse_2( - /// Extract the fields of 'v0' and 'v1' + // Extract the fields of 'v0' and 'v1' dr::fields(v0), dr::fields(v1), // .. and call the following lambda function on them [&](auto &x, auto &y) { visit_jit_pairs(x, y); } diff --git a/include/drjit/local.h b/include/drjit/local.h index 7e2fe176a..76d73e468 100644 --- a/include/drjit/local.h +++ b/include/drjit/local.h @@ -16,8 +16,18 @@ NAMESPACE_BEGIN(drjit) -template struct Local { - using Index = uint32_t; +template , + typename SFINAE = int> +struct Local { + static constexpr size_t Size = Size_; + static_assert(Size != Dynamic, "Scalar local arrays are only fixed size. " + "If you meant to instantiate a JIT variant " + "or a DRJIT_STRUCT you may have forgotten to " + "add the Index template parameter."); + using Value = Value_; + using Index = Index_; using Mask = bool; Local() = default; @@ -28,16 +38,10 @@ template struct Local { } ~Local() = default; - Local(const Local &) = delete; - Local(Local &&l) { - for (size_t i = 0; i < Size; ++i) - m_data[i] = l.m_data[i]; - } - Local &operator=(const Local &) = delete; - Local &operator=(Local &&l) { - for (size_t i = 0; i < Size; ++i) - m_data[i] = l.m_data[i]; - } + Local(const Local &) = default; + Local(Local &&l) = default; + Local &operator=(const Local &) = default; + Local &operator=(Local &&l) = default; Value read(const Index &offset, const Mask &active = true) const { if (active) @@ -57,50 +61,178 @@ template struct Local { Value m_data[Size]; }; -template struct Local> { - static constexpr JitBackend Backend = backend_v; - using Index = uint32_array_t; - using Mask = mask_t; - Local() { - m_index = jit_array_create(Backend, Value::Type, 1, Size); +NAMESPACE_BEGIN(detail) + +template +void init_impl(const T &value, const size_t size, vector& arrays) { + if constexpr (is_jit_v && depth_v == 1) { + uint32_t result; + if (!value.empty()) { + uint32_t i1 = value.index(); + size_t width = jit_var_size(i1); + uint32_t i2 = jit_array_create( + backend_v, var_type>::value, + width, size); + result = jit_array_init(i2, i1); + jit_var_dec_ref(i2); + } else { + result = jit_array_create( + backend_v, var_type>::value, + 1, size); + } + arrays.push_back(result); + } else if constexpr (is_traversable_v) { + // Recurse and try again if the object is traversable + traverse_1(fields(value), + [&](auto &v) { init_impl(v, size, arrays); }); } - - Local(const Value &value) { - uint32_t tmp = jit_array_create(Backend, Value::Type, 1, Size); - m_index = jit_array_init(tmp, value.index()); - jit_var_dec_ref(tmp); +} + +template +void read_impl(T &result, + const uint32_t &offset, + const uint32_t &active, + const vector &arrays, + size_t &counter) { + if constexpr (is_jit_v && depth_v == 1) { + if (counter >= arrays.size()) + jit_raise("Local::read(): internal error, ran out of " + "variable arrays!"); + result = T::steal(jit_array_read(arrays[counter++], offset, active)); + } else if constexpr (is_traversable_v) { + // Recurse and try again if the object is traversable + traverse_1(fields(result), [&](auto &r) { + read_impl(r, offset, active, arrays, counter); + }); + } +} + +template +void write_impl(const uint32_t &offset, + const T &value, + const uint32_t &active, + vector &arrays, + size_t &counter) { + if constexpr (is_jit_v && depth_v == 1) { + if (counter >= arrays.size()) + jit_raise("Local::write(): internal error, ran out of " + "variable arrays!"); + + if (value.index_ad()) + jit_raise("Local memory writes are not differentiable. You " + "must use 'drjit.detach()' to disable gradient " + "tracking of the written value."); + + uint32_t result = + jit_array_write(arrays[counter], offset, value.index(), active); + jit_var_dec_ref(arrays[counter]); + arrays[counter++] = result; + + } else if constexpr (is_traversable_v) { + // Recurse and try again if the object is traversable + traverse_1(fields(value), + [&](auto &v) { write_impl(offset, v, active, arrays, counter); }); + } +} + +NAMESPACE_END(detail) + + +/** + * \brief Local memory implemented on top of drjit-core jit_array_* + * \details The array `value` of static or dynamic width will be used + * to initialize the entries of local memory with length `Size`. + * `Size` can be drjit::Dynamic, in which case a call to resize will + * be required before usage. + */ +template +struct Local || (is_array_v && is_drjit_struct_v)>> +{ + static constexpr JitBackend Backend = backend_v; + static constexpr size_t Size = Size_; + using Value = Value_; + using Index = Index_; + using Mask = mask_t; + + /** + * \brief Allocate local memory + * \param value optional inital value (also used when resizing dynamic memory) + */ + Local(Value value = empty()) + : m_size(Size == Dynamic ? 1 : Size), m_value(value) { + detail::init_impl(m_value, m_size, m_arrays); } - ~Local() { jit_var_dec_ref(m_index); } - Local(const Local &) = delete; - Local(Local &&l) { - m_index = l.m_index; - l.m_index = 0; + ~Local() { + for (uint32_t index : m_arrays) + jit_var_dec_ref(index); + } + Local(const Local &l) { + *this = l; + } + Local(Local &&l) = default; + + Local &operator=(const Local &l) { + for (uint32_t index : m_arrays) + jit_var_dec_ref(index); + m_size = l.m_size; + m_value = l.m_value; + m_arrays = l.m_arrays; + for (uint32_t index : m_arrays) + jit_var_inc_ref(index); + return *this; } - Local &operator=(const Local &) = delete; Local &operator=(Local &&l) { - jit_var_dec_ref(m_index); - m_index = l.m_index; - l.m_index = 0; + for (uint32_t index : m_arrays) + jit_var_dec_ref(index); + m_size = std::move(l.m_size); + m_value = std::move(l.m_value); + m_arrays = std::move(l.m_arrays); + return *this; } Value read(const Index &offset, const Mask &active = true) const { - return Value::steal(jit_array_read(m_index, offset.index(), active.index())); + Value result; + size_t counter = 0; + detail::read_impl(result, offset.index(), active.index(), m_arrays, counter); + + if (counter != m_arrays.size()) + jit_raise( + "Local::read(): internal error, did not access all variable " + "arrays!"); + + return result; } void write(const Index &offset, const Value &value, const Mask &active = true) { - uint32_t new_index = jit_array_write(m_index, offset.index(), - value.index(), active.index()); - jit_var_dec_ref(m_index); - m_index = new_index; + size_t counter = 0; + detail::write_impl(offset.index(), value, active.index(), m_arrays, counter); + + if (counter != m_arrays.size()) + jit_raise( + "Local.write(): internal error, did not access all variable " + "arrays!"); } - size_t size() const { return Size; } + size_t size() { return m_size; } + + /** + * Reserve a new array of `length` and discard any current contents + */ + void resize(size_t size) { + for (uint32_t index : m_arrays) + jit_var_dec_ref(index); + m_arrays.clear(); + m_size = size; + detail::init_impl(m_value, m_size, m_arrays); + } private: - uint32_t m_index; + size_t m_size; + Value m_value; + vector m_arrays; }; - NAMESPACE_END(drjit) diff --git a/tests/local_ext.cpp b/tests/local_ext.cpp index 02eade5df..d186f2839 100644 --- a/tests/local_ext.cpp +++ b/tests/local_ext.cpp @@ -1,20 +1,144 @@ #define NB_INTRUSIVE_EXPORT NB_IMPORT #include +#include #include +#include +#include namespace nb = nanobind; namespace dr = drjit; +using nb::literals::operator""_a; + +template +auto bind_local(nb::module_ &m, const dr::string& name) { + auto c = nb::class_(m, name.c_str()) + .def(nb::init<>()) + .def(nb::init()) + .def("__len__", &Local::size) + .def("read", &Local::read, "index"_a, "active"_a = true) + .def("write", &Local::write, "offset"_a, "value"_a, "active"_a = true); + + if constexpr (dr::is_jit_v) + c = c.def("resize", &Local::resize); + + m.def(("test_" + name + "_loop").c_str(), []() { + auto initial = Local(); + auto counter = int32_t(0); + + if constexpr (dr::is_jit_v) + initial.resize(10); + + dr::tie(initial, counter) = dr::while_loop( + dr::make_tuple(initial, counter), + [](const Local& l, const int32_t& i) { + DRJIT_MARK_USED(l); + return i < 5; + }, + [](Local& l, int32_t& i) { + l.write(i, dr::full(i)); + auto written = l.read(i); + DRJIT_MARK_USED(written); + i += 1; + } + ); + for(unsigned int i = 0; i < 5; ++i) { + auto value = initial.read(i); + if(dr::any(value != dr::full(i))) { + jit_raise("Index %d doesn't match %s", i, dr::string(value).c_str()); + } + } + }); + + m.def(("test_" + name + "_loop_struct").c_str(), []() { + auto initial = Local(); + auto counter = int32_t(0); + + if constexpr (dr::is_jit_v) + initial.resize(10); + + + struct LoopState { + Local l; + int32_t counter; + + DRJIT_STRUCT(LoopState, l, counter) + } ls = { initial, counter }; + + dr::tie(ls) = dr::while_loop( + dr::make_tuple(ls), + [](const LoopState &ls) { return ls.counter < 5; }, + [](LoopState &ls) { + ls.l.write(ls.counter, dr::full(ls.counter)); + auto written = ls.l.read(ls.counter); + DRJIT_MARK_USED(written); + ls.counter += 1; + } + ); + for(unsigned int i = 0; i < 5; ++i) { + auto value = ls.l.read(i); + if(dr::any(value != dr::full(i))) { + jit_raise("Index %d doesn't match %s", i, dr::string(value).c_str()); + } + } + }); + + return c; +} + template void bind(nb::module_ &m) { using UInt32 = dr::uint32_array_t; + using Bool = dr::bool_array_t; - m.def("lookup", [](UInt32 offset, Float value, UInt32 offset2) { - dr::Local local(3.f); - local.write(offset, value); - return local.read(offset2); - }); + using Local10 = dr::Local; + using LocalDyn = dr::Local; + + bind_local(m, "Local10"); + + if constexpr (dr::is_jit_v) + bind_local(m, "LocalDyn"); + + struct MyStruct + { + Float value; + UInt32 priority; + DRJIT_STRUCT(MyStruct, value, priority) + + Bool operator!=(const MyStruct& other) const { + return priority != other.priority; + } + + MyStruct(int i) : value(i), priority(i) {} + }; + + auto mystruct = nb::class_(m, "MyStruct") + .def(nb::init<>()) + .def(nb::init()) + .def_rw("value", &MyStruct::value) + .def_rw("priority", &MyStruct::priority) + .def(nb::self != nb::self); + + nb::handle u32; + if constexpr (dr::is_array_v) + u32 = nb::type(); + else + u32 = nb::handle((PyObject *) &PyLong_Type); + + nb::handle f32; + if constexpr (dr::is_array_v) + f32 = nb::type(); + else + f32 = nb::handle((PyObject *) &PyFloat_Type); + + nb::dict fields; + fields["value"] = f32; + fields["priority"] = u32; + mystruct.attr("DRJIT_STRUCT") = fields; + + using LocalStruct10 = dr::Local; + bind_local(m, "LocalStruct10"); } NB_MODULE(local_ext, m) { diff --git a/tests/test_local_ext.py b/tests/test_local_ext.py index ba458f728..8d31fb2a0 100644 --- a/tests/test_local_ext.py +++ b/tests/test_local_ext.py @@ -12,10 +12,105 @@ def get_pkg(t): elif backend == dr.JitBackend.Invalid: return m.scalar +def is_constant_valued(local, value): + for i in range(len(local)): + assert dr.allclose(local.read(i), value) + + +@pytest.test_arrays('float32,shape=(*)') +def test01_initialization(t): + pkg = get_pkg(t) + initial = 25.4 + + l10 = pkg.Local10() + assert len(l10) == 10 + + with pytest.raises(AssertionError): + is_constant_valued(l10, initial) + + l10 = pkg.Local10(initial) + assert len(l10) == 10 + + is_constant_valued(l10, initial) + + +@pytest.test_arrays('float32,is_jit,shape=(*)') +def test02_dynamic_initialization(t): + pkg = get_pkg(t) + initial = dr.full(t, 25.4, 15) + + ldyn = pkg.LocalDyn(initial) + assert len(ldyn) == 1 + is_constant_valued(ldyn, initial) + + ldyn.resize(20) + assert len(ldyn) == 20 + is_constant_valued(ldyn, initial) + + +@pytest.test_arrays('float32,shape=(*)') +def test03_write_read(t): + pkg = get_pkg(t) + width = 20 + + local = pkg.Local10() + + for i in range(len(local)): + value = dr.arange(t, i, i+width) + if dr.backend_v(t) == dr.JitBackend.Invalid: + value = value[0] + local.write(i, value) + + sum = dr.zeros(t, width) + for i in range(len(local)): + sum += local.read(i) + + expected = dr.sum(dr.arange(t, len(local))) + (dr.arange(t, width) * len(local)) + + if dr.backend_v(t) == dr.JitBackend.Invalid: + sum = sum[0] + expected = expected[0] + + assert dr.allclose(sum, expected) + @pytest.test_arrays('float32,shape=(*)') -def test01_lookup(t): +def test04_struct(t): pkg = get_pkg(t) + width = 20 if dr.backend_v(t) == dr.JitBackend else 1 + + values = dr.ones(pkg.MyStruct, width) + local = pkg.LocalStruct10(values) + + def validate_index(idx, value): + struct = local.read(idx) + assert dr.width(struct) == width + assert dr.allclose(struct.value, value) + assert dr.allclose(struct.priority, value) + + for i in range(10): + validate_index(i, 1) + + values = dr.zeros(pkg.MyStruct, width) + local.write(0, values) + + validate_index(0, 0) + for i in range(1,10): + validate_index(i, 1) + + if dr.backend_v(t) == dr.JitBackend: + with pytest.raises(RuntimeError, match="out of bounds"): + validate_index(10, 1) + +@pytest.test_arrays('float32,shape=(*)') +def test05_loop(t): + pkg = get_pkg(t) + pkg.test_Local10_loop() + pkg.test_Local10_loop_struct() + + if dr.backend_v(t) == dr.JitBackend: + pkg.test_LocalDyn_loop() + pkg.test_LocalDyn_loop_struct() - assert dr.all(pkg.lookup(4, 5.0, 2) == 3.0) - assert dr.all(pkg.lookup(4, 5.0, 4) == 5.0) + pkg.test_LocalStruct10_loop() + pkg.test_LocalStruct10_loop_struct()