-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend to enable Llama-4 #19904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @tjtanaa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request implements a crucial bug fix for the AITER Flash Attention backend on ROCm, specifically to enable proper support for Llama-4 models. The changes ensure that the appropriate attention backend is utilized for ROCm environments and introduce necessary logic to handle advanced attention patterns, such as KV cache sharing, which are characteristic of models like Llama-4. This significantly enhances vLLM's compatibility and performance on AMD hardware for modern large language models.
Highlights
- ROCm Attention Backend Enforcement: Explicitly forces the
TORCH_SDPA
attention backend when running on ROCm platforms. This ensures compatibility and stability for AMD GPUs, overriding the default Flash Attention/XFormers logic. - KV Cache Sharing Support: Introduces a new
kv_sharing_target_layer_name
parameter to the ROCm AITER Flash Attention backend. This enables conditional skipping of KV cache updates, supporting models like Llama-4 that might share KV cache across different attention layers (e.g., global and local attention). - Type Coercion for Sequence Lengths: Adds explicit integer casting (
int()
) for the results ofmax()
operations when calculatinglocal_max_query_len
andlocal_max_seq_len
. This addresses potential type mismatches or implicit conversion issues in the ROCm AITER Flash Attention backend.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces targeted fixes to the AITER Flash Attention backend for ROCm, aimed at enabling support for Llama-4 models. The changes include:
- Forcing
MultiHeadAttention
to use theTORCH_SDPA
backend on ROCm for better compatibility. - Explicitly casting NumPy scalar types to Python
int
for sequence lengths in theAiterFlashAttentionMetadataBuilder
, enhancing type safety. - Adding support for KV cache sharing in
AiterFlashAttentionImpl
by introducing akv_sharing_target_layer_name
parameter and conditionally skipping KV cache writes.
The changes appear correct, well-justified, and align with the linked GitHub issue. The provided test results also indicate successful application of these fixes.
if self.kv_sharing_target_layer_name is None: | ||
# Reshape the input keys and values and store them in the cache. | ||
# Skip this if sharing KV cache with an earlier attention layer. | ||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is | ||
# not padded. However, we don't need to do key[:num_actual_tokens] | ||
# and value[:num_actual_tokens] because the reshape_and_cache_flash | ||
# op uses the slot_mapping's shape to determine the number of | ||
# actual tokens. | ||
torch.ops._C_cache_ops.reshape_and_cache_flash( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
self.kv_cache_dtype, | ||
layer._k_scale, | ||
layer._v_scale, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The introduction of the kv_sharing_target_layer_name
parameter (initialized in the __init__
method at line 390 and stored at line 412) and its use here to conditionally call torch.ops._C_cache_ops.reshape_and_cache_flash
is a clean implementation for KV cache sharing. By skipping the cache update when kv_sharing_target_layer_name
is set, it correctly avoids redundant writes and ensures that the shared KV cache from a target layer is utilized. The added comment # Skip this if sharing KV cache with an earlier attention layer.
effectively clarifies the logic. This change is crucial for models leveraging KV cache sharing.
if current_platform.is_rocm(): | ||
# currently, only torch_sdpa is supported on rocm | ||
self.attn_backend = _Backend.TORCH_SDPA | ||
else: | ||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: | ||
backend = _Backend.XFORMERS | ||
|
||
self.attn_backend = backend if backend in { | ||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 | ||
} else _Backend.TORCH_SDPA | ||
self.attn_backend = backend if backend in { | ||
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 | ||
} else _Backend.TORCH_SDPA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The modification to set self.attn_backend = _Backend.TORCH_SDPA
for MultiHeadAttention
when running on ROCm is a clear and sensible platform-specific adjustment. Given the comment # currently, only torch_sdpa is supported on rocm
, this change directly addresses compatibility or support limitations on ROCm, ensuring that a known-working backend is utilized.
local_max_query_len = int(seqlens_q_local_np.max()) | ||
local_max_seq_len = int(virt_k_seqlens_np.max()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly casting the results of seqlens_q_local_np.max()
and virt_k_seqlens_np.max()
to int
is a good practice. It ensures that local_max_query_len
and local_max_seq_len
are Python integers, which can prevent potential type mismatches with downstream operations or library calls that expect standard integer types rather than NumPy scalar types (e.g., numpy.int64
). This enhances type safety and robustness.
thanks @tjtanaa for the fix. |
@hongxiayang Thank you for the feedbacks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering why ChartQA score decreased?
Updates:@houseroad I think it is just a noise. I have rerun the evaluation to obtain another data point. ChartQA (Vision Language Benchmark)
|
Signed-off-by: tjtanaa <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
This is to fix this issue #19867
PR #18212 introduced cross-layer kvcache
This PR #18212 (which introduced cross-layer kvcache) has introduced a new argument to
AttentionImpl
init function (kv_sharing_target_layer_name: Optional[str] = None,
)Causes
There are two bugs:
irope error is due to missing argument
kv_sharing_target_layer_name
introduced in [V1] Support cross-layer KV sharing #18212 .xformers error is due to introducing AiterFlashAttentionImpl backend as
FLASH_ATTN_VLLM_V1
in ([Hardware][AMD] integrate aiter chunked prefill into vllm #18596), which previously,FLASH_ATTN_VLLM_V1
is only available on CUDA, and CUDA supportsxformers
.When AITER Flash Attention backed is introduced, the condition is not fixed for ROCm platform.
Before AITER Flash Attention backend introduced, the backend is named
TRITON_ATTN_VLLM_V1
which will cause theMHA
attention to useTorch_SPDA
(so there is not issue running Llama4 on ROCm).Test Plan
Perform lm_eval (gsm8k) on two models Llama-3.1-70B (global attention) and Llama-4-Scout BF16 model (global + local attention and irope).
example command:
evaluate the correctness of the vision part of the Llama4 using this tool https://github.com/[mistralai/mistral-evals](https://github.com/mistralai/mistral-evals)
Test Result
NOTE: The test results by no means refer to the full performance of the model. The values are used to gauge the end-to-end implementation correctness by comparing with No AITER as baseline for accuracy.
ChartQA (Vision Language Benchmark)
Visual benchmark
mistral-eval: MM-MT-Bench dataset (with AITER)
mmmu (With AITER)
Updates
Fixed in commit 2dda1d9
Error
Click to view
Reproduce Steps
Click to view
Server
Client
Longbench
Click to expand
vllm (pretrained=RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic,tensor_parallel_size=8,max_model_len=100000,max_num_seqs=512,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128
(Optional) Documentation Update