Skip to content

Commit f0ec36c

Browse files
authored
Test accepting thunks passed to pullbacks (#73)
1 parent 4cf486b commit f0ec36c

File tree

3 files changed

+31
-27
lines changed

3 files changed

+31
-27
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.18"
4+
version = "0.1.19"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111

1212
[compat]
1313
ChainRulesCore = "0.9.44, 0.10"
14-
ChainRulesTestUtils = "0.6.3, 0.7"
14+
ChainRulesTestUtils = "0.7.10"
1515
FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10, 0.11"
1616
FiniteDifferences = "0.12.3"
1717
julia = "1"

src/chainrules.jl

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
# constructor
2+
_BlockDiagonal_pullback::Tangent) = (NoTangent(), Δ.blocks)
3+
_BlockDiagonal_pullback::AbstractThunk) = _BlockDiagonal_pullback(unthunk(Δ))
24
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V}
3-
BlockDiagonal_pullback::Tangent) = (NoTangent(), Δ.blocks)
4-
return BlockDiagonal(blocks), BlockDiagonal_pullback
5+
return BlockDiagonal(blocks), _BlockDiagonal_pullback
56
end
67

78
# densification
9+
function _densification_pullback(Ȳ::Matrix, T, nrows, ncols)
10+
row_idxs = cumsum(nrows) .- nrows .+ 1
11+
col_idxs = cumsum(ncols) .- ncols .+ 1
12+
13+
Δblocks = map(eachindex(nrows)) do n
14+
block_rows = row_idxs[n]:(row_idxs[n] + nrows[n] - 1)
15+
block_cols = col_idxs[n]:(col_idxs[n] + ncols[n] - 1)
16+
return Ȳ[block_rows, block_cols]
17+
end
18+
return (NoTangent(), Tangent{T}(blocks=Δblocks))
19+
end
20+
function _densification_pullback(Ȳ::AbstractThunk, T, nrows, ncols)
21+
return _densification_pullback(unthunk(Ȳ), T, nrows, ncols)
22+
end
823
function ChainRulesCore.rrule(::Type{<:Base.Matrix}, B::T) where {T<:BlockDiagonal}
924
nrows = size.(B.blocks, 1)
1025
ncols = size.(B.blocks, 2)
11-
function Matrix_pullback::Matrix)
12-
row_idxs = cumsum(nrows) .- nrows .+ 1
13-
col_idxs = cumsum(ncols) .- ncols .+ 1
14-
15-
Δblocks = map(eachindex(nrows)) do n
16-
block_rows = row_idxs[n]:(row_idxs[n] + nrows[n] - 1)
17-
block_cols = col_idxs[n]:(col_idxs[n] + ncols[n] - 1)
18-
return Δ[block_rows, block_cols]
19-
end
20-
return (NoTangent(), Tangent{T}(blocks=Δblocks))
21-
end
22-
return Matrix(B), Matrix_pullback
26+
densification_pullback(ȳ) = _densification_pullback(ȳ, T, nrows, ncols)
27+
return Matrix(B), densification_pullback
2328
end
2429

2530
# multiplication
@@ -37,23 +42,20 @@ function ChainRulesCore.rrule(
3742
row_idxs = cumsum(nrows) .- nrows .+ 1
3843
col_idxs = cumsum(ncols) .- ncols .+ 1
3944

40-
function bm_vector_mul_pullback(Δ)
45+
function bm_vector_mul_pullback(Δy)
46+
= unthunk(Δy)
4147
Δblocks = map(eachindex(nrows)) do i
4248
block_rows = row_idxs[i]:(row_idxs[i] + nrows[i] - 1)
4349
block_cols = col_idxs[i]:(col_idxs[i] + ncols[i] - 1)
4450
return InplaceableThunk(
45-
@thunk(Δ[block_rows] * v[block_cols]'),
46-
-> mul!(X̄, Δ[block_rows], v[block_cols]', true, true)
51+
@thunk([block_rows] * v[block_cols]'),
52+
-> mul!(X̄, [block_rows], v[block_cols]', true, true)
4753
)
4854
end
49-
return (
50-
NoTangent(),
51-
Tangent{BlockDiagonal{T, V}}(;blocks=Δblocks),
52-
InplaceableThunk(
53-
@thunk(bm' * Δ),
54-
-> mul!(X̄, bm', Δ, true, true)
55-
),
56-
)
55+
56+
b̄m = Tangent{BlockDiagonal{T, V}}(;blocks=Δblocks)
57+
= InplaceableThunk(@thunk(bm' * ȳ), X̄ -> mul!(X̄, bm', ȳ, true, true))
58+
return NoTangent(), b̄m, v̄
5759
end
5860
return y, bm_vector_mul_pullback
5961
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using FiniteDifferences # For overloading to_vec
66
using Test
77
using LinearAlgebra
88

9+
push!(ChainRulesTestUtils.TRANSFORMS_TO_ALT_TANGENTS, x -> @thunk(x))
10+
911
@testset "BlockDiagonals" begin
1012
# The doctests fail version other than 64bit julia 1.6.x, due to printing differences
1113
Sys.WORD_SIZE == 64 && v"1.6" <= VERSION < v"1.7" && doctest(BlockDiagonals)

0 commit comments

Comments
 (0)