Skip to content

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Oct 28, 2025

Description

Feature update PR which includes several iterative changes for client-driven optimization targets. This PR includes both API changes for CK/AITER as well as changes in internal integration. See the list of changes for specifics.

Note that this will not be ready for merger until ROCm/aiter#1212 is merged in and this PR's AITER commit is updated.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Integrated support for native padding kernels in fwd/bwd
  • Added BSHD + Padding --> THD + Padding conversion mechanism
  • Streamlined memory allocation logic
  • Added runtime max_seqlen calculation gated by new env var NVTE_CK_RUNTIME_MAX_SEQLEN
  • Adds v3_api_check support (temporary)
  • Implements new AITER/CK API
  • Update MQA post-processing kernels
  • Remove pad_between_seqs (need to follow-up with a PR cleaning up test suite for old pad_between_seqs edge-cases)
  • Added NVTE_CK_RUNTIME_NUM_SEGMENTS to guard runtime-calculation of the number of segments in the JAX integration

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, I think we can try to remove all memset except for dq, dq_acc. We can confirm with aiter/ck people

@wangye805
Copy link
Collaborator

Let's also add how to use the runtime segment/max seqlen in readme under https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#fused-attention-backends-on-rocm. Remind our customers that this will break the cudagraph

@Micky774
Copy link
Contributor Author

Let's also add how to use the runtime segment/max seqlen in readme under https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#fused-attention-backends-on-rocm. Remind our customers that this will break the cudagraph

@wangye805 I've now updated the readme, but let me know if you have specific thoughts on it.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at several unresolved conversation previously

- Updated debug message for BSHD-->THD conversion
- Added env variable to gate FWD output memset for padding
- Removed guards on memsets for d{Q,K,V} matrices
@wenchenvincent
Copy link
Collaborator

@Micky774 Could you rebase/merge latest dev to incorporate the hot fixes for sgpu tests?

@wangye805
Copy link
Collaborator

pytorch test_numerics also shows some fused-attn related failures:
FAILED tests/pytorch/test_numerics.py::test_kv_cache_accuracy[False-FusedAttention-TransformerLayer-sbhd-False-126m-1-dtype1] - AssertionError: Outputs not close enough in tensor at idx=0. Maximum difference at location [0, 650] with -0.90625 vs 0.5654296875 (diff 1.4716796875).

Not sure whether this is related to our decision to remove memsettings.

@Micky774
Copy link
Contributor Author

pytorch test_numerics also shows some fused-attn related failures: FAILED tests/pytorch/test_numerics.py::test_kv_cache_accuracy[False-FusedAttention-TransformerLayer-sbhd-False-126m-1-dtype1] - AssertionError: Outputs not close enough in tensor at idx=0. Maximum difference at location [0, 650] with -0.90625 vs 0.5654296875 (diff 1.4716796875).

Not sure whether this is related to our decision to remove memsettings.

Those failures were due to a mix of not correctly dispatching to the is_SBHD workflow when dealing with SBHD_2BSHD formats, and miscalculating stride in the case of the same format. Resolved now.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those newly added hybrid qkv formats in upstream (NVTE_SBHD_2BSHD, NVTE_BSHD_2SBHD, NVTE_THD_2BSHD, and NVTE_THD_2SBHD): in addition to the SBHD_2BSHD pytest failures, are we able to correctly handle all other 3? Or is there only SBHD_2BSHD pytests now?

NV upstream is separating format and is_ragged on q/kv and do subsequent processings accordingly:

NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);

Maybe we can try similar technique. If I recall correctly, we need padding/unpadding for just q in SBHD_2BSHD and for just k/v in BSHD_2SBHD.

Or it's okay if you want to leave this for another PR.

By the way, there is an "extra line" comment you may have ignored :-)

@wangye805
Copy link
Collaborator

In fact, I saw some level 3 pytorch cp pytest failures by run level 3 ci locally:

=========================== short test summary info ============================
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_1_0-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_1_0', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_1_1-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_1_1', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_2_0-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_2_0', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
FAILED tests/pytorch/fused_attn/test_fused_attn_with_cp.py::test_cp_with_fused_attention[False-p2p-thd-cp_2_1-bf16] - subprocess.CalledProcessError: Command '['python3', '-m', 'torch.distributed.launch', '--nproc-per-node=2', '/workspace/te_native_bshd_thd/tests/pytorch/fused_attn/run_fused_attn_with_cp.py', 'dtype=bf16', 'model=cp_2_1', 'qkv_format=thd', 'kernel_backend=FusedAttention', 'cp_comm_type=p2p', 'fp8_mha=False']' returned non-zero exit status 1.
SKIPPED [48] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:68: CP implementation with KV P2P does not support sliding window yet!
SKIPPED [16] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:70: CP implementation with KV all-gather does not support THD format yet!
SKIPPED [24] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:74: CP implementation with QKVO A2A does not support THD format yet!
SKIPPED [240] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:133: FP8 attention has not been supported on ROCm yet!
SKIPPED [40] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:153: CP implementation with KV P2P does not support sliding window yet!
SKIPPED [64] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:137: THD format does not support post_scale_bias yet!
SKIPPED [32] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:155: CP implementation with KV all-gather does not support bias yet!
SKIPPED [24] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:139: CP implementation with KV all-gather does not support THD format yet!
SKIPPED [64] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:157: CP implementation with QKVO A2A does not support bias yet!
SKIPPED [48] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:141: CP implementation with QKVO A2A does not support THD format yet!
SKIPPED [104] tests/pytorch/fused_attn/test_fused_attn_with_cp.py:164: Only fp8 works with fp8_mha=True!
===== 4 failed, 204 passed, 704 skipped, 2 warnings in 3065.08s (0:51:05) ======
Error in test [ck] fused_attn/test_fused_attn_with_cp.py
Done [ck] fused_attn/test_fused_attn_with_cp.py
Got 1 test errors during run at level 3

Attached you can find the detailed log
torch_mgpu.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants