Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- 'lts'
- '1'
- 'nightly'
Expand Down
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
VectorInterfaceChainRulesCoreExt = "ChainRulesCore"
VectorInterfaceMooncakeExt = "Mooncake"

[compat]
Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
LinearAlgebra = "1"
Mooncake = "0.5"
Random = "1"
Test = "1"
TestExtras = "0.2,0.3"
julia = "1"
Comment thread
lkdvos marked this conversation as resolved.
Expand All @@ -25,8 +29,10 @@ julia = "1"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[targets]
test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore"]
test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Random"]
173 changes: 173 additions & 0 deletions ext/VectorInterfaceMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
module VectorInterfaceMooncakeExt

using VectorInterface
using Mooncake
using Mooncake: @is_primitive, DefaultCtx,
NoFData, NoRData, NoTangent,
CoDual, Dual, arrayify, primal, extract

# Projection
# ----------
"""
project_scalar(x::Number, dx::Number)

Project a computed tangent `dx` onto the correct tangent type for `x`.
For example, we might compute a complex `dx` but only require the real part.
"""
project_scalar(x::Number, dx::Number) = oftype(x, dx)
project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))

_needs_tangent(x) = _needs_tangent(typeof(x))
_needs_tangent(::Type{T}) where {T <: Number} =
Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData

# scale
# -----
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number}
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
α = primal(α_Δα)

# primal call
C_cache = copy(C)
scale!(C, α)

function scale_pullback(::NoRData)
copy!(C, C_cache)
Δαr = _needs_tangent(α) ? project_scalar(α, inner(C, ΔC)) : NoRData()
scale!(ΔC, conj(α))
return NoRData(), NoRData(), Δαr
end

return C_ΔC, scale_pullback
end

function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
α, Δα = extract(α_Δα)

if !isa(Δα, NoTangent)
add!(ΔC, C, Δα, α)
else
scale!(ΔC, α)
end
scale!(C, α)

return C_ΔC
end

@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number}

function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α = primal(α_Δα)

# primal call
C_cache = copy(C)
scale!(C, A, α)

function scale_pullback(::NoRData)
copy!(C, C_cache)
add!(ΔA, ΔC, conj(α))
Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData()
zerovector!(ΔC)
return NoRData(), NoRData(), NoRData(), Δαr
end

return C_ΔC, scale_pullback
end

function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α, Δα = extract(α_Δα)

scale!(ΔC, ΔA, α)
!isa(Δα, NoTangent) && add!(ΔC, A, Δα, One())
scale!(C, A, α)
return C_ΔC
end

# add
# ---

@is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number}

function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α = primal(α_Δα)
β = primal(β_Δβ)

# primal call
C_cache = copy(C)
add!(C, A, α, β)

function add_pullback(::NoRData)
copy!(C, C_cache)

Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData()
Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData()
add!(ΔA, ΔC, conj(α))
scale!(ΔC, conj(β))

return NoRData(), NoRData(), NoRData(), Δαr, Δβr
end

return C_ΔC, add_pullback
end

function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α, Δα = extract(α_Δα)
β, Δβ = extract(β_Δβ)
add!(ΔC, ΔA, α, β)
!isa(Δα, NoTangent) && add!(ΔC, A, Δα, One())
!isa(Δβ, NoTangent) && add!(ΔC, C, Δβ, One())
add!(C, A, α, β)
return C_ΔC
end


# inner
# -----

@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray}

function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray}, B_ΔB::CoDual{<:AbstractArray})
# prepare arguments
A, ΔA = arrayify(A_ΔA)
B, ΔB = arrayify(B_ΔB)

# primal call
s = inner(A, B)

function inner_pullback(Δs)
add!(ΔA, B, conj(Δs))
add!(ΔB, A, Δs)
return NoRData(), NoRData(), NoRData()
end

return CoDual(s, NoFData()), inner_pullback
end

function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractArray}, B_ΔB::Dual{<:AbstractArray})
# prepare arguments
A, ΔA = arrayify(A_ΔA)
B, ΔB = arrayify(B_ΔB)

s = inner(A, B)
Δs = inner(A, ΔB) + inner(ΔA, B)

return Dual(s, Δs)
end

end
88 changes: 88 additions & 0 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
module MooncakeTests

using VectorInterface
using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec
using Test, TestExtras
using Mooncake
using Random

rng = Random.default_rng()

precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32))
precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64))

eltypes = (Float32, Float64, ComplexF64)

@testset "scale ($T)" for T in eltypes
n = 12
atol = rtol = n * precision(T)

# Vector
x = randn(T, n)
y = randn(T, n)
α = randn(T)
Mooncake.TestUtils.test_rule(rng, scale, x, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!!, x, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!!, y, x, α; atol, rtol, is_primitive = false)

# MinimalMVec
mx = MinimalMVec(x)
my = MinimalMVec(y)
Mooncake.TestUtils.test_rule(rng, scale, mx, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!!, mx, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!!, my, mx, α; atol, rtol, is_primitive = false)

# MinimalSVec
mx = MinimalSVec(x)
my = MinimalSVec(y)
Mooncake.TestUtils.test_rule(rng, scale, mx, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!!, mx, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!!, my, mx, α; atol, rtol, is_primitive = false)
end

@testset "add pullbacks ($T)" for T in eltypes
n = 12
atol = rtol = n * precision(T)

# Vector
x = randn(T, n)
y = randn(T, n)
α = randn(T)
β = randn(T)
Mooncake.TestUtils.test_rule(rng, add, y, x, α, β; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!!, y, x, α, β; atol, rtol, is_primitive = false)

# MinimalMVec
mx = MinimalMVec(x)
my = MinimalMVec(y)
Mooncake.TestUtils.test_rule(rng, add, my, mx, α, β; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!!, my, mx, α, β; atol, rtol, is_primitive = false)

# MinimalSVec
mx = MinimalSVec(x)
my = MinimalSVec(y)
Mooncake.TestUtils.test_rule(rng, add, my, mx, α, β; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!!, my, mx, α, β; atol, rtol, is_primitive = false)
end

@testset "inner pullbacks ($T)" for T in eltypes
n = 12
atol = rtol = n * precision(T)

# Vector
x = randn(T, n)
y = randn(T, n)
Mooncake.TestUtils.test_rule(rng, inner, x, y; atol, rtol, is_primitive = false)

# MinimalMVec
mx = MinimalMVec(x)
my = MinimalMVec(y)
Mooncake.TestUtils.test_rule(rng, inner, mx, my; atol, rtol, is_primitive = false)

# MinimalSVec
mx = MinimalSVec(x)
my = MinimalSVec(y)
Mooncake.TestUtils.test_rule(rng, inner, mx, my; atol, rtol, is_primitive = false)
end

end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,10 @@ end
@static if isdefined(Base, :get_extension) && isempty(VERSION.prerelease)
println("Testing AD rules")
println("================")
println("Testing ChainRules")
println("==================")
include("chainrules.jl")
println("Testing Mooncake")
println("==================")
include("mooncake.jl")
end
Loading