88from tilelang .intrinsics .mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
99
1010torch .backends .cuda .matmul .allow_tf32 = False
11- torch .set_printoptions (threshold = float ('inf' ), edgeitems = float ('inf' ), linewidth = 10000 )
1211torch .manual_seed (42 )
1312
1413
@@ -46,7 +45,7 @@ def main(
4645 A_shared = T .alloc_shared (A_shared_shape , in_dtype )
4746 B_shared = T .alloc_shared (B_shared_shape , in_dtype )
4847 E_shared = T .alloc_shared ((block_M , block_K // E_factor ), 'uint8' )
49- C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
48+ C_frag = T .alloc_fragment ((block_M , block_N ), accum_dtype )
5049 T .annotate_layout ({
5150 E :
5251 make_metadata_layout (
@@ -60,7 +59,7 @@ def main(
6059 block_k = block_K ),
6160 })
6261 T .disable_warp_group_reg_alloc ()
63- T .clear (C_local )
62+ T .clear (C_frag )
6463 for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = num_stages ):
6564 T .copy (E [by * block_M , k * block_K // E_factor ], E_shared )
6665 if trans_A :
@@ -71,8 +70,8 @@ def main(
7170 T .copy (B [bx * block_N , k * block_K ], B_shared )
7271 else :
7372 T .copy (B [k * block_K , bx * block_N ], B_shared )
74- T .gemm_sp (A_shared , E_shared , B_shared , C_local , trans_A , trans_B )
75- T .copy (C_local , C [by * block_M , bx * block_N ])
73+ T .gemm_sp (A_shared , E_shared , B_shared , C_frag , trans_A , trans_B )
74+ T .copy (C_frag , C [by * block_M , bx * block_N ])
7675
7776 return main
7877
@@ -132,7 +131,7 @@ def main(
132131 T .copy (B [bx * block_N , k * block_K ], B_shared )
133132 else :
134133 T .copy (B [k * block_K , bx * block_N ], B_shared )
135- T .gemm_sp_v2 (A_shared , E_shared , B_shared , C_frag , trans_A , trans_B )
134+ T .gemm_sp (A_shared , E_shared , B_shared , C_frag , trans_A , trans_B )
136135 T .copy (C_frag , C [by * block_M , bx * block_N ])
137136
138137 return main
@@ -349,10 +348,4 @@ def test_gemm_sp_sm80():
349348
350349if __name__ == "__main__" :
351350 tilelang .disable_cache ()
352- # tilelang.testing.main()
353- # run_gemm_sp_sm80(32, 64, 64, "float16", "float32", "float32", 32, 64, 64, 0, 32)
354- # run_gemm_sp_sm80(32, 16, 64, "float16", "float32", "float32", 32, 16, 64, 0, 32, trans_B=True)
355- # run_gemm_sp_sm80(32, 32, 32, "float32", "float32", "float32", 32, 32, 32, 0, 32, trans_B=False)
356- # run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True)
357- run_gemm_sp_sm80 (128 , 128 , 128 , "float8_e4m3" , "float32" , "float32" , 128 , 128 , 64 , 2 , 32 , False , True )
358- # run_gemm_sp_sm80(128, 128, 128, "float8_e4m3", "float8_e4m3", "float32", 128, 128, 64, 2, 32, False, True)
351+ tilelang .testing .main ()
0 commit comments