Skip to content

Conversation

Jeff-Huang
Copy link
Contributor

Introduces support for padded sequence lengths in the backward pass of the variable-length flash attention (fmha_v3_varlen_bwd).

  • Updated Python and C++ function signatures to accept optional cu_seqlens_q_padded and cu_seqlens_k_padded arguments.
  • Modified the underlying CUDA kernels and code generation scripts to pass padding information via the new seqlen_q_ptr and seqlen_k_ptr fields in the CK fmha_bwd_args struct.
  • Modified the underlying kernels and code generation scripts to correctly handle pointers for both padded and unpadded sequence data.
  • Added comprehensive gradient verification to the test suite (test_mha_varlen.py) to ensure the correctness of the backward pass with various padding scenarios.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Introduces support for padded sequence lengths in the backward pass of the variable-length flash attention (fmha_v3_varlen_bwd).
- Updated Python and C++ function signatures to accept optional `cu_seqlens_q_padded` and `cu_seqlens_k_padded` arguments.
- Modified the underlying CUDA kernels and code generation scripts to pass padding information via the new `seqlen_q_ptr` and `seqlen_k_ptr` fields in
     the CK `fmha_bwd_args` struct.
- Modified the underlying kernels and code generation scripts to correctly handle pointers for both padded and unpadded sequence data.
- Added comprehensive gradient verification to the test suite (`test_mha_varlen.py`) to ensure the correctness of the backward pass with various
     padding scenarios.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant