|
| 1 | +# Copyright (c) Tile-AI Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +import argparse |
| 4 | + |
| 5 | +import tilelang |
| 6 | +import tilelang.language as T |
| 7 | + |
| 8 | +from tilelang.layout import make_metadata_layout |
| 9 | +from tilelang.utils.sparse import compress, randn_semi_sparse, arange_semi_sparse |
| 10 | +from tilelang.contrib import nvcc |
| 11 | +from tilelang.utils.tensor import torch_assert_close, map_torch_type |
| 12 | + |
| 13 | +from triton.testing import do_bench |
| 14 | + |
| 15 | +import torch |
| 16 | + |
| 17 | +torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000) |
| 18 | + |
| 19 | + |
| 20 | +DEFAULT_CONFIG = { # take best config from autotune script |
| 21 | + "4090": { |
| 22 | + 'float': { |
| 23 | + 'block_M': 128, |
| 24 | + 'block_N': 64, |
| 25 | + 'block_K': 64, |
| 26 | + 'num_stages': 1, |
| 27 | + 'thread_num': 128, |
| 28 | + 'policy': T.GemmWarpPolicy.Square, |
| 29 | + 'enable_rasterization': True |
| 30 | + }, |
| 31 | + 'float16': { |
| 32 | + 'block_M': 256, |
| 33 | + 'block_N': 128, |
| 34 | + 'block_K': 64, |
| 35 | + 'num_stages': 2, |
| 36 | + 'thread_num': 128, |
| 37 | + 'policy': T.GemmWarpPolicy.Square, |
| 38 | + 'enable_rasterization': True |
| 39 | + } |
| 40 | + }, |
| 41 | + "h20": { |
| 42 | + 'float': { |
| 43 | + 'block_M': 128, |
| 44 | + 'block_N': 64, |
| 45 | + 'block_K': 128, |
| 46 | + 'num_stages': 3, |
| 47 | + 'thread_num': 128, |
| 48 | + 'policy': T.GemmWarpPolicy.Square, |
| 49 | + 'enable_rasterization': True |
| 50 | + }, |
| 51 | + 'float16': { |
| 52 | + 'block_M': 128, |
| 53 | + 'block_N': 64, |
| 54 | + 'block_K': 128, |
| 55 | + 'num_stages': 3, |
| 56 | + 'thread_num': 128, |
| 57 | + 'policy': T.GemmWarpPolicy.Square, |
| 58 | + 'enable_rasterization': True |
| 59 | + } |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} |
| 64 | + |
| 65 | + |
| 66 | +@tilelang.jit(out_idx=[-1]) |
| 67 | +def matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, |
| 68 | + enable_rasterization): |
| 69 | + e_factor, e_dtype = (16, "int16") |
| 70 | + |
| 71 | + @T.prim_func |
| 72 | + def gemm_sp_fp16_custom_compress( |
| 73 | + A_sparse: T.Tensor((M, K // 2), 'float16'), |
| 74 | + E: T.Tensor((M, K // e_factor), e_dtype), |
| 75 | + B: T.Tensor((K, N), 'float16'), |
| 76 | + C: T.Tensor((M, N), accum_dtype), |
| 77 | + ): |
| 78 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): |
| 79 | + A_shared = T.alloc_shared((block_M, block_K // 2), 'float16') |
| 80 | + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) |
| 81 | + B_shared = T.alloc_shared((block_K, block_N), 'float16') |
| 82 | + C_shared = T.alloc_shared((block_M, block_N), accum_dtype) |
| 83 | + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 84 | + |
| 85 | + T.clear(C_local) |
| 86 | + # T.disable_warp_group_reg_alloc() |
| 87 | + # T.use_swizzle(panel_size=10, enable=enable_rasterization) |
| 88 | + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): |
| 89 | + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) |
| 90 | + T.copy(E[by * block_M, k * block_K // e_factor], E_shared) |
| 91 | + T.copy(B[k * block_K, bx * block_N], B_shared) |
| 92 | + T.gemm_sp_v2(A_shared, E_shared, B_shared, C_local, False, False, policy=policy) |
| 93 | + |
| 94 | + T.copy(C_local, C_shared) |
| 95 | + T.copy(C_shared, C[by * block_M, bx * block_N]) |
| 96 | + |
| 97 | + return gemm_sp_fp16_custom_compress |
| 98 | + |
| 99 | +def torch_compress(dense): |
| 100 | + """ |
| 101 | + A naive compression function, where each 4-bit meta matches 4 elements in original matrix in row major layout. |
| 102 | + """ |
| 103 | + if dense.dim() != 2: |
| 104 | + raise RuntimeError( |
| 105 | + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" |
| 106 | + ) |
| 107 | + |
| 108 | + m, k = dense.shape |
| 109 | + device = dense.device |
| 110 | + |
| 111 | + meta_dtype = torch.int8 |
| 112 | + if dense.dtype == torch.int8: |
| 113 | + meta_dtype = torch.int32 |
| 114 | + elif dense.dtype in [torch.half, torch.bfloat16, torch.float]: |
| 115 | + meta_dtype = torch.int16 |
| 116 | + else: |
| 117 | + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") |
| 118 | + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 |
| 119 | + if quadbits_per_meta_elem not in (4, 8): |
| 120 | + raise RuntimeError("Invalid number of elements per meta element calculated") |
| 121 | + |
| 122 | + if meta_dtype == torch.int32: |
| 123 | + if m % 16 != 0: |
| 124 | + raise RuntimeError( |
| 125 | + f"Number of rows of dense matrix {m} must be divisible by 16" |
| 126 | + ) |
| 127 | + else: |
| 128 | + if m % 32 != 0: |
| 129 | + raise RuntimeError( |
| 130 | + f"Number of rows of dense matrix {m} must be divisible by 32" |
| 131 | + ) |
| 132 | + if k % (4 * quadbits_per_meta_elem) != 0: |
| 133 | + raise RuntimeError( |
| 134 | + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" |
| 135 | + ) |
| 136 | + |
| 137 | + if dense.dtype != torch.float: |
| 138 | + ksparse = 4 |
| 139 | + dense_4 = dense.view(-1, k // ksparse, ksparse) |
| 140 | + m0, m1, _m2, m3 = (dense_4 != 0).unbind(-1) |
| 141 | + else: |
| 142 | + ksparse = 2 |
| 143 | + dense_2 = dense.view(-1, k // ksparse, ksparse) |
| 144 | + m0, _m2 = m1, m3 = (dense_2 != 0).unbind(-1) |
| 145 | + meta_ncols = k // (ksparse * quadbits_per_meta_elem) |
| 146 | + |
| 147 | + # Encoding quadruples of True/False values as follows: |
| 148 | + # [True, True, False, False] -> 0b0100 |
| 149 | + # [True, False, True, False] -> 0b1000 |
| 150 | + # [False, True, True, False] -> 0b1001 |
| 151 | + # [True, False, False, True ] -> 0b1100 |
| 152 | + # [False, True, False, True ] -> 0b1101 |
| 153 | + # [False, False, True, True ] -> 0b1110 |
| 154 | + # Thus, lower two bits in the encoding are index of the True value |
| 155 | + # at the lowest index in the quadruple, and the higher two bits in |
| 156 | + # the encoding are index of the other True value in the quadruple. |
| 157 | + # In case there are less than two True values, than False value or |
| 158 | + # values at some index or indices are considered True for the |
| 159 | + # encoding. In case there are more than two True values, then the |
| 160 | + # excess True value(s) at some indices are considered False for |
| 161 | + # the encoding. The exact encodings used for these cases are as |
| 162 | + # follows: |
| 163 | + # [False, False, False, False] -> 0b1110 |
| 164 | + # [False, False, False, True ] -> 0b1110 |
| 165 | + # [False, False, True, False] -> 0b1110 |
| 166 | + # [False, True, False, False] -> 0b1001 |
| 167 | + # [False, True, True, True ] -> 0b1101 |
| 168 | + # [True, False, False, False] -> 0b1000 |
| 169 | + # [True, False, True, True ] -> 0b1100 |
| 170 | + # [True, True, False, True ] -> 0b0100 |
| 171 | + # [True, True, True, False] -> 0b0100 |
| 172 | + # [True, True, True, True ] -> 0b0100 |
| 173 | + # These particular encodings are chosen, with the help of Espresso |
| 174 | + # logic minimizer software, for the purpose of minimization of |
| 175 | + # corresponding Boolean functions, that translate non-zero flags |
| 176 | + # into encoding bits. Note also possible choices for the first |
| 177 | + # and last of these encodings were limited only to (0b0100, |
| 178 | + # 0b1110), in order to produce valid encodings for 1:2 sparsity |
| 179 | + # case. |
| 180 | + |
| 181 | + expr0 = m0 & m1 |
| 182 | + expr1 = ~m0 & m1 |
| 183 | + expr2 = ~m0 & ~m1 |
| 184 | + bit0 = expr1 |
| 185 | + bit1 = expr2 |
| 186 | + bit2 = expr0 | expr2 | m3 |
| 187 | + bit3 = expr1 | ~m1 |
| 188 | + idxs0 = bit0 | (bit1.to(torch.int64) << 1) |
| 189 | + idxs1 = bit2 | (bit3.to(torch.int64) << 1) |
| 190 | + |
| 191 | + if dense.dtype != torch.float: |
| 192 | + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] |
| 193 | + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) |
| 194 | + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) |
| 195 | + else: |
| 196 | + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] |
| 197 | + |
| 198 | + meta_4 = idxs0 | (idxs1 << 2) |
| 199 | + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) |
| 200 | + |
| 201 | + if quadbits_per_meta_elem == 4: |
| 202 | + meta = ( |
| 203 | + meta_n[:, :, 0] |
| 204 | + | (meta_n[:, :, 1] << 4) |
| 205 | + | (meta_n[:, :, 2] << 8) |
| 206 | + | (meta_n[:, :, 3] << 12) |
| 207 | + ) |
| 208 | + elif quadbits_per_meta_elem == 8: |
| 209 | + meta = ( |
| 210 | + meta_n[:, :, 0] |
| 211 | + | (meta_n[:, :, 1] << 4) |
| 212 | + | (meta_n[:, :, 2] << 8) |
| 213 | + | (meta_n[:, :, 3] << 12) |
| 214 | + | (meta_n[:, :, 4] << 16) |
| 215 | + | (meta_n[:, :, 5] << 20) |
| 216 | + | (meta_n[:, :, 6] << 24) |
| 217 | + | (meta_n[:, :, 7] << 28) |
| 218 | + ) |
| 219 | + |
| 220 | + return (sparse, meta) |
| 221 | + |
| 222 | +def decode_2to4_metadata_2x2bit(meta: torch.Tensor, M, K) -> torch.Tensor: |
| 223 | + meta = meta.view(-1) |
| 224 | + groups_per_meta = 16 // 4 # 4 groups per uint16 |
| 225 | + out = [] |
| 226 | + |
| 227 | + for g in range(groups_per_meta - 1, -1, -1): |
| 228 | + group_bits = (meta >> (g * 4)) & 0xF |
| 229 | + idx0 = group_bits & 0x3 |
| 230 | + idx1 = (group_bits >> 2) & 0x3 |
| 231 | + out.append(torch.stack([idx0, idx1], dim=-1)) |
| 232 | + return torch.cat(out, dim=0).to(torch.int32).view(M, K // 4, 2) |
| 233 | + |
| 234 | +def layout_mapping(i, j): |
| 235 | + return i, j |
| 236 | + |
| 237 | +@tilelang.jit(out_idx=[1, 2], pass_configs={ |
| 238 | + tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True, |
| 239 | +}) |
| 240 | +def compress_kernel(M, K, block_M, block_K, dtype): |
| 241 | + e_factor, e_dtype = ARCH_INFO["8.0"] |
| 242 | + e_K = K // e_factor |
| 243 | + elem, group = 2, 4 |
| 244 | + |
| 245 | + assert M % block_M == 0, "M must be divisible by block_M" |
| 246 | + assert K % block_K == 0, "K must be divisible by block_K" |
| 247 | + assert K % e_factor == 0, "K must be divisible by e_factor" |
| 248 | + assert block_K % e_factor == 0, "block_K must be divisible by e_factor" |
| 249 | + |
| 250 | + @T.prim_func |
| 251 | + def kernel( |
| 252 | + A: T.Tensor((M, K), dtype), |
| 253 | + A_sp: T.Tensor((M, K // 2), dtype), |
| 254 | + E: T.Tensor((M, e_K), e_dtype), |
| 255 | + ): |
| 256 | + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by): |
| 257 | + A_shared = T.alloc_shared((block_M, block_K), dtype) |
| 258 | + A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype) |
| 259 | + E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) |
| 260 | + # TODO: alloc_var seems buggy here |
| 261 | + non_zero_cnt = T.alloc_local((1, ), dtype="uint8") |
| 262 | + non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8") |
| 263 | + T.copy(A[bx * block_M, by * block_K], A_shared) |
| 264 | + for tm in T.Parallel(block_M): |
| 265 | + for g_i in range(0, block_K // group): |
| 266 | + a_k = g_i * group |
| 267 | + T.clear(non_zero_cnt) |
| 268 | + T.clear(non_zero_elt_log_idx) |
| 269 | + for i in range(group): |
| 270 | + val = A_shared[tm, a_k + i] |
| 271 | + if val != 0.0: |
| 272 | + non_zero_elt_log_idx[non_zero_cnt[0]] = i |
| 273 | + A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val |
| 274 | + non_zero_cnt[0] += 1 |
| 275 | + # TODO: use T.device_assert(non_zero_cnt <= 2) after rebasing main |
| 276 | + if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3: |
| 277 | + non_zero_elt_log_idx[0] = 0 |
| 278 | + non_zero_elt_log_idx[1] = 3 |
| 279 | + A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2] |
| 280 | + A_sp_shared[tm, a_k // 2] = 0.0 |
| 281 | + elif non_zero_cnt[0] == 1: |
| 282 | + A_sp_shared[tm, a_k // 2 + 1] = 0 |
| 283 | + non_zero_elt_log_idx[1] = 3 |
| 284 | + for i in T.serial(elem): |
| 285 | + val = non_zero_elt_log_idx[i] |
| 286 | + E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i) |
| 287 | + T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2]) |
| 288 | + T.copy(E_shared, E[bx * block_M, by * block_K // e_factor]) |
| 289 | + |
| 290 | + return kernel |
| 291 | + |
| 292 | + |
| 293 | +def main(): |
| 294 | + |
| 295 | + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") |
| 296 | + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") |
| 297 | + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") |
| 298 | + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") |
| 299 | + parser.add_argument("--use_torch_sparse", action='store_true', help="Use torch sparse for reference") |
| 300 | + parser.add_argument( |
| 301 | + "--accum_dtype", |
| 302 | + type=str, |
| 303 | + default="float", |
| 304 | + choices=["float", "float16"], |
| 305 | + help="Accumulation datatype") |
| 306 | + parser.add_argument("--cfg", type=str, choices=["4090"], required=True) |
| 307 | + args = parser.parse_args() |
| 308 | + kernel = matmul_sp_fp16_custom_compress(args.m, args.n, args.k, args.accum_dtype, |
| 309 | + **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) |
| 310 | + |
| 311 | + a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) |
| 312 | + # a = arange_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half) |
| 313 | + # b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half) |
| 314 | + b = torch.eye(args.k, device='cuda', dtype=torch.half) |
| 315 | + |
| 316 | + if args.use_torch_sparse: |
| 317 | + a_sparse, e = torch_compress(a) |
| 318 | + else: |
| 319 | + a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16")(a) |
| 320 | + print(a, e) |
| 321 | + res = decode_2to4_metadata_2x2bit(e, args.m, args.k) |
| 322 | + print(res.shape, res.stride()) |
| 323 | + print(res) |
| 324 | + print(f'{res[0, 0]=} {res[0, 1]=}') |
| 325 | + exit() |
| 326 | + c = kernel(a_sparse, e, b) |
| 327 | + |
| 328 | + ref_c = a @ b |
| 329 | + |
| 330 | + assert not c.isnan().any(), "Reference result contains NaNs, please report an issue" |
| 331 | + torch_assert_close(c, ref_c.to(c.dtype), rtol=1e-3, atol=1e-3) |
| 332 | + print(f"Precision check passed. Max diff: {(c - ref_c).abs().max()}, Mean diff: {(c - ref_c).abs().mean()}") |
| 333 | + |
| 334 | + latency = do_bench(lambda: kernel(a_sparse, e, b)) |
| 335 | + ref_latency = do_bench(lambda: a @ b) |
| 336 | + |
| 337 | + total_flops = 2 * args.m * args.n * args.k |
| 338 | + tflops = total_flops / latency / 1e9 |
| 339 | + ref_tflops = total_flops / ref_latency / 1e9 |
| 340 | + print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s") |
| 341 | + print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s") |
| 342 | + |
| 343 | + |
| 344 | +if __name__ == "__main__": |
| 345 | + main() |
0 commit comments