@@ -18,7 +18,17 @@ Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)
18
18
19
19
export MPSMatrixDescriptor
20
20
21
- @objcwrapper MPSMatrixDescriptor <: NSObject
21
+ @objcwrapper immutable= false MPSMatrixDescriptor <: NSObject
22
+
23
+ @objcproperties MPSMatrixDescriptor begin
24
+ @autoproperty rows:: NSUInteger setter= setRows
25
+ @autoproperty columns:: NSUInteger setter= setColumns
26
+ @autoproperty matrices:: NSUInteger
27
+ @autoproperty dataType:: MPSDataType setter= setDataType
28
+ @autoproperty rowBytes:: NSUInteger setter= setRowBytes
29
+ @autoproperty matrixBytes:: NSUInteger
30
+ end
31
+
22
32
23
33
# Mapping from Julia types to the Performance Shader bitfields
24
34
const jl_typ_to_mps = Dict {DataType,MPSDataType} (
@@ -49,6 +59,17 @@ function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
49
59
return obj
50
60
end
51
61
62
+ function MPSMatrixDescriptor (rows, columns, matrices, rowBytes, matrixBytes, dataType)
63
+ desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows: rows:: NSUInteger
64
+ columns: columns:: NSUInteger
65
+ matrices: matrices:: NSUInteger
66
+ rowBytes: rowBytes:: NSUInteger
67
+ matrixBytes: matrixBytes:: NSUInteger
68
+ dataType: jl_typ_to_mps[dataType]:: MPSDataType ]:: id{MPSMatrixDescriptor}
69
+ obj = MPSMatrixDescriptor (desc)
70
+ # XXX : who releases this object?
71
+ return obj
72
+ end
52
73
53
74
#
54
75
# matrix object
@@ -58,6 +79,19 @@ export MPSMatrix
58
79
59
80
@objcwrapper immutable= false MPSMatrix <: NSObject
60
81
82
+ @objcproperties MPSMatrix begin
83
+ @autoproperty device:: id{MTLDevice}
84
+ @autoproperty rows:: NSUInteger
85
+ @autoproperty columns:: NSUInteger
86
+ @autoproperty matrices:: NSUInteger
87
+ @autoproperty dataType:: MPSDataType
88
+ @autoproperty rowBytes:: NSUInteger
89
+ @autoproperty matrixBytes:: NSUInteger
90
+ @autoproperty offset:: NSUInteger
91
+ @autoproperty data:: id{MTLBuffer}
92
+ end
93
+
94
+
61
95
"""
62
96
MPSMatrix(arr::MtlMatrix)
63
97
@@ -71,13 +105,37 @@ function MPSMatrix(arr::MtlMatrix{T}) where T
71
105
desc = MPSMatrixDescriptor (n_rows, n_cols, sizeof (T)* n_cols, T)
72
106
mat = @objc [MPSMatrix alloc]:: id{MPSMatrix}
73
107
obj = MPSMatrix (mat)
108
+ offset = arr. offset * sizeof (T)
74
109
finalizer (release, obj)
75
110
@objc [obj:: id{MPSMatrix} initWithBuffer: arr:: id{MTLBuffer}
111
+ offset: offset:: NSUInteger
76
112
descriptor: desc:: id{MPSMatrixDescriptor} ]:: id{MPSMatrix}
77
113
return obj
78
114
end
79
115
80
116
117
+ """
118
+ MPSMatrix(arr::MtlArray{T,3})
119
+
120
+ Metal batched matrix representation used in Performance Shaders.
121
+
122
+ Note that this results in a transposed view of the input,
123
+ as Metal stores matrices row-major instead of column-major.
124
+ """
125
+ function MPSMatrix (arr:: MtlArray{T,3} ) where T
126
+ n_cols, n_rows, n_matrices = size (arr)
127
+ row_bytes = sizeof (T)* n_cols
128
+ desc = MPSMatrixDescriptor (n_rows, n_cols, n_matrices, row_bytes, row_bytes * n_rows, T)
129
+ mat = @objc [MPSMatrix alloc]:: id{MPSMatrix}
130
+ obj = MPSMatrix (mat)
131
+ offset = arr. offset * sizeof (T)
132
+ finalizer (release, obj)
133
+ @objc [obj:: id{MPSMatrix} initWithBuffer: arr:: id{MTLBuffer}
134
+ offset: offset:: NSUInteger
135
+ descriptor: desc:: id{MPSMatrixDescriptor} ]:: id{MPSMatrix}
136
+ return obj
137
+ end
138
+
81
139
#
82
140
# matrix multiplication
83
141
#
0 commit comments