Skip to content

Commit 320ab71

Browse files
author
Lu Fang
committed
add attention_chunk_size in full attention spec
Signed-off-by: Lu Fang <[email protected]>
1 parent 78385bc commit 320ab71

File tree

3 files changed

+13
-10
lines changed

3 files changed

+13
-10
lines changed

vllm/v1/core/kv_cache_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
927927
head_size=spec.head_size,
928928
dtype=spec.dtype,
929929
use_mla=spec.use_mla,
930+
attention_chunk_size=spec.attention_chunk_size,
930931
)
931932

932933
if is_hybrid(kv_cache_spec):

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def get_num_common_prefix_blocks(self, request_id: str,
386386
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
387387
So it's not correct to count ref_cnt like FullAttentionManager. Return
388388
0 here for correctness. Need to support cascade attention + sliding
389-
window in the future
389+
window in the future.
390390
"""
391391
return 0
392392

@@ -414,22 +414,23 @@ def find_longest_cache_hit(
414414
"chunked local attention groups")
415415
max_num_blocks = max_length // kv_cache_spec.block_size
416416
if max_length > 0:
417-
local_attention_start_idx = (
418-
(max_length-1) // kv_cache_spec.attention_chunk_size
419-
* kv_cache_spec.attention_chunk_size)
417+
local_attention_start_idx = ((max_length - 1) //
418+
kv_cache_spec.attention_chunk_size *
419+
kv_cache_spec.attention_chunk_size)
420420
else:
421421
local_attention_start_idx = 0
422422
# [ block 0, ..., block x(x_start<=first_attention_token),
423423
# block x+1, .., block N (N_end <=max_len), ...]
424-
local_attention_start_block_idx = (
425-
local_attention_start_idx // kv_cache_spec.block_size)
424+
local_attention_start_block_idx = (local_attention_start_idx //
425+
kv_cache_spec.block_size)
426426
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
427427
[block_pool.null_block] * local_attention_start_block_idx
428428
for _ in range(len(kv_cache_group_ids)))
429-
# for local chunked attention, we marked blocks out of window as computed
430-
# with null blocks, and blocks inside window based on cache lookup result
431-
# [null] [null] ... [null] [hit block 1 (1st block contain last window)]
432-
# [hit block 2] ... [hit block x][
429+
# we marked blocks out of window as computed
430+
# with null blocks, and blocks inside window
431+
# based on cache lookup result
432+
# [null] [null] ... [null] [hit block 1 (1st block contain last window)]
433+
# [hit block 2] ... [hit block x]
433434
for i in range(local_attention_start_block_idx, max_num_blocks):
434435
block_hash = block_hashes[i]
435436
if cached_block := block_pool.get_cached_block(

vllm/v1/kv_cache_interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def page_size_bytes(self) -> int:
8686
@dataclass
8787
class FullAttentionSpec(AttentionSpec):
8888
sliding_window: Optional[int] = None
89+
attention_chunk_size: Optional[int] = None
8990
"""
9091
When hybrid allocator is disabled and the model contains both full
9192
attention layers and sliding window attention layers, sliding

0 commit comments

Comments
 (0)