-
Notifications
You must be signed in to change notification settings - Fork 149
[Triton] A8W8 blockscale GEMM tuning for Qwen3 #1195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Triton] A8W8 blockscale GEMM tuning for Qwen3 #1195
Conversation
vgokhale
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have a ton of changes that your editor presumably auto added. Can you revert all changes except the ones to the blockscale GEMM?
dfc4e1d to
0726067
Compare
@vgokhale all clean now |
Facilitate FP8 blockscale GEMM configs for Qwen3 model and performance speedup through optimizing the block configs and switching from loading BLOCK_SIZE_N FP32 scale factors of B tensor to BLOCK_SIZE_N / GROUP_N unique scaling factors and perform group broadcasting.
Main branch
This branch
A finer grain set of GEMM configs has been provisioned to maximize gains of various M size of QWEN3 model shapes.
The current kernel do BLOCK_SIZE_N loading for the FP32 scale factors of B tensor, which could potentially be wasteful. On the other hands, there are actually only BLOCK_SIZE_N / 128 unique scaling factors being loaded considering the blockscale shape (128, 128) is used in
op_tests/op_benchmarks/triton/bench_gemm_a8w8_blockscale.py.Adopting the idea in [Triton] e2e fused MoE for small N and fp8 blockscale MoE benching #1126 , scalars of the BLOCK_SIZE_N / 128 unique values in the current tile are loaded to then get group broadcasted (similar to torch repeat interleave). This potentially reduce wait time on
tl.load.