|
1 |
| -@inline (\)(a::StaticMatrix, b::StaticVecOrMat) = solve(Size(a), Size(b), a, b) |
| 1 | +@inline (\)(a::StaticMatrix, b::StaticVecOrMat) = _solve(Size(a), Size(b), a, b) |
2 | 2 |
|
3 |
| -@inline function solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} |
| 3 | +@inline function _solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} |
4 | 4 | @inbounds return similar_type(b, typeof(a[1] \ b[1]))(a[1] \ b[1])
|
5 | 5 | end
|
6 | 6 |
|
7 |
| -@inline function solve(::Size{(2,2)}, ::Size{(2,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} |
| 7 | +@inline function _solve(::Size{(2,2)}, ::Size{(2,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} |
8 | 8 | d = det(a)
|
9 | 9 | T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d)
|
10 | 10 | @inbounds return similar_type(b, T)((a[2,2]*b[1] - a[1,2]*b[2])/d,
|
11 | 11 | (a[1,1]*b[2] - a[2,1]*b[1])/d)
|
12 | 12 | end
|
13 | 13 |
|
14 |
| -@inline function solve(::Size{(3,3)}, ::Size{(3,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} |
| 14 | +@inline function _solve(::Size{(3,3)}, ::Size{(3,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} |
15 | 15 | d = det(a)
|
16 | 16 | T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d)
|
17 | 17 | @inbounds return similar_type(b, T)(
|
|
28 | 28 |
|
29 | 29 | for Sa in [(2,2), (3,3)] # not needed for Sa = (1, 1);
|
30 | 30 | @eval begin
|
31 |
| - @inline function solve(::Size{$Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {Sb, Ta, Tb} |
| 31 | + @inline function _solve(::Size{$Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {Sb, Ta, Tb} |
32 | 32 | d = det(a)
|
33 | 33 | T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/d)
|
34 | 34 | c = similar(b, T)
|
35 | 35 | for col = 1:Sb[2]
|
36 |
| - @inbounds c[:, col] = solve(Size($Sa), Size($Sa[1],), a, b[:, col]) |
| 36 | + @inbounds c[:, col] = _solve(Size($Sa), Size($Sa[1],), a, b[:, col]) |
37 | 37 | end
|
38 | 38 | return similar_type(b, T)(c)
|
39 | 39 | end
|
|
42 | 42 |
|
43 | 43 |
|
44 | 44 |
|
45 |
| -@generated function solve(::Size{Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb} |
46 |
| - if Sa[end] != Sb[1] |
47 |
| - throw(DimensionMismatch("right hand side B needs first dimension of size $(Sa[end]), has size $Sb")) |
| 45 | +@generated function _solve(::Size{Sa}, ::Size{Sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVecOrMat{Tb}) where {Sa, Sb, Ta, Tb} |
| 46 | + if Sa[1] != Sb[1] |
| 47 | + return quote |
| 48 | + throw(DimensionMismatch("Left and right hand side first dimensions do not match in backdivide (got sizes $Sa and $Sb)")) |
| 49 | + end |
48 | 50 | end
|
49 |
| - LinearAlgebra.checksquare(a) |
50 |
| - if prod(Sa) ≤ 14*14 |
| 51 | + if prod(Sa) ≤ 14*14 && Sa[1] == Sa[2] |
| 52 | + # TODO: Consider triangular special cases as in Base? |
51 | 53 | quote
|
52 | 54 | @_inline_meta
|
53 | 55 | LUp = lu(a)
|
54 | 56 | LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p])))
|
55 | 57 | end
|
| 58 | + # TODO: Could also use static QR here if `a` is nonsquare. |
| 59 | + # Requires that we implement \(::StaticArrays.QR,::StaticVecOrMat) |
56 | 60 | else
|
| 61 | + # Fall back to LinearAlgebra, but carry across the statically known size. |
| 62 | + outsize = length(Sb) == 1 ? Size(Sa[2]) : Size(Sa[2],Sb[end]) |
57 | 63 | quote
|
58 | 64 | @_inline_meta
|
59 | 65 | T = typeof((one(Ta)*zero(Tb) + one(Ta)*zero(Tb))/one(Ta))
|
60 |
| - similar_type(b, T)(Matrix(a) \ b) |
| 66 | + similar_type(b, T, $outsize)(Matrix(a) \ b) |
61 | 67 | end
|
62 | 68 | end
|
63 | 69 | end
|
0 commit comments