Skip to content

Commit 7deefb3

Browse files
authored
[TRTLLM-7192][feat] optimize MLA chunked prefill && support fp8 mla chunked prefill (#7477)
Signed-off-by: Mingyang Jiang <[email protected]>
1 parent 24fc1f9 commit 7deefb3

File tree

22 files changed

+591
-326
lines changed

22 files changed

+591
-326
lines changed

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ def test_trtllm_context_mla_attention_fmha(dtype, s):
178178
check=True)
179179

180180
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v).
181-
# Currently fp8 kernel doesn't support saving softmax.
182-
if dtype == "-bf16":
181+
if dtype in ["-bf16", "-e4m3"]:
183182
# padding mask
184183
subprocess.run(
185184
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "

cpp/kernels/fmha_v2/setup.py

Lines changed: 123 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -3815,124 +3815,126 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
38153815
combinations = product([False, True], \
38163816
[InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
38173817
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V],
3818-
[False, True])
3819-
for (alibi, input_layout, enable_attn_logit_softcapping) in combinations:
3818+
[False, True], [False, True])
3819+
for (alibi, input_layout, enable_attn_logit_softcapping,
3820+
return_softmax) in combinations:
38203821
# alibi and bmm1_tanh_scale shouldn't be used together.
38213822
if alibi and enable_attn_logit_softcapping:
38223823
continue
3823-
# D <= 64: KV_STEP = 256
3824-
specs.append(
3825-
kernel_spec(
3826-
sm=sm,
3827-
sm_mma=90,
3828-
dtype=dtype,
3829-
seq_len=0, # support any sequence length
3830-
head_size=[32, 40, 48, 64],
3831-
warps_m=4, #4x1 warpgroups
3832-
warps_n=1,
3833-
version=2,
3834-
interleaved=False,
3835-
ldgsts_q=
3836-
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3837-
ldgsts_k=False,
3838-
ldgsts_v=False,
3839-
share_smem_k_v=False,
3840-
loop_step=64,
3841-
q_tile_buffers=1, # only used by warp specialized kernels
3842-
has_noloop=0,
3843-
noloop_step=64,
3844-
kv_loop_step=256,
3845-
kv_tile_buffers=4, # only used by warp specialized kernels
3846-
unroll_threshold=1,
3847-
has_scale_max=False,
3848-
flash_attention=True,
3849-
warp_specialization=True,
3850-
alibi=alibi,
3851-
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3852-
return_softmax_stats=
3853-
False, # return softmax stats is not supported for fp8 now
3854-
scheduling_mode=scheduling_mode,
3855-
input_layout=input_layout,
3856-
sage_block_sizes=sage_block_sizes,
3857-
output_dtype=output_dtype))
3858-
3859-
# 64 < D <=128: KV_STEP = 128
3860-
specs.append(
3861-
kernel_spec(
3862-
sm=sm,
3863-
sm_mma=90,
3864-
dtype=dtype,
3865-
seq_len=0, # support any sequence length
3866-
head_size=[80, 96, 104, 128],
3867-
warps_m=4, #4x1 warpgroups
3868-
warps_n=1,
3869-
version=2,
3870-
interleaved=False,
3871-
ldgsts_q=
3872-
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3873-
ldgsts_k=False,
3874-
ldgsts_v=False,
3875-
share_smem_k_v=False,
3876-
loop_step=64,
3877-
q_tile_buffers=1, # only used by warp specialized kernels
3878-
has_noloop=0,
3879-
noloop_step=64,
3880-
kv_loop_step=256,
3881-
kv_tile_buffers=2, # only used by warp specialized kernels
3882-
unroll_threshold=1,
3883-
has_scale_max=False,
3884-
flash_attention=True,
3885-
warp_specialization=True,
3886-
alibi=alibi,
3887-
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3888-
return_softmax_stats=
3889-
False, # return softmax stats is not supported for fp8 now
3890-
scheduling_mode=scheduling_mode,
3891-
input_layout=input_layout,
3892-
sage_block_sizes=sage_block_sizes,
3893-
output_dtype=output_dtype))
3894-
3895-
# 128 < D <=256: KV_STEP = 128
3896-
specs.append(
3897-
kernel_spec(
3898-
sm=sm,
3899-
sm_mma=90,
3900-
dtype=dtype,
3901-
seq_len=0, # support any sequence length
3902-
head_size=[160, 192, 256],
3903-
warps_m=4, #4x1 warpgroups
3904-
warps_n=1,
3905-
version=2,
3906-
interleaved=False,
3907-
ldgsts_q=
3908-
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3909-
ldgsts_k=False,
3910-
ldgsts_v=False,
3911-
share_smem_k_v=False,
3912-
loop_step=64,
3913-
q_tile_buffers=1, # only used by warp specialized kernels
3914-
has_noloop=0,
3915-
noloop_step=64,
3916-
kv_loop_step=
3917-
128, # use 128 kv step size to avoid register spilling
3918-
kv_tile_buffers=2, # only used by warp specialized kernels
3919-
unroll_threshold=1,
3920-
has_scale_max=False,
3921-
flash_attention=True,
3922-
warp_specialization=True,
3923-
alibi=alibi,
3924-
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3925-
return_softmax_stats=
3926-
False, # return softmax stats is not supported for fp8 now
3927-
scheduling_mode=scheduling_mode,
3928-
input_layout=input_layout,
3929-
sage_block_sizes=sage_block_sizes,
3930-
output_dtype=output_dtype))
3931-
3932-
# context MLA (192x128)
3933-
# we could use param 'output_dtype' of enumerate_qgmma_flash_warpspec_kernels(),
3934-
# but it will generate many unnecessary kernels and they are not easy to filter out.
3935-
for output_type in [None, 'bf16']:
3824+
# for normal attention, we do not need return softmax for ws fp8 kernels currently.
3825+
# also fp8 input and bf16 output is only needed for MLA kernel.
3826+
skip_combination = return_softmax or (output_dtype is not None)
3827+
# for context mla, we need separate qkv as input layout when returning softmax.
3828+
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
3829+
if not skip_combination:
3830+
# D <= 64: KV_STEP = 256
3831+
specs.append(
3832+
kernel_spec(
3833+
sm=sm,
3834+
sm_mma=90,
3835+
dtype=dtype,
3836+
seq_len=0, # support any sequence length
3837+
head_size=[32, 40, 48, 64],
3838+
warps_m=4, #4x1 warpgroups
3839+
warps_n=1,
3840+
version=2,
3841+
interleaved=False,
3842+
ldgsts_q=
3843+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3844+
ldgsts_k=False,
3845+
ldgsts_v=False,
3846+
share_smem_k_v=False,
3847+
loop_step=64,
3848+
q_tile_buffers=1, # only used by warp specialized kernels
3849+
has_noloop=0,
3850+
noloop_step=64,
3851+
kv_loop_step=256,
3852+
kv_tile_buffers=4, # only used by warp specialized kernels
3853+
unroll_threshold=1,
3854+
has_scale_max=False,
3855+
flash_attention=True,
3856+
warp_specialization=True,
3857+
alibi=alibi,
3858+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3859+
return_softmax_stats=return_softmax,
3860+
scheduling_mode=scheduling_mode,
3861+
input_layout=input_layout,
3862+
sage_block_sizes=sage_block_sizes,
3863+
output_dtype=output_dtype))
3864+
3865+
# 64 < D <=128: KV_STEP = 128
3866+
specs.append(
3867+
kernel_spec(
3868+
sm=sm,
3869+
sm_mma=90,
3870+
dtype=dtype,
3871+
seq_len=0, # support any sequence length
3872+
head_size=[80, 96, 104, 128],
3873+
warps_m=4, #4x1 warpgroups
3874+
warps_n=1,
3875+
version=2,
3876+
interleaved=False,
3877+
ldgsts_q=
3878+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3879+
ldgsts_k=False,
3880+
ldgsts_v=False,
3881+
share_smem_k_v=False,
3882+
loop_step=64,
3883+
q_tile_buffers=1, # only used by warp specialized kernels
3884+
has_noloop=0,
3885+
noloop_step=64,
3886+
kv_loop_step=256,
3887+
kv_tile_buffers=2, # only used by warp specialized kernels
3888+
unroll_threshold=1,
3889+
has_scale_max=False,
3890+
flash_attention=True,
3891+
warp_specialization=True,
3892+
alibi=alibi,
3893+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3894+
return_softmax_stats=return_softmax,
3895+
scheduling_mode=scheduling_mode,
3896+
input_layout=input_layout,
3897+
sage_block_sizes=sage_block_sizes,
3898+
output_dtype=output_dtype))
3899+
3900+
# 128 < D <=256: KV_STEP = 128
3901+
specs.append(
3902+
kernel_spec(
3903+
sm=sm,
3904+
sm_mma=90,
3905+
dtype=dtype,
3906+
seq_len=0, # support any sequence length
3907+
head_size=[160, 192, 256],
3908+
warps_m=4, #4x1 warpgroups
3909+
warps_n=1,
3910+
version=2,
3911+
interleaved=False,
3912+
ldgsts_q=
3913+
False, # for Hopper kernels, ldgsts = False signals TMA usage.
3914+
ldgsts_k=False,
3915+
ldgsts_v=False,
3916+
share_smem_k_v=False,
3917+
loop_step=64,
3918+
q_tile_buffers=1, # only used by warp specialized kernels
3919+
has_noloop=0,
3920+
noloop_step=64,
3921+
kv_loop_step=
3922+
128, # use 128 kv step size to avoid register spilling
3923+
kv_tile_buffers=2, # only used by warp specialized kernels
3924+
unroll_threshold=1,
3925+
has_scale_max=False,
3926+
flash_attention=True,
3927+
warp_specialization=True,
3928+
alibi=alibi,
3929+
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3930+
return_softmax_stats=return_softmax,
3931+
scheduling_mode=scheduling_mode,
3932+
input_layout=input_layout,
3933+
sage_block_sizes=sage_block_sizes,
3934+
output_dtype=output_dtype))
3935+
3936+
if not skip_mla_combination:
3937+
# context MLA (192x128)
39363938
specs.append(
39373939
kernel_spec(
39383940
sm=sm,
@@ -3962,12 +3964,11 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
39623964
warp_specialization=True,
39633965
alibi=alibi,
39643966
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
3965-
return_softmax_stats=
3966-
False, # return softmax stats is not supported for fp8 now
3967+
return_softmax_stats=return_softmax,
39673968
scheduling_mode=scheduling_mode,
39683969
input_layout=input_layout,
39693970
sage_block_sizes=sage_block_sizes,
3970-
output_dtype=output_type))
3971+
output_dtype=output_dtype))
39713972

39723973

39733974
def enumerate_igmma_kernels(specs, sm=90):
@@ -6215,6 +6216,10 @@ def enumerate_kernels():
62156216
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16')
62166217
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='bf16')
62176218
enumerate_qgmma_flash_warpspec_kernels(specs, sm=90, dtype='e4m3')
6219+
enumerate_qgmma_flash_warpspec_kernels(specs,
6220+
sm=90,
6221+
dtype='e4m3',
6222+
output_dtype="bf16")
62186223

62196224
# For now SageAttention only needs BF16
62206225
# block_size_q should be divisible by 64

cpp/kernels/fmha_v2/src/fmha/fragment.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,98 @@ struct Tile_o_normalizer<Ada_qmma_e4m3_fp32_traits, Cta_tile>
17491749

17501750
// Default ctor
17511751
Tile_o_normalizer() = default;
1752+
1753+
// The fragment accumulator.
1754+
using Fragment_accu = Fragment_accumulator<Traits>;
1755+
1756+
// The Mma tile.
1757+
using Mma_tile = typename Traits::template Mma_tile<Cta_tile>;
1758+
1759+
// The number of MMAs in the M dimension.
1760+
enum
1761+
{
1762+
MMAS_M = Mma_tile::MMAS_M
1763+
};
1764+
1765+
// The number of MMAs in the N dimension.
1766+
enum
1767+
{
1768+
MMAS_N = Mma_tile::VALID_MMAS_N
1769+
};
1770+
1771+
// The number of rows per thread.
1772+
enum
1773+
{
1774+
ROWS_PER_THREAD = 2 * MMAS_M
1775+
};
1776+
1777+
// The number of registers per thread.
1778+
enum
1779+
{
1780+
REGS_PER_THREAD = 8
1781+
};
1782+
1783+
// Warps.
1784+
enum
1785+
{
1786+
WARPS_M = Cta_tile::WARPS_M
1787+
};
1788+
1789+
enum
1790+
{
1791+
WARPS_N = Cta_tile::WARPS_N
1792+
};
1793+
1794+
enum
1795+
{
1796+
WARPS_K = Cta_tile::WARPS_K
1797+
};
1798+
1799+
// softmax data bytes
1800+
enum
1801+
{
1802+
BYTES_PER_ELEMENT = sizeof(float)
1803+
};
1804+
1805+
// Update o after P * V, the only difference from the basic class is we need to dequant the sum for softmax saver.
1806+
inline __device__ void final_update(Fragment_accu (&acc_o)[MMAS_M][MMAS_N], float (&sum)[ROWS_PER_THREAD])
1807+
{
1808+
1809+
constexpr float dequant_scale = Traits::SOFTMAX_FP_DEQUANT_SCALE;
1810+
#pragma unroll
1811+
for (int mi = 0; mi < MMAS_M; ++mi)
1812+
{
1813+
1814+
// Precompute the scaling factors for the 2 rows.
1815+
float beta[2];
1816+
#pragma unroll
1817+
for (int ii = 0; ii < 2; ++ii)
1818+
{
1819+
// The row.
1820+
int jj = 2 * mi + ii;
1821+
1822+
// The diviser.
1823+
beta[ii] = (sum[jj] == 0.f || sum[jj] != sum[jj]) ? 1.f : 1.f / sum[jj];
1824+
// softmax saver need the original sum.
1825+
sum[jj] = sum[jj] * dequant_scale;
1826+
}
1827+
1828+
#pragma unroll
1829+
for (int ni = 0; ni < MMAS_N; ++ni)
1830+
{
1831+
#pragma unroll
1832+
for (int ii = 0; ii < REGS_PER_THREAD; ++ii)
1833+
{
1834+
// The register for O.
1835+
float acc_o_f = acc_o[mi][ni].elt(ii);
1836+
// Compute the next accumulator.
1837+
acc_o_f = acc_o_f * beta[(ii & 2) / 2];
1838+
// Update the accumulator.
1839+
acc_o[mi][ni].elt(ii) = acc_o_f;
1840+
}
1841+
}
1842+
}
1843+
}
17521844
};
17531845

17541846
////////////////////////////////////////////////////////////////////////////////////////////////////

cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,11 @@ struct Tile_o_epilogue<Hopper_qgmma_e4m3_fp32_traits, Kernel_traits>
13181318
#else
13191319
float scale = global_sum_mi == 0.f ? 1.0f : 1.0f / global_sum_mi;
13201320
#endif
1321+
if constexpr (Kernel_traits::RETURN_SOFTMAX_STATS)
1322+
{
1323+
// Save the dequant exp sum for softmax saver.
1324+
global_sum[mi] *= Traits_o::SOFTMAX_FP_DEQUANT_SCALE;
1325+
}
13211326
// Assume only N has multiple MMAs (MMAS_M = 1).
13221327
#pragma unroll
13231328
for (int mma_ni = 0; mma_ni < Mma_tile_o::MMAS_N; mma_ni++)

0 commit comments

Comments
 (0)