Skip to content

Commit d67c199

Browse files
committed
inv, size, and \ for QR objects
Fixes JuliaArrays#1192
1 parent e23a2f5 commit d67c199

File tree

2 files changed

+168
-1
lines changed

2 files changed

+168
-1
lines changed

src/qr.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import Base: \, size, inv
2+
import LinearAlgebra: ldiv!, checksquare
3+
14
# define our own struct since LinearAlgebra.QR are restricted to Matrix
25
struct QR{Q,R,P}
36
Q::Q
@@ -245,3 +248,83 @@ end
245248
# end
246249
#end
247250

251+
252+
size(F::QR) = size(F.Q)
253+
254+
is_identity_perm(p::AbstractVector{T}) where {T<:Integer} = all(i->i==p[i], first(axes(p)))
255+
256+
ldiv!(x::StaticArray, F::QR, y::AbstractVecOrMat) = (x .= F \ y) # Compatibility. Note that \ already allocates 0 bytes.
257+
258+
259+
function \(F::QR, y::AbstractVecOrMat)
260+
checksquare(F.R)
261+
v = F.Q' * y
262+
263+
x = backsub(F.R, v)
264+
265+
invpivot(x, F.p)
266+
end
267+
268+
269+
@inline function invpivot(x::AbstractVecOrMat, p)
270+
if is_identity_perm(p)
271+
x
272+
else
273+
extra = ntuple(_ -> Colon(), ndims(x) - 1)
274+
getindex(x, invperm(p), extra...)
275+
end
276+
end
277+
278+
279+
# Simple back substitution for an upper–triangular system R*x = y.
280+
function backsub(R::StaticMatrix{r,c,T}, y::AbstractVector{T}) where {r,c,T}
281+
x = MVector{c,T}(undef)
282+
backsub!(x, R, y)
283+
SVector(x)
284+
end
285+
286+
function backsub(R::StaticMatrix{r,c,T}, y::AbstractMatrix{T}) where {r,c,T}
287+
x = MMatrix{c,size(y,2),T}(undef)
288+
for i in 1:size(y,2)
289+
@views backsub!(x[:,i], R, y[:,i])
290+
end
291+
SMatrix(x)
292+
end
293+
294+
@inline function backsub!(x::StaticVector{c}, R::StaticMatrix{r,c,T}, y::AbstractVector{T}) where {r,c,T}
295+
Base.@boundscheck Base.checkbounds(x, Base.OneTo(r))
296+
Base.@boundscheck Base.checkbounds(y, Base.OneTo(r))
297+
298+
@inbounds for i in r:-1:1
299+
s = zero(T)
300+
for j in i+1:c
301+
s += R[i, j] * x[j]
302+
end
303+
x[i] = (y[i] - s) / R[i, i]
304+
end
305+
if r < c
306+
@inbounds for i in r+1:c
307+
x[i] = 0
308+
end
309+
end
310+
x
311+
end
312+
313+
314+
function inv(F::QR)
315+
checksquare(F.Q)
316+
n = checksquare(F.R)
317+
318+
T = eltype(F.R)
319+
320+
# Compute inverse of R via back substitution on each column of the identity.
321+
R_inv_cols = ntuple(j -> begin
322+
# Build the j-th unit vector.
323+
e_j = SVector{n, T}(ntuple(i -> i == j ? one(T) : zero(T), n))
324+
backsub(F.R, e_j)
325+
end, n)
326+
327+
R⁻¹ = hcat(R_inv_cols...)
328+
A⁻¹ = R⁻¹ * F.Q'
329+
invpivot(A⁻¹, F.p)
330+
end

test/qr.jl

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,94 @@ Random.seed!(42)
6969
end
7070
end
7171

72+
7273
@testset "QR method ambiguity" begin
7374
# Issue #931; just test that methods do not throw an ambiguity error when called
7475
A = @SMatrix [1.0 2.0 3.0; 4.0 5.0 6.0]
7576
@test isa(qr(A), StaticArrays.QR)
7677
@test isa(qr(A, Val(true)), StaticArrays.QR)
7778
@test isa(qr(A, Val(false)), StaticArrays.QR)
78-
end
79+
end
80+
81+
82+
@testset "#1192 The following functions are available for the QR objects: inv, size, and \\." begin
83+
@testset "pivot=$pivot" for pivot in [Val(true), Val(false)] #, ColumnNorm()]
84+
y = @SVector rand(5)
85+
Y = @SMatrix rand(5,5)
86+
A = @SMatrix rand(5,5)
87+
A_over = @SMatrix rand(5,6)
88+
A_under = @SMatrix rand(5,4)
89+
90+
F = qr(A, pivot)
91+
F_over = qr(A_over, pivot)
92+
F_under = qr(A_under, pivot)
93+
94+
@testset "size" begin
95+
@test size(A) == (5,5)
96+
@test size(A_over) == (5,6)
97+
@test size(A_under) == (5,4)
98+
end
99+
100+
@testset "square inversion" begin
101+
A_inv = inv(F)
102+
@test inv(F) * A I(5)
103+
@test inv(F) inv(qr(Matrix(A)))
104+
@test_throws DimensionMismatch inv(F_under)
105+
@test_throws DimensionMismatch inv(F_over)
106+
end
107+
108+
@testset "solve linear system" begin
109+
x = Matrix(A) \ Vector(y)
110+
@test x A \ y F \ y F \ Vector(y)
111+
112+
x_under = Matrix(A_under) \ Vector(y)
113+
@test x_under == A_under \ y
114+
@test x_under F_under \ y
115+
@test F_under \ y == F_under \ Vector(y)
116+
117+
x_over = Matrix(A_over) \ Vector(y)
118+
@test x_over A_over \ y
119+
@test A_over * x_over y
120+
121+
@test_throws DimensionMismatch F_over \ y
122+
@test_throws DimensionMismatch qr(Matrix(A_over)) \ y
123+
end
124+
125+
@testset "solve several linear systems" begin
126+
@test F \ Y A \ Y
127+
@test F_under \ Y A_under \ Y
128+
end
129+
130+
@testset "ldiv!" begin
131+
x = @MVector zeros(5)
132+
ldiv!(x, F, y)
133+
@test x A \ y
134+
135+
X = @MMatrix zeros(5,5)
136+
Y = @SMatrix rand(5,5)
137+
ldiv!(X, F, Y)
138+
@test X A \ Y
139+
end
140+
141+
@testset "invperm" begin
142+
x = @SVector [10,15,3,7]
143+
p = @SVector [4,2,1,3]
144+
@test x == x[p][invperm(p)]
145+
@test StaticArrays.is_identity_perm(p[invperm(p)])
146+
@test_throws Union{BoundsError,ArgumentError} invperm(x)
147+
end
148+
149+
@testset "10x faster" begin
150+
time_to_test = @elapsed (function()
151+
y2 = @SVector rand(50)
152+
A2 = @SMatrix rand(50,5)
153+
F2 = qr(A2, pivot)
154+
155+
min_time_to_solve = minimum(@elapsed(A2 \ y2) for _ in 1:1_000)
156+
min_time_to_solve_qr = minimum(@elapsed(F2 \ y2) for _ in 1:1_000)
157+
@test 10min_time_to_solve_qr < min_time_to_solve
158+
end)()
159+
@test time_to_test < 10
160+
end
161+
end
162+
end

0 commit comments

Comments
 (0)