|
1 | 1 | # constructor |
2 | 2 | _BlockDiagonal_pullback(Δ::Tangent) = (NoTangent(), Δ.blocks) |
3 | 3 | _BlockDiagonal_pullback(Δ::AbstractThunk) = _BlockDiagonal_pullback(unthunk(Δ)) |
| 4 | +_BlockDiagonal_pullback(Δ::BlockDiagonal) = (NoTangent(), Δ.blocks) |
4 | 5 | function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V} |
5 | 6 | return BlockDiagonal(blocks), _BlockDiagonal_pullback |
6 | 7 | end |
@@ -48,14 +49,34 @@ function ChainRulesCore.rrule( |
48 | 49 | block_rows = row_idxs[i]:(row_idxs[i] + nrows[i] - 1) |
49 | 50 | block_cols = col_idxs[i]:(col_idxs[i] + ncols[i] - 1) |
50 | 51 | return InplaceableThunk( |
| 52 | + X̄ -> mul!(X̄, ȳ[block_rows], v[block_cols]', true, true), |
51 | 53 | @thunk(ȳ[block_rows] * v[block_cols]'), |
52 | | - X̄ -> mul!(X̄, ȳ[block_rows], v[block_cols]', true, true) |
53 | 54 | ) |
54 | 55 | end |
55 | 56 |
|
56 | 57 | b̄m = Tangent{BlockDiagonal{T, V}}(;blocks=Δblocks) |
57 | | - v̄ = InplaceableThunk(@thunk(bm' * ȳ), X̄ -> mul!(X̄, bm', ȳ, true, true)) |
| 58 | + v̄ = InplaceableThunk(X̄ -> mul!(X̄, bm', ȳ, true, true), @thunk(bm' * ȳ)) |
58 | 59 | return NoTangent(), b̄m, v̄ |
59 | 60 | end |
60 | 61 | return y, bm_vector_mul_pullback |
61 | 62 | end |
| 63 | + |
| 64 | +function ProjectTo(b::BlockDiagonal) |
| 65 | + blocks = map(ProjectTo, b.blocks) |
| 66 | + return ProjectTo{BlockDiagonal}(; blocks=blocks, blocksizes=blocksizes(b)) |
| 67 | +end |
| 68 | + |
| 69 | +function (project::ProjectTo{BlockDiagonal})(dx) |
| 70 | + # prepare to index into the dense array |
| 71 | + nrows = first.(project.blocksizes) |
| 72 | + ncols = last.(project.blocksizes) |
| 73 | + row_idxs = cumsum(nrows) .- nrows .+ 1 |
| 74 | + col_idxs = cumsum(ncols) .- ncols .+ 1 |
| 75 | + # project each block individually |
| 76 | + blocks = map(eachindex(nrows)) do i |
| 77 | + block_rows = row_idxs[i]:(row_idxs[i] + nrows[i] - 1) |
| 78 | + block_cols = col_idxs[i]:(col_idxs[i] + ncols[i] - 1) |
| 79 | + project.blocks[i](dx[block_rows, block_cols]) |
| 80 | + end |
| 81 | + return BlockDiagonal(blocks) |
| 82 | +end |
0 commit comments