-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[Core] Support Local Chunked Attention for Hybrid KV Cache #19351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @luccafong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here to provide a summary of this pull request. Based on the title [wip] support local chunked attention for hybrid kv cache
and the code changes, this PR introduces support for a new KV cache management strategy: chunked local attention. This is intended to work within vLLM's hybrid KV cache framework. Since the description is a placeholder, the intent is primarily derived from the title and the implementation details in the code.
Highlights
- New KV Cache Specification: A new dataclass,
ChunkedLocalAttentionSpec
, is introduced invllm/v1/kv_cache_interface.py
to define the configuration for chunked local attention KV caches, including theattention_chunk_size
. - New KV Cache Manager: A new manager class,
ChunkedLocalAttentionManager
, is added invllm/v1/core/single_type_kv_cache_manager.py
. This manager implements the logic for finding cache hits and removing skipped blocks specifically for chunked local attention, handling the non-contiguous nature of the relevant KV cache blocks. - Hybrid KV Cache Integration: The KV cache utility functions in
vllm/v1/core/kv_cache_utils.py
are updated to recognizeChunkedLocalAttentionSpec
when determining hybrid configurations and unifying specs, ensuring compatibility with existing full attention and sliding window mechanisms. - Model Runner Updates: The
GpuModelRunner
invllm/v1/worker/gpu_model_runner.py
is modified to select theChunkedLocalAttentionSpec
for a layer ifattention_chunk_size
is configured and the layer uses iRoPE (rotary positional embeddings). - New Tests: Dedicated tests for the
ChunkedLocalAttentionManager
are added intests/v1/core/test_specialized_manager.py
to verify thefind_longest_cache_hit
andremove_skipped_blocks
logic for the new attention type.
Changelog
Click here to see the changelog
- tests/v1/core/test_specialized_manager.py
- Imported
random
. - Imported
ChunkedLocalAttentionManager
andChunkedLocalAttentionSpec
. - Added
get_chunked_local_attention_manager
helper function. - Added
test_chunked_local_attention_possible_cached_prefix
to test cache hit finding for chunked local attention. - Added
test_chunked_local_attention_remove_skipped_blocks
to test block removal logic. - Renamed
test_get_num_blocks_to_allocate
totest_sliding_window_get_num_blocks_to_allocate
. - Added
test_chunked_local_attention_get_num_blocks_to_allocate
.
- Imported
- vllm/attention/layer.py
- Added
use_irope
attribute to theAttention
class, initialized fromextra_impl_args
.
- Added
- vllm/executor/executor_base.py
- Added a debug print statement showing
block_size
andnum_gpu_blocks
during cache initialization.
- Added a debug print statement showing
- vllm/v1/core/kv_cache_utils.py
- Imported
ChunkedLocalAttentionSpec
. - Added a debug print statement in
_get_kv_cache_config_uniform_page_size
. - Updated
is_hybrid
andunify_hybrid_kv_cache_specs
functions to includeChunkedLocalAttentionSpec
in their logic for handling hybrid configurations.
- Imported
- vllm/v1/core/single_type_kv_cache_manager.py
- Imported
ChunkedLocalAttentionSpec
. - Added the
ChunkedLocalAttentionManager
class, inheriting fromSingleTypeKVCacheManager
. - Implemented
find_longest_cache_hit
andremove_skipped_blocks
methods withinChunkedLocalAttentionManager
. - Added debug print statements within
ChunkedLocalAttentionManager
methods. - Added
ChunkedLocalAttentionManager
to thespec_manager_map
.
- Imported
- vllm/v1/kv_cache_interface.py
- Defined the
ChunkedLocalAttentionSpec
dataclass withattention_chunk_size
. - Implemented
type_id
,max_memory_usage_bytes
, andmerge
methods forChunkedLocalAttentionSpec
.
- Defined the
- vllm/v1/worker/gpu_model_runner.py
- Imported
ChunkedLocalAttentionSpec
. - Modified
get_kv_cache_spec
to create aChunkedLocalAttentionSpec
for layers that haveattention_chunk_size
configured and use iRoPE.
- Imported
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request lays the groundwork for supporting local chunked attention in the hybrid KV cache. The introduction of ChunkedLocalAttentionManager
and ChunkedLocalAttentionSpec
, along with initial tests, is a good step forward.
However, as this is marked [wip]
, there are a few areas to address:
- Pull Request Description: The PR description is currently a template. Please fill it out with the purpose of these changes, a test plan (even if preliminary for WIP), and any expected outcomes or known limitations. This context is crucial for reviewers.
- Debugging Code: Several
print
statements used for debugging are present in the codebase. These should be removed before merging. - Clarity and TODOs: There are a few comments and TODOs that need attention or clarification.
Overall, the direction seems good, and the core logic for the new manager and spec is taking shape. Addressing the points below will help improve the clarity and readiness of this PR.
Summary of Findings
- PR Description: The pull request description is currently a template and needs to be filled out with details about the purpose, test plan, and results of these changes. This is especially important for a work-in-progress PR.
- Debugging Code: Several
print
statements, likely used for debugging, are present in the codebase (e.g., inexecutor_base.py
,kv_cache_utils.py
,single_type_kv_cache_manager.py
,kv_cache_interface.py
). These should be removed before merging. - Clarity and Completeness: Some comments, like the one in
ChunkedLocalAttentionSpec.max_memory_usage_bytes
, could be clarified. Additionally, theget_num_common_prefix_blocks
method in the new manager (and existingSlidingWindowManager
) has a known simplification related to cascade attention that should be tracked. - TODOs: A TODO comment exists in
kv_cache_utils.py
regarding making the hybrid spec unification more generic. This should ideally be tracked with a follow-up issue if it's a larger task.
Merge Readiness
This pull request is a work-in-progress and introduces significant new functionality for chunked local attention. Before it can be considered for merging, the PR description needs to be completed, all debugging print statements must be removed, and the identified points for clarification should be addressed. Given the WIP nature and these outstanding items, I recommend further changes before this PR is merged. I am unable to approve the pull request, and it should be reviewed by others before merging.
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Please don't forget to fill the task description :-) |
e8c8c6e
to
50281ac
Compare
This pull request has merge conflicts that must be resolved before it can be |
50281ac
to
b5157ad
Compare
bd29016
to
d8b297c
Compare
d8b297c
to
3a04e9a
Compare
Signed-off-by: Lucia Fang <[email protected]>
3a04e9a
to
7913145
Compare
Signed-off-by: Lucia Fang <[email protected]>
vllm/v1/core/kv_cache_utils.py
Outdated
@@ -845,6 +846,7 @@ def _get_kv_cache_config_uniform_page_size( | |||
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 | |||
# full.1, sw.1: share another Tensor with size=available_memory//2 | |||
page_size = get_uniform_page_size(kv_cache_spec) | |||
# print(f"{page_size=}, {group_size=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# print(f"{page_size=}, {group_size=}") |
@@ -715,6 +715,7 @@ def use_cascade_attention( | |||
num_kv_heads: int, | |||
use_alibi: bool, | |||
use_sliding_window: bool, | |||
use_local_attention: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: What is this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_local_attention does not support cascade attention as noted in
vllm/vllm/v1/attention/backends/flash_attn.py
Lines 693 to 694 in c3fec47
assert not use_local_attn, ( | |
"Cascade attention does not support local attention.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this line in this function?
if use_local_attention: return False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left it out when merging, will add back
max_num_blocks = max_length // kv_cache_spec.block_size | ||
if max_length > 0: | ||
local_attention_start_idx = \ | ||
(max_length-1) // kv_cache_spec.attention_chunk_size \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why -1 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need the actual index instead of the length here to calculate the actual attending window,
e.g. given a max length of 128, and chunk size = 64, the context chunked as [0, 63] and [64, 127], the 127th should attend with window [64, 127], where start idx=64 = (127// 64 * 64) instead of 2 * 64=128.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the 1024-th token needs kv cache of token [0-1023] if attn_chunk_size is 1024? I think most of my question comes from this problem.
[block_pool.null_block] * local_attention_start_block_idx | ||
for _ in range(len(kv_cache_group_ids))) | ||
|
||
for i in range(local_attention_start_block_idx, max_num_blocks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the rule of cache hit? For example, block_size 1 and chunk_size 2, what is the expect result of the following cases?
- [miss miss] [miss miss] [miss miss]. Should it be 0 or 6?
- [miss miss] [hit miss] [miss miss]. Should it be 3 or 6?
And please add some comment to describe the expect behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah
For current token, we check from the first block that contains the attention window for cache hit until it miss.
it mark computed blocks = previous unattended blocks + # of hit blocks, so even zero hit, it return the previous unattended blocks.
So in your questions here:
- it return 4, since last window missed
- still 4 since last window missed.
For case like [miss, miss][miss miss][hit miss] it return 5.
I will add more comments to explain.
super().__init__(kv_cache_spec, block_pool, **kwargs) | ||
self.attention_chunk_size = kv_cache_spec.attention_chunk_size | ||
self._null_block = block_pool.null_block | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert self.attention_chunk_size % block_size == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the logic should have covered case !=0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right.
local_attention_start_idx = ( | ||
num_computed_tokens - | ||
1) // self.attention_chunk_size * self.attention_chunk_size | ||
# 1024-> 0, 1025-> 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 1024 -> 0? Does the attention of the 1024-th token (the first token of the next chunk) need tokens 0-1023?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here num_computed_tokens = 1024, so it is indexed 1023, which the local attention start from 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you update the comment?
it does not need [0-1023] for 1024th token |
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
1f53790
to
320ab71
Compare
use_local_attention = isinstance(kv_cache_spec, | ||
ChunkedLocalAttentionSpec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_local_attention = isinstance(kv_cache_spec, | |
ChunkedLocalAttentionSpec) | |
use_local_attention = isinstance(kv_cache_spec, | |
ChunkedLocalAttentionSpec) or ((isinstance(kv_cache_spec, FullAttentionSpec) | |
and kv_cache_spec.attention_chunk_size is not None)) |
|
||
# 4 tokens are computed. no token is out of the local attention window. | ||
manager.remove_skipped_blocks("test", 4) | ||
assert_block_id(block_table, original_block_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test, as token 4 doesn't need the kv cache of tokens [0-3], why do you need to keep them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
token 4 (if 1 indexed) need kv cache of [0-4],
@@ -715,6 +715,7 @@ def use_cascade_attention( | |||
num_kv_heads: int, | |||
use_alibi: bool, | |||
use_sliding_window: bool, | |||
use_local_attention: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this line in this function?
if use_local_attention: return False
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Lu Fang <[email protected]>
19fbdee
to
f7b6961
Compare
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Lu Fang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great job. I think we have aligned on the expect behavior. Can you write some examples in find_longest_cache_hit
and remove_skipped_blocks
to help people understand it?
from collections.abc import Sequence | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import regex as re |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need this line?
@@ -385,12 +386,102 @@ def get_num_common_prefix_blocks(self, request_id: str, | |||
""" | |||
NOTE(Chen): The prefix blocks are null blocks for sliding window layers. | |||
So it's not correct to count ref_cnt like FullAttentionManager. Return | |||
0 here for correctness. Need to support cascade attention + sliding | |||
0 here for correctness. Need to support cascade attention + sliding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you revert?
break | ||
if use_eagle and computed_blocks[0]: | ||
for computed in computed_blocks: | ||
computed.pop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In eagle, we can't simply pop the last block.
For example, chunk size 2 and block size 1:
[miss, miss] [miss miss] -> cache_hit_length 4
if we remove the 3-th block (0-indexed), the cache_hit_length becomes 3, but [miss, miss] [miss] is not a valid cache hit prefix. I think we should return cache_hit_length 2 in this case.
# we marked blocks out of window as computed | ||
# with null blocks, and blocks inside window | ||
# based on cache lookup result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change the length of each line to ~80 characters? And should [430-432] be put before [425-429]?
# [ block 0, ..., block x(x_start<=first_attention_token), | ||
# block x+1, .., block N (N_end <=max_len), ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need this comment? what is x for?
local_attention_start_idx = ( | ||
num_computed_tokens - | ||
1) // self.attention_chunk_size * self.attention_chunk_size | ||
# 1024-> 0, 1025-> 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you update the comment?
) // self.attention_chunk_size * self.attention_chunk_size | ||
# 1024-> 0, 1025-> 1024 | ||
first_useful_block_idx = local_attention_start_idx // self.block_size | ||
# block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2 | |
# if block size = 128, 0 -> block 0, 1024 -> block 8, 372 -> block 2 |
# block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2 | ||
blocks = self.req_to_blocks[request_id] | ||
removed_blocks: list[KVCacheBlock] = [] | ||
blockids = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need this blocdids
?
return cdiv(num_tokens, self.block_size) * self.page_size_bytes | ||
|
||
@classmethod | ||
def merge(cls, specs: list[Self]) -> Self: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this function after updating type_id
Purpose
This PR follows #17996 to add Hybrid KV Cache support for local chunked attention for supporting models like llama4 maverick and scout.
Test Plan
unit test:
eval:
Test Result
unit tests:
eval
mmlu_pro
This PR
Baseline (trunk: 7e3e74c)
ruler niah_multikey_2
This PR
Baseline