Skip to content

Commit 5415ea5

Browse files
author
Chris Foster
committed
Fix A\B where A is nonsquare
Fixes #606
1 parent ee01ea1 commit 5415ea5

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

src/solve.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
@inline (\)(a::StaticMatrix, b::StaticVecOrMat) = solve(Size(a), Size(b), a, b)
1+
@inline (\)(a::StaticMatrix, b::StaticVecOrMat) = _solve(Size(a), Size(b), a, b)
22

3-
@inline function solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
3+
@inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
44
@inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1])
55
end
66

7-
@inline function solve(::Size{(2,2)}, ::Size{(2,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
7+
@inline function _solve(::Size{(2,2)}, ::Size{(2,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
88
d = det(a)
99
T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d)
1010
@inbounds return similar_type(b, T)((a[2,2]*b[1] - a[1,2]*b[2])/d,
1111
(a[1,1]*b[2] - a[2,1]*b[1])/d)
1212
end
1313

14-
@inline function solve(::Size{(3,3)}, ::Size{(3,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
14+
@inline function _solve(::Size{(3,3)}, ::Size{(3,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
1515
d = det(a)
1616
T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d)
1717
@inbounds return similar_type(b, T)(
@@ -28,12 +28,12 @@ end
2828

2929
for Sa in [(2,2), (3,3)] # not needed for Sa = (1, 1);
3030
@eval begin
31-
@inline function solve(::Size{$Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {Sb, Ta, Tb}
31+
@inline function _solve(::Size{$Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {Sb, Ta, Tb}
3232
d = det(a)
3333
T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d)
3434
c = similar(b, T)
3535
for col = 1:Sb[2]
36-
@inbounds c[:, col] = solve(Size($Sa), Size($Sa[1],), a, b[:, col])
36+
@inbounds c[:, col] = _solve(Size($Sa), Size($Sa[1],), a, b[:, col])
3737
end
3838
return similar_type(b, T)(c)
3939
end
@@ -42,22 +42,28 @@ end
4242

4343

4444

45-
@generated function solve(::Size{Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb}
46-
if Sa[end] != Sb[1]
47-
throw(DimensionMismatch("right hand side B needs first dimension of size $(Sa[end]), has size $Sb"))
45+
@generated function _solve(::Size{Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb}
46+
if Sa[1] != Sb[1]
47+
return quote
48+
throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)"))
49+
end
4850
end
49-
LinearAlgebra.checksquare(a)
50-
if prod(Sa) 14*14
51+
if prod(Sa) 14*14 && Sa[1] == Sa[2]
52+
# TODO: Consider triangular special cases as in Base?
5153
quote
5254
@_inline_meta
5355
LUp = lu(a)
5456
LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p])))
5557
end
58+
# TODO: Could also use static QR here if `a` is nonsquare.
59+
# Requires that we implement \(::StaticArrays.QR,::StaticVecOrMat)
5660
else
61+
# Fall back to LinearAlgebra, but carry across the statically known size.
62+
outsize = length(Sb) == 1 ? Size(Sa[2]) : Size(Sa[2],Sb[end])
5763
quote
5864
@_inline_meta
5965
T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/one(Ta))
60-
similar_type(b, T)(Matrix(a) \ b)
66+
similar_type(b, T, $outsize)(Matrix(a) \ b)
6167
end
6268
end
6369
end

test/solve.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ end
4141

4242
end
4343

44+
# Solve with non-square left hand sides (#606)
45+
m1 = @SMatrix[0.2 0.3
46+
0.0 0.1
47+
0.5 0.1]
48+
m2 = @SVector[1,2,3]
49+
@test @inferred(m1\m2) Array(m1)\Array(m2)
50+
m2 = @SMatrix[1 4
51+
2 5
52+
3 6]
53+
@test @inferred(m1\m2) Array(m1)\Array(m2)
54+
4455
@testset "Mixed static/dynamic" begin
4556
m2 = @SMatrix([0.2 0.3; 0.0 0.1])
4657
for m1 in (@SMatrix([1.0 0; 0 1.0]), @SMatrix([1.0 0; 1.0 1.0]),

0 commit comments

Comments
 (0)