From c7e4444f1d8adaab4c893045cf951c8eda01b998 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 10:40:56 +0200 Subject: [PATCH 1/5] Add Mooncake fwd and rvs rules --- Project.toml | 8 +- ext/VectorInterfaceMooncakeExt.jl | 181 ++++++++++++++++++++++++++++++ test/mooncake.jl | 97 ++++++++++++++++ test/runtests.jl | 5 + 4 files changed, 290 insertions(+), 1 deletion(-) create mode 100644 ext/VectorInterfaceMooncakeExt.jl create mode 100644 test/mooncake.jl diff --git a/Project.toml b/Project.toml index cc279e2..f12b8a6 100644 --- a/Project.toml +++ b/Project.toml @@ -8,15 +8,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] VectorInterfaceChainRulesCoreExt = "ChainRulesCore" +VectorInterfaceMooncakeExt = "Mooncake" [compat] Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" LinearAlgebra = "1" +Mooncake = "0.5" +Random = "1" Test = "1" TestExtras = "0.2,0.3" julia = "1" @@ -25,8 +29,10 @@ julia = "1" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [targets] -test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore"] +test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Random"] diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl new file mode 100644 index 0000000..eb40e03 --- /dev/null +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -0,0 +1,181 @@ +module VectorInterfaceMooncakeExt + +using VectorInterface +using Mooncake +using Mooncake: @is_primitive, DefaultCtx, + NoFData, NoRData, NoTangent, + CoDual, Dual, arrayify, primal, extract + +# Projection +# ---------- +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{T}) where {T <: Number} = + Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData + +# scale +# ----- +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number} +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + Δαr = _needs_tangent(α) ? project_scalar(α, inner(C, ΔC)) : NoRData() + scale!(ΔC, conj(α)) + return NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + α, Δα = extract(α_Δα) + + if !isa(Δα, NoTangent) + add!(ΔC, C, Δα, α) + else + scale!(ΔC, α) + end + scale!(C, α) + + return C_ΔC +end + +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number} + +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, A, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + add!(ΔA, ΔC, conj(α)) + Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() + zerovector!(ΔC) + return NoRData(), NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α, Δα = extract(α_Δα) + + scale!(ΔC, ΔA, α) + if !isa(Δα, NoTangent) + add!(ΔC, A, Δα, One()) + end + scale!(C, A, α) + return C_ΔC +end + +# add +# --- + +@is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number} + +function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + β = primal(β_Δβ) + + # primal call + C_cache = copy(C) + add!(C, A, α, β) + + function add_pullback(::NoRData) + copy!(C, C_cache) + + Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() + Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() + add!(ΔA, ΔC, conj(α)) + scale!(ΔC, conj(β)) + + return NoRData(), NoRData(), NoRData(), Δαr, Δβr + end + + return C_ΔC, add_pullback +end + +function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α, Δα = extract(α_Δα) + β, Δβ = extract(β_Δβ) + add!(ΔC, ΔA, α, β) + if isa(Δβ, NoTangent) && !isa(Δα, NoTangent) + add!(ΔC, A, Δα, One()) + elseif isa(Δα, NoTangent) && !isa(Δβ, NoTangent) + add!(ΔC, C, Δβ, One()) + elseif !isa(Δα, NoTangent) && !isa(Δβ, NoTangent) + add!(ΔC, A, Δα, One()) + add!(ΔC, C, Δβ, One()) + end + add!(C, A, α, β) + return C_ΔC +end + + +# inner +# ----- + +@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray} + +function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray}, B_ΔB::CoDual{<:AbstractArray}) + # prepare arguments + A, ΔA = arrayify(A_ΔA) + B, ΔB = arrayify(B_ΔB) + + # primal call + s = inner(A, B) + + function inner_pullback(Δs) + add!(ΔA, B, conj(Δs)) + add!(ΔB, A, Δs) + return NoRData(), NoRData(), NoRData() + end + + return CoDual(s, NoFData()), inner_pullback +end + +function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractArray}, B_ΔB::Dual{<:AbstractArray}) + # prepare arguments + A, ΔA = arrayify(A_ΔA) + B, ΔB = arrayify(B_ΔB) + + s = inner(A, B) + Δs = inner(A, ΔB) + inner(ΔA, B) + + return Dual(s, Δs) +end + +end diff --git a/test/mooncake.jl b/test/mooncake.jl new file mode 100644 index 0000000..714f958 --- /dev/null +++ b/test/mooncake.jl @@ -0,0 +1,97 @@ +module MooncakeTests + +using VectorInterface +using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec +using Test, TestExtras +using Mooncake +using Random + +rng = Random.default_rng() + +precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32)) +precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64)) + +# Small adaptations to make tests work with MinimalVec +#=function ChainRulesTestUtils.test_approx(::AbstractZero, x::MinimalVec, msg = ""; kwargs...) + return test_approx(zerovector(x), x, msg; kwargs...) +end +function ChainRulesTestUtils.test_approx(x::MinimalVec, ::AbstractZero, msg = ""; kwargs...) + return test_approx(x, zerovector(x), msg; kwargs...) +end +Base.collect(x::MinimalVec) = x.vec=# + +eltypes = (Float32, Float64, ComplexF64) + +@testset "scale ($T)" for T in eltypes + n = 12 + atol = rtol = n * precision(T) + + # Vector + x = randn(T, n) + y = randn(T, n) + α = randn(T) + Mooncake.TestUtils.test_rule(rng, scale, x, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!!, x, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!!, y, x, α; atol, rtol, is_primitive = false) + + # MinimalMVec + mx = MinimalMVec(x) + my = MinimalMVec(y) + Mooncake.TestUtils.test_rule(rng, scale, mx, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!!, mx, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!!, my, mx, α; atol, rtol, is_primitive = false) + + # MinimalSVec + mx = MinimalSVec(x) + my = MinimalSVec(y) + Mooncake.TestUtils.test_rule(rng, scale, mx, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!!, mx, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!!, my, mx, α; atol, rtol, is_primitive = false) +end + +@testset "add pullbacks ($T)" for T in eltypes + n = 12 + atol = rtol = n * precision(T) + + # Vector + x = randn(T, n) + y = randn(T, n) + α = randn(T) + β = randn(T) + Mooncake.TestUtils.test_rule(rng, add, y, x, α, β; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!!, y, x, α, β; atol, rtol, is_primitive = false) + + # MinimalMVec + mx = MinimalMVec(x) + my = MinimalMVec(y) + Mooncake.TestUtils.test_rule(rng, add, my, mx, α, β; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!!, my, mx, α, β; atol, rtol, is_primitive = false) + + # MinimalSVec + mx = MinimalSVec(x) + my = MinimalSVec(y) + Mooncake.TestUtils.test_rule(rng, add, my, mx, α, β; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!!, my, mx, α, β; atol, rtol, is_primitive = false) +end + +@testset "inner pullbacks ($T)" for T in eltypes + n = 12 + atol = rtol = n * precision(T) + + # Vector + x = randn(T, n) + y = randn(T, n) + Mooncake.TestUtils.test_rule(rng, inner, x, y; atol, rtol, is_primitive = false) + + # MinimalMVec + mx = MinimalMVec(x) + my = MinimalMVec(y) + Mooncake.TestUtils.test_rule(rng, inner, mx, my; atol, rtol, is_primitive = false) + + # MinimalSVec + mx = MinimalSVec(x) + my = MinimalSVec(y) + Mooncake.TestUtils.test_rule(rng, inner, mx, my; atol, rtol, is_primitive = false) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 94710df..2e53df6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,5 +38,10 @@ end @static if isdefined(Base, :get_extension) && isempty(VERSION.prerelease) println("Testing AD rules") println("================") + println("Testing ChainRules") + println("==================") include("chainrules.jl") + println("Testing Mooncake") + println("==================") + include("mooncake.jl") end From f779b68e8ec70185822002f23facb013ad3ab7d6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 14:14:24 +0200 Subject: [PATCH 2/5] Formatter --- ext/VectorInterfaceMooncakeExt.jl | 4 ++-- test/mooncake.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index eb40e03..64c984d 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -3,8 +3,8 @@ module VectorInterfaceMooncakeExt using VectorInterface using Mooncake using Mooncake: @is_primitive, DefaultCtx, - NoFData, NoRData, NoTangent, - CoDual, Dual, arrayify, primal, extract + NoFData, NoRData, NoTangent, + CoDual, Dual, arrayify, primal, extract # Projection # ---------- diff --git a/test/mooncake.jl b/test/mooncake.jl index 714f958..566b091 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -3,7 +3,7 @@ module MooncakeTests using VectorInterface using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec using Test, TestExtras -using Mooncake +using Mooncake using Random rng = Random.default_rng() From 7747d2fdefaf71e9681302e40491f8f4ab86adfa Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 16:21:01 +0200 Subject: [PATCH 3/5] Apply Jutho's cleanup suggestions Co-authored-by: Jutho --- ext/VectorInterfaceMooncakeExt.jl | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index 64c984d..57c26b8 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -88,9 +88,7 @@ function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α, Δα = extract(α_Δα) scale!(ΔC, ΔA, α) - if !isa(Δα, NoTangent) - add!(ΔC, A, Δα, One()) - end + !isa(Δα, NoTangent) && add!(ΔC, A, Δα, One()) scale!(C, A, α) return C_ΔC end @@ -132,14 +130,8 @@ function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ α, Δα = extract(α_Δα) β, Δβ = extract(β_Δβ) add!(ΔC, ΔA, α, β) - if isa(Δβ, NoTangent) && !isa(Δα, NoTangent) - add!(ΔC, A, Δα, One()) - elseif isa(Δα, NoTangent) && !isa(Δβ, NoTangent) - add!(ΔC, C, Δβ, One()) - elseif !isa(Δα, NoTangent) && !isa(Δβ, NoTangent) - add!(ΔC, A, Δα, One()) - add!(ΔC, C, Δβ, One()) - end + !isa(Δα, NoTangent) && add!(ΔC, A, Δα, One()) + !isa(Δβ, NoTangent) && add!(ΔC, C, Δβ, One()) add!(C, A, α, β) return C_ΔC end From 6c990e379f9b491fbc060e44db22fc407c97be78 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 16:24:15 +0200 Subject: [PATCH 4/5] Remove 1.6 --- .github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e5e3b1c..5a2ed93 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - 'lts' - '1' - 'nightly' From c250dc0b7d0b9a32e309e0261e2881546104e1cc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 16:24:31 +0200 Subject: [PATCH 5/5] Remove unneeded commented out stuff --- test/mooncake.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/mooncake.jl b/test/mooncake.jl index 566b091..d8466d1 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -11,15 +11,6 @@ rng = Random.default_rng() precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32)) precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64)) -# Small adaptations to make tests work with MinimalVec -#=function ChainRulesTestUtils.test_approx(::AbstractZero, x::MinimalVec, msg = ""; kwargs...) - return test_approx(zerovector(x), x, msg; kwargs...) -end -function ChainRulesTestUtils.test_approx(x::MinimalVec, ::AbstractZero, msg = ""; kwargs...) - return test_approx(x, zerovector(x), msg; kwargs...) -end -Base.collect(x::MinimalVec) = x.vec=# - eltypes = (Float32, Float64, ComplexF64) @testset "scale ($T)" for T in eltypes