Skip to content

Commit 7486448

Browse files
committed
[CK_TILE] fmha: Add backward pass support for padded inputs
Introduces support for padded sequence lengths in the backward pass of the variable-length flash attention (fmha_v3_varlen_bwd). - Updated Python and C++ function signatures to accept optional `cu_seqlens_q_padded` and `cu_seqlens_k_padded` arguments. - Modified the underlying CUDA kernels and code generation scripts to pass padding information via the new `seqlen_q_ptr` and `seqlen_k_ptr` fields in the CK `fmha_bwd_args` struct. - Modified the underlying kernels and code generation scripts to correctly handle pointers for both padded and unpadded sequence data. - Added comprehensive gradient verification to the test suite (`test_mha_varlen.py`) to ensure the correctness of the backward pass with various padding scenarios.
1 parent 9e99dc8 commit 7486448

File tree

15 files changed

+547
-250
lines changed

15 files changed

+547
-250
lines changed

3rdparty/composable_kernel

aiter/ops/mha.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,8 @@ def cmdGenFunc_mha_varlen_bwd(
846846
alibi_slopes: Optional[Tensor] = None,
847847
rng_state: Optional[Tensor] = None,
848848
gen: Optional[Generator] = None,
849+
cu_seqlens_q_padded: Optional[Tensor] = None,
850+
cu_seqlens_k_padded: Optional[Tensor] = None,
849851
) -> dict[str, Any]:
850852
md_name = "mha_varlen_bwd"
851853
filter1 = "*" # get_bwd_dot_do_o_blobs()
@@ -1007,6 +1009,8 @@ def mha_varlen_bwd(
10071009
alibi_slopes: Optional[Tensor] = None,
10081010
rng_state: Optional[Tensor] = None,
10091011
gen: Optional[Generator] = None,
1012+
cu_seqlens_q_padded: Optional[Tensor] = None,
1013+
cu_seqlens_k_padded: Optional[Tensor] = None,
10101014
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...
10111015

10121016

@@ -1036,6 +1040,8 @@ def gen_fmha_v3_varlen_bwd_fake_tensor(
10361040
alibi_slopes: Optional[Tensor] = None,
10371041
rng_state: Optional[Tensor] = None,
10381042
gen: Optional[Generator] = None,
1043+
cu_seqlens_q_padded: Optional[Tensor] = None,
1044+
cu_seqlens_k_padded: Optional[Tensor] = None,
10391045
):
10401046
return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv)
10411047

@@ -1054,8 +1060,6 @@ def fmha_v3_varlen_bwd(
10541060
softmax_lse: Tensor,
10551061
cu_seqlens_q: Tensor,
10561062
cu_seqlens_k: Tensor,
1057-
# cu_seqlens_q_padded: Tensor,
1058-
# cu_seqlens_k_padded: Tensor,
10591063
max_seqlen_q: int,
10601064
max_seqlen_k: int,
10611065
dropout_p: float,
@@ -1073,6 +1077,8 @@ def fmha_v3_varlen_bwd(
10731077
alibi_slopes: Optional[Tensor] = None,
10741078
rng_state: Optional[Tensor] = None,
10751079
gen: Optional[Generator] = None,
1080+
cu_seqlens_q_padded: Optional[Tensor] = None,
1081+
cu_seqlens_k_padded: Optional[Tensor] = None,
10761082
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...
10771083

10781084

@@ -1874,10 +1880,6 @@ def _flash_attn_varlen_backward(
18741880
dv: Optional[torch.Tensor],
18751881
cu_seqlens_q: torch.Tensor,
18761882
cu_seqlens_k: torch.Tensor,
1877-
# FIXME: this two args currently not support on ck side
1878-
# and has no host code on aiter side
1879-
# cu_seqlens_q_padded: Tensor,
1880-
# cu_seqlens_k_padded: Tensor,
18811883
max_seqlen_q: int,
18821884
max_seqlen_k: int,
18831885
dropout_p: float,
@@ -1891,6 +1893,8 @@ def _flash_attn_varlen_backward(
18911893
is_v3_atomic_fp32: Optional[bool] = True,
18921894
how_v3_bf16_cvt: Optional[int] = 1,
18931895
zero_tensors: bool = False,
1896+
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
1897+
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
18941898
) -> torch.Tensor:
18951899

18961900
(_, nhead_q, hdim_q) = q.shape
@@ -1999,8 +2003,6 @@ def can_impl_fmha_v3_bwd_gfx950():
19992003
softmax_lse,
20002004
cu_seqlens_q,
20012005
cu_seqlens_k,
2002-
# cu_seqlens_q_padded,
2003-
# cu_seqlens_k_padded,
20042006
max_seqlen_q,
20052007
max_seqlen_k,
20062008
dropout_p,
@@ -2018,6 +2020,8 @@ def can_impl_fmha_v3_bwd_gfx950():
20182020
alibi_slopes,
20192021
rng_state,
20202022
None,
2023+
cu_seqlens_q_padded,
2024+
cu_seqlens_k_padded,
20212025
)
20222026
else:
20232027
(
@@ -2049,6 +2053,8 @@ def can_impl_fmha_v3_bwd_gfx950():
20492053
alibi_slopes,
20502054
rng_state,
20512055
None,
2056+
cu_seqlens_q_padded,
2057+
cu_seqlens_k_padded,
20522058
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
20532059
)
20542060
return softmax_d
@@ -2138,6 +2144,8 @@ def forward(
21382144
ctx.head_size_q_og = head_size_q_og
21392145
ctx.is_v3_atomic_fp32 = is_v3_atomic_fp32
21402146
ctx.how_v3_bf16_cvt = how_v3_bf16_cvt
2147+
ctx.cu_seqlens_q_padded = cu_seqlens_q_padded
2148+
ctx.cu_seqlens_k_padded = cu_seqlens_k_padded
21412149

21422150
out = out_padded[..., :head_size_v_og]
21432151

@@ -2194,6 +2202,8 @@ def backward(ctx, dout, *args):
21942202
rng_state=rng_state,
21952203
is_v3_atomic_fp32=ctx.is_v3_atomic_fp32,
21962204
how_v3_bf16_cvt=ctx.how_v3_bf16_cvt,
2205+
cu_seqlens_q_padded=ctx.cu_seqlens_q_padded,
2206+
cu_seqlens_k_padded=ctx.cu_seqlens_k_padded,
21972207
)
21982208
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
21992209
dk = dk[..., :head_size_q_og]

csrc/include/mha_bwd.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ __attribute__((visibility("default"))) float mha_bwd(mha_bwd_args args,
6161
int how_v3_bf16_cvt,
6262
const void* seqlen_q_padded = nullptr,
6363
const void* seqlen_k_padded = nullptr,
64-
bool is_v3_api_check = false);
64+
bool is_v3_api_check = false);
6565

6666
struct __attribute__((packed)) fmha_bwd_v3_args
6767
{
@@ -364,9 +364,9 @@ struct __attribute__((packed)) fmha_bwd_dq_shuffle_args
364364
p3 _p9;
365365
unsigned int head_dim;
366366
p3 _p10;
367-
const void *ptr_qseq;
367+
const void* ptr_qseq;
368368
p2 _p11;
369-
const void *ptr_qseq_padded;
369+
const void* ptr_qseq_padded;
370370
p2 _p12;
371371
unsigned int max_seqlen_dq;
372372
p3 _p13;
@@ -418,17 +418,17 @@ namespace gfx942 {
418418
float fmha_bwd_v3(mha_bwd_traits t,
419419
mha_bwd_args a,
420420
const ck_tile::stream_config& s,
421-
const void* seqlen_q_padded = nullptr,
422-
const void* seqlen_k_padded = nullptr,
423-
bool is_v3_api_check = false);
421+
const void* seqlen_q_unpadded = nullptr,
422+
const void* seqlen_k_unpadded = nullptr,
423+
bool is_v3_api_check = false);
424424
}
425425

426426
namespace gfx950 {
427427
float fmha_bwd_v3(mha_bwd_traits t,
428428
mha_bwd_args a,
429429
const ck_tile::stream_config& s,
430-
const void* seqlen_q_padded = nullptr,
431-
const void* seqlen_k_padded = nullptr,
432-
bool is_v3_api_check = false);
430+
const void* seqlen_q_unpadded = nullptr,
431+
const void* seqlen_k_unpadded = nullptr,
432+
bool is_v3_api_check = false);
433433
}
434434
} // namespace aiter

