@@ -261,5 +261,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
261
261
262
262
commit! (cmdbuf)
263
263
264
+ wait_completed (cmdbuf)
265
+
266
+ return B
267
+ end
268
+
269
+
270
+ function LinearAlgebra.:(\ )(A:: LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
271
+ C = deepcopy (B)
272
+ LinearAlgebra. ldiv! (A, C)
273
+ return C
274
+ end
275
+
276
+
277
+ function LinearAlgebra. ldiv! (A:: LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
278
+ M, N = size (B, 1 ), size (B, 2 )
279
+ dev = current_device ()
280
+ queue = global_queue (dev)
281
+
282
+ At = similar (A. factors)
283
+ Bt = similar (B, (N, M))
284
+ P = reshape ((A. ipiv .- UInt32 (1 )), (1 , M))
285
+ X = similar (B, (N, M))
286
+
287
+ transpose! (At, A. factors)
288
+ transpose! (Bt, B)
289
+
290
+ mps_a = MPSMatrix (At)
291
+ mps_b = MPSMatrix (Bt)
292
+ mps_p = MPSMatrix (P)
293
+ mps_x = MPSMatrix (X)
294
+
295
+ MTLCommandBuffer (queue) do cmdbuf
296
+ kernel = MPSMatrixSolveLU (dev, false , M, N)
297
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
298
+ end
299
+
300
+ transpose! (B, X)
301
+ return B
302
+ end
303
+
304
+
305
+ function LinearAlgebra. ldiv! (A:: UpperTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
306
+ M, N = size (B, 1 ), size (B, 2 )
307
+ dev = current_device ()
308
+ queue = global_queue (dev)
309
+
310
+ Ad = MtlMatrix (A' )
311
+ Br = similar (B, (M, M))
312
+ X = similar (Br)
313
+
314
+ transpose! (Br, B)
315
+
316
+ mps_a = MPSMatrix (Ad)
317
+ mps_b = MPSMatrix (Br)
318
+ mps_x = MPSMatrix (X)
319
+
320
+ buf = MTLCommandBuffer (queue) do cmdbuf
321
+ kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , N, M, 1.0 )
322
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
323
+ end
324
+
325
+ wait_completed (buf)
326
+
327
+ copy! (B, X)
328
+ return B
329
+ end
330
+
331
+
332
+ function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
333
+ M, N = size (B, 1 ), size (B, 2 )
334
+ dev = current_device ()
335
+ queue = global_queue (dev)
336
+
337
+ Ad = MtlMatrix (A)
338
+ Br = reshape (B, (M, N))
339
+ X = similar (Br)
340
+
341
+ mps_a = MPSMatrix (Ad)
342
+ mps_b = MPSMatrix (Br)
343
+ mps_x = MPSMatrix (X)
344
+
345
+
346
+ buf = MTLCommandBuffer (queue) do cmdbuf
347
+ kernel = MPSMatrixSolveTriangular (dev, true , false , false , true , M, N, 1.0 )
348
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
349
+ end
350
+
351
+ wait_completed (buf)
352
+
353
+ copy! (Br, X)
354
+ return B
355
+ end
356
+
357
+
358
+ function LinearAlgebra. ldiv! (A:: LowerTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
359
+ M, N = size (B, 1 ), size (B, 2 )
360
+ dev = current_device ()
361
+ queue = global_queue (dev)
362
+
363
+ Ad = MtlMatrix (A)
364
+ Br = reshape (B, (M, N))
365
+ X = similar (Br)
366
+
367
+ mps_a = MPSMatrix (Ad)
368
+ mps_b = MPSMatrix (Br)
369
+ mps_x = MPSMatrix (X)
370
+
371
+
372
+ buf = MTLCommandBuffer (queue) do cmdbuf
373
+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , false , M, N, 1.0 )
374
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
375
+ end
376
+
377
+ wait_completed (buf)
378
+
379
+ copy! (Br, X)
264
380
return B
265
381
end
382
+
383
+
384
+ function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T,<:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
385
+ M, N = size (B, 1 ), size (B, 2 )
386
+ dev = current_device ()
387
+ queue = global_queue (dev)
388
+
389
+ Ad = MtlMatrix (A)
390
+ Br = reshape (B, (M, N))
391
+ X = similar (Br)
392
+
393
+ mps_a = MPSMatrix (Ad)
394
+ mps_b = MPSMatrix (Br)
395
+ mps_x = MPSMatrix (X)
396
+
397
+
398
+ buf = MTLCommandBuffer (queue) do cmdbuf
399
+ kernel = MPSMatrixSolveTriangular (dev, true , true , false , true , M, N, 1.0 )
400
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
401
+ end
402
+
403
+ wait_completed (buf)
404
+
405
+ copy! (Br, X)
406
+ return B
407
+ end
0 commit comments