Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTrun
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, heev!, heevx!
import MatrixAlgebraKit: _sylvester, svd_rank
import MatrixAlgebraKit: _sylvester, svd_rank, svd_pullback!
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand Down Expand Up @@ -185,6 +185,12 @@ function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
return ROCArray(hX)
end

svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s ≥ rank_atol, S)
function svd_rank(S::AnyROCVector; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S))
return something(findlast(≥(rank_atol), S), 0)
Comment thread
lkdvos marked this conversation as resolved.
end

function svd_pullback!(ΔA::AnyROCMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyROCVector; kwargs...)
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, collect(ind); kwargs...)
end

end
10 changes: 8 additions & 2 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, geev!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!
using CUDA, CUDA.cuBLAS
using CUDA: i32
using LinearAlgebra
Expand Down Expand Up @@ -197,6 +197,12 @@ function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
return CuArray(hX)
end

svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s ≥ rank_atol, S)
function svd_rank(S::AnyCuVector; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S))
return something(findlast(≥(rank_atol), S), 0)
Comment thread
lkdvos marked this conversation as resolved.
end
Comment thread
kshyatt marked this conversation as resolved.

function svd_pullback!(ΔA::AnyCuMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyCuVector; kwargs...)
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, collect(ind); kwargs...)
end

end
14 changes: 14 additions & 0 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ using LinearAlgebra

Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent

# needed for GPU tests because Mooncake can't differentiate through CUDA kernels
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(zero!), AbstractArray}
function Mooncake.rrule!!(::CoDual{typeof(zero!)}, A_dA::CoDual)
A, dA = arrayify(A_dA)
Ac = copy(A)
zero!(A)
function zero_adjoint(::NoRData)
copy!(A, Ac)
zero!(dA)
return NoRData(), NoRData()
end
return A_dA, zero_adjoint
end

# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
Expand Down
14 changes: 6 additions & 8 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,17 @@ function check_and_prepare_svd_cotangents(
bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v
return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v)
end
Δgauge = max(Δgauge, norm(bc, Inf))
Δgauge = max(Δgauge, maximum(abs, bc))

if !iszerotangent(ΔSmat)
ΔS = diagview(ΔSmat)
length(indS) == length(ΔS) || throw(DimensionMismatch(lazy"length of selected S values ($(length(indS))) does not match length of ΔS ($(length(ΔS)))"))
bad_indS = _ind_intersect((r + 1):length(ΔS), indS)
good_indS = _ind_intersect(1:r, indS)
ΔS₁ = zero(S₁)
for (j, i) in enumerate(indS)
if i <= r
ΔS₁[i] = real(ΔS[j])
else
Δgauge = max(Δgauge, abs(ΔS[j]))
end
end
ΔS₁[1:length(good_indS)] .= real.(ΔS[good_indS])
Comment thread
kshyatt marked this conversation as resolved.
badΔS₁ = view(ΔS, bad_indS)
Δgauge = max(Δgauge, maximum(abs, badΔS₁; init = abs(zero(eltype(ΔS)))))
else
ΔS₁ = nothing
end
Expand Down
7 changes: 7 additions & 0 deletions test/mooncake/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,11 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
if T ∈ BLASFloats && CUDA.functional()
TestSuite.test_mooncake_svd(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
if m == n
AT = Diagonal{T, CuVector{T}}
TestSuite.test_mooncake_svd(AT, m; atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
end
2 changes: 1 addition & 1 deletion test/testsuite/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function test_enzyme_svd_trunc(
end
@testset "trunctol" begin
S = svd_vals(A, alg)
trunc = trunctol(atol = S[1] / 2)
trunc = trunctol(atol = maximum(S) / 2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hooray for maximum.

truncalg = TruncatedAlgorithm(alg, trunc)
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
Expand Down
2 changes: 1 addition & 1 deletion test/testsuite/mooncake/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ function test_mooncake_svd_trunc(

@testset "trunctol" begin
S = svd_vals(A)
trunc = trunctol(atol = S[1] / 2)
trunc = trunctol(atol = maximum(S) / 2)
alg_trunc = TruncatedAlgorithm(alg, trunc)

USVᴴ, USVᴴtrunc, ΔUSVᴴ_arrays, ΔUSVᴴtrunc_arrays = ad_svd_trunc_setup(A, alg_trunc)
Expand Down
Loading