Skip to content

Make svd_pullback! GPU-compatible#232

Merged
Jutho merged 12 commits into
mainfrom
ksh/cusvd
May 20, 2026
Merged

Make svd_pullback! GPU-compatible#232
Jutho merged 12 commits into
mainfrom
ksh/cusvd

Conversation

@kshyatt
Copy link
Copy Markdown
Member

@kshyatt kshyatt commented May 18, 2026

Should fix #228

We need to collect the inds for the truncated tests because otherwise we trigger scalar indexing :(

The solution for the pullback itself is a bit rough so if someone has a nicer list comprehension, feel free to suggest!

@kshyatt kshyatt requested a review from Jutho May 18, 2026 13:48
Comment thread test/testsuite/mooncake/svd.jl Outdated
Copy link
Copy Markdown
Member

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit much to force a copy of ind across the board, even in non-GPU context. Is there any way we can only selectively do this, or find a different approach?
I'm also slightly confused by the collect, I thought we did our best to keep everything on GPU precisely to allow the GPU kernels to do their work, and having ind::Vector{Int} combined with values::CuVector was bad?

Comment thread test/testsuite/mooncake/svd.jl Outdated
Comment thread src/pullbacks/svd.jl Outdated
Comment on lines +93 to +96
good_indS = findall(i -> i <= r, indS)
bad_indS = setdiff(1:length(indS), good_indS)
ΔS₁[indS[good_indS]] = real.(ΔS[good_indS])
Δgauge = max(Δgauge, mapreduce(abs, max, ΔS[bad_indS]))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it help if we re-use the machinery we have for the TruncationIntersection, which I think we already made work, and then something like this? (untested)

bad_indS = _ind_intersect(r+1:length(ΔS), indS)
ΔS₁ = real(ΔS)
badΔS₁ = view(ΔS₁, bad_indS)
Δgauge = max(Δgauge, maximum(abs, badΔS₁))
badΔS₁ .= 0

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll test this now

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ΔS₁ ends up with the wrong length with this, I'll try to fix...

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is much nicer but it's passing again at least 😓

Comment thread ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl Outdated
@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented May 18, 2026

I'm also slightly confused by the collect, I thought we did our best to keep everything on GPU precisely to allow the GPU kernels to do their work, and having ind::Vector{Int} combined with values::CuVector was bad?

I think in general we want to keep the inds on the GPU, but in this case, where we're indexing a CPU array (axes(A, dim)), it's a problem. I should be able to make a pass-through in the GPU extensions that doesn't force the CPU ones to copy.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 18, 2026

Your PR no longer requires formatting changes. Thank you for your contribution!

@codecov
Copy link
Copy Markdown

codecov Bot commented May 18, 2026

Codecov Report

❌ Patch coverage is 82.60870% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...ixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl 0.00% 4 Missing ⚠️
Files with missing lines Coverage Δ
...MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl 76.66% <100.00%> (+2.98%) ⬆️
...gebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl 63.68% <100.00%> (+0.64%) ⬆️
src/pullbacks/svd.jl 93.29% <100.00%> (+0.51%) ⬆️
...ixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl 64.91% <0.00%> (-3.61%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kshyatt
Copy link
Copy Markdown
Member Author

kshyatt commented May 19, 2026

Moved away from view as the broadcast from SubArray to SubArray was causing compilation errors, the current change allows a sample PEPSKit script of @GlebFedorovich to run successfully 🚀

Comment thread ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl Outdated
Comment thread src/pullbacks/svd.jl Outdated
Comment thread ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
@lkdvos lkdvos changed the title Fix and test AD rules for SVD Fix and test AD rules for SVD on GPU May 19, 2026
@lkdvos lkdvos changed the title Fix and test AD rules for SVD on GPU Make svd_pullback! GPU-compatible May 19, 2026
@kshyatt kshyatt enabled auto-merge (squash) May 19, 2026 14:16
Comment thread ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Comment thread ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Comment thread src/pullbacks/svd.jl
Comment thread src/pullbacks/svd.jl Outdated
@testset "trunctol" begin
S = svd_vals(A, alg)
trunc = trunctol(atol = S[1] / 2)
trunc = trunctol(atol = maximum(S) / 2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hooray for maximum.

@Jutho Jutho disabled auto-merge May 19, 2026 16:16
@Jutho
Copy link
Copy Markdown
Member

Jutho commented May 19, 2026

I did have some final suggestions so I disabled auto-merge.

kshyatt and others added 2 commits May 20, 2026 08:31
Co-authored-by: Jutho <Jutho@users.noreply.github.com>
@Jutho
Copy link
Copy Markdown
Member

Jutho commented May 20, 2026

Ok, failures on windows-latest seem unrelated, I'll merge.

@Jutho Jutho merged commit 92f576a into main May 20, 2026
33 of 36 checks passed
@Jutho Jutho deleted the ksh/cusvd branch May 20, 2026 09:12
@lkdvos lkdvos mentioned this pull request May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SVD pullback rules aren't GPU-friendly

3 participants