diff --git a/Project.toml b/Project.toml index af43fc925..d7ea297cc 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" @@ -32,6 +34,8 @@ TensorKitAdaptExt = "Adapt" TensorKitAMDGPUExt = "AMDGPU" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" +TensorKitEnzymeExt = "Enzyme" +TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils" TensorKitFiniteDifferencesExt = "FiniteDifferences" TensorKitMooncakeExt = "Mooncake" @@ -44,6 +48,8 @@ AMDGPU = "2" CUDA = "6" ChainRulesCore = "1" Dictionaries = "0.4" +Enzyme = "0.13.146" +EnzymeTestUtils = "0.2.5" FiniteDifferences = "0.12" LRUCache = "1.0.2" LinearAlgebra = "1" @@ -55,8 +61,11 @@ Random = "1" ScopedValues = "1.3.0" Strided = "2" TensorKitSectors = "0.3.7" -TensorOperations = "5.1" +TensorOperations = "5.5.2" TupleTools = "1.5" VectorInterface = "0.4.8, 0.5" cuTENSOR = "6" julia = "1.10" + +[sources] +VectorInterface = {url = "https://github.com/QuantumKitHub/VectorInterface.jl", rev = "main"} diff --git a/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl new file mode 100644 index 000000000..338f92137 --- /dev/null +++ b/ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl @@ -0,0 +1,9 @@ +module TensorKitEnzymeExt + +using Enzyme +using TensorKit +import TensorKit as TK + +include("utility.jl") + +end diff --git a/ext/TensorKitEnzymeExt/utility.jl b/ext/TensorKitEnzymeExt/utility.jl new file mode 100644 index 000000000..3106861da --- /dev/null +++ b/ext/TensorKitEnzymeExt/utility.jl @@ -0,0 +1,38 @@ +# Projection +# ---------- +pullback_dα(α::Const, C::Const, A) = nothing +pullback_dα(α::Const, C::Annotation, A) = nothing +pullback_dα(α::Annotation, C::Const, A) = zero(α.val) +pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval)) + +pullback_dβ(β::Const, C::Const, Ccache) = nothing +pullback_dβ(β::Const, C::Annotation, Ccache) = nothing +pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val) +pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval)) + +pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β)) + +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + +# Ignore derivatives +# ------------------ + +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true +@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true + +@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing +@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing diff --git a/ext/TensorKitEnzymeTestUtilsExt.jl b/ext/TensorKitEnzymeTestUtilsExt.jl new file mode 100644 index 000000000..ca0a42e89 --- /dev/null +++ b/ext/TensorKitEnzymeTestUtilsExt.jl @@ -0,0 +1,60 @@ +module TensorKitEnzymeTestUtilsExt + +using TensorKit +using EnzymeTestUtils +using EnzymeTestUtils: Enzyme +import EnzymeTestUtils: to_vec, from_vec, rand_tangent + +function EnzymeTestUtils.to_vec(x::TensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + has_seen = haskey(seen_vecs, x) + is_const = Enzyme.Compiler.guaranteed_const(Core.Typeof(x)) + if has_seen || is_const + x_vec = Float32[] + else + vec_of_vecs = [b * TensorKit.sqrtdim(c) for (c, b) in blocks(x)] + x_vec, back = to_vec(vec_of_vecs) + seen_vecs[x] = x_vec + end + function TensorMap_from_vec(x_vec_new::AbstractVector, seen_xs::EnzymeTestUtils.AliasDict) + if xor(has_seen, haskey(seen_xs, x)) + throw(ErrorException("Arrays must be reconstructed in the same order as they are vectorized.")) + end + has_seen && return seen_xs[x] + is_const && return x + + x_new = similar(x) + xvec_of_vecs = back(x_vec_new) + for (i, (c, b)) in enumerate(blocks(x_new)) + scale!(b, xvec_of_vecs[i], TensorKit.invsqrtdim(c)) + end + if Core.Typeof(x_new) != Core.Typeof(x) + x_new = Core.Typeof(x)(x_new) + end + seen_xs[x] = x_new + return x_new + end + return x_vec, TensorMap_from_vec +end +function EnzymeTestUtils.to_vec(t::TensorKit.AdjointTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(parent(t), seen_vecs) + return parent_vec, adjoint ∘ parent_t +end +function EnzymeTestUtils.to_vec(t::TensorKit.DiagonalTensorMap, seen_vecs::EnzymeTestUtils.AliasDict) + parent_vec, parent_t = to_vec(TensorMap(t), seen_vecs) + return parent_vec, TensorKit.DiagonalTensorMap ∘ parent_t +end + +# generate random tangents for testing +function EnzymeTestUtils.rand_tangent(rng, t::TensorMap) + return TensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t)) +end + +function EnzymeTestUtils.rand_tangent(rng, t::TensorKit.AdjointTensorMap) + return adjoint(rand_tangent(rng, parent(t))) +end + +function EnzymeTestUtils.rand_tangent(rng, t::DiagonalTensorMap) + return DiagonalTensorMap(EnzymeTestUtils.rand_tangent(rng, t.data), space(t, 1)) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 18af8af80..56d11900a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -36,6 +38,8 @@ AllocCheck = "0.2" ChainRulesTestUtils = "1" Combinatorics = "1" GPUArrays = "11.3.1" +Enzyme = "0.13.134" +EnzymeTestUtils = "0.2.5" JET = "0.9, 0.10, 0.11" ParallelTestRunner = "2" Test = "1" diff --git a/test/enzyme-vectorinterface/add.jl b/test/enzyme-vectorinterface/add.jl new file mode 100644 index 000000000..7dc3c76b0 --- /dev/null +++ b/test/enzyme-vectorinterface/add.jl @@ -0,0 +1,44 @@ +using Test, TestExtras +using TensorKit, Enzyme, EnzymeTestUtils +using TensorOperations +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@testset "Enzyme - VectorInterface (add!) $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + for TC in (Duplicated, Const), TA in (Duplicated, Const) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA); atol, rtol, testset_name = "add! TC $TC TA $TA no α no β") + EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA); atol, rtol, testset_name = "add! TC $TC TA $TA no α no β") + for Tα in (Active, Const) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα no β") + for Tβ in (Active, Const) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_reverse(add!, TC, (C, TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα Tβ $Tβ") + end + end + for Tα in (Duplicated, Const) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA), (α, Tα); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα no β") + for Tβ in (Duplicated, Const) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_forward(add!, TC, (C, TC), (A, TA), (α, Tα), (β, Tβ); atol, rtol, testset_name = "add! TC $TC TA $TA Tα $Tα Tβ $Tβ") + end + end + end +end diff --git a/test/enzyme-vectorinterface/inner.jl b/test/enzyme-vectorinterface/inner.jl new file mode 100644 index 000000000..b2cb17e2d --- /dev/null +++ b/test/enzyme-vectorinterface/inner.jl @@ -0,0 +1,25 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random, FiniteDifferences + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@testset "Enzyme - VectorInterface" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + @testset for TC in (Duplicated, Const), TA in (Duplicated, Const), f in (identity, adjoint) + atol = default_tol(T) + rtol = default_tol(T) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + for RT in (Active, Const) + EnzymeTestUtils.test_reverse(inner, RT, (f(C), TC), (f(A), TA); atol, rtol) + end + for RT in (Duplicated, Const) + EnzymeTestUtils.test_forward(inner, RT, (f(C), TC), (f(A), TA); atol, rtol) + end + end + end +end diff --git a/test/enzyme-vectorinterface/scale.jl b/test/enzyme-vectorinterface/scale.jl new file mode 100644 index 000000000..622cca252 --- /dev/null +++ b/test/enzyme-vectorinterface/scale.jl @@ -0,0 +1,40 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Enzyme, EnzymeTestUtils +using Random + +spacelist = ad_spacelist(fast_tests) +eltypes = (Float64, ComplexF64) + +@testset "Enzyme - VectorInterface (scale!)" begin + @timedtestset "$(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + α = randn(T) + @testset for TC in (Duplicated,) + for Tα in (Active, Const) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_reverse(scale!, TC, (C, TC), (α, Tα); atol, rtol) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_reverse(scale!, TC, (C', TC), (α, Tα); atol, rtol) + @testset for TA in (Duplicated,), (fc, fa) in ((identity, identity), (adjoint, adjoint)) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_reverse(scale!, TC, (fc(C), TC), (fa(A), TA), (α, Tα); atol, rtol) + end + end + for Tα in (Duplicated, Const) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_forward(scale!, TC, (C, TC), (α, Tα); atol, rtol) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_forward(scale!, TC, (C', TC), (α, Tα); atol, rtol) + @testset for TA in (Duplicated,), (fc, fa) in ((identity, identity), (adjoint, adjoint)) + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + EnzymeTestUtils.test_forward(scale!, TC, (fc(C), TC), (fa(A), TA), (α, Tα); atol, rtol) + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 881d538a8..7e76a60e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ end if (Sys.isapple() && get(ENV, "CI", "false") == "true") || !isempty(VERSION.prerelease) filter!(!startswith("chainrules") ∘ first, testsuite) filter!(!startswith("mooncake") ∘ first, testsuite) + filter!(!startswith("enzyme") ∘ first, testsuite) end args = parse_args(ARGS; custom = ["fast"])