|
149 | 149 | info)
|
150 | 150 | b-tensor)))
|
151 | 151 |
|
| 152 | +(defun generate-lapack-lsd-for-type (class type lstsq-function) |
| 153 | + `(defmethod lapack-lsd ((a ,class) (b-tensor ,class) rcond) |
| 154 | + (let* ((m (nrows a)) |
| 155 | + (n (ncols a))) |
| 156 | + (policy-cond:with-expectations (> speed safety) |
| 157 | + ((assertion (cl:= m (nrows b-tensor)))) |
| 158 | + (let* ((nrhs (ncols b-tensor)) |
| 159 | + (a (magicl::storage (deep-copy-tensor (if (eql :row-major (layout a)) (transpose a) a)))) |
| 160 | + (lda m) |
| 161 | + (ldb (max m n)) |
| 162 | + (b (make-array (* ldb nrhs) :element-type ',type)) |
| 163 | + (s (make-array (min m n) :element-type ',type)) |
| 164 | + (rcond (or rcond |
| 165 | + (* (max m n) |
| 166 | + ,(cond ((or (eq type 'single-float) |
| 167 | + (equal type '(complex single-float))) |
| 168 | + 'single-float-epsilon) |
| 169 | + ((or (eq type 'double-float) |
| 170 | + (equal type '(complex double-float))) |
| 171 | + 'double-float-epsilon) |
| 172 | + (t (error "Unknown type for lapack-lsd: ~a" type)))))) |
| 173 | + (rank 0) |
| 174 | + (work1 (make-array 1 :element-type ',type)) |
| 175 | + (work nil) |
| 176 | + (lwork -1) |
| 177 | + (iwork1 (make-array 1 :element-type '(signed-byte 32))) |
| 178 | + (iwork nil) |
| 179 | + (info 0)) |
| 180 | + (setf (subseq b 0) |
| 181 | + (magicl::storage |
| 182 | + (if (eql :row-major (layout b-tensor)) (transpose b-tensor) b-tensor))) |
| 183 | + ;; Perform work size query with work of length 1 |
| 184 | + (,lstsq-function |
| 185 | + m |
| 186 | + n |
| 187 | + nrhs |
| 188 | + a |
| 189 | + lda |
| 190 | + b |
| 191 | + ldb |
| 192 | + s |
| 193 | + rcond |
| 194 | + rank |
| 195 | + work1 |
| 196 | + lwork |
| 197 | + iwork1 |
| 198 | + info) |
| 199 | + (setf lwork (round (realpart (aref work1 0)))) |
| 200 | + (setf work (make-array (max 1 lwork) :element-type ',type)) |
| 201 | + (setf iwork (make-array (max 1 (aref iwork1 0)) :element-type '(signed-byte 32))) |
| 202 | + ;; Perform actual operation with correct work size |
| 203 | + (,lstsq-function |
| 204 | + m |
| 205 | + n |
| 206 | + nrhs |
| 207 | + a |
| 208 | + lda |
| 209 | + b |
| 210 | + ldb |
| 211 | + s |
| 212 | + rcond |
| 213 | + rank |
| 214 | + work |
| 215 | + lwork |
| 216 | + iwork |
| 217 | + info) |
| 218 | + (let* ((sol (magicl::from-storage b (list ldb nrhs) :layout :column-major)) |
| 219 | + (x (slice sol (list 0 0) (list n nrhs)))) |
| 220 | + (cond ((> m n) |
| 221 | + ;; r^2 is given by sum of squares of n+1:m rows in B |
| 222 | + (let ((r (slice sol (list n 0) (list m nrhs)))) |
| 223 | + (values x (row-matrix->vector (@ (ones (list 1 (- m n)) :type ',type) (.* r r)))))) |
| 224 | + (t |
| 225 | + (values x (zeros (list nrhs) :type ',type)))))))))) |
| 226 | + |
152 | 227 | (defun generate-lapack-inv-for-type (class type lu-function inv-function)
|
153 | 228 | `(defmethod lapack-inv ((a ,class))
|
154 | 229 | (let ((a-tensor (deep-copy-tensor a)))
|
|
0 commit comments