csrc/include/rocm_ops.hpp

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -494,34 +494,36 @@
494494
py::arg("rng_state") = std::nullopt, \
495495
py::arg("gen") = std::nullopt);
496496

497-
#define MHA_VARLEN_BWD_ASM_PYBIND \
498-
m.def("fmha_v3_varlen_bwd", \
499-
&aiter::torch_itfs::fmha_v3_varlen_bwd, \
500-
py::arg("dout"), \
501-
py::arg("q"), \
502-
py::arg("k"), \
503-
py::arg("v"), \
504-
py::arg("out"), \
505-
py::arg("softmax_lse"), \
506-
py::arg("cu_seqlens_q"), \
507-
py::arg("cu_seqlens_k"), \
508-
py::arg("max_seqlen_q"), \
509-
py::arg("max_seqlen_k"), \
510-
py::arg("dropout_p"), \
511-
py::arg("softmax_scale"), \
512-
py::arg("zero_tensors"), \
513-
py::arg("is_causal"), \
514-
py::arg("window_size_left"), \
515-
py::arg("window_size_right"), \
516-
py::arg("deterministic"), \
517-
py::arg("is_v3_atomic_fp32"), \
518-
py::arg("how_v3_bf16_cvt"), \
519-
py::arg("dq") = std::nullopt, \
520-
py::arg("dk") = std::nullopt, \
521-
py::arg("dv") = std::nullopt, \
522-
py::arg("alibi_slopes") = std::nullopt, \
523-
py::arg("rng_state") = std::nullopt, \
524-
py::arg("gen") = std::nullopt);
497+
#define MHA_VARLEN_BWD_ASM_PYBIND \
498+
m.def("fmha_v3_varlen_bwd", \
499+
&aiter::torch_itfs::fmha_v3_varlen_bwd, \
500+
py::arg("dout"), \
501+
py::arg("q"), \
502+
py::arg("k"), \
503+
py::arg("v"), \
504+
py::arg("out"), \
505+
py::arg("softmax_lse"), \
506+
py::arg("cu_seqlens_q"), \
507+
py::arg("cu_seqlens_k"), \
508+
py::arg("max_seqlen_q"), \
509+
py::arg("max_seqlen_k"), \
510+
py::arg("dropout_p"), \
511+
py::arg("softmax_scale"), \
512+
py::arg("zero_tensors"), \
513+
py::arg("is_causal"), \
514+
py::arg("window_size_left"), \
515+
py::arg("window_size_right"), \
516+
py::arg("deterministic"), \
517+
py::arg("is_v3_atomic_fp32"), \
518+
py::arg("how_v3_bf16_cvt"), \
519+
py::arg("dq") = std::nullopt, \
520+
py::arg("dk") = std::nullopt, \
521+
py::arg("dv") = std::nullopt, \
522+
py::arg("alibi_slopes") = std::nullopt, \
523+
py::arg("rng_state") = std::nullopt, \
524+
py::arg("gen") = std::nullopt, \
525+
py::arg("cu_seqlens_q_padded") = std::nullopt, \
526+
py::arg("cu_seqlens_k_padded") = std::nullopt);
525527

526528
#define MHA_BWD_PYBIND \
527529
m.def("mha_bwd", \
@@ -612,32 +614,34 @@
612614
py::arg("alibi_slopes") = std::nullopt, \
613615
py::arg("gen") = std::nullopt);
614616

615-
#define MHA_VARLEN_BWD_PYBIND \
616-
m.def("mha_varlen_bwd", \
617-
&aiter::torch_itfs::mha_varlen_bwd, \
618-
py::arg("dout"), \
619-
py::arg("q"), \
620-
py::arg("k"), \
621-
py::arg("v"), \
622-
py::arg("out"), \
623-
py::arg("softmax_lse"), \
624-
py::arg("cu_seqlens_q"), \
625-
py::arg("cu_seqlens_k"), \
626-
py::arg("max_seqlen_q"), \
627-
py::arg("max_seqlen_k"), \
628-
py::arg("dropout_p"), \
629-
py::arg("softmax_scale"), \
630-
py::arg("zero_tensors"), \
631-
py::arg("is_causal"), \
632-
py::arg("window_size_left"), \
633-
py::arg("window_size_right"), \
634-
py::arg("deterministic"), \
635-
py::arg("dq") = std::nullopt, \
636-
py::arg("dk") = std::nullopt, \
637-
py::arg("dv") = std::nullopt, \
638-
py::arg("alibi_slopes") = std::nullopt, \
639-
py::arg("rng_state") = std::nullopt, \
640-
py::arg("gen") = std::nullopt);
617+
#define MHA_VARLEN_BWD_PYBIND \
618+
m.def("mha_varlen_bwd", \
619+
&aiter::torch_itfs::mha_varlen_bwd, \
620+
py::arg("dout"), \
621+
py::arg("q"), \
622+
py::arg("k"), \
623+
py::arg("v"), \
624+
py::arg("out"), \
625+
py::arg("softmax_lse"), \
626+
py::arg("cu_seqlens_q"), \
627+
py::arg("cu_seqlens_k"), \
628+
py::arg("max_seqlen_q"), \
629+
py::arg("max_seqlen_k"), \
630+
py::arg("dropout_p"), \
631+
py::arg("softmax_scale"), \
632+
py::arg("zero_tensors"), \
633+
py::arg("is_causal"), \
634+
py::arg("window_size_left"), \
635+
py::arg("window_size_right"), \
636+
py::arg("deterministic"), \
637+
py::arg("dq") = std::nullopt, \
638+
py::arg("dk") = std::nullopt, \
639+
py::arg("dv") = std::nullopt, \
640+
py::arg("alibi_slopes") = std::nullopt, \
641+
py::arg("rng_state") = std::nullopt, \
642+
py::arg("gen") = std::nullopt, \
643+
py::arg("cu_seqlens_q_padded") = std::nullopt, \
644+
py::arg("cu_seqlens_k_padded") = std::nullopt);
641645

642646
#define MOE_CK_2STAGES_PYBIND \
643647
m.def("ck_moe_stage1", \

csrc/include/torch/mha_v3_varlen_bwd.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ fmha_v3_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d_v]
1414
const at::Tensor& softmax_lse, // [b, hq, sq]
1515
const at::Tensor& cu_seqlens_q, // [b+1]
1616
const at::Tensor& cu_seqlens_k, // [b+1]
17-
// FIXME: this two args currently not support on ck side
18-
// and has no host code on aiter side
19-
// const at::Tensor& cu_seqlens_q_padded, // [b+1]
20-
// const at::Tensor& cu_seqlens_k_padded, // [b+1]
2117
const int max_seqlen_q,
2218
const int max_seqlen_k,
2319
const float p_dropout,
@@ -34,7 +30,9 @@ fmha_v3_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d_v]
3430
std::optional<at::Tensor> dv_, // [total_k, hk, d_v]
3531
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
3632
std::optional<const at::Tensor> rng_state_,
37-
std::optional<at::Generator> gen_);
33+
std::optional<at::Generator> gen_,
34+
std::optional<const at::Tensor> cu_seqlens_q_padded = std::nullopt,
35+
std::optional<const at::Tensor> cu_seqlens_k_padded = std::nullopt);
3836

3937
} // namespace torch_itfs
4038
} // namespace aiter

csrc/include/torch/mha_varlen_bwd.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22
// SPDX-License-Identifier: MIT
3-
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
44
#include <torch/extension.h>
55

66
namespace aiter {
@@ -28,6 +28,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d]
2828
std::optional<at::Tensor> dv, // [total_k, hk, d]
2929
std::optional<const at::Tensor> alibi_slopes, // [hq] or [b, hq]
3030
std::optional<const at::Tensor> rng_state,
31-
std::optional<at::Generator> gen);
31+
std::optional<at::Generator> gen,
32+
std::optional<const at::Tensor> cu_seqlens_q_padded, // [b+1]
33+
std::optional<const at::Tensor> cu_seqlens_k_padded // [b+1]
34+
);
3235
} // namespace torch_itfs
3336
} // namespace aiter

csrc/py_itfs_ck/mha_bwd_kernels.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
148148
nullptr, // seqstart_q
149149
nullptr, // seqstart_k
150150
nullptr, // seqlen_k_ptr
151+
nullptr, // seqlen_q_ptr
151152
seqlen_q,
152153
seqlen_k,
153154
b,

0 commit comments

Comments
 (0)