@@ -22,3 +22,38 @@ function ChainRulesCore.rrule(::Type{<:Base.Matrix}, B::T) where {T<:BlockDiagon
2222 return Matrix (B), Matrix_pullback
2323end
2424
25+ # multiplication
26+ function ChainRulesCore. rrule (
27+ :: typeof (* ),
28+ bm:: BlockDiagonal{T, V} ,
29+ v:: StridedVector{T}
30+ ) where {T<: Union{Real, Complex} , V<: Matrix{T} }
31+
32+ y = bm * v
33+
34+ # needed for computing Δ * v' blockwise
35+ nrows = size .(bm. blocks, 1 )
36+ ncols = size .(bm. blocks, 2 )
37+ row_idxs = cumsum (nrows) .- nrows .+ 1
38+ col_idxs = cumsum (ncols) .- ncols .+ 1
39+
40+ function bm_vector_mul_pullback (Δ)
41+ Δblocks = map (eachindex (nrows)) do i
42+ block_rows = row_idxs[i]: (row_idxs[i] + nrows[i] - 1 )
43+ block_cols = col_idxs[i]: (col_idxs[i] + ncols[i] - 1 )
44+ return InplaceableThunk (
45+ @thunk (Δ[block_rows] * v[block_cols]' ),
46+ X̄ -> mul! (X̄, Δ[block_rows], v[block_cols]' , true , true )
47+ )
48+ end
49+ return (
50+ NO_FIELDS,
51+ Composite {BlockDiagonal{T, V}} (;blocks= Δblocks),
52+ InplaceableThunk (
53+ @thunk (bm' * Δ),
54+ X̄ -> mul! (X̄, bm' , Δ, true , true )
55+ ),
56+ )
57+ end
58+ return y, bm_vector_mul_pullback
59+ end
0 commit comments