Skip to content

Commit f5d932a

Browse files
committed
fix other solvers
1 parent ef68e4d commit f5d932a

File tree

1 file changed

+69
-58
lines changed

1 file changed

+69
-58
lines changed

lib/mps/linalg.jl

Lines changed: 69 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,16 @@ end
197197

198198

199199
function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
200-
orig = size(B)
201-
M,N = size(B)[1], ndims(B) > 1 ? size(B)[2] : 1
200+
M,N = size(B,1), size(B,2)
202201
dev = current_device()
203202
queue = global_queue(dev)
204203

205-
B = reshape(B, (N,M))
204+
Bt = reshape(B, (N,M))
206205
P = reshape((A.ipiv .- UInt32(1)), (1,M))
207206
X = similar(B)
208207

209208
mps_a = MPSMatrix(A.factors)
210-
mps_b = MPSMatrix(B)
209+
mps_b = MPSMatrix(Bt)
211210
mps_p = MPSMatrix(P)
212211
mps_x = MPSMatrix(X)
213212

@@ -216,86 +215,98 @@ function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::M
216215
encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x)
217216
end
218217

219-
B .= X
220-
B = reshape(B, orig)
218+
Bt .= X
219+
return B
221220
end
222221

223-
function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
224-
M,N = size(B)
222+
223+
function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
224+
M,N = size(B,1), size(B,2)
225225
dev = current_device()
226226
queue = global_queue(dev)
227-
cmdbuf = MTLCommandBuffer(queue)
228-
enqueue!(cmdbuf)
229227

230-
Bh = reshape(B, )
231-
X = MtlMatrix{T}(undef, size(B))
228+
Ad = MtlMatrix(A; storage=Private)
229+
Bt = reshape(B, (N,M))
230+
X = similar(B)
232231

233-
mps_a = MPSMatrix(A)
234-
mps_b = MPSMatrix(Bh) # TODO reshape to matrix if B is a vector
232+
mps_a = MPSMatrix(Ad)
233+
mps_b = MPSMatrix(Bt)
235234
mps_x = MPSMatrix(X)
236235

237-
solve_kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
238-
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
239-
commit!(cmdbuf)
236+
MTLCommandBuffer(queue) do cmdbuf
237+
kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0)
238+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
239+
end
240240

241-
return X
241+
Bt .= X
242+
return B
242243
end
243244

244-
function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
245-
M,N = size(B)
245+
246+
function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
247+
M,N = size(B,1), size(B,2)
246248
dev = current_device()
247249
queue = global_queue(dev)
248-
cmdbuf = MTLCommandBuffer(queue)
249-
enqueue!(cmdbuf)
250250

251-
X = MtlMatrix{T}(undef, size(B))
251+
Ad = MtlMatrix(A; storage=Private)
252+
Bt = reshape(B, (N,M))
253+
X = similar(B)
252254

253-
mps_a = MPSMatrix(A)
254-
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
255+
mps_a = MPSMatrix(Ad)
256+
mps_b = MPSMatrix(Bt)
255257
mps_x = MPSMatrix(X)
256258

257-
solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
258-
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
259-
commit!(cmdbuf)
259+
MTLCommandBuffer(queue) do cmdbuf
260+
kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0)
261+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
262+
end
260263

261-
return X
264+
Bt .= X
265+
return B
262266
end
263267

264-
function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T}
265-
M,N = size(B)
268+
269+
function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
270+
M,N = size(B,1), size(B,2)
266271
dev = current_device()
267272
queue = global_queue(dev)
268-
cmdbuf = MTLCommandBuffer(queue)
269-
enqueue!(cmdbuf)
270273

271-
X = MtlMatrix{T}(undef, size(B))
274+
Ad = MtlMatrix(A; storage=Private)
275+
Bt = reshape(B, (N,M))
276+
X = similar(B)
272277

273-
mps_a = MPSMatrix(A)
274-
mps_b = MPSMatrix(B) # TODO reshape to matrix if B is a vector
278+
mps_a = MPSMatrix(Ad)
279+
mps_b = MPSMatrix(Bt)
275280
mps_x = MPSMatrix(X)
276281

277-
solve_kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
278-
encode!(cmdbuf, solve_kernel, mps_a, mps_b, mps_x)
279-
commit!(cmdbuf)
282+
MTLCommandBuffer(queue) do cmdbuf
283+
kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0)
284+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
285+
end
280286

281-
return X
287+
Bt .= X
288+
return B
282289
end
283290

284-
# function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
285-
# require_one_based_indexing(A, B)
286-
# m, n = size(A)
287-
# if m == n
288-
# if istril(A)
289-
# if istriu(A)
290-
# return Diagonal(A) \ B
291-
# else
292-
# return LowerTriangular(A) \ B
293-
# end
294-
# end
295-
# if istriu(A)
296-
# return UpperTriangular(A) \ B
297-
# end
298-
# return lu(A) \ B
299-
# end
300-
# return qr(A, ColumnNorm()) \ B
301-
# end
291+
292+
function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrMat{T}) where {T<:MtlFloat}
293+
M,N = size(B,1), size(B,2)
294+
dev = current_device()
295+
queue = global_queue(dev)
296+
297+
A = MtlMatrix(A; storage=Private)
298+
Bt = reshape(B, (N,M))
299+
X = similar(B)
300+
301+
mps_a = MPSMatrix(A)
302+
mps_b = MPSMatrix(Bt)
303+
mps_x = MPSMatrix(X)
304+
305+
MTLCommandBuffer(queue) do cmdbuf
306+
kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0)
307+
encode!(cmdbuf, kernel, mps_a, mps_b, mps_x)
308+
end
309+
310+
Bt .= X
311+
return B
312+
end

0 commit comments

Comments
 (0)