diff --git a/.gitignore b/.gitignore index 2fd004dc70..4f9e38b261 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ scripts csrc/flash_attn_ck .eggs log +*.rocprof* *.log core.* gpucore.* diff --git a/flash_attn/flash_attn_triton_amd/bwd.py b/flash_attn/flash_attn_triton_amd/bwd.py index 085232cedc..4d4c22866d 100755 --- a/flash_attn/flash_attn_triton_amd/bwd.py +++ b/flash_attn/flash_attn_triton_amd/bwd.py @@ -19,286 +19,147 @@ tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) -def get_autotune_configs(): - if False: - if is_cdna(): - # shared meta-parameters - NUM_STAGES = 1 - NUM_WARPS = 4 - WAVES_PER_EU = 2 - MATRIX_INSTR_NONKDIM = 16 - - preprocess_autotune_configs = [ - triton.Config( - { - "PRE_BLOCK": 128, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), # og config - triton.Config( - { - "PRE_BLOCK": 64, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "PRE_BLOCK": 32, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "PRE_BLOCK": 16, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - ] - preprocess_autotune_keys = [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", - ] - causal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), # og config - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - ] - causal_autotune_keys = [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", - ] - noncausal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), # og config - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 128, - "BLOCK_M2": 128, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 16, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 16, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - triton.Config( - { - "BLOCK_M1": 32, - "BLOCK_N1": 64, - "BLOCK_M2": 64, - "BLOCK_N2": 32, - "BLK_SLICE_FACTOR": 2, - "waves_per_eu": WAVES_PER_EU, - "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), - ] - noncausal_autotune_keys = [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", - ] - - return ( - (preprocess_autotune_configs, preprocess_autotune_keys), - (causal_autotune_configs, causal_autotune_keys), - (noncausal_autotune_configs, noncausal_autotune_keys), - ) - else: - raise ValueError("Unknown Device Type") - else: - # meta-parameters - # TODO: fix num_stages later - NUM_WARPS, NUM_STAGES = 4, 1 - WAVES_PER_EU = 1 - PRE_BLOCK = 128 - BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 - BLK_SLICE_FACTOR = 2 - - assert BLOCK_N1 == BLOCK_M2 +def get_bwd_configs(autotune = False): + # default config + if not autotune: + # preprocess params + PRE_BLOCK = 64 + PRE_WAVES_PER_EU=2 + PRE_NUM_STAGES=2 + PRE_NUM_WARPS=8 # configs for the kernels preprocess_autotune_configs = [ - triton.Config( - {"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), + triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": PRE_WAVES_PER_EU}, num_stages=PRE_NUM_STAGES, num_warps=PRE_NUM_WARPS), ] preprocess_autotune_keys = [ "max_seqlen_q", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", + "ACTUAL_HEAD_DIM", "IS_VARLEN", ] + + # main params + NUM_STAGES=1 + NUM_WARPS= 4 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 64 + BLK_SLICE_FACTOR = 2 + MATRIX_INSTR_NONKDIM=16 + assert BLOCK_N1 == BLOCK_M2 + causal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": BLOCK_M1, - "BLOCK_N1": BLOCK_N1, - "BLOCK_M2": BLOCK_M2, - "BLOCK_N2": BLOCK_N2, - "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, - "waves_per_eu": WAVES_PER_EU, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] causal_autotune_keys = [ - "dropout_p", - "max_seqlen_q", - "max_seqlen_k", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ - triton.Config( - { - "BLOCK_M1": BLOCK_M1, - "BLOCK_N1": BLOCK_N1, - "BLOCK_M2": BLOCK_M2, - "BLOCK_N2": BLOCK_N2, - "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, - "waves_per_eu": WAVES_PER_EU, - }, - num_stages=NUM_STAGES, - num_warps=NUM_WARPS, - ), + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] noncausal_autotune_keys = [ - "dropout_p", - "max_seqlen_q", - "max_seqlen_k", - "ACTUAL_HEAD_DIM_QK", - "ACTUAL_HEAD_DIM_V", - "IS_VARLEN", - "HQ", - "HK", + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", ] - return ( - (preprocess_autotune_configs, preprocess_autotune_keys), - (causal_autotune_configs, causal_autotune_keys), - (noncausal_autotune_configs, noncausal_autotune_keys), - ) + return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) + + + # params + PRE_BLOCK_OPTIONS = [64, 128] # og: 128 + PRE_WAVES_PER_EU_OPTIONS=[1, 2] + PRE_NUM_STAGES_OPTIONS=[1, 2] + PRE_NUM_WARPS_OPTIONS=[4, 8] + + + # Preprocess configs + preprocess_autotune_configs = [] + for pre_num_warps in PRE_NUM_WARPS_OPTIONS: + for pre_num_stages in PRE_NUM_STAGES_OPTIONS: + for pre_waves in PRE_WAVES_PER_EU_OPTIONS: + for pre_block in PRE_BLOCK_OPTIONS: + preprocess_autotune_configs.append( + triton.Config({ + "PRE_BLOCK": pre_block, + "waves_per_eu": pre_waves, + }, num_stages=pre_num_stages, num_warps=pre_num_warps) + ) + + NUM_STAGES_OPTIONS = [1, 2] # og: 1 + NUM_WARPS_OPTIONS = [4, 8] # og: 4 + WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 + MATRIX_INSTR_NONKDIM_OPTIONS = [16, 32] # og: 16 + BLOCK_M1_OPTIONS = [ # og: 32 + 32, 64 + ] + BLOCK_N1_M2_OPTIONS = [ # og: 128 + 64, 128 + ] + BLOCK_N2_OPTIONS = [ # og: 32 + 32, 64 + ] + BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 + + # build configs + causal_autotune_configs = [] + noncausal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for matrix_instr_nonkdim in MATRIX_INSTR_NONKDIM_OPTIONS: + # Causal and non-causal configs + for m1 in BLOCK_M1_OPTIONS: + for n1 in BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in BLOCK_N2_OPTIONS: + # Ensure constraint + assert n1 == m2, f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: + causal_autotune_configs.append( + triton.Config({ + "BLOCK_M1": m1, "BLOCK_N1": n1, + "BLOCK_M2": m2, "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + "matrix_instr_nonkdim": matrix_instr_nonkdim + }, num_stages=num_stages, num_warps=num_warps) + ) + + noncausal_autotune_configs.append( + triton.Config({ + "BLOCK_M1": m1, "BLOCK_N1": n1, + "BLOCK_M2": m2, "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + "matrix_instr_nonkdim": matrix_instr_nonkdim + }, num_stages=num_stages, num_warps=num_warps) + ) + + # kernel keys + preprocess_autotune_keys = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM", "IS_VARLEN", + ] + + causal_autotune_keys = [ + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + + noncausal_autotune_keys = [ + "dropout_p", "max_seqlen_q", "max_seqlen_k", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + + return (preprocess_autotune_configs, preprocess_autotune_keys), \ + (causal_autotune_configs, causal_autotune_keys), \ + (noncausal_autotune_configs, noncausal_autotune_keys) ( (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys), -) = get_autotune_configs() +) = get_bwd_configs() # This function computes delta given output Out and gradient DO