Skip to content

Commit 10f2c9f

Browse files
authored
Add fallback for +(::BlockDiagonal,::Diagonal) when blocks are not square (#120)
* Add fallback to `+(::BlockDiagonal,::Diagonal)` for nonsquare blocks * Add tests * Bump version
1 parent 960aa87 commit 10f2c9f

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockDiagonals"
22
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
33
authors = ["Invenia Technical Computing Corporation"]
4-
version = "0.1.37"
4+
version = "0.1.38"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/base_maths.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ function Base.:+(B::BlockDiagonal, M::StridedMatrix)
4242
return A
4343
end
4444

45-
function Base.:+(B::BlockDiagonal, M::Diagonal)::BlockDiagonal
45+
function Base.:+(B::BlockDiagonal, M::Diagonal)
4646
size(B) == size(M) || throw(DimensionMismatch("dimensions must match"))
47+
if !all(is_square, blocks(B))
48+
return Matrix(B) + M # Fallback on the generic Base method
49+
end
4750
A = copy(B)
4851
d = diag(M)
4952
row = 1

test/base_maths.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ using Test
1919
b64 = BlockDiagonal([rand(rng, 2, 2), rand(rng, 2, 2)])
2020
b32 = BlockDiagonal([rand(rng, Float32, 2, 2), rand(rng, Float32, 2, 2)])
2121

22+
bns = BlockDiagonal([rand(rng, N1, N2), rand(rng, N2, N3), rand(rng, N3, N1)])
23+
2224
@testset "Addition" begin
2325
@testset "BlockDiagonal + BlockDiagonal" begin
2426
@test b1 + b1 isa BlockDiagonal
@@ -58,6 +60,15 @@ using Test
5860
@test D + b1 isa BlockDiagonal
5961
@test D + b1 == D + Matrix(b1)
6062
@test_throws DimensionMismatch D′ + b1
63+
64+
# Non-square blocks
65+
@test D + bns isa Matrix
66+
@test D + bns == D + Matrix(bns)
67+
@test_throws DimensionMismatch D′ + bns
68+
69+
@test bns + D isa Matrix
70+
@test bns + D == D + Matrix(bns)
71+
@test_throws DimensionMismatch bns + D′
6172
end
6273

6374
@testset "BlockDiagonal + UniformScaling" begin
@@ -69,11 +80,11 @@ using Test
6980
@test 5I + b1 == 5I + Matrix(b1)
7081
end
7182
end # Addition
72-
83+
7384
@testset "Subtraction" begin
7485
@test -b1 isa BlockDiagonal
7586
@test b1 - b1 isa BlockDiagonal
76-
87+
7788
@test -b1 == -Matrix(b1)
7889
@test b1 - b1 == Matrix(b1) - Matrix(b1)
7990
@test Matrix(b1) - b2 == Matrix(b1) - Matrix(b2)

0 commit comments

Comments
 (0)