Skip to content

[Bugfix][V1][ROCm] Fix AITER Flash Attention Backend to enable Llama-4 #19904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Jun 20, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This is to fix this issue #19867
PR #18212 introduced cross-layer kvcache
This PR #18212 (which introduced cross-layer kvcache) has introduced a new argument to AttentionImpl init function (kv_sharing_target_layer_name: Optional[str] = None,)

Causes

There are two bugs:

  1. irope error is due to missing argument kv_sharing_target_layer_name introduced in [V1] Support cross-layer KV sharing #18212 .

  2. xformers error is due to introducing AiterFlashAttentionImpl backend as FLASH_ATTN_VLLM_V1 in ([Hardware][AMD] integrate aiter chunked prefill into vllm #18596), which previously, FLASH_ATTN_VLLM_V1 is only available on CUDA, and CUDA supports xformers.

When AITER Flash Attention backed is introduced, the condition is not fixed for ROCm platform.

Before AITER Flash Attention backend introduced, the backend is named TRITON_ATTN_VLLM_V1 which will cause the MHA attention to use Torch_SPDA (so there is not issue running Llama4 on ROCm).

Test Plan

Perform lm_eval (gsm8k) on two models Llama-3.1-70B (global attention) and Llama-4-Scout BF16 model (global + local attention and irope).

example command:

VLLM_RPC_TIMEOUT=1800000 \
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MHA=1 \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_USE_TRITON_FLASH_ATTN=0 \
SAFETENSORS_FAST_GPU=1 \
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=4,max_model_len=10000 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

evaluate the correctness of the vision part of the Llama4 using this tool https://github.com/[mistralai/mistral-evals](https://github.com/mistralai/mistral-evals)

Test Result

NOTE: The test results by no means refer to the full performance of the model. The values are used to gauge the end-to-end implementation correctness by comparing with No AITER as baseline for accuracy.

Model AITER Flexible-Extract Score Flexible-Extract Stderr Strict-Match Score Strict-Match Stderr
Llama-3.1-70B-Instruct No 0.9204 ±0.0075 0.8779 ±0.0090
Llama-3.1-70B-Instruct Yes 0.9234 ±0.0073 0.8870 ±0.0087
Llama-4-Scout-17B-16E-Instruct No 0.9158 ±0.0076 0.8969 ±0.0084
Llama-4-Scout-17B-16E-Instruct Yes 0.9181 ±0.0076 0.8984 ±0.0083

ChartQA (Vision Language Benchmark)

AITER Explicit Prompt Relaxed Correctness Anywhere in Answer Relaxed Correctness
No 0.8720 0.8720
Yes 0.8696 0.8696
Visual benchmark

mistral-eval: MM-MT-Bench dataset (with AITER)

================================================================================
Metrics:
{
    "micro_average_score": 6.87603305785124,
    "macro_average_score": 6.957930702758288,
    "charts_average": 7.066666666666666,
    "1_average": 6.565217391304348,
    "0_average": 6.934782608695652,
    "2_average": 7.0,
    "diagrams_average": 6.833333333333333,
    "tables_average": 6.862068965517241,
    "3_average": 8.0,
    "pdf_pages_average": 6.904761904761905,
    "misc_average": 6.454545454545454
}
================================================================================

mmmu (With AITER)

Metrics:
{
    "explicit_prompt_relaxed_correctness": 0.67,
    "anywhere_in_answer_relaxed_correctness": 0.6822222222222222
}

Updates

Fixed in commit 2dda1d9

Error

Click to view
   
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527] WorkerProc hit an exception.
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527] Traceback (most recent call last):
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/executor/multiproc_executor.py", line 522, in worker_busy_loop
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     output = func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/worker/gpu_worker.py", line 303, in execute_model
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     output = self.model_runner.execute_model(scheduler_output,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/worker/gpu_model_runner.py", line 1350, in execute_model
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     model_output = self.model(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                    ^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/model_executor/models/mllama4.py", line 841, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self.language_model(input_ids, positions, intermediate_tensors,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/model_executor/models/llama.py", line 581, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     model_output = self.model(input_ids, positions, intermediate_tensors,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/compilation/decorators.py", line 246, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     model_output = self.forward(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/model_executor/models/llama.py", line 368, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     def forward(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return fn(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._wrapped_call(self, *args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 406, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     raise e
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 393, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "<eval_with_key>.98", line 638, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     submod_1 = self.submod_1(getitem, s0, getitem_1, getitem_2, getitem_3);  getitem = getitem_1 = getitem_2 = submod_1 = None
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._wrapped_call(self, *args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 406, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     raise e
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 393, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "<eval_with_key>.2", line 5, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     unified_attention_with_output = torch.ops.vllm.unified_attention_with_output(query_2, key_2, value, output_1, 'language_model.model.layers.0.self_attn.attn');  query_2 = key_2 = value = output_1 = unified_attention_with_output = None
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1158, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._op(*args, **(kwargs or {}))
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/attention/layer.py", line 451, in unified_attention_with_output
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     self.impl.forward(self,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/attention/backends/rocm_aiter_fa.py", line 545, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     torch.ops.vllm.flash_attn_varlen_func(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1158, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._op(*args, **(kwargs or {}))
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/attention/backends/rocm_aiter_fa.py", line 135, in flash_attn_varlen_func_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     output = aiter.flash_attn_varlen_func(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 1444, in flash_attn_varlen_func
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return FlashAttnVarlenFunc.apply(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 575, in apply
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return super().apply(*args, **kwargs)  # type: ignore[misc]
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 1243, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 959, in _flash_attn_varlen_forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     out, softmax_lse, S_dmask, rng_state = mha_varlen_fwd(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                                            ^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 607, in wrapper
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return op(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527] RuntimeError: cu_seqlens_k.value() must have shape (batch_size + 1)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527] Traceback (most recent call last):
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/executor/multiproc_executor.py", line 522, in worker_busy_loop
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     output = func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/worker/gpu_worker.py", line 303, in execute_model
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     output = self.model_runner.execute_model(scheduler_output,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return func(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/worker/gpu_model_runner.py", line 1350, in execute_model
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     model_output = self.model(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                    ^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/model_executor/models/mllama4.py", line 841, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self.language_model(input_ids, positions, intermediate_tensors,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/model_executor/models/llama.py", line 581, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     model_output = self.model(input_ids, positions, intermediate_tensors,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/compilation/decorators.py", line 246, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     model_output = self.forward(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/model_executor/models/llama.py", line 368, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     def forward(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return fn(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._wrapped_call(self, *args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 406, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     raise e
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 393, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "<eval_with_key>.98", line 638, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     submod_1 = self.submod_1(getitem, s0, getitem_1, getitem_2, getitem_3);  getitem = getitem_1 = getitem_2 = submod_1 = None
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._wrapped_call(self, *args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 406, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     raise e
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 393, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "<eval_with_key>.2", line 5, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     unified_attention_with_output = torch.ops.vllm.unified_attention_with_output(query_2, key_2, value, output_1, 'language_model.model.layers.0.self_attn.attn');  query_2 = key_2 = value = output_1 = unified_attention_with_output = None
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1158, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._op(*args, **(kwargs or {}))
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/attention/layer.py", line 451, in unified_attention_with_output
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     self.impl.forward(self,
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/attention/backends/rocm_aiter_fa.py", line 545, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     torch.ops.vllm.flash_attn_varlen_func(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1158, in __call__
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return self._op(*args, **(kwargs or {}))
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/app/benchmarkrc2/0624_rc2/vllm/v1/attention/backends/rocm_aiter_fa.py", line 135, in flash_attn_varlen_func_impl
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     output = aiter.flash_attn_varlen_func(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 1444, in flash_attn_varlen_func
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return FlashAttnVarlenFunc.apply(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 575, in apply
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return super().apply(*args, **kwargs)  # type: ignore[misc]
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 1243, in forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 959, in _flash_attn_varlen_forward
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     out, softmax_lse, S_dmask, rng_state = mha_varlen_fwd(
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]                                            ^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 607, in wrapper
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]     return op(*args, **kwargs)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527] RuntimeError: cu_seqlens_k.value() must have shape (batch_size + 1)
^[[1;36m(VllmWorker rank=0 pid=21649)^[[0;0m ERROR 06-22 11:05:45 [multiproc_executor.py:527] 

Reproduce Steps

Click to view

Server

#!/bin/bash
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
SAFETENSORS_FAST_GPU=1 \
vllm serve RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic \
    --disable-log-requests \
    -tp 8 \
    --max_num_batched_tokens 32768 \
    --max-num-seqs 1024 \
    --max-model-len 36000 \
> server_RedHatAI_Llama-4-Scout-17B-16E-Instruct-FP8-dynamic4.log 2>&1

Client

#!/bin/bash
for pair in "10000,500" "30000,100"; do
  input_len=$(echo $pair | cut -d',' -f1)
  output_len=$(echo $pair | cut -d',' -f2)
  
  python3 vllm/benchmarks/benchmark_serving.py \
    --host localhost \
    --model RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic \
    --dataset-name random \
    --ignore-eos \
    --num-prompts 640 \
    --request-rate 64 \
    --random-input-len $input_len \
    --random-output-len $output_len \
    --save-result \
    --result-dir results/llama-4 \
    --result-filename vllm-$input_len-$output_len.json \
    > vllm-$input_len-$output_len.log 2>&1
    
  sleep 30
done

Longbench

Click to expand

vllm (pretrained=RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic,tensor_parallel_size=8,max_model_len=100000,max_num_seqs=512,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128

Tasks Version Filter n-shot Metric Value Stderr
longbench_2wikimqa 3 none 0 qa_f1_score 0.4244 ± 0.0334
longbench_dureader 3 none 0 rouge_zh_score 0.2672 ± 0.0166
longbench_gov_report 3 none 0 rouge_score 0.3458 ± 0.0056
longbench_hotpotqa 3 none 0 qa_f1_score 0.5766 ± 0.0310
longbench_lcc 3 none 0 code_sim_score 0.1737 ± 0.0107
longbench_lsht 3 none 0 classification_score 0.4250 ± 0.0349
longbench_multi_news 3 none 0 rouge_score 0.2343 ± 0.0070
longbench_multifieldqa_en 3 none 0 qa_f1_score 0.4903 ± 0.0301
longbench_multifieldqa_zh 3 none 0 qa_f1_zh_score 0.3859 ± 0.0261
longbench_musique 3 none 0 qa_f1_score 0.3486 ± 0.0309
longbench_narrativeqa 3 none 0 qa_f1_score 0.2849 ± 0.0250
longbench_passage_count 3 none 0 count_score 0.0362 ± 0.0131
longbench_passage_retrieval_en 3 none 0 retrieval_score 0.6150 ± 0.0341
longbench_passage_retrieval_zh 3 none 0 retrieval_zh_score 0.9050 ± 0.0208
longbench_qasper 3 none 0 qa_f1_score 0.2071 ± 0.0250
longbench_qmsum 3 none 0 rouge_score 0.1591 ± 0.0093
longbench_repobench-p 3 none 0 code_sim_score 0.1897 ± 0.0094
longbench_samsum 3 none 0 rouge_score 0.3843 ± 0.0140
longbench_trec 3 none 0 classification_score 0.6496 ± 0.0318
longbench_triviaqa 3 none 0 qa_f1_score 0.9046 ± 0.0180
longbench_vcsum 3 none 0 rouge_zh_score 0.1524 ± 0.0037

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @tjtanaa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements a crucial bug fix for the AITER Flash Attention backend on ROCm, specifically to enable proper support for Llama-4 models. The changes ensure that the appropriate attention backend is utilized for ROCm environments and introduce necessary logic to handle advanced attention patterns, such as KV cache sharing, which are characteristic of models like Llama-4. This significantly enhances vLLM's compatibility and performance on AMD hardware for modern large language models.

Highlights

  • ROCm Attention Backend Enforcement: Explicitly forces the TORCH_SDPA attention backend when running on ROCm platforms. This ensures compatibility and stability for AMD GPUs, overriding the default Flash Attention/XFormers logic.
  • KV Cache Sharing Support: Introduces a new kv_sharing_target_layer_name parameter to the ROCm AITER Flash Attention backend. This enables conditional skipping of KV cache updates, supporting models like Llama-4 that might share KV cache across different attention layers (e.g., global and local attention).
  • Type Coercion for Sequence Lengths: Adds explicit integer casting (int()) for the results of max() operations when calculating local_max_query_len and local_max_seq_len. This addresses potential type mismatches or implicit conversion issues in the ROCm AITER Flash Attention backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Jun 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces targeted fixes to the AITER Flash Attention backend for ROCm, aimed at enabling support for Llama-4 models. The changes include:

  1. Forcing MultiHeadAttention to use the TORCH_SDPA backend on ROCm for better compatibility.
  2. Explicitly casting NumPy scalar types to Python int for sequence lengths in the AiterFlashAttentionMetadataBuilder, enhancing type safety.
  3. Adding support for KV cache sharing in AiterFlashAttentionImpl by introducing a kv_sharing_target_layer_name parameter and conditionally skipping KV cache writes.

The changes appear correct, well-justified, and align with the linked GitHub issue. The provided test results also indicate successful application of these fixes.

Comment on lines +484 to +501
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The introduction of the kv_sharing_target_layer_name parameter (initialized in the __init__ method at line 390 and stored at line 412) and its use here to conditionally call torch.ops._C_cache_ops.reshape_and_cache_flash is a clean implementation for KV cache sharing. By skipping the cache update when kv_sharing_target_layer_name is set, it correctly avoids redundant writes and ensures that the shared KV cache from a target layer is utilized. The added comment # Skip this if sharing KV cache with an earlier attention layer. effectively clarifies the logic. This change is crucial for models leveraging KV cache sharing.

Comment on lines +309 to +318
if current_platform.is_rocm():
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The modification to set self.attn_backend = _Backend.TORCH_SDPA for MultiHeadAttention when running on ROCm is a clear and sensible platform-specific adjustment. Given the comment # currently, only torch_sdpa is supported on rocm, this change directly addresses compatibility or support limitations on ROCm, ensuring that a known-working backend is utilized.

Comment on lines +246 to +247
local_max_query_len = int(seqlens_q_local_np.max())
local_max_seq_len = int(virt_k_seqlens_np.max())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly casting the results of seqlens_q_local_np.max() and virt_k_seqlens_np.max() to int is a good practice. It ensures that local_max_query_len and local_max_seq_len are Python integers, which can prevent potential type mismatches with downstream operations or library calls that expect standard integer types rather than NumPy scalar types (e.g., numpy.int64). This enhances type safety and robustness.

@hongxiayang
Copy link
Collaborator

thanks @tjtanaa for the fix.

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jun 20, 2025

@hongxiayang Thank you for the feedbacks.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why ChartQA score decreased?

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jun 20, 2025

Wondering why ChartQA score decreased?

Will further investigation, however the drop is within 1% (~0.2%).

Updates:

@houseroad I think it is just a noise. I have rerun the evaluation to obtain another data point.

ChartQA (Vision Language Benchmark)

AITER temperature Explicit Prompt Relaxed Correctness Anywhere in Answer Relaxed Correctness
No 0.0 0.8720 0.8720
Yes 0.0 0.8696 0.8696
Yes (New data point) 0.5 0.8756 0.8756

@mergify mergify bot added the llama Related to Llama models label Jun 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llama Related to Llama models rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants