@@ -197,17 +197,16 @@ end
197
197
198
198
199
199
function LinearAlgebra. ldiv! (A:: LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
200
- orig = size (B)
201
- M,N = size (B)[1 ], ndims (B) > 1 ? size (B)[2 ] : 1
200
+ M,N = size (B,1 ), size (B,2 )
202
201
dev = current_device ()
203
202
queue = global_queue (dev)
204
203
205
- B = reshape (B, (N,M))
204
+ Bt = reshape (B, (N,M))
206
205
P = reshape ((A. ipiv .- UInt32 (1 )), (1 ,M))
207
206
X = similar (B)
208
207
209
208
mps_a = MPSMatrix (A. factors)
210
- mps_b = MPSMatrix (B )
209
+ mps_b = MPSMatrix (Bt )
211
210
mps_p = MPSMatrix (P)
212
211
mps_x = MPSMatrix (X)
213
212
@@ -216,86 +215,98 @@ function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::M
216
215
encode! (cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
217
216
end
218
217
219
- B .= X
220
- B = reshape (B, orig)
218
+ Bt .= X
219
+ return B
221
220
end
222
221
223
- function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
224
- M,N = size (B)
222
+
223
+ function LinearAlgebra. ldiv! (A:: UpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
224
+ M,N = size (B,1 ), size (B,2 )
225
225
dev = current_device ()
226
226
queue = global_queue (dev)
227
- cmdbuf = MTLCommandBuffer (queue)
228
- enqueue! (cmdbuf)
229
227
230
- Bh = reshape (B, )
231
- X = MtlMatrix {T} (undef, size (B))
228
+ Ad = MtlMatrix (A; storage= Private)
229
+ Bt = reshape (B, (N,M))
230
+ X = similar (B)
232
231
233
- mps_a = MPSMatrix (A )
234
- mps_b = MPSMatrix (Bh) # TODO reshape to matrix if B is a vector
232
+ mps_a = MPSMatrix (Ad )
233
+ mps_b = MPSMatrix (Bt)
235
234
mps_x = MPSMatrix (X)
236
235
237
- solve_kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
238
- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
239
- commit! (cmdbuf)
236
+ MTLCommandBuffer (queue) do cmdbuf
237
+ kernel = MPSMatrixSolveTriangular (dev, false , false , false , false , M, N, 1.0 )
238
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
239
+ end
240
240
241
- return X
241
+ Bt .= X
242
+ return B
242
243
end
243
244
244
- function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
245
- M,N = size (B)
245
+
246
+ function LinearAlgebra. ldiv! (A:: UnitUpperTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
247
+ M,N = size (B,1 ), size (B,2 )
246
248
dev = current_device ()
247
249
queue = global_queue (dev)
248
- cmdbuf = MTLCommandBuffer (queue)
249
- enqueue! (cmdbuf)
250
250
251
- X = MtlMatrix {T} (undef, size (B))
251
+ Ad = MtlMatrix (A; storage= Private)
252
+ Bt = reshape (B, (N,M))
253
+ X = similar (B)
252
254
253
- mps_a = MPSMatrix (A )
254
- mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
255
+ mps_a = MPSMatrix (Ad )
256
+ mps_b = MPSMatrix (Bt)
255
257
mps_x = MPSMatrix (X)
256
258
257
- solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
258
- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
259
- commit! (cmdbuf)
259
+ MTLCommandBuffer (queue) do cmdbuf
260
+ kernel = MPSMatrixSolveTriangular (dev, false , false , false , true , M, N, 1.0 )
261
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
262
+ end
260
263
261
- return X
264
+ Bt .= X
265
+ return B
262
266
end
263
267
264
- function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T}
265
- M,N = size (B)
268
+
269
+ function LinearAlgebra. ldiv! (A:: LowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
270
+ M,N = size (B,1 ), size (B,2 )
266
271
dev = current_device ()
267
272
queue = global_queue (dev)
268
- cmdbuf = MTLCommandBuffer (queue)
269
- enqueue! (cmdbuf)
270
273
271
- X = MtlMatrix {T} (undef, size (B))
274
+ Ad = MtlMatrix (A; storage= Private)
275
+ Bt = reshape (B, (N,M))
276
+ X = similar (B)
272
277
273
- mps_a = MPSMatrix (A )
274
- mps_b = MPSMatrix (B) # TODO reshape to matrix if B is a vector
278
+ mps_a = MPSMatrix (Ad )
279
+ mps_b = MPSMatrix (Bt)
275
280
mps_x = MPSMatrix (X)
276
281
277
- solve_kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
278
- encode! (cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
279
- commit! (cmdbuf)
282
+ MTLCommandBuffer (queue) do cmdbuf
283
+ kernel = MPSMatrixSolveTriangular (dev, false , true , false , false , M, N, 1.0 )
284
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
285
+ end
280
286
281
- return X
287
+ Bt .= X
288
+ return B
282
289
end
283
290
284
- # function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
285
- # require_one_based_indexing(A, B)
286
- # m, n = size(A)
287
- # if m == n
288
- # if istril(A)
289
- # if istriu(A)
290
- # return Diagonal(A) \ B
291
- # else
292
- # return LowerTriangular(A) \ B
293
- # end
294
- # end
295
- # if istriu(A)
296
- # return UpperTriangular(A) \ B
297
- # end
298
- # return lu(A) \ B
299
- # end
300
- # return qr(A, ColumnNorm()) \ B
301
- # end
291
+
292
+ function LinearAlgebra. ldiv! (A:: UnitLowerTriangular{T, <:MtlMatrix{T}} , B:: MtlVecOrMat{T} ) where {T<: MtlFloat }
293
+ M,N = size (B,1 ), size (B,2 )
294
+ dev = current_device ()
295
+ queue = global_queue (dev)
296
+
297
+ A = MtlMatrix (A; storage= Private)
298
+ Bt = reshape (B, (N,M))
299
+ X = similar (B)
300
+
301
+ mps_a = MPSMatrix (A)
302
+ mps_b = MPSMatrix (Bt)
303
+ mps_x = MPSMatrix (X)
304
+
305
+ MTLCommandBuffer (queue) do cmdbuf
306
+ kernel = MPSMatrixSolveTriangular (dev, false , true , false , true , M, N, 1.0 )
307
+ encode! (cmdbuf, kernel, mps_a, mps_b, mps_x)
308
+ end
309
+
310
+ Bt .= X
311
+ return B
312
+ end
0 commit comments