Skip to content

Commit 83a4fe5

Browse files
committed
ldiv
1 parent 69aa51e commit 83a4fe5

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

lib/mps/linalg.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,5 +261,147 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T}
261261

262262
commit!(cmdbuf)
263263

264+
wait_completed(cmdbuf)
265+
266+
return B
267+
end
268+
269+
270+
function LinearAlgebra.:(\)(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
271+
C = deepcopy(B)
272+
LinearAlgebra.ldiv!(A, C)
273+
return C
274+
end
275+
276+
277+
function LinearAlgebra.ldiv!(A::LU{T,<:MtlMatrix{T},<:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
278+
M, N = size(B, 1), size(B, 2)
279+
dev = current_device()
280+
queue = global_queue(dev)
281+
282+
At = similar(A.factors)
283+
Bt = similar(B, (N, M))
284+
P = reshape((A.ipiv .- UInt32(1)), (1, M))
285+
X = similar(B, (N, M))
286+
287+
transpose!(At, A.factors)
288+
transpose!(Bt, B)
289+
290+
mps_a = MPSMatrix(At)
291+
mps_b = MPSMatrix(Bt)
292+
mps_p = MPSMatrix(P)
293+
mps_x = MPSMatrix(X)
294+
295+
MTLCommandBuffer(queue) do cmdbuf
296+
kernel = MPSMatrixSolveLU(dev, false, M, N)
297+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
298+
end
299+
300+
transpose!(B, X)
301+
return B
302+
end
303+
304+
305+
function LinearAlgebra.ldiv!(A::UpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
306+
M, N = size(B, 1), size(B, 2)
307+
dev = current_device()
308+
queue = global_queue(dev)
309+
310+
Ad = MtlMatrix(A')
311+
Br = similar(B, (M, M))
312+
X = similar(Br)
313+
314+
transpose!(Br, B)
315+
316+
mps_a = MPSMatrix(Ad)
317+
mps_b = MPSMatrix(Br)
318+
mps_x = MPSMatrix(X)
319+
320+
buf = MTLCommandBuffer(queue) do cmdbuf
321+
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0)
322+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
323+
end
324+
325+
wait_completed(buf)
326+
327+
copy!(B, X)
328+
return B
329+
end
330+
331+
332+
function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
333+
M, N = size(B, 1), size(B, 2)
334+
dev = current_device()
335+
queue = global_queue(dev)
336+
337+
Ad = MtlMatrix(A)
338+
Br = reshape(B, (M, N))
339+
X = similar(Br)
340+
341+
mps_a = MPSMatrix(Ad)
342+
mps_b = MPSMatrix(Br)
343+
mps_x = MPSMatrix(X)
344+
345+
346+
buf = MTLCommandBuffer(queue) do cmdbuf
347+
kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0)
348+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
349+
end
350+
351+
wait_completed(buf)
352+
353+
copy!(Br, X)
354+
return B
355+
end
356+
357+
358+
function LinearAlgebra.ldiv!(A::LowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
359+
M, N = size(B, 1), size(B, 2)
360+
dev = current_device()
361+
queue = global_queue(dev)
362+
363+
Ad = MtlMatrix(A)
364+
Br = reshape(B, (M, N))
365+
X = similar(Br)
366+
367+
mps_a = MPSMatrix(Ad)
368+
mps_b = MPSMatrix(Br)
369+
mps_x = MPSMatrix(X)
370+
371+
372+
buf = MTLCommandBuffer(queue) do cmdbuf
373+
kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0)
374+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
375+
end
376+
377+
wait_completed(buf)
378+
379+
copy!(Br, X)
264380
return B
265381
end
382+
383+
384+
function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T,<:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
385+
M, N = size(B, 1), size(B, 2)
386+
dev = current_device()
387+
queue = global_queue(dev)
388+
389+
Ad = MtlMatrix(A)
390+
Br = reshape(B, (M, N))
391+
X = similar(Br)
392+
393+
mps_a = MPSMatrix(Ad)
394+
mps_b = MPSMatrix(Br)
395+
mps_x = MPSMatrix(X)
396+
397+
398+
buf = MTLCommandBuffer(queue) do cmdbuf
399+
kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0)
400+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
401+
end
402+
403+
wait_completed(buf)
404+
405+
copy!(Br, X)
406+
return B
407+
end

0 commit comments

Comments
 (0)