Skip to content

Commit 86c57e3

Browse files
mjp98mzgubic
andauthored
Specialise LinearAlgebra.lmul! for LowerTriangular blockdiagonal matrices (#119)
* Extend `lmul!` for `LowerTriangular` * Add tests for `lmul!` for `LowerTriangular` blockdiagonals * Update src/linalg.jl Co-authored-by: Miha Zgubic <[email protected]> * copy all test inputs for safety Co-authored-by: Miha Zgubic <[email protected]> * use `return` keyword Co-authored-by: Miha Zgubic <[email protected]> * Add link to slow sampling issue * Bump version * Add `BenchmarkTools.jl` as a test dependency * Add allocation test for `lmul!` * Up allocation bound to 320 for Julia 1.0 on x64 Co-authored-by: Miha Zgubic <[email protected]>
1 parent 10f2c9f commit 86c57e3

File tree

3 files changed

+44
-10
lines changed

3 files changed

+44
-10
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.38"
4+
version = "0.1.39"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -18,11 +18,12 @@ PDMats = "0.11"
1818
julia = "1"
1919

2020
[extras]
21+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
2122
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2223
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2324
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
2425
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2526
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2627

2728
[targets]
28-
test = ["ChainRulesTestUtils", "Documenter", "PDMats", "Random", "Test"]
29+
test = ["BenchmarkTools", "ChainRulesTestUtils", "Documenter", "PDMats", "Random", "Test"]

src/linalg.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
"""
4646
eigen_blockwise(B::BlockDiagonal, args...; kwargs...) -> values, vectors
4747
48-
Computes the eigendecomposition for each block separately and keeps the block diagonal
48+
Computes the eigendecomposition for each block separately and keeps the block diagonal
4949
structure in the matrix of eigenvectors. Hence any parameters given are applied to each
5050
eigendecomposition separately, but there is e.g. no global sorting of eigenvalues.
5151
"""
@@ -58,16 +58,16 @@ function eigen_blockwise(B::BlockDiagonal, args...; kwargs...)
5858
values = promote([e.values for e in eigens]...)
5959
vectors = promote([e.vectors for e in eigens]...)
6060
vcat(values...), BlockDiagonal([vectors...])
61-
end
61+
end
6262

6363
## This function never keeps the block diagonal structure
6464
function LinearAlgebra.eigen(B::BlockDiagonal, args...; kwargs...)
6565
values, vectors = eigen_blockwise(B, args...; kwargs...)
6666
vectors = Matrix(vectors) # always convert to avoid type instability (also it speeds up the permutation step)
6767
@static if VERSION > v"1.2.0-DEV.275"
6868
Eigen(LinearAlgebra.sorteig!(values, vectors, kwargs...)...)
69-
else
70-
Eigen(values, vectors)
69+
else
70+
Eigen(values, vectors)
7171
end
7272
end
7373

@@ -157,6 +157,21 @@ function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal, α::Number,
157157
return C
158158
end
159159

160+
# Resolves MvNormal slow sampling issue https://github.com/invenia/BlockDiagonals.jl/issues/116
161+
function LinearAlgebra.lmul!(B::LowerTriangular{<:Any,<:BlockDiagonal}, vm::StridedVecOrMat)
162+
# BlockDiagonals with non-square blocks
163+
if !all(is_square, blocks(parent(B)))
164+
return lmul!(LowerTriangular(Matrix(B)), vm) # Fallback on the generic LinearAlgebra method
165+
end
166+
row_i = 1
167+
for block in blocks(parent(B))
168+
nrow = size(block, 1)
169+
@views lmul!(LowerTriangular(block), vm[row_i:(row_i + nrow - 1), :])
170+
row_i += nrow
171+
end
172+
return vm
173+
end
174+
160175
function LinearAlgebra.:\(B::BlockDiagonal, vm::AbstractVecOrMat)
161176
row_i = 1
162177
# BlockDiagonals with non-square blocks

test/linalg.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using BenchmarkTools
12
using BlockDiagonals
23
using BlockDiagonals: svd_blockwise, eigen_blockwise
34
using LinearAlgebra
@@ -78,7 +79,7 @@ end
7879
evals, evecs = eigen(Matrix(B))
7980

8081
@test E isa Eigen
81-
@test Matrix(E) B
82+
@test Matrix(E) B
8283

8384
# There is no test like @test eigen(B) == eigen(Matrix(B))
8485
# 1. this fails in the complex case. Probably a convergence thing that leads to ever so slight differences
@@ -88,7 +89,7 @@ end
8889
@static if VERSION < v"1.2"
8990
# pre-v1.2 we need to sort the eigenvalues consistently
9091
# Since eigenvalues may be complex here, I use this function, which works for this test.
91-
# This test is already somewhat fragile w. r. t. degenerate eigenvalues
92+
# This test is already somewhat fragile w. r. t. degenerate eigenvalues
9293
# and this just makes this a little worse.
9394
perm_bd = sortperm(real.(evals_bd) + 100*imag.(evals_bd))
9495
evals_bd = evals_bd[perm_bd]
@@ -131,7 +132,7 @@ end
131132
E = Eigen(block_vals, blocks(vecs)[i])
132133
evals_bd, evecs_bd = E
133134
evals, evecs = eigen(block)
134-
135+
135136
@test block Matrix(E)
136137

137138
@static if VERSION < v"1.2"
@@ -144,7 +145,7 @@ end
144145
evals = evals[perm]
145146
evecs = evecs[:, perm]
146147
end
147-
148+
148149
@test evals_bd evals
149150
@test all(min.(abs.(evecs_bd - evecs), abs.(evecs_bd + evecs)) .< 1e-13)
150151
end
@@ -245,6 +246,23 @@ end
245246
end
246247
end
247248
end # SVD
249+
@testset "Left multiplication" begin
250+
N1 = 20
251+
N2 = 8
252+
N3 = 5
253+
N4 = N1 + N3 - N2
254+
A = BlockDiagonal([rand(rng, N1, N1), rand(rng, N2, N2)])
255+
B = BlockDiagonal([rand(rng, N1, N2), rand(rng, N3, N4)])
256+
x = rand(rng, N1 + N2)
257+
y = rand(rng, N2 + N4)
258+
259+
@testset "Lower triangular" begin
260+
@test lmul!(LowerTriangular(A), copy(x)) lmul!(LowerTriangular(Matrix(A)), copy(x))
261+
@test lmul!(LowerTriangular(B), copy(y)) lmul!(LowerTriangular(Matrix(B)), copy(y))
262+
cx = copy(x)
263+
@test 320 >= @ballocated lmul!($(LowerTriangular(A)), $cx)
264+
end
265+
end
248266
@testset "Left division" begin
249267
N1 = 20
250268
N2 = 8

0 commit comments

Comments
 (0)