Skip to content

Commit ab81b56

Browse files
authored
Merge pull request #1119 from max-Hawkins/wmma_int
Add Int8 WMMA Support
2 parents 9d69cab + a383dbd commit ab81b56

File tree

2 files changed

+260
-121
lines changed

2 files changed

+260
-121
lines changed

src/device/intrinsics/wmma.jl

Lines changed: 154 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,72 @@ using Core: LLVMPtr
1010

1111
# Maps PTX types to Julia array types
1212
const map_ptx_to_jl_array = Dict(
13+
"u8" => UInt8,
14+
"s8" => Int8,
15+
"s32" => Int32,
1316
"f16" => Float16,
1417
"f32" => Float32
1518
)
1619

1720
# Maps PTX types to Julia fragment types
1821
const map_ptx_to_jl_frag = Dict(
22+
"u8" => UInt32,
23+
"s8" => UInt32,
24+
"s32" => Int32,
1925
"f16" => NTuple{2, VecElement{Float16}},
2026
"f32" => Float32
2127
)
2228

2329
# Maps matrix & PTX types to fragment sizes
2430
const map_frag_sizes = Dict(
25-
"a.f16" => 8,
26-
"b.f16" => 8,
27-
"c.f16" => 4,
28-
"c.f32" => 8,
29-
"d.f16" => 4,
30-
"d.f32" => 8
31+
# A
32+
"a.u8.m16n16k16" => 2,
33+
"a.u8.m8n32k16" => 1,
34+
"a.u8.m32n8k16" => 4,
35+
36+
"a.s8.m16n16k16" => 2,
37+
"a.s8.m8n32k16" => 1,
38+
"a.s8.m32n8k16" => 4,
39+
40+
"a.f16.m16n16k16" => 8,
41+
"a.f16.m8n32k16" => 8,
42+
"a.f16.m32n8k16" => 8,
43+
# B
44+
"b.u8.m16n16k16" => 2,
45+
"b.u8.m8n32k16" => 4,
46+
"b.u8.m32n8k16" => 1,
47+
48+
"b.s8.m16n16k16" => 2,
49+
"b.s8.m8n32k16" => 4,
50+
"b.s8.m32n8k16" => 1,
51+
52+
"b.f16.m16n16k16" => 8,
53+
"b.f16.m8n32k16" => 8,
54+
"b.f16.m32n8k16" => 8,
55+
# C
56+
"c.s32.m16n16k16" => 8,
57+
"c.s32.m8n32k16" => 8,
58+
"c.s32.m32n8k16" => 8,
59+
60+
"c.f16.m16n16k16" => 4,
61+
"c.f16.m8n32k16" => 4,
62+
"c.f16.m32n8k16" => 4,
63+
64+
"c.f32.m16n16k16" => 8,
65+
"c.f32.m8n32k16" => 8,
66+
"c.f32.m32n8k16" => 8,
67+
# D
68+
"d.s32.m16n16k16" => 8,
69+
"d.s32.m8n32k16" => 8,
70+
"d.s32.m32n8k16" => 8,
71+
72+
"d.f16.m16n16k16" => 4,
73+
"d.f16.m8n32k16" => 4,
74+
"d.f16.m32n8k16" => 4,
75+
76+
"d.f32.m16n16k16" => 8,
77+
"d.f32.m8n32k16" => 8,
78+
"d.f32.m32n8k16" => 8,
3179
)
3280

3381
# Maps PTX AS to CUDA.AS
@@ -37,15 +85,41 @@ const map_ptx_as_to_as_ty = Dict(
3785
"global" => AS.Global
3886
)
3987

88+
# Valid WMMA Operation configurations: Shape (M,N,K), Matrix, Element Type
89+
90+
# Half-Precision Floating Point
91+
const ldst_half_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["f16"]
92+
const ldst_half_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["f16", "f32"]
93+
const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f32"], ["f16", "f32"]
94+
# Integer
95+
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
96+
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
97+
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
98+
99+
const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
100+
ldst_int_ab_ops, ldst_int_cd_ops)
101+
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)
102+
103+
# Valid WMMA operation shapes
104+
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
105+
40106
################################################################################
41107
# HELPER FUNCTIONS
42108
################################################################################
43109

110+
# Returns shape information as a string
111+
function get_hl_shape(M, N, K)
112+
if (M, N, K) in valid_shapes
113+
return "m$(M)n$(N)k$(K)"
114+
end
115+
error("Invalid shape for WMMA: (M, N, K) = ($M, $N, $K)")
116+
end
117+
44118
# Returns (Julia array type, Julia fragment type, fragment size)
45-
get_frag_info(matrix, ptx_el_type) = (
119+
get_frag_info(matrix, ptx_el_type, shape) = (
46120
map_ptx_to_jl_array[ptx_el_type],
47121
map_ptx_to_jl_frag[ptx_el_type],
48-
map_frag_sizes["$matrix.$ptx_el_type"]
122+
map_frag_sizes["$matrix.$ptx_el_type.$shape"]
49123
)
50124

51125
get_addrspace_info(addr_space) = convert(Int, map_ptx_as_to_as_ty[addr_space])
@@ -86,27 +160,26 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.load.{matrix}.sync.{layout}.{
86160
# Placeholders
87161
- `{matrix}`: The matrix to load. Can be `a`, `b` or `c`.
88162
- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively.
89-
- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`.
163+
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
90164
- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`.
91-
- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point). Note that `f32` is only valid for the matrix ``C``.
165+
- `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer),
166+
`s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are
167+
`s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
92168
"""
93169
llvm_wmma_load() = error("Cannot call llvm_wmma_load without values for placeholders!")
94170
export llvm_wmma_load
95171

96-
for mat in ["a", "b", "c"],
172+
for ops in all_ldst_ops,
173+
mnk in ops[1],
174+
mat in ops[2],
175+
elem_type in ops[3],
97176
layout in ["col", "row"],
98-
shape in ["m16n16k16"],
99177
addr_space in ["", "shared", "global"],
100-
stride in ["stride"],
101-
elem_type in ["f16", "f32"]
178+
stride in ["stride"]
102179

180+
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
103181
# TODO: Non-stride versions?
104182

105-
# Float32 is only supported for C
106-
if (elem_type == "f32") && (mat != "c")
107-
continue
108-
end
109-
110183
addr_space_int = get_addrspace_info(addr_space)
111184

112185
# Name of the Julia wrapper function
@@ -116,7 +189,7 @@ for mat in ["a", "b", "c"],
116189
llvm_intr = "llvm.nvvm.wmma.$shape.load.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8"
117190

118191
# Determine types + size for this (matrix, elem_type) combination
119-
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type)
192+
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape)
120193

121194
ccall_name = "extern $llvm_intr"
122195

@@ -144,19 +217,28 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.store.d.sync.{layout}.{shape}
144217
145218
# Placeholders
146219
- `{layout}`: The storage layout for the matrix. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively.
147-
- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`.
220+
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
148221
- `{addr_space}`: The address space of `src_addr`. Can be empty (generic addressing), `shared` or `global`.
149-
- `{elem_type}`: The type of each element in the matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point).
222+
- `{elem_type}`: The type of each element in the matrix. For `a` and `b` matrices, valid values are `u8` (byte unsigned integer),
223+
`s8` (byte signed integer), and `f16` (half precision floating point). For `c` and `d` matrices, valid values are
224+
`s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
150225
"""
151226
llvm_wmma_store() = error("Cannot call llvm_wmma_store without values for placeholders!")
152227
export llvm_wmma_store
153228

154-
for mat in ["d"],
155-
layout in ["col", "row"],
156-
shape in ["m16n16k16"],
157-
addr_space in ["", "shared", "global"],
158-
stride in ["stride"],
159-
elem_type in ["f16", "f32"]
229+
for ops in all_ldst_ops,
230+
mnk in ops[1],
231+
mat in ops[2],
232+
elem_type in ops[3],
233+
layout in ["col", "row"],
234+
addr_space in ["", "shared", "global"],
235+
stride in ["stride"]
236+
237+
if mat != "d"
238+
continue
239+
end
240+
241+
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
160242

161243
# TODO: Non-stride versions?
162244

@@ -169,7 +251,7 @@ for mat in ["d"],
169251
llvm_intr = "llvm.nvvm.wmma.$shape.store.$mat.$layout.stride.$elem_type.p$(addr_space_int)i8"
170252

171253
# Determine types + size for this (matrix, elem_type) combination
172-
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type)
254+
arr_ty, frag_ty, sz = get_frag_info(mat, elem_type, shape)
173255

174256
ccall_name = "extern $llvm_intr"
175257
frag_types = ntuple(i -> frag_ty, sz)
@@ -187,9 +269,11 @@ end
187269
# --------------------------
188270

189271
@doc """
190-
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c)
272+
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{d_elem_type}_{c_elem_type}(a, b, c) or
273+
WMMA.llvm_wmma_mma_{a_layout}_{b_layout}_{shape}_{a_elem_type}(a, b, c)
191274
192-
Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`.
275+
For floating point operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{d_elem_type}.{c_elem_type}`
276+
For all other operations: wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout}.{shape}.{a_elem_type}`
193277
194278
# Arguments
195279
- `a`: The WMMA fragment corresponding to the matrix ``A``.
@@ -199,9 +283,10 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout
199283
# Placeholders
200284
- `{a_layout}`: The storage layout for matrix ``A``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
201285
- `{b_layout}`: The storage layout for matrix ``B``. Can be `row` or `col`, for row major (C style) or column major (Julia style), respectively. Note that this must match the layout used in the load operation.
202-
- `{shape}`: The overall shape of the MAC operation. The only valid value is `m16n16k16`.
203-
- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point).
204-
- `{c_elem_type}`: The type of each element in the ``C`` matrix. Can be `f16` (half precision floating point) or `f32` (full precision floating point).
286+
- `{shape}`: The overall shape of the MAC operation. Valid values are `m16n16k16`, `m32n8k16`, and `m8n32k16`.
287+
- `{a_elem_type}`: The type of each element in the ``A`` matrix. Valid values are `u8` (byte unsigned integer), `s8` (byte signed integer), and `f16` (half precision floating point).
288+
- `{d_elem_type}`: The type of each element in the resultant ``D`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
289+
- `{c_elem_type}`: The type of each element in the ``C`` matrix. Valid values are `s32` (32-bit signed integer), `f16` (half precision floating point), and `f32` (full precision floating point).
205290
206291
!!! warning
207292
@@ -211,25 +296,34 @@ Wrapper around the LLVM intrinsic `@llvm.nvvm.wmma.mma.sync.{a_layout}.{b_layout
211296
llvm_wmma_mma() = error("Cannot call llvm_wmma_mma without values for placeholders!")
212297
export llvm_wmma_mma
213298

214-
for a_layout in ["col", "row"],
299+
for ops in all_wmma_ops,
300+
a_layout in ["col", "row"],
215301
b_layout in ["col", "row"],
216-
shape in ["m16n16k16"],
217-
d_elem_type in ["f16", "f32"],
218-
c_elem_type in ["f16", "f32"],
219-
b_elem_type in ["f16"],
220-
a_elem_type in ["f16"]
302+
mnk in ops[1],
303+
d_elem_type in ops[4],
304+
c_elem_type in ops[3],
305+
b_elem_type in ops[2]
221306

222-
# Name of the Julia wrapper function
223-
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
307+
a_elem_type = b_elem_type
308+
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
224309

225310
# Name of the LLVM intrinsic
226-
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
311+
# If integer/sub-byte/bit A/B types, name is determined by A/B types
312+
if d_elem_type == "s32"
313+
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type"
314+
# Name of the Julia wrapper function
315+
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))
316+
else # Name defined by D/C types
317+
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
318+
# Name of the Julia wrapper function
319+
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
320+
end
227321

228322
# Determine types + size for the (matrix, elem_type) combinations for matrix A, B, C and D
229-
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type)
230-
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type)
231-
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type)
232-
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type)
323+
a_arr_ty, a_frag_ty, a_sz = get_frag_info("a", a_elem_type, shape)
324+
b_arr_ty, b_frag_ty, b_sz = get_frag_info("b", b_elem_type, shape)
325+
c_arr_ty, c_frag_ty, c_sz = get_frag_info("c", c_elem_type, shape)
326+
d_arr_ty, d_frag_ty, d_sz = get_frag_info("d", d_elem_type, shape)
233327

234328
ccall_name = "extern $llvm_intr"
235329

@@ -439,17 +533,9 @@ function get_hl_layout(L)
439533
end
440534
end
441535

442-
function get_hl_shape(M, N, K)
443-
if (M, N, K) != (16, 16, 16)
444-
error("Invalid shape for WMMA: (M, N, K) = ($M, $N, $K)")
445-
end
446-
447-
return "m$(M)n$(N)k$(K)"
448-
end
449-
450536
get_hl_mat_use(mat) = map_matrix_to_use[mat]
451537

452-
function get_hl_frag_info(matrix, T)
538+
function get_hl_frag_info(matrix, T, shape)
453539
ptx_ty = nothing
454540

455541
try
@@ -460,7 +546,7 @@ function get_hl_frag_info(matrix, T)
460546

461547
try
462548
return (map_num_elems[(matrix, T)],
463-
map_frag_sizes["$matrix.$ptx_ty"],
549+
map_frag_sizes["$matrix.$ptx_ty.$shape"],
464550
map_ptx_to_jl_frag[ptx_ty],
465551
ptx_ty)
466552
catch
@@ -507,7 +593,7 @@ for mat in ["a", "b", "c"]
507593
as_str = get_hl_as_info(AS)
508594
layout = get_hl_layout(L)
509595
shape = get_hl_shape(M, N, K)
510-
num_els, _, _, arr_str = get_hl_frag_info($mat, T)
596+
num_els, _, _, arr_str = get_hl_frag_info($mat, T, shape)
511597
U = get_hl_mat_use($mat)
512598
L_ret = ($mat == "c") ? Unspecified : L
513599

@@ -552,15 +638,17 @@ mma
552638
c::Fragment{M, N, K, C_SZ, C_T, Unspecified, Accumulator},
553639
config::Type{Config{M, N, K, D_T}}) where {M, N, K, A_SZ, A_T, A_L, B_SZ, B_T, B_L, C_SZ, C_T, D_T}
554640

555-
_, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T)
556-
_, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T)
557-
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T)
558-
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T)
559-
560641
a_layout = get_hl_layout(A_L)
561642
b_layout = get_hl_layout(B_L)
562643
shape = get_hl_shape(M, N, K)
563644

645+
_, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape)
646+
_, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape)
647+
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
648+
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape)
649+
650+
651+
564652
# Name of the Julia wrapper
565653
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_"))
566654

@@ -611,7 +699,7 @@ store_d
611699
as_str = get_hl_as_info(AS)
612700
layout = get_hl_layout(L)
613701
shape = get_hl_shape(M, N, K)
614-
num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T)
702+
num_els, frag_sz, frag_ty, arr_str = get_hl_frag_info("d", T, shape)
615703

616704
# Name of the Julia wrapper
617705
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "store", "d", layout, shape, as_str, "stride", arr_str]), "_"))
@@ -648,7 +736,8 @@ fill_c
648736

649737
# We can't use closures in @generated functions, so we'll have to do it this way instead of
650738
# ntuple(i -> val, $num_els)
651-
num_els, _, _ = get_hl_frag_info("c", T)
739+
shape = get_hl_shape(M, N, K)
740+
num_els, _, _ = get_hl_frag_info("c", T, shape)
652741

653742
args = [:value for i=1:num_els]
654743
expr = :(tuple($(args...)))

0 commit comments

Comments
 (0)