From 0f142de57a25cd6d22a9202564941f925f90ff58 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 09:22:26 -0400 Subject: [PATCH 1/3] Working EIG + CUDA AD --- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 8 ++++++++ src/MatrixAlgebraKit.jl | 2 +- src/pullbacks/eig.jl | 7 ++++--- test/mooncake/eig.jl | 5 +++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index caf13b678..b9563a7df 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -8,7 +8,11 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj! import MatrixAlgebraKit: heevj!, heevd!, geev! +<<<<<<< HEAD import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback! +======= +import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eig_pullback! +>>>>>>> 292f243 (Working EIG + CUDA AD) using CUDA, CUDA.cuBLAS using CUDA: i32 using LinearAlgebra @@ -209,4 +213,8 @@ function eigh_pullback!(ΔA::AnyCuMatrix, A, DV, ΔDV, ind::AnyCuVector; kwargs. return eigh_pullback!(ΔA, A, DV, ΔDV, collect(ind); kwargs...) end +function eig_pullback!(ΔA::AnyCuMatrix, A, DV, ΔDV, ind::AnyCuVector; kwargs...) + return eig_pullback!(ΔA, A, DV, ΔDV, collect(ind); kwargs...) +end + end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 22bf79e9c..25dc60c1b 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -5,7 +5,7 @@ using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl? using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv! using LinearAlgebra: sylvester, lu!, diagm using LinearAlgebra: isposdef, issymmetric -using LinearAlgebra: Diagonal, diag, diagind, isdiag +using LinearAlgebra: Diagonal, Hermitian, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 0d6c2cfde..e8eaeeaed 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -30,7 +30,7 @@ function check_and_prepare_eig_cotangents( bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) end - Δgauge = norm(bc, Inf) + Δgauge = maximum(abs, bc; init = abs(zero(eltype(D)))) Δgauge ≤ gauge_atol || @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" @@ -41,7 +41,8 @@ function check_and_prepare_eig_cotangents( if !iszerotangent(ΔDmat) ΔD = diagview(ΔDmat) length(indD) == length(ΔD) || throw(DimensionMismatch()) - view(diagview(VᴴAΔV), indD) .+= ΔD + # needed to avoid GPUCompiler errors + VᴴAΔV[diagind(VᴴAΔV)[indD]] .+= ΔD else ΔD = nothing end @@ -243,7 +244,7 @@ function remove_eig_gauge_dependence!( Ddiag = diagview(D) gaugepart = V' * ΔV gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 - ViG = V / LinearAlgebra.cholesky!(V' * V) + ViG = V / LinearAlgebra.cholesky!(Hermitian(V' * V)) mul!(ΔV, ViG, gaugepart, -1, 1) return ΔV end diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index a0e606941..105021b98 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -18,4 +18,9 @@ for T in (BLASFloats..., GenericFloats...) AT = Diagonal{T, Vector{T}} TestSuite.test_mooncake_eig(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end + if T ∈ BLASFloats && CUDA.functional() + TestSuite.test_mooncake_eig(CuMatrix{T}, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_eig(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end end From 033ba1835aaf61b85ec864b0fbdaf7013e355b05 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 10:02:19 -0400 Subject: [PATCH 2/3] Formatter --- src/pullbacks/eig.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index e8eaeeaed..cce8577bc 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -41,7 +41,7 @@ function check_and_prepare_eig_cotangents( if !iszerotangent(ΔDmat) ΔD = diagview(ΔDmat) length(indD) == length(ΔD) || throw(DimensionMismatch()) - # needed to avoid GPUCompiler errors + # needed to avoid GPUCompiler errors VᴴAΔV[diagind(VᴴAΔV)[indD]] .+= ΔD else ΔD = nothing From b928170613b290c4099e6885d83a87e2e76b8246 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 20 May 2026 14:32:13 -0400 Subject: [PATCH 3/3] Use different RNG seed --- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 6 +----- test/mooncake/eig.jl | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index b9563a7df..558fcec45 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -8,11 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj! import MatrixAlgebraKit: heevj!, heevd!, geev! -<<<<<<< HEAD -import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback! -======= -import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eig_pullback! ->>>>>>> 292f243 (Working EIG + CUDA AD) +import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!, eig_pullback! using CUDA, CUDA.cuBLAS using CUDA: i32 using LinearAlgebra diff --git a/test/mooncake/eig.jl b/test/mooncake/eig.jl index 105021b98..c3d395bca 100644 --- a/test/mooncake/eig.jl +++ b/test/mooncake/eig.jl @@ -12,7 +12,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) - TestSuite.seed_rng!(1234) + TestSuite.seed_rng!(12345) if !is_buildkite TestSuite.test_mooncake_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}}