diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 4efffbcc7..0bdb10497 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -8,7 +8,7 @@ using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTrun using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj! import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx! -import MatrixAlgebraKit: _sylvester, svd_rank +import MatrixAlgebraKit: _sylvester, svd_rank, svd_pullback! using AMDGPU using LinearAlgebra using LinearAlgebra: BlasFloat @@ -185,6 +185,12 @@ function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix) return ROCArray(hX) end -svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s ≥ rank_atol, S) +function svd_rank(S::AnyROCVector; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S)) + return something(findlast(≥(rank_atol), S), 0) +end + +function svd_pullback!(ΔA::AnyROCMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyROCVector; kwargs...) + return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, collect(ind); kwargs...) +end end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 25a739df0..f9a021644 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj! import MatrixAlgebraKit: heevj!, heevd!, geev! -import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank +import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback! using CUDA, CUDA.cuBLAS using CUDA: i32 using LinearAlgebra @@ -197,6 +197,12 @@ function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) return CuArray(hX) end -svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s ≥ rank_atol, S) +function svd_rank(S::AnyCuVector; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S)) + return something(findlast(≥(rank_atol), S), 0) +end + +function svd_pullback!(ΔA::AnyCuMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyCuVector; kwargs...) + return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, collect(ind); kwargs...) +end end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index a1d5d534d..e2ae96c11 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -15,6 +15,20 @@ using LinearAlgebra Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent +# needed for GPU tests because Mooncake can't differentiate through CUDA kernels +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(zero!), AbstractArray} +function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual) + A, dA = arrayify(A_dA) + Ac = copy(A) + zero!(A) + function zero_adjoint(::NoRData) + copy!(A, Ac) + zero!(dA) + return NoRData(), NoRData() + end + return A_dA, zero_adjoint +end + # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index a6e27104a..de1f91fa8 100755 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -85,19 +85,17 @@ function check_and_prepare_svd_cotangents( bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) end - Δgauge = max(Δgauge, norm(bc, Inf)) + Δgauge = max(Δgauge, maximum(abs, bc)) if !iszerotangent(ΔSmat) ΔS = diagview(ΔSmat) length(indS) == length(ΔS) || throw(DimensionMismatch(lazy"length of selected S values ($(length(indS))) does not match length of ΔS ($(length(ΔS)))")) + bad_indS = _ind_intersect((r + 1):length(ΔS), indS) + good_indS = _ind_intersect(1:r, indS) ΔS₁ = zero(S₁) - for (j, i) in enumerate(indS) - if i <= r - ΔS₁[i] = real(ΔS[j]) - else - Δgauge = max(Δgauge, abs(ΔS[j])) - end - end + ΔS₁[1:length(good_indS)] .= real.(ΔS[good_indS]) + badΔS₁ = view(ΔS, bad_indS) + Δgauge = max(Δgauge, maximum(abs, badΔS₁; init = abs(zero(eltype(ΔS))))) else ΔS₁ = nothing end diff --git a/test/mooncake/svd.jl b/test/mooncake/svd.jl index f096fdb8e..2b63198f4 100644 --- a/test/mooncake/svd.jl +++ b/test/mooncake/svd.jl @@ -20,4 +20,11 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end + if T ∈ BLASFloats && CUDA.functional() + TestSuite.test_mooncake_svd(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + if m == n + AT = Diagonal{T, CuVector{T}} + TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + end + end end diff --git a/test/testsuite/enzyme/svd.jl b/test/testsuite/enzyme/svd.jl index 530663558..2131aa8d5 100644 --- a/test/testsuite/enzyme/svd.jl +++ b/test/testsuite/enzyme/svd.jl @@ -69,7 +69,7 @@ function test_enzyme_svd_trunc( end @testset "trunctol" begin S = svd_vals(A, alg) - trunc = trunctol(atol = S[1] / 2) + trunc = trunctol(atol = maximum(S) / 2) truncalg = TruncatedAlgorithm(alg, trunc) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index 23b1d70b0..5ac79744e 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -141,7 +141,7 @@ function test_mooncake_svd_trunc( @testset "trunctol" begin S = svd_vals(A) - trunc = trunctol(atol = S[1] / 2) + trunc = trunctol(atol = maximum(S) / 2) alg_trunc = TruncatedAlgorithm(alg, trunc) USVᴴ, USVᴴtrunc, ΔUSVᴴ_arrays, ΔUSVᴴtrunc_arrays = ad_svd_trunc_setup(A, alg_trunc)