Skip to content

Commit 11b52d2

Browse files
committed
tmp
1 parent f48de8f commit 11b52d2

File tree

2 files changed

+350
-6
lines changed

2 files changed

+350
-6
lines changed
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
import argparse
4+
5+
import tilelang
6+
import tilelang.language as T
7+
8+
from tilelang.layout import make_metadata_layout
9+
from tilelang.utils.sparse import compress, randn_semi_sparse, arange_semi_sparse
10+
from tilelang.contrib import nvcc
11+
from tilelang.utils.tensor import torch_assert_close, map_torch_type
12+
13+
from triton.testing import do_bench
14+
15+
import torch
16+
17+
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
18+
19+
20+
DEFAULT_CONFIG = { # take best config from autotune script
21+
"4090": {
22+
'float': {
23+
'block_M': 128,
24+
'block_N': 64,
25+
'block_K': 64,
26+
'num_stages': 1,
27+
'thread_num': 128,
28+
'policy': T.GemmWarpPolicy.Square,
29+
'enable_rasterization': True
30+
},
31+
'float16': {
32+
'block_M': 256,
33+
'block_N': 128,
34+
'block_K': 64,
35+
'num_stages': 2,
36+
'thread_num': 128,
37+
'policy': T.GemmWarpPolicy.Square,
38+
'enable_rasterization': True
39+
}
40+
},
41+
"h20": {
42+
'float': {
43+
'block_M': 128,
44+
'block_N': 64,
45+
'block_K': 128,
46+
'num_stages': 3,
47+
'thread_num': 128,
48+
'policy': T.GemmWarpPolicy.Square,
49+
'enable_rasterization': True
50+
},
51+
'float16': {
52+
'block_M': 128,
53+
'block_N': 64,
54+
'block_K': 128,
55+
'num_stages': 3,
56+
'thread_num': 128,
57+
'policy': T.GemmWarpPolicy.Square,
58+
'enable_rasterization': True
59+
}
60+
}
61+
}
62+
63+
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
64+
65+
66+
@tilelang.jit(out_idx=[-1])
67+
def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
68+
enable_rasterization):
69+
e_factor, e_dtype = (16, "int16")
70+
71+
@T.prim_func
72+
def gemm_sp_fp16_custom_compress(
73+
A_sparse: T.Tensor((M, K // 2), 'float16'),
74+
E: T.Tensor((M, K // e_factor), e_dtype),
75+
B: T.Tensor((K, N), 'float16'),
76+
C: T.Tensor((M, N), accum_dtype),
77+
):
78+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
79+
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
80+
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
81+
B_shared = T.alloc_shared((block_K, block_N), 'float16')
82+
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
83+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
84+
85+
T.clear(C_local)
86+
# T.disable_warp_group_reg_alloc()
87+
# T.use_swizzle(panel_size=10, enable=enable_rasterization)
88+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
89+
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
90+
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
91+
T.copy(B[k * block_K, bx * block_N], B_shared)
92+
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
93+
94+
T.copy(C_local, C_shared)
95+
T.copy(C_shared, C[by * block_M, bx * block_N])
96+
97+
return gemm_sp_fp16_custom_compress
98+
99+
def torch_compress(dense):
100+
"""
101+
A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout.
102+
"""
103+
if dense.dim() != 2:
104+
raise RuntimeError(
105+
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
106+
)
107+
108+
m, k = dense.shape
109+
device = dense.device
110+
111+
meta_dtype = torch.int8
112+
if dense.dtype == torch.int8:
113+
meta_dtype = torch.int32
114+
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
115+
meta_dtype = torch.int16
116+
else:
117+
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
118+
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
119+
if quadbits_per_meta_elem not in (4, 8):
120+
raise RuntimeError("Invalid number of elements per meta element calculated")
121+
122+
if meta_dtype == torch.int32:
123+
if m % 16 != 0:
124+
raise RuntimeError(
125+
f"Number of rows of dense matrix {m} must be divisible by 16"
126+
)
127+
else:
128+
if m % 32 != 0:
129+
raise RuntimeError(
130+
f"Number of rows of dense matrix {m} must be divisible by 32"
131+
)
132+
if k % (4 * quadbits_per_meta_elem) != 0:
133+
raise RuntimeError(
134+
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
135+
)
136+
137+
if dense.dtype != torch.float:
138+
ksparse = 4
139+
dense_4 = dense.view(-1, k // ksparse, ksparse)
140+
m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1)
141+
else:
142+
ksparse = 2
143+
dense_2 = dense.view(-1, k // ksparse, ksparse)
144+
m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1)
145+
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
146+
147+
# Encoding quadruples of True/False values as follows:
148+
# [True, True, False, False] -> 0b0100
149+
# [True, False, True, False] -> 0b1000
150+
# [False, True, True, False] -> 0b1001
151+
# [True, False, False, True ] -> 0b1100
152+
# [False, True, False, True ] -> 0b1101
153+
# [False, False, True, True ] -> 0b1110
154+
# Thus, lower two bits in the encoding are index of the True value
155+
# at the lowest index in the quadruple, and the higher two bits in
156+
# the encoding are index of the other True value in the quadruple.
157+
# In case there are less than two True values, than False value or
158+
# values at some index or indices are considered True for the
159+
# encoding. In case there are more than two True values, then the
160+
# excess True value(s) at some indices are considered False for
161+
# the encoding. The exact encodings used for these cases are as
162+
# follows:
163+
# [False, False, False, False] -> 0b1110
164+
# [False, False, False, True ] -> 0b1110
165+
# [False, False, True, False] -> 0b1110
166+
# [False, True, False, False] -> 0b1001
167+
# [False, True, True, True ] -> 0b1101
168+
# [True, False, False, False] -> 0b1000
169+
# [True, False, True, True ] -> 0b1100
170+
# [True, True, False, True ] -> 0b0100
171+
# [True, True, True, False] -> 0b0100
172+
# [True, True, True, True ] -> 0b0100
173+
# These particular encodings are chosen, with the help of Espresso
174+
# logic minimizer software, for the purpose of minimization of
175+
# corresponding Boolean functions, that translate non-zero flags
176+
# into encoding bits. Note also possible choices for the first
177+
# and last of these encodings were limited only to (0b0100,
178+
# 0b1110), in order to produce valid encodings for 1:2 sparsity
179+
# case.
180+
181+
expr0 = m0 & m1
182+
expr1 = ~m0 & m1
183+
expr2 = ~m0 & ~m1
184+
bit0 = expr1
185+
bit1 = expr2
186+
bit2 = expr0 | expr2 | m3
187+
bit3 = expr1 | ~m1
188+
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
189+
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
190+
191+
if dense.dtype != torch.float:
192+
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
193+
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
194+
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
195+
else:
196+
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
197+
198+
meta_4 = idxs0 | (idxs1 << 2)
199+
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
200+
201+
if quadbits_per_meta_elem == 4:
202+
meta = (
203+
meta_n[:, :, 0]
204+
| (meta_n[:, :, 1] << 4)
205+
| (meta_n[:, :, 2] << 8)
206+
| (meta_n[:, :, 3] << 12)
207+
)
208+
elif quadbits_per_meta_elem == 8:
209+
meta = (
210+
meta_n[:, :, 0]
211+
| (meta_n[:, :, 1] << 4)
212+
| (meta_n[:, :, 2] << 8)
213+
| (meta_n[:, :, 3] << 12)
214+
| (meta_n[:, :, 4] << 16)
215+
| (meta_n[:, :, 5] << 20)
216+
| (meta_n[:, :, 6] << 24)
217+
| (meta_n[:, :, 7] << 28)
218+
)
219+
220+
return (sparse, meta)
221+
222+
def decode_2to4_metadata_2x2bit(meta: torch.Tensor, M, K) -> torch.Tensor:
223+
meta = meta.view(-1)
224+
groups_per_meta = 16 // 4 # 4 groups per uint16
225+
out = []
226+
227+
for g in range(groups_per_meta - 1, -1, -1):
228+
group_bits = (meta >> (g * 4)) & 0xF
229+
idx0 = group_bits & 0x3
230+
idx1 = (group_bits >> 2) & 0x3
231+
out.append(torch.stack([idx0, idx1], dim=-1))
232+
return torch.cat(out, dim=0).to(torch.int32).view(M, K // 4, 2)
233+
234+
def layout_mapping(i, j):
235+
return i, j
236+
237+
@tilelang.jit(out_idx=[1, 2], pass_configs={
238+
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
239+
})
240+
def compress_kernel(M, K, block_M, block_K, dtype):
241+
e_factor, e_dtype = ARCH_INFO["8.0"]
242+
e_K = K // e_factor
243+
elem, group = 2, 4
244+
245+
assert M % block_M == 0, "M must be divisible by block_M"
246+
assert K % block_K == 0, "K must be divisible by block_K"
247+
assert K % e_factor == 0, "K must be divisible by e_factor"
248+
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
249+
250+
@T.prim_func
251+
def kernel(
252+
A: T.Tensor((M, K), dtype),
253+
A_sp: T.Tensor((M, K // 2), dtype),
254+
E: T.Tensor((M, e_K), e_dtype),
255+
):
256+
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
257+
A_shared = T.alloc_shared((block_M, block_K), dtype)
258+
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
259+
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
260+
# TODO: alloc_var seems buggy here
261+
non_zero_cnt = T.alloc_local((1, ), dtype="uint8")
262+
non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8")
263+
T.copy(A[bx * block_M, by * block_K], A_shared)
264+
for tm in T.Parallel(block_M):
265+
for g_i in range(0, block_K // group):
266+
a_k = g_i * group
267+
T.clear(non_zero_cnt)
268+
T.clear(non_zero_elt_log_idx)
269+
for i in range(group):
270+
val = A_shared[tm, a_k + i]
271+
if val != 0.0:
272+
non_zero_elt_log_idx[non_zero_cnt[0]] = i
273+
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
274+
non_zero_cnt[0] += 1
275+
# TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main
276+
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
277+
non_zero_elt_log_idx[0] = 0
278+
non_zero_elt_log_idx[1] = 3
279+
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
280+
A_sp_shared[tm, a_k // 2] = 0.0
281+
elif non_zero_cnt[0] == 1:
282+
A_sp_shared[tm, a_k // 2 + 1] = 0
283+
non_zero_elt_log_idx[1] = 3
284+
for i in T.serial(elem):
285+
val = non_zero_elt_log_idx[i]
286+
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
287+
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
288+
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
289+
290+
return kernel
291+
292+
293+
def main():
294+
295+
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
296+
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
297+
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
298+
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
299+
parser.add_argument("--use_torch_sparse", action='store_true', help="Use torch sparse for reference")
300+
parser.add_argument(
301+
"--accum_dtype",
302+
type=str,
303+
default="float",
304+
choices=["float", "float16"],
305+
help="Accumulation datatype")
306+
parser.add_argument("--cfg", type=str, choices=["4090"], required=True)
307+
args = parser.parse_args()
308+
kernel = matmul_sp_fp16_custom_compress(args.m, args.n, args.k, args.accum_dtype,
309+
**DEFAULT_CONFIG[args.cfg][args.accum_dtype])
310+
311+
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
312+
# a = arange_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
313+
# b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
314+
b = torch.eye(args.k, device='cuda', dtype=torch.half)
315+
316+
if args.use_torch_sparse:
317+
a_sparse, e = torch_compress(a)
318+
else:
319+
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16")(a)
320+
print(a, e)
321+
res = decode_2to4_metadata_2x2bit(e, args.m, args.k)
322+
print(res.shape, res.stride())
323+
print(res)
324+
print(f'{res[0, 0]=} {res[0, 1]=}')
325+
exit()
326+
c = kernel(a_sparse, e, b)
327+
328+
ref_c = a @ b
329+
330+
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
331+
torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3)
332+
print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}")
333+
334+
latency = do_bench(lambda: kernel(a_sparse, e, b))
335+
ref_latency = do_bench(lambda: a @ b)
336+
337+
total_flops = 2 * args.m * args.n * args.k
338+
tflops = total_flops / latency / 1e9
339+
ref_tflops = total_flops / ref_latency / 1e9
340+
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
341+
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
342+
343+
344+
if __name__ == "__main__":
345+
main()

examples/gemm_sp/example_gemm_sp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
arch = nvcc.get_target_compute_version()
1616

17-
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
18-
19-
default_config = { # take best config from autotune script
17+
DEFAULT_CONFIG = { # take best config from autotune script
2018
"4090": {
2119
'float': {
2220
'block_M': 128,
@@ -59,6 +57,7 @@
5957
}
6058
}
6159

60+
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
6261

6362
@tilelang.jit(out_idx=[-1])
6463
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
@@ -98,7 +97,7 @@ def gemm_sp_fp16(
9897
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
9998
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
10099
T.copy(B[k * block_K, bx * block_N], B_shared)
101-
T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
100+
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
102101

103102
T.copy(C_local, C_shared)
104103
T.copy(C_shared, C[by * block_M, bx * block_N])
@@ -120,15 +119,15 @@ def main():
120119
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
121120
args = parser.parse_args()
122121
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
123-
**default_config[args.cfg][args.accum_dtype])
122+
**DEFAULT_CONFIG[args.cfg][args.accum_dtype])
124123

125124
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
126125
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
127126

128127
a_sparse, e = compress(
129128
a,
130129
transposed=False,
131-
block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
130+
block_k=DEFAULT_CONFIG[args.cfg][args.accum_dtype]['block_K'],
132131
arch=arch)
133132
c = kernel(a_sparse, e, b)
134133

0 commit comments

Comments
 (0)