Skip to content

Commit db7103c

Browse files
authored
Merge pull request #77 from invenia/mz/cr1
Update to ChainRules 1.0
2 parents 795774f + 965528d commit db7103c

File tree

4 files changed

+38
-5
lines changed

4 files changed

+38
-5
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.20"
4+
version = "0.1.21"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -10,8 +10,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111

1212
[compat]
13-
ChainRulesCore = "0.9.44, 0.10"
14-
ChainRulesTestUtils = "0.7.10"
13+
ChainRulesCore = "1"
14+
ChainRulesTestUtils = "1"
1515
FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12"
1616
FiniteDifferences = "0.12.3"
1717
julia = "1"

src/BlockDiagonals.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using FillArrays: Zeros
66
using FiniteDifferences
77
using LinearAlgebra
88

9+
import ChainRulesCore.ProjectTo
10+
911
export BlockDiagonal, blocks
1012
export blocksize, blocksizes, nblocks
1113

src/chainrules.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# constructor
22
_BlockDiagonal_pullback::Tangent) = (NoTangent(), Δ.blocks)
33
_BlockDiagonal_pullback::AbstractThunk) = _BlockDiagonal_pullback(unthunk(Δ))
4+
_BlockDiagonal_pullback::BlockDiagonal) = (NoTangent(), Δ.blocks)
45
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V}
56
return BlockDiagonal(blocks), _BlockDiagonal_pullback
67
end
@@ -48,14 +49,34 @@ function ChainRulesCore.rrule(
4849
block_rows = row_idxs[i]:(row_idxs[i] + nrows[i] - 1)
4950
block_cols = col_idxs[i]:(col_idxs[i] + ncols[i] - 1)
5051
return InplaceableThunk(
52+
-> mul!(X̄, ȳ[block_rows], v[block_cols]', true, true),
5153
@thunk(ȳ[block_rows] * v[block_cols]'),
52-
-> mul!(X̄, ȳ[block_rows], v[block_cols]', true, true)
5354
)
5455
end
5556

5657
b̄m = Tangent{BlockDiagonal{T, V}}(;blocks=Δblocks)
57-
= InplaceableThunk(@thunk(bm' * ȳ), -> mul!(X̄, bm', ȳ, true, true))
58+
= InplaceableThunk(X̄ -> mul!(X̄, bm', ȳ, true, true), @thunk(bm' *))
5859
return NoTangent(), b̄m, v̄
5960
end
6061
return y, bm_vector_mul_pullback
6162
end
63+
64+
function ProjectTo(b::BlockDiagonal)
65+
blocks = map(ProjectTo, b.blocks)
66+
return ProjectTo{BlockDiagonal}(; blocks=blocks, blocksizes=blocksizes(b))
67+
end
68+
69+
function (project::ProjectTo{BlockDiagonal})(dx)
70+
# prepare to index into the dense array
71+
nrows = first.(project.blocksizes)
72+
ncols = last.(project.blocksizes)
73+
row_idxs = cumsum(nrows) .- nrows .+ 1
74+
col_idxs = cumsum(ncols) .- ncols .+ 1
75+
# project each block individually
76+
blocks = map(eachindex(nrows)) do i
77+
block_rows = row_idxs[i]:(row_idxs[i] + nrows[i] - 1)
78+
block_cols = col_idxs[i]:(col_idxs[i] + ncols[i] - 1)
79+
project.blocks[i](dx[block_rows, block_cols])
80+
end
81+
return BlockDiagonal(blocks)
82+
end

test/chainrules.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
@testset "BlockDiagonal" begin
33
x = [randn(1, 2), randn(2, 2)]
44
test_rrule(BlockDiagonal, x)
5+
test_rrule(BlockDiagonal, x; output_tangent=Tangent{BlockDiagonal}(;blocks=x))
56
end
67

78
@testset "Matrix" begin
@@ -14,4 +15,13 @@
1415
v = rand(6)
1516
test_rrule(*, D, v)
1617
end
18+
19+
@testset "ProjectTo" begin
20+
bd = BlockDiagonal([ones(2, 2), ones(3, 3)])
21+
project = ProjectTo(bd)
22+
@test project(ones(5, 5)) == bd
23+
@test project(adjoint(ones(5, 5))) == bd
24+
@test project(Diagonal(ones(5))) isa BlockDiagonal
25+
@test project(Diagonal(ones(5))) == Diagonal(ones(5))
26+
end
1727
end

0 commit comments

Comments
 (0)