Skip to content

Conversation

@anhminhnguyenhoang
Copy link

@anhminhnguyenhoang anhminhnguyenhoang commented Oct 14, 2025

Facilitate FP8 blockscale GEMM configs for Qwen3 model and performance speedup through optimizing the block configs and switching from loading BLOCK_SIZE_N FP32 scale factors of B tensor to BLOCK_SIZE_N / GROUP_N unique scaling factors and perform group broadcasting.

Main branch

# as2 machine
# Triton
       M     N     K time (ms) throughput (TFLOPs) bandwidth (GB/s)
0      4  2048  2048  0.044884            0.765203        98.176164
1      4  2048  4096  0.051416            1.334147       175.813529
2      4  5120  2048  0.047369            1.832975        235.86101
3      4  5120  4096  0.053619            3.349956       424.900691
4      8  2048  2048  0.045761            1.525462        96.695001
5      8  2048  4096  0.052602            2.678513       176.520806
6      8  5120  2048  0.045526            3.720364       233.846262
7      8  5120  4096  0.056298            6.533555       409.235979
8     16  2048  2048  0.044299            3.044586        99.351376
9     16  2048  4096  0.053322            5.161822       172.536526
10    16  5120  2048  0.045502            7.634204       240.190141
11    16  5120  4096  0.057426           12.418994       410.933916
12    32  2048  2048  0.044556            6.032158        99.159314
13    32  2048  4096  0.054893           10.336277       174.055657
14    32  5120  2048  0.045565           15.363404       247.125622
15    32  5120  4096  0.057851           24.442517       406.420687
16    64  2048  2048  0.045051           11.809648       102.739497
17    64  2048  4096  0.054262           20.430484        176.43738
18    64  5120  2048  0.043906            30.66497       262.429258
19    64  5120  4096  0.058412            48.24189       411.622043
20   128  2048  2048  0.044563           25.185225       115.487478
21   128  2048  4096  0.055542            40.03018       183.565406
22   128  5120  2048     0.045           59.288782       271.388807
23   128  5120  4096  0.058244           95.863216       416.591913
24  2000  2048  2048    0.0443           376.81302       374.848162
25  2000  2048  4096  0.060013          602.498478       455.203495
26  2000  5120  2048  0.072109          624.182899       522.545955
27  2000  5120  4096  0.134485          697.554756       413.410244
28  6017  2048  2048  0.077969          735.033547       598.352723
29  6017  2048  4096  0.149863          819.231361        468.36132
30  6017  5120  2048  0.190446          794.536631        530.44712
31  6017  5120  4096  0.346628          844.637226       358.635609

# as2 machine
# gluon
# triton-lang commit e8bfb7143cbe5089ad1a801868b3a1a52e8ca83b since the current kernel implement for gluon is no longer compatible with newer commits.
       M     N     K time (ms) throughput (TFLOPs) bandwidth (GB/s)
0      4  2048  2048  0.075699            0.463032        57.730388
1      4  2048  4096  0.073197            0.941938       118.520119
2      4  5120  2048  0.076534            1.122416       142.971448
3      4  5120  4096  0.069746            2.429273       308.169057
4      8  2048  2048  0.072825            0.941465        59.962762
5      8  2048  4096   0.07439            1.840966       116.330017
6      8  5120  2048   0.07446            2.330315       148.040062
7      8  5120  4096  0.073379            4.747941       302.227857
8     16  2048  2048  0.073538             1.84331        59.314222
9     16  2048  4096  0.075244            3.619397       115.959241
10    16  5120  2048  0.074378            4.639287       149.506264
11    16  5120  4096  0.072334            9.245479       295.470649
12    32  2048  2048  0.073097            3.717178        61.448586
13    32  2048  4096  0.074192            7.419994       121.780336
14    32  5120  2048   0.07354            9.306815       152.747637
15    32  5120  4096  0.069097           19.317026       313.199682
16    64  2048  2048  0.076548             7.38374        63.794082
17    64  2048  4096  0.073109           14.871409        124.26918
18    64  5120  2048  0.072776           18.884707       159.027927
19    64  5120  4096  0.071117            38.52674       315.213188
20   128  2048  2048  0.072452           14.777889        69.157756
21   128  2048  4096   0.07252           29.550503       131.424833
22   128  5120  2048  0.071206           38.181305       172.862903
23   128  5120  4096  0.072851           73.202771       311.767788
24  2000  2048  2048    0.0415          400.369726       399.737302
25  2000  2048  4096  0.055314          646.335089       486.098201
26  2000  5120  2048   0.06591          688.189257       575.408083
27  2000  5120  4096  0.122081          785.203565       464.516893
28  6017  2048  2048  0.070392          806.004282       656.499792
29  6017  2048  4096   0.13965           908.77939       520.484559
30  6017  5120  2048  0.175361          901.032022       601.333348
31  6017  5120  4096  0.324464          983.571025       417.522658

This branch

# as2 machine
# Triton
       M     N     K time (ms) throughput (TFLOPs) bandwidth (GB/s)
0      4  2048  2048  0.042128            0.808035       104.391878
1      4  2048  4096  0.044991            1.507517       189.565502
2      4  5120  2048  0.044691            1.906203        244.74171
3      4  5120  4096  0.043438            3.882657       492.766106
4      8  2048  2048  0.044033            1.587628       102.851222
5      8  2048  4096  0.045833            2.941682       188.516195
6      8  5120  2048  0.045219            3.914195       247.703457
7      8  5120  4096  0.045584            7.486542        483.71276
8     16  2048  2048  0.043344             3.22513       105.697929
9     16  2048  4096  0.045595            5.929107       190.877712
10    16  5120  2048  0.045112            7.503683       248.361898
11    16  5120  4096  0.044858            14.32205       490.563671
12    32  2048  2048  0.046295              5.8333        97.104687
13    32  2048  4096  0.045604           12.292135       197.319677
14    32  5120  2048  0.044028           15.521966         261.1442
15    32  5120  4096    0.0426           31.631936        515.37714
16    64  2048  2048  0.046673            11.83755       103.104087
17    64  2048  4096  0.048191           23.527294       199.161407
18    64  5120  2048  0.045907           29.589281       253.732356
19    64  5120  4096  0.043696           63.120865        520.84965
20   128  2048  2048  0.045763           23.640731       111.687063
21   128  2048  4096  0.047464           47.622922       214.777576
22   128  5120  2048  0.045783           60.249895       273.563605
23   128  5120  4096  0.043929          124.887072       537.376262
24  2000  2048  2048  0.044769          384.162537       380.781505
25  2000  2048  4096  0.047697          757.493571       559.122508
26  2000  5120  2048  0.059897          756.744829       634.698691
27  2000  5120  4096  0.118632          844.934505       501.869989
28  6017  2048  2048  0.069221          804.572937       657.740146
29  6017  2048  4096  0.141831          900.863615        515.24398
30  6017  5120  2048  0.159564          971.303027       648.734135
31  6017  5120  4096  0.272049         1144.121317       484.582279

A finer grain set of GEMM configs has been provisioned to maximize gains of various M size of QWEN3 model shapes.

The current kernel do BLOCK_SIZE_N loading for the FP32 scale factors of B tensor, which could potentially be wasteful. On the other hands, there are actually only BLOCK_SIZE_N / 128 unique scaling factors being loaded considering the blockscale shape (128, 128) is used in op_tests/op_benchmarks/triton/bench_gemm_a8w8_blockscale.py.

Adopting the idea in [Triton] e2e fused MoE for small N and fp8 blockscale MoE benching #1126 , scalars of the BLOCK_SIZE_N / 128 unique values in the current tile are loaded to then get group broadcasted (similar to torch repeat interleave). This potentially reduce wait time on tl.load.

@anhminhnguyenhoang anhminhnguyenhoang self-assigned this Oct 14, 2025
@anhminhnguyenhoang anhminhnguyenhoang changed the title Anguyenh/a8w8 blockscale gemm tuning qwen3 [Triton] A8W8 blockscale GEMM tuning for QWEN3 Oct 14, 2025
@anhminhnguyenhoang anhminhnguyenhoang changed the title [Triton] A8W8 blockscale GEMM tuning for QWEN3 [Triton] A8W8 blockscale GEMM tuning for Qwen3 Oct 14, 2025
@anhminhnguyenhoang anhminhnguyenhoang marked this pull request as ready for review November 18, 2025 14:47
Copy link
Contributor

@vgokhale vgokhale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a ton of changes that your editor presumably auto added. Can you revert all changes except the ones to the blockscale GEMM?

@anhminhnguyenhoang anhminhnguyenhoang force-pushed the anguyenh/a8w8-blockscale-gemm-tuning-qwen3 branch from dfc4e1d to 0726067 Compare November 20, 2025 21:47
@anhminhnguyenhoang
Copy link
Author

anhminhnguyenhoang commented Nov 21, 2025

You have a ton of changes that your editor presumably auto added. Can you revert all changes except the ones to the blockscale GEMM?

@vgokhale all clean now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants