@@ -3815,124 +3815,126 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
3815
3815
combinations = product ([False , True ], \
3816
3816
[InputLayout .PACKED_QKV , InputLayout .CONTIGUOUS_Q_KV ,
3817
3817
InputLayout .Q_PAGED_KV , InputLayout .SEPARATE_Q_K_V ],
3818
- [False , True ])
3819
- for (alibi , input_layout , enable_attn_logit_softcapping ) in combinations :
3818
+ [False , True ], [False , True ])
3819
+ for (alibi , input_layout , enable_attn_logit_softcapping ,
3820
+ return_softmax ) in combinations :
3820
3821
# alibi and bmm1_tanh_scale shouldn't be used together.
3821
3822
if alibi and enable_attn_logit_softcapping :
3822
3823
continue
3823
- # D <= 64: KV_STEP = 256
3824
- specs .append (
3825
- kernel_spec (
3826
- sm = sm ,
3827
- sm_mma = 90 ,
3828
- dtype = dtype ,
3829
- seq_len = 0 , # support any sequence length
3830
- head_size = [32 , 40 , 48 , 64 ],
3831
- warps_m = 4 , #4x1 warpgroups
3832
- warps_n = 1 ,
3833
- version = 2 ,
3834
- interleaved = False ,
3835
- ldgsts_q =
3836
- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3837
- ldgsts_k = False ,
3838
- ldgsts_v = False ,
3839
- share_smem_k_v = False ,
3840
- loop_step = 64 ,
3841
- q_tile_buffers = 1 , # only used by warp specialized kernels
3842
- has_noloop = 0 ,
3843
- noloop_step = 64 ,
3844
- kv_loop_step = 256 ,
3845
- kv_tile_buffers = 4 , # only used by warp specialized kernels
3846
- unroll_threshold = 1 ,
3847
- has_scale_max = False ,
3848
- flash_attention = True ,
3849
- warp_specialization = True ,
3850
- alibi = alibi ,
3851
- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3852
- return_softmax_stats =
3853
- False , # return softmax stats is not supported for fp8 now
3854
- scheduling_mode = scheduling_mode ,
3855
- input_layout = input_layout ,
3856
- sage_block_sizes = sage_block_sizes ,
3857
- output_dtype = output_dtype ))
3858
-
3859
- # 64 < D <=128: KV_STEP = 128
3860
- specs .append (
3861
- kernel_spec (
3862
- sm = sm ,
3863
- sm_mma = 90 ,
3864
- dtype = dtype ,
3865
- seq_len = 0 , # support any sequence length
3866
- head_size = [80 , 96 , 104 , 128 ],
3867
- warps_m = 4 , #4x1 warpgroups
3868
- warps_n = 1 ,
3869
- version = 2 ,
3870
- interleaved = False ,
3871
- ldgsts_q =
3872
- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3873
- ldgsts_k = False ,
3874
- ldgsts_v = False ,
3875
- share_smem_k_v = False ,
3876
- loop_step = 64 ,
3877
- q_tile_buffers = 1 , # only used by warp specialized kernels
3878
- has_noloop = 0 ,
3879
- noloop_step = 64 ,
3880
- kv_loop_step = 256 ,
3881
- kv_tile_buffers = 2 , # only used by warp specialized kernels
3882
- unroll_threshold = 1 ,
3883
- has_scale_max = False ,
3884
- flash_attention = True ,
3885
- warp_specialization = True ,
3886
- alibi = alibi ,
3887
- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3888
- return_softmax_stats =
3889
- False , # return softmax stats is not supported for fp8 now
3890
- scheduling_mode = scheduling_mode ,
3891
- input_layout = input_layout ,
3892
- sage_block_sizes = sage_block_sizes ,
3893
- output_dtype = output_dtype ))
3894
-
3895
- # 128 < D <=256: KV_STEP = 128
3896
- specs .append (
3897
- kernel_spec (
3898
- sm = sm ,
3899
- sm_mma = 90 ,
3900
- dtype = dtype ,
3901
- seq_len = 0 , # support any sequence length
3902
- head_size = [160 , 192 , 256 ],
3903
- warps_m = 4 , #4x1 warpgroups
3904
- warps_n = 1 ,
3905
- version = 2 ,
3906
- interleaved = False ,
3907
- ldgsts_q =
3908
- False , # for Hopper kernels, ldgsts = False signals TMA usage.
3909
- ldgsts_k = False ,
3910
- ldgsts_v = False ,
3911
- share_smem_k_v = False ,
3912
- loop_step = 64 ,
3913
- q_tile_buffers = 1 , # only used by warp specialized kernels
3914
- has_noloop = 0 ,
3915
- noloop_step = 64 ,
3916
- kv_loop_step =
3917
- 128 , # use 128 kv step size to avoid register spilling
3918
- kv_tile_buffers = 2 , # only used by warp specialized kernels
3919
- unroll_threshold = 1 ,
3920
- has_scale_max = False ,
3921
- flash_attention = True ,
3922
- warp_specialization = True ,
3923
- alibi = alibi ,
3924
- enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3925
- return_softmax_stats =
3926
- False , # return softmax stats is not supported for fp8 now
3927
- scheduling_mode = scheduling_mode ,
3928
- input_layout = input_layout ,
3929
- sage_block_sizes = sage_block_sizes ,
3930
- output_dtype = output_dtype ))
3931
-
3932
- # context MLA (192x128)
3933
- # we could use param 'output_dtype' of enumerate_qgmma_flash_warpspec_kernels(),
3934
- # but it will generate many unnecessary kernels and they are not easy to filter out.
3935
- for output_type in [None , 'bf16' ]:
3824
+ # for normal attention, we do not need return softmax for ws fp8 kernels currently.
3825
+ # also fp8 input and bf16 output is only needed for MLA kernel.
3826
+ skip_combination = return_softmax or (output_dtype is not None )
3827
+ # for context mla, we need separate qkv as input layout when returning softmax.
3828
+ skip_mla_combination = return_softmax and input_layout != InputLayout .SEPARATE_Q_K_V
3829
+ if not skip_combination :
3830
+ # D <= 64: KV_STEP = 256
3831
+ specs .append (
3832
+ kernel_spec (
3833
+ sm = sm ,
3834
+ sm_mma = 90 ,
3835
+ dtype = dtype ,
3836
+ seq_len = 0 , # support any sequence length
3837
+ head_size = [32 , 40 , 48 , 64 ],
3838
+ warps_m = 4 , #4x1 warpgroups
3839
+ warps_n = 1 ,
3840
+ version = 2 ,
3841
+ interleaved = False ,
3842
+ ldgsts_q =
3843
+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3844
+ ldgsts_k = False ,
3845
+ ldgsts_v = False ,
3846
+ share_smem_k_v = False ,
3847
+ loop_step = 64 ,
3848
+ q_tile_buffers = 1 , # only used by warp specialized kernels
3849
+ has_noloop = 0 ,
3850
+ noloop_step = 64 ,
3851
+ kv_loop_step = 256 ,
3852
+ kv_tile_buffers = 4 , # only used by warp specialized kernels
3853
+ unroll_threshold = 1 ,
3854
+ has_scale_max = False ,
3855
+ flash_attention = True ,
3856
+ warp_specialization = True ,
3857
+ alibi = alibi ,
3858
+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3859
+ return_softmax_stats = return_softmax ,
3860
+ scheduling_mode = scheduling_mode ,
3861
+ input_layout = input_layout ,
3862
+ sage_block_sizes = sage_block_sizes ,
3863
+ output_dtype = output_dtype ))
3864
+
3865
+ # 64 < D <=128: KV_STEP = 128
3866
+ specs .append (
3867
+ kernel_spec (
3868
+ sm = sm ,
3869
+ sm_mma = 90 ,
3870
+ dtype = dtype ,
3871
+ seq_len = 0 , # support any sequence length
3872
+ head_size = [80 , 96 , 104 , 128 ],
3873
+ warps_m = 4 , #4x1 warpgroups
3874
+ warps_n = 1 ,
3875
+ version = 2 ,
3876
+ interleaved = False ,
3877
+ ldgsts_q =
3878
+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3879
+ ldgsts_k = False ,
3880
+ ldgsts_v = False ,
3881
+ share_smem_k_v = False ,
3882
+ loop_step = 64 ,
3883
+ q_tile_buffers = 1 , # only used by warp specialized kernels
3884
+ has_noloop = 0 ,
3885
+ noloop_step = 64 ,
3886
+ kv_loop_step = 256 ,
3887
+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3888
+ unroll_threshold = 1 ,
3889
+ has_scale_max = False ,
3890
+ flash_attention = True ,
3891
+ warp_specialization = True ,
3892
+ alibi = alibi ,
3893
+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3894
+ return_softmax_stats = return_softmax ,
3895
+ scheduling_mode = scheduling_mode ,
3896
+ input_layout = input_layout ,
3897
+ sage_block_sizes = sage_block_sizes ,
3898
+ output_dtype = output_dtype ))
3899
+
3900
+ # 128 < D <=256: KV_STEP = 128
3901
+ specs .append (
3902
+ kernel_spec (
3903
+ sm = sm ,
3904
+ sm_mma = 90 ,
3905
+ dtype = dtype ,
3906
+ seq_len = 0 , # support any sequence length
3907
+ head_size = [160 , 192 , 256 ],
3908
+ warps_m = 4 , #4x1 warpgroups
3909
+ warps_n = 1 ,
3910
+ version = 2 ,
3911
+ interleaved = False ,
3912
+ ldgsts_q =
3913
+ False , # for Hopper kernels, ldgsts = False signals TMA usage.
3914
+ ldgsts_k = False ,
3915
+ ldgsts_v = False ,
3916
+ share_smem_k_v = False ,
3917
+ loop_step = 64 ,
3918
+ q_tile_buffers = 1 , # only used by warp specialized kernels
3919
+ has_noloop = 0 ,
3920
+ noloop_step = 64 ,
3921
+ kv_loop_step =
3922
+ 128 , # use 128 kv step size to avoid register spilling
3923
+ kv_tile_buffers = 2 , # only used by warp specialized kernels
3924
+ unroll_threshold = 1 ,
3925
+ has_scale_max = False ,
3926
+ flash_attention = True ,
3927
+ warp_specialization = True ,
3928
+ alibi = alibi ,
3929
+ enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3930
+ return_softmax_stats = return_softmax ,
3931
+ scheduling_mode = scheduling_mode ,
3932
+ input_layout = input_layout ,
3933
+ sage_block_sizes = sage_block_sizes ,
3934
+ output_dtype = output_dtype ))
3935
+
3936
+ if not skip_mla_combination :
3937
+ # context MLA (192x128)
3936
3938
specs .append (
3937
3939
kernel_spec (
3938
3940
sm = sm ,
@@ -3962,12 +3964,11 @@ def enumerate_qgmma_flash_warpspec_kernels(specs,
3962
3964
warp_specialization = True ,
3963
3965
alibi = alibi ,
3964
3966
enable_attn_logit_softcapping = enable_attn_logit_softcapping ,
3965
- return_softmax_stats =
3966
- False , # return softmax stats is not supported for fp8 now
3967
+ return_softmax_stats = return_softmax ,
3967
3968
scheduling_mode = scheduling_mode ,
3968
3969
input_layout = input_layout ,
3969
3970
sage_block_sizes = sage_block_sizes ,
3970
- output_dtype = output_type ))
3971
+ output_dtype = output_dtype ))
3971
3972
3972
3973
3973
3974
def enumerate_igmma_kernels (specs , sm = 90 ):
@@ -6215,6 +6216,10 @@ def enumerate_kernels():
6215
6216
enumerate_hgmma_flash_warpspec_kernels (specs , sm = 90 , dtype = 'fp16' )
6216
6217
enumerate_hgmma_flash_warpspec_kernels (specs , sm = 90 , dtype = 'bf16' )
6217
6218
enumerate_qgmma_flash_warpspec_kernels (specs , sm = 90 , dtype = 'e4m3' )
6219
+ enumerate_qgmma_flash_warpspec_kernels (specs ,
6220
+ sm = 90 ,
6221
+ dtype = 'e4m3' ,
6222
+ output_dtype = "bf16" )
6218
6223
6219
6224
# For now SageAttention only needs BF16
6220
6225
# block_size_q should be divisible by 64
0 commit comments