Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -32,6 +34,8 @@ TensorKitAdaptExt = "Adapt"
TensorKitAMDGPUExt = "AMDGPU"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -44,6 +48,8 @@ AMDGPU = "2"
CUDA = "6"
ChainRulesCore = "1"
Dictionaries = "0.4"
Enzyme = "0.13.146"
EnzymeTestUtils = "0.2.5"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
Expand All @@ -55,8 +61,11 @@ Random = "1"
ScopedValues = "1.3.0"
Strided = "2"
TensorKitSectors = "0.3.7"
TensorOperations = "5.1"
TensorOperations = "5.5.2"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5"
cuTENSOR = "6"
julia = "1.10"

[sources]
VectorInterface = {url = "https://github.com/QuantumKitHub/VectorInterface.jl", rev = "main"}
9 changes: 9 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK

include("utility.jl")

end
38 changes: 38 additions & 0 deletions ext/TensorKitEnzymeExt/utility.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Projection
# ----------
pullback_dα(α::Const, C::Const, A) = nothing
pullback_dα(α::Const, C::Annotation, A) = nothing
pullback_dα(α::Annotation, C::Const, A) = zero(α.val)
pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval))

pullback_dβ(β::Const, C::Const, Ccache) = nothing
pullback_dβ(β::Const, C::Annotation, Ccache) = nothing
pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val)
pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval))

pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β))

"""
project_scalar(x::Number, dx::Number)

Project a computed tangent `dx` onto the correct tangent type for `x`.
For example, we might compute a complex `dx` but only require the real part.
"""
project_scalar(x::Number, dx::Number) = oftype(x, dx)
project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))

# Ignore derivatives
# ------------------

@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true

@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing
@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing
60 changes: 60 additions & 0 deletions ext/TensorKitEnzymeTestUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module TensorKitEnzymeTestUtilsExt

using TensorKit
using EnzymeTestUtils
using EnzymeTestUtils: Enzyme
import EnzymeTestUtils: to_vec, from_vec, rand_tangent

function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
has_seen = haskey(seen_vecs, x)
is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x))
if has_seen || is_const
x_vec = Float32[]
else
vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)]
x_vec, back = to_vec(vec_of_vecs)
seen_vecs[x] = x_vec
end
function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict)
if xor(has_seen, haskey(seen_xs, x))
throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized."))
end
has_seen && return seen_xs[x]
is_const && return x

x_new = similar(x)
xvec_of_vecs = back(x_vec_new)
for (i, (c, b)) in enumerate(blocks(x_new))
scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c))
end
if Core.Typeof(x_new) != Core.Typeof(x)
x_new = Core.Typeof(x)(x_new)
end
seen_xs[x] = x_new
return x_new
end
return x_vec, TensorMap_from_vec
end
function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
parent_vec, parent_t = to_vec(parent(t), seen_vecs)
return parent_vec, adjoint ∘ parent_t
end
function EnzymeTestUtils.to_vec(t::TensorKit.DiagonalTensorMap, seen_vecs::EnzymeTestUtils.AliasDict)
parent_vec, parent_t = to_vec(TensorMap(t), seen_vecs)
return parent_vec, TensorKit.DiagonalTensorMap ∘ parent_t
end

# generate random tangents for testing
function EnzymeTestUtils.rand_tangent(rng, t::TensorMap)
return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t))
end

function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap)
return adjoint(rand_tangent(rng, parent(t)))
end

function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap)
return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1))
end

end
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand Down Expand Up @@ -36,6 +38,8 @@ AllocCheck = "0.2"
ChainRulesTestUtils = "1"
Combinatorics = "1"
GPUArrays = "11.3.1"
Enzyme = "0.13.134"
EnzymeTestUtils = "0.2.5"
JET = "0.9, 0.10, 0.11"
ParallelTestRunner = "2"
Test = "1"
Expand Down
44 changes: 44 additions & 0 deletions test/enzyme-vectorinterface/add.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Test, TestExtras
using TensorKit, Enzyme, EnzymeTestUtils
using TensorOperations
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@testset "Enzyme - VectorInterface (add!) $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
atol = default_tol(T)
rtol = default_tol(T)

C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
α = randn(T)
β = randn(T)

for TC in (Duplicated, Const), TA in (Duplicated, Const)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA); atol, rtol, testset_name = "add! TC $TC TA $TA no α no β")
EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA); atol, rtol, testset_name = "add! TC $TC TA $TA no α no β")
for Tα in (Active, Const)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα no β")
for Tβ in (Active, Const)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα Tβ $Tβ")
end
end
for Tα in (Duplicated, Const)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα no β")
for Tβ in (Duplicated, Const)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα Tβ $Tβ")
end
end
end
end
25 changes: 25 additions & 0 deletions test/enzyme-vectorinterface/inner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using Enzyme, EnzymeTestUtils
using Random, FiniteDifferences

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@testset "Enzyme - VectorInterface" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
@testset for TC in (Duplicated, Const), TA in (Duplicated, Const), f in (identity, adjoint)
atol = default_tol(T)
rtol = default_tol(T)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
for RT in (Active, Const)
EnzymeTestUtils.test_reverse(inner, RT, (f(C), TC), (f(A), TA); atol, rtol)
end
for RT in (Duplicated, Const)
EnzymeTestUtils.test_forward(inner, RT, (f(C), TC), (f(A), TA); atol, rtol)
end
end
end
end
40 changes: 40 additions & 0 deletions test/enzyme-vectorinterface/scale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using Enzyme, EnzymeTestUtils
using Random

spacelist = ad_spacelist(fast_tests)
eltypes = (Float64, ComplexF64)

@testset "Enzyme - VectorInterface (scale!)" begin
@timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes
atol = default_tol(T)
rtol = default_tol(T)
α = randn(T)
@testset for TC in (Duplicated,)
for Tα in (Active, Const)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (α, Tα); atol, rtol)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (α, Tα); atol, rtol)
@testset for TA in (Duplicated,), (fc, fa) in ((identity, identity), (adjoint, adjoint))
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_reverse(scale!, TC, (fc(C), TC), (fa(A), TA), (α, Tα); atol, rtol)
end
end
for Tα in (Duplicated, Const)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_forward(scale!, TC, (C, TC), (α, Tα); atol, rtol)
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_forward(scale!, TC, (C', TC), (α, Tα); atol, rtol)
@testset for TA in (Duplicated,), (fc, fa) in ((identity, identity), (adjoint, adjoint))
C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5])
EnzymeTestUtils.test_forward(scale!, TC, (fc(C), TC), (fa(A), TA), (α, Tα); atol, rtol)
end
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
if (Sys.isapple() && get(ENV, "CI", "false") == "true") || !isempty(VERSION.prerelease)
filter!(!startswith("chainrules") ∘ first, testsuite)
filter!(!startswith("mooncake") ∘ first, testsuite)
filter!(!startswith("enzyme") ∘ first, testsuite)
end

args = parse_args(ARGS; custom = ["fast"])
Expand Down
Loading