Skip to content

Commit bc803bf

Browse files
committed
add ordinary least square with lapack backend
1 parent 6b49aee commit bc803bf

File tree

8 files changed

+2921
-4457
lines changed

8 files changed

+2921
-4457
lines changed

src/bindings/lapack-cffi.lisp

Lines changed: 2812 additions & 4457 deletions
Large diffs are not rendered by default.

src/extensions/lapack/lapack-bindings.lisp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
matrix-class type (lapack-routine "getrf"))
2929
(generate-lapack-lu-solve-for-type
3030
matrix-class type (lapack-routine "getrs"))
31+
(generate-lapack-lsd-for-type
32+
matrix-class type (lapack-routine "gelsd"))
3133
(generate-lapack-inv-for-type
3234
matrix-class type
3335
(lapack-routine "getrf") (lapack-routine "getri"))

src/extensions/lapack/lapack-generics.lisp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
(magicl:extend-function (magicl:lu-solve lapack-lu-solve :lapack) (lu ipiv b))
1919

20+
(magicl:extend-function (magicl:lsd lapack-lsd :lapack) (a b rcond))
21+
2022
(magicl:extend-function (magicl:inv lapack-inv :lapack) (matrix))
2123

2224
(magicl:extend-function (magicl:csd-blocks csd-blocks-extension :lapack) (matrix p q))

src/extensions/lapack/lapack-templates.lisp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,81 @@
149149
info)
150150
b-tensor)))
151151

152+
(defun generate-lapack-lsd-for-type (class type lstsq-function)
153+
`(defmethod lapack-lsd ((a ,class) (b-tensor ,class) rcond)
154+
(let* ((m (nrows a))
155+
(n (ncols a)))
156+
(policy-cond:with-expectations (> speed safety)
157+
((assertion (cl:= m (nrows b-tensor))))
158+
(let* ((nrhs (ncols b-tensor))
159+
(a (magicl::storage (deep-copy-tensor (if (eql :row-major (layout a)) (transpose a) a))))
160+
(lda m)
161+
(ldb (max m n))
162+
(b (make-array (* ldb nrhs) :element-type ',type))
163+
(s (make-array (min m n) :element-type ',type))
164+
(rcond (or rcond
165+
(* (max m n)
166+
,(cond ((or (eq type 'single-float)
167+
(equal type '(complex single-float)))
168+
'single-float-epsilon)
169+
((or (eq type 'double-float)
170+
(equal type '(complex double-float)))
171+
'double-float-epsilon)
172+
(t (error "Unknown type for lapack-lsd: ~a" type))))))
173+
(rank 0)
174+
(work1 (make-array 1 :element-type ',type))
175+
(work nil)
176+
(lwork -1)
177+
(iwork1 (make-array 1 :element-type '(signed-byte 32)))
178+
(iwork nil)
179+
(info 0))
180+
(setf (subseq b 0)
181+
(magicl::storage
182+
(if (eql :row-major (layout b-tensor)) (transpose b-tensor) b-tensor)))
183+
;; Perform work size query with work of length 1
184+
(,lstsq-function
185+
m
186+
n
187+
nrhs
188+
a
189+
lda
190+
b
191+
ldb
192+
s
193+
rcond
194+
rank
195+
work1
196+
lwork
197+
iwork1
198+
info)
199+
(setf lwork (round (realpart (aref work1 0))))
200+
(setf work (make-array (max 1 lwork) :element-type ',type))
201+
(setf iwork (make-array (max 1 (aref iwork1 0)) :element-type '(signed-byte 32)))
202+
;; Perform actual operation with correct work size
203+
(,lstsq-function
204+
m
205+
n
206+
nrhs
207+
a
208+
lda
209+
b
210+
ldb
211+
s
212+
rcond
213+
rank
214+
work
215+
lwork
216+
iwork
217+
info)
218+
(let* ((sol (magicl::from-storage b (list ldb nrhs) :layout :column-major))
219+
(x (slice sol (list 0 0) (list n nrhs))))
220+
(cond ((> m n)
221+
;; r^2 is given by sum of squares of n+1:m rows in B
222+
(let ((r (slice sol (list n 0) (list m nrhs))))
223+
(values x (row-matrix->vector (@ (ones (list 1 (- m n)) :type ',type) (.* r r))))))
224+
(t
225+
(values x (zeros (list nrhs) :type ',type))))))))))
226+
152227
(defun generate-lapack-inv-for-type (class type lu-function inv-function)
153228
`(defmethod lapack-inv ((a ,class))
154229
(let ((a-tensor (deep-copy-tensor a)))

src/functions.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ gges
1616
gees
1717
getrf
1818
getrs
19+
gelsd
1920
getri
2021
gesvd
2122
geev

src/high-level/matrix.lisp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ So assuming P is a permutation matrix representing IPIV, we have
567567
(define-backend-function lu-solve (lu ipiv b)
568568
"Solve the system AX=B, where A is a square matrix, B is a compatibly shaped matrix, and A has PLU factorization indicated by the permutation vector IPIV and lower & upper triangular portions of the argument LU.")
569569

570+
(define-backend-function lsd (a b &optional rcond)
571+
"Solve the linear least square problem: argmin_X ||B-AX||^2 using the singular value decomposition (SVD) of A. A is a M-by-N matrix, and B is a M-by-NRHS right hand side matrix. Return two values: the solution as a N-by-NRHS solution matrix, and the sum of squared residuals as a NRHS vector. The effective rank of A is determined by treating as zero those singular values which are less than RCOND times the largest singular value.")
572+
570573
(define-extensible-function (csd-blocks csd-blocks-lisp) (matrix p q)
571574
(:documentation "Compute the cosine-sine decomposition of the matrix MATRIX and return the result as blocks. See LISP-CSD-BLOCKS for mathematical details.
572575
@@ -677,3 +680,11 @@ NOTE: If H is not Hermitian, the behavior is undefined.")
677680
:matrix A)))
678681
(let ((rmat (lu-solve lu ipiv bmat)))
679682
(from-storage (storage rmat) (shape b))))))
683+
684+
(defun ols (a b &optional rcond)
685+
"Attempt to solve the ordinary least square problem argmin_X ||B-AX||^2. Returns X and sum of squared residuals."
686+
(cond ((vector-p b)
687+
(multiple-value-bind (x r) (lsd a (vector->column-matrix b) rcond)
688+
(values (column-matrix->vector x) (tref r 0))))
689+
(t
690+
(lsd a b rcond))))

src/packages.lisp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@
173173
#:inv
174174
#:lu
175175
#:lu-solve
176+
#:lsd
176177
#:csd-blocks
177178
#:csd
178179
#:svd
@@ -186,6 +187,7 @@
186187
#:logm
187188
#:expih
188189
#:linear-solve
190+
#:ols
189191

190192
#:polynomial
191193
#:make-polynomial

tests/high-level-tests.lisp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,19 @@
292292

293293
(signals magicl::rank-deficiency-error
294294
(magicl:linear-solve (magicl:ones '(3 3)) (magicl:ones '(3)))))
295+
296+
(deftest test-ols ()
297+
"Check that we can solve ordinary least square."
298+
(let ((A (magicl:from-list '((1d0 1d0) (2d0 1d0) (3d0 1d0) (4d0 1d0)) '(4 2)))
299+
(b (magicl:from-list '(0d0 1d0 5d0 6d0) '(4))))
300+
(multiple-value-bind (x r) (magicl:ols A b)
301+
(is (magicl:= x (magicl:from-list '(2.2d0 -2.5d0) '(2))))
302+
(is (< (abs (- r 1.8d0))) (* 1.0d2 double-float-epsilon)))
303+
304+
(signals error (magicl:ols A (magicl:from-list '(0d0 1d0 5d0) '(3)))))
305+
306+
(let ((A (magicl:from-list '((1d0 1d0) (2d0 1d0) (3d0 1d0) (4d0 1d0)) '(4 2)))
307+
(b (magicl:from-list '((0d0 0d0) (1d0 1d0) (5d0 2d0) (6d0 3d0)) '(4 2))))
308+
(multiple-value-bind (x r) (magicl:ols A b)
309+
(is (magicl:= x (magicl:from-list '((2.2d0 1d0) (-2.5d0 -1d0)) '(2 2))))
310+
(is (magicl:= r (magicl:from-list '(1.8d0 0.0d0) '(2)))))))

0 commit comments

Comments
 (0)