11# constructor
2+ _BlockDiagonal_pullback (Δ:: Tangent ) = (NoTangent (), Δ. blocks)
3+ _BlockDiagonal_pullback (Δ:: AbstractThunk ) = _BlockDiagonal_pullback (unthunk (Δ))
24function ChainRulesCore. rrule (:: Type{<:BlockDiagonal} , blocks:: Vector{V} ) where {V}
3- BlockDiagonal_pullback (Δ:: Tangent ) = (NoTangent (), Δ. blocks)
4- return BlockDiagonal (blocks), BlockDiagonal_pullback
5+ return BlockDiagonal (blocks), _BlockDiagonal_pullback
56end
67
78# densification
9+ function _densification_pullback (Ȳ:: Matrix , T, nrows, ncols)
10+ row_idxs = cumsum (nrows) .- nrows .+ 1
11+ col_idxs = cumsum (ncols) .- ncols .+ 1
12+
13+ Δblocks = map (eachindex (nrows)) do n
14+ block_rows = row_idxs[n]: (row_idxs[n] + nrows[n] - 1 )
15+ block_cols = col_idxs[n]: (col_idxs[n] + ncols[n] - 1 )
16+ return Ȳ[block_rows, block_cols]
17+ end
18+ return (NoTangent (), Tangent {T} (blocks= Δblocks))
19+ end
20+ function _densification_pullback (Ȳ:: AbstractThunk , T, nrows, ncols)
21+ return _densification_pullback (unthunk (Ȳ), T, nrows, ncols)
22+ end
823function ChainRulesCore. rrule (:: Type{<:Base.Matrix} , B:: T ) where {T<: BlockDiagonal }
924 nrows = size .(B. blocks, 1 )
1025 ncols = size .(B. blocks, 2 )
11- function Matrix_pullback (Δ:: Matrix )
12- row_idxs = cumsum (nrows) .- nrows .+ 1
13- col_idxs = cumsum (ncols) .- ncols .+ 1
14-
15- Δblocks = map (eachindex (nrows)) do n
16- block_rows = row_idxs[n]: (row_idxs[n] + nrows[n] - 1 )
17- block_cols = col_idxs[n]: (col_idxs[n] + ncols[n] - 1 )
18- return Δ[block_rows, block_cols]
19- end
20- return (NoTangent (), Tangent {T} (blocks= Δblocks))
21- end
22- return Matrix (B), Matrix_pullback
26+ densification_pullback (ȳ) = _densification_pullback (ȳ, T, nrows, ncols)
27+ return Matrix (B), densification_pullback
2328end
2429
2530# multiplication
@@ -37,23 +42,20 @@ function ChainRulesCore.rrule(
3742 row_idxs = cumsum (nrows) .- nrows .+ 1
3843 col_idxs = cumsum (ncols) .- ncols .+ 1
3944
40- function bm_vector_mul_pullback (Δ)
45+ function bm_vector_mul_pullback (Δy)
46+ ȳ = unthunk (Δy)
4147 Δblocks = map (eachindex (nrows)) do i
4248 block_rows = row_idxs[i]: (row_idxs[i] + nrows[i] - 1 )
4349 block_cols = col_idxs[i]: (col_idxs[i] + ncols[i] - 1 )
4450 return InplaceableThunk (
45- @thunk (Δ [block_rows] * v[block_cols]' ),
46- X̄ -> mul! (X̄, Δ [block_rows], v[block_cols]' , true , true )
51+ @thunk (ȳ [block_rows] * v[block_cols]' ),
52+ X̄ -> mul! (X̄, ȳ [block_rows], v[block_cols]' , true , true )
4753 )
4854 end
49- return (
50- NoTangent (),
51- Tangent {BlockDiagonal{T, V}} (;blocks= Δblocks),
52- InplaceableThunk (
53- @thunk (bm' * Δ),
54- X̄ -> mul! (X̄, bm' , Δ, true , true )
55- ),
56- )
55+
56+ b̄m = Tangent {BlockDiagonal{T, V}} (;blocks= Δblocks)
57+ v̄ = InplaceableThunk (@thunk (bm' * ȳ), X̄ -> mul! (X̄, bm' , ȳ, true , true ))
58+ return NoTangent (), b̄m, v̄
5759 end
5860 return y, bm_vector_mul_pullback
5961end
0 commit comments