Skip to content

Commit d566a49

Browse files
tgymnichchristiangnrd
authored andcommitted
add solvers
1 parent 5440fbe commit d566a49

File tree

4 files changed

+90
-4
lines changed

4 files changed

+90
-4
lines changed

lib/mps/MPS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ include("vector.jl")
3939
include("matrixrandom.jl")
4040
include("ndarray.jl")
4141
include("decomposition.jl")
42+
include("solve.jl")
4243
include("copy.jl")
4344

4445
# integrations
4546
include("random.jl")
4647
include("linalg.jl")
47-
4848
end

lib/mps/libmps.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,11 +535,11 @@ end
535535

536536
@objcwrapper immutable = false MPSMatrixVectorMultiplication <: MPSMatrixBinaryKernel
537537

538-
@objcwrapper immutable = true MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel
538+
@objcwrapper immutable = false MPSMatrixSolveTriangular <: MPSMatrixBinaryKernel
539539

540-
@objcwrapper immutable = true MPSMatrixSolveLU <: MPSMatrixBinaryKernel
540+
@objcwrapper immutable = false MPSMatrixSolveLU <: MPSMatrixBinaryKernel
541541

542-
@objcwrapper immutable = true MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel
542+
@objcwrapper immutable = false MPSMatrixSolveCholesky <: MPSMatrixBinaryKernel
543543

544544
@cenum MPSMatrixDecompositionStatus::Int32 begin
545545
MPSMatrixDecompositionStatusSuccess = 0

lib/mps/solve.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
2+
export MPSMatrixSolveTriangular
3+
4+
# @objcwrapper immutable=false MPSMatrixSolveTriangular <: MPSMatrixUnaryKernel
5+
6+
function MPSMatrixSolveTriangular(device, right, upper, unit, order, numberOfRightHandSides, alpha)
7+
kernel = @objc [MPSMatrixSolveTriangular alloc]::id{MPSMatrixSolveTriangular}
8+
obj = MPSMatrixSolveTriangular(kernel)
9+
finalizer(release, obj)
10+
@objc [obj::id{MPSMatrixSolveTriangular} initWithDevice:device::id{MTLDevice}
11+
right:right::Bool
12+
upper:upper::Bool
13+
transpose:transpose::Bool
14+
unit:unit::Bool
15+
order:order::NSUInteger
16+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger
17+
alpha:alpha::Float64]::id{MPSMatrixSolveTriangular}
18+
return obj
19+
end
20+
21+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveTriangular, sourceMatrix, resultMatrix, pivotIndices, status)
22+
@objc [kernel::id{MPSMatrixSolveTriangular} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
23+
sourceMatrix:sourceMatrix::id{MPSMatrix}
24+
resultMatrix:resultMatrix::id{MPSMatrix}
25+
pivotIndices:pivotIndices::id{MPSMatrix}
26+
status:status::id{MPSMatrix}]::Nothing
27+
end
28+
29+
30+
export MPSMatrixSolveLU
31+
32+
# @objcwrapper immutable=false MPSMatrixSolveLU <: MPSMatrixUnaryKernel
33+
34+
function MPSMatrixSolveLU(device, transpose, order, numberOfRightHandSides)
35+
kernel = @objc [MPSMatrixSolveLU alloc]::id{MPSMatrixSolveLU}
36+
obj = MPSMatrixSolveLU(kernel)
37+
finalizer(release, obj)
38+
@objc [obj::id{MPSMatrixSolveLU} initWithDevice:device::id{MTLDevice}
39+
transpose:transpose::Bool
40+
order:order::NSUInteger
41+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveLU}
42+
return obj
43+
end
44+
45+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveLU, sourceMatrix, rightHandSideMatrix, pivotIndices, solutionMatrix)
46+
@objc [kernel::id{MPSMatrixSolveLU} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
47+
sourceMatrix:sourceMatrix::id{MPSMatrix}
48+
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
49+
pivotIndices:pivotIndices::id{MPSMatrix}
50+
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
51+
end
52+
53+
54+
55+
56+
export MPSMatrixSolveCholesky
57+
58+
# @objcwrapper immutable=false MPSMatrixSolveCholesky <: MPSMatrixUnaryKernel
59+
60+
function MPSMatrixSolveCholesky(device, upper, order, numberOfRightHandSides)
61+
kernel = @objc [MPSMatrixSolveCholesky alloc]::id{MPSMatrixSolveCholesky}
62+
obj = MPSMatrixSolveCholesky(kernel)
63+
finalizer(release, obj)
64+
@objc [obj::id{MPSMatrixSolveCholesky} initWithDevice:device::id{MTLDevice}
65+
upper:upper::Bool
66+
order:order::NSUInteger
67+
numberOfRightHandSides:numberOfRightHandSides::NSUInteger]::id{MPSMatrixSolveCholesky}
68+
return obj
69+
end
70+
71+
function encode!(cmdbuf::MTLCommandBuffer, kernel::MPSMatrixSolveCholesky, sourceMatrix, rightHandSideMatrix, solutionMatrix)
72+
@objc [kernel::id{MPSMatrixSolveCholesky} encodeToCommandBuffer:cmdbuf::id{MTLCommandBuffer}
73+
sourceMatrix:sourceMatrix::id{MPSMatrix}
74+
rightHandSideMatrix:rightHandSideMatrix::id{MPSMatrix}
75+
pivotIndices:pivotIndices::id{MPSMatrix}
76+
solutionMatrix:solutionMatrix::id{MPSMatrix}]::Nothing
77+
end

res/wrap/libmps.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@ immutable=false
102102
[api.MPSMatrixSoftMax]
103103
immutable=false
104104

105+
[api.MPSMatrixSolveTriangular]
106+
immutable=false
107+
108+
[api.MPSMatrixSolveLU]
109+
immutable=false
110+
111+
[api.MPSMatrixSolveCholesky]
112+
immutable=false
113+
105114
[api.MPSMatrixUnaryKernel]
106115
immutable=false
107116

0 commit comments

Comments
 (0)