Skip to content

Commit ed148d5

Browse files
committed
[test] refactor gemm_sp test
1 parent 826a580 commit ed148d5

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmitter
99

1010
torch.backends.cuda.matmul.allow_tf32 = False
11-
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
1211
torch.manual_seed(42)
1312

1413

@@ -46,7 +45,7 @@ def main(
4645
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
4746
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
4847
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
49-
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
48+
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
5049
T.annotate_layout({
5150
E:
5251
make_metadata_layout(
@@ -60,7 +59,7 @@ def main(
6059
block_k=block_K),
6160
})
6261
T.disable_warp_group_reg_alloc()
63-
T.clear(C_local)
62+
T.clear(C_frag)
6463
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
6564
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
6665
if trans_A:
@@ -71,8 +70,8 @@ def main(
7170
T.copy(B[bx * block_N, k * block_K], B_shared)
7271
else:
7372
T.copy(B[k * block_K, bx * block_N], B_shared)
74-
T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B)
75-
T.copy(C_local, C[by * block_M, bx * block_N])
73+
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
74+
T.copy(C_frag, C[by * block_M, bx * block_N])
7675

7776
return main
7877

@@ -132,7 +131,7 @@ def main(
132131
T.copy(B[bx * block_N, k * block_K], B_shared)
133132
else:
134133
T.copy(B[k * block_K, bx * block_N], B_shared)
135-
T.gemm_sp_v2(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
134+
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
136135
T.copy(C_frag, C[by * block_M, bx * block_N])
137136

138137
return main
@@ -349,10 +348,4 @@ def test_gemm_sp_sm80():
349348

350349
if __name__ == "__main__":
351350
tilelang.disable_cache()
352-
# tilelang.testing.main()
353-
# run_gemm_sp_sm80(32, 64, 64, "float16", "float32", "float32", 32, 64, 64, 0, 32)
354-
# run_gemm_sp_sm80(32, 16, 64, "float16", "float32", "float32", 32, 16, 64, 0, 32, trans_B=True)
355-
# run_gemm_sp_sm80(32, 32, 32, "float32", "float32", "float32", 32, 32, 32, 0, 32, trans_B=False)
356-
# run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True)
357-
run_gemm_sp_sm80(128, 128, 128, "float8_e4m3", "float32", "float32", 128, 128, 64, 2, 32, False, True)
358-
# run_gemm_sp_sm80(128, 128, 128, "float8_e4m3", "float8_e4m3", "float32", 128, 128, 64, 2, 32, False, True)
351+
tilelang.testing.main()

0 commit comments

Comments
 (0)