Skip to content

Commit 69aa51e

Browse files
authored
MPSMatrix improvements (#157)
* Add batched MPSMatrix * MPSMatrix from SubArray
1 parent 19f3df1 commit 69aa51e

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

lib/mps/matrix.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,17 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
1818

1919
export MPSMatrixDescriptor
2020

21-
@objcwrapper MPSMatrixDescriptor <: NSObject
21+
@objcwrapper immutable=false MPSMatrixDescriptor <: NSObject
22+
23+
@objcproperties MPSMatrixDescriptor begin
24+
@autoproperty rows::NSUInteger setter=setRows
25+
@autoproperty columns::NSUInteger setter=setColumns
26+
@autoproperty matrices::NSUInteger
27+
@autoproperty dataType::MPSDataType setter=setDataType
28+
@autoproperty rowBytes::NSUInteger setter=setRowBytes
29+
@autoproperty matrixBytes::NSUInteger
30+
end
31+
2232

2333
# Mapping from Julia types to the Performance Shader bitfields
2434
const jl_typ_to_mps = Dict{DataType,MPSDataType}(
@@ -49,6 +59,17 @@ function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
4959
return obj
5060
end
5161

62+
function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dataType)
63+
desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger
64+
columns:columns::NSUInteger
65+
matrices:matrices::NSUInteger
66+
rowBytes:rowBytes::NSUInteger
67+
matrixBytes:matrixBytes::NSUInteger
68+
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
69+
obj = MPSMatrixDescriptor(desc)
70+
# XXX: who releases this object?
71+
return obj
72+
end
5273

5374
#
5475
# matrix object
@@ -58,6 +79,19 @@ export MPSMatrix
5879

5980
@objcwrapper immutable=false MPSMatrix <: NSObject
6081

82+
@objcproperties MPSMatrix begin
83+
@autoproperty device::id{MTLDevice}
84+
@autoproperty rows::NSUInteger
85+
@autoproperty columns::NSUInteger
86+
@autoproperty matrices::NSUInteger
87+
@autoproperty dataType::MPSDataType
88+
@autoproperty rowBytes::NSUInteger
89+
@autoproperty matrixBytes::NSUInteger
90+
@autoproperty offset::NSUInteger
91+
@autoproperty data::id{MTLBuffer}
92+
end
93+
94+
6195
"""
6296
MPSMatrix(arr::MtlMatrix)
6397
@@ -71,13 +105,37 @@ function MPSMatrix(arr::MtlMatrix{T}) where T
71105
desc = MPSMatrixDescriptor(n_rows, n_cols, sizeof(T)*n_cols, T)
72106
mat = @objc [MPSMatrix alloc]::id{MPSMatrix}
73107
obj = MPSMatrix(mat)
108+
offset = arr.offset * sizeof(T)
74109
finalizer(release, obj)
75110
@objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer}
111+
offset:offset::NSUInteger
76112
descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix}
77113
return obj
78114
end
79115

80116

117+
"""
118+
MPSMatrix(arr::MtlArray{T,3})
119+
120+
Metal batched matrix representation used in Performance Shaders.
121+
122+
Note that this results in a transposed view of the input,
123+
as Metal stores matrices row-major instead of column-major.
124+
"""
125+
function MPSMatrix(arr::MtlArray{T,3}) where T
126+
n_cols, n_rows, n_matrices = size(arr)
127+
row_bytes = sizeof(T)*n_cols
128+
desc = MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, row_bytes * n_rows, T)
129+
mat = @objc [MPSMatrix alloc]::id{MPSMatrix}
130+
obj = MPSMatrix(mat)
131+
offset = arr.offset * sizeof(T)
132+
finalizer(release, obj)
133+
@objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer}
134+
offset:offset::NSUInteger
135+
descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix}
136+
return obj
137+
end
138+
81139
#
82140
# matrix multiplication
83141
#

lib/mps/vector.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ function MPSVector(arr::MtlVector{T}) where T
5353
desc = MPSVectorDescriptor(len, T)
5454
vec = @objc [MPSVector alloc]::id{MPSVector}
5555
obj = MPSVector(vec)
56+
offset = arr.offset * sizeof(T)
5657
finalizer(release, obj)
5758
@objc [obj::id{MPSVector} initWithBuffer:arr::id{MTLBuffer}
59+
offset:offset::NSUInteger
5860
descriptor:desc::id{MPSVectorDescriptor}]::id{MPSVector}
5961
return obj
6062
end

test/mps.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,26 @@ if MPS.is_supported(current_device())
3535
end
3636
end
3737

38+
@testset "test matrix vector multiplication of views" begin
39+
N = 20
40+
a = rand(Float32, N,N)
41+
b = rand(Float32, N)
42+
43+
mtl_a = mtl(a)
44+
mtl_b = mtl(b)
45+
46+
view_a = @view a[:,10:end]
47+
view_b = @view b[10:end]
48+
49+
mtl_view_a = @view mtl_a[:,10:end]
50+
mtl_view_b = @view mtl_b[10:end]
51+
52+
mtl_c = mtl_view_a * mtl_view_b
53+
c = view_a * view_b
54+
55+
@test mtl_c == mtl(c)
56+
end
57+
3858
@testset "mixed-precision matrix vector multiplication" begin
3959
N = 10
4060
rows = N

0 commit comments

Comments
 (0)