Skip to content

Commit 19fbdee

Browse files
author
Lu Fang
committed
fix issue of local attention start idx based on num computed tokens
Signed-off-by: Lu Fang <[email protected]>
1 parent f2887f6 commit 19fbdee

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

tests/v1/core/test_specialized_manager.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,24 @@ def run_one_case(block_is_cached, tail_token, expect_length):
7474

7575
run_one_case([True], 0, 1)
7676
run_one_case([True], 1, 1)
77-
run_one_case([True, False], 0, 1)
77+
run_one_case([True, False], 0, 2)
7878
run_one_case([True, False], 1, 2)
7979
run_one_case([True, True], 0, 2)
8080
run_one_case([True, True], 1, 2)
8181
run_one_case([True, True, False], 0, 2)
8282
run_one_case([True, True, False], 1, 2)
8383
run_one_case([True, True, True], 0, 3)
8484
run_one_case([True, True, True], 1, 3)
85-
run_one_case([True, True, True, False], 0, 3)
85+
run_one_case([True, True, True, False], 0, 4)
8686
run_one_case([True, True, True, False], 1, 4)
87+
run_one_case([random.choice([True, False])] * 8 + [True], 1, 9)
88+
run_one_case([random.choice([True, False])] * 8 + [False], 1, 8)
8789
run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10)
88-
run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 9)
90+
run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 10)
8991
run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10)
90-
run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 8)
92+
run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 10)
9193
run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10)
92-
run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 8)
94+
run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 10)
9395
run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10)
9496

9597

@@ -198,23 +200,18 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
198200
manager.remove_skipped_blocks("test", 0)
199201
assert_block_id(block_table, original_block_ids)
200202

201-
# 4 tokens are computed. no token is out of the local attention window.
203+
# For 4th token (0-indexed), token 0-3 is out of the local attention window.
202204
manager.remove_skipped_blocks("test", 4)
203-
assert_block_id(block_table, original_block_ids)
204-
205-
# 5 tokens are computed. token 0 is out of the local attention window.
206-
# no block can be removed.
207-
manager.remove_skipped_blocks("test", 5)
208-
assert_block_id(block_table, [null_block_id])
205+
assert_block_id(block_table, [null_block_id] * 2)
209206

210-
# 6 tokens are computed. token 4 - 5 are in local attention window,
207+
# For 6th token (0-indexed), token 4 - 6 are in local attention window,
211208
# token 0 - 3 are out, 2 blocks can be removed.
212209
manager.remove_skipped_blocks("test", 6)
213210
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
214-
# 11 tokens are computed. token 8 - 11 are in local attention window,
215-
# token 0-7 are out, 4 block can be removed.
216-
manager.remove_skipped_blocks("test", 11)
217-
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])
211+
# For 12th token (0-indexed),
212+
# token 0-11 are out, 6 block can be removed.
213+
manager.remove_skipped_blocks("test", 12)
214+
assert_block_id(block_table, [null_block_id] * 6)
218215

219216

220217
def test_sliding_window_remove_skipped_blocks():

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
396396
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec,
397397
block_pool: BlockPool, **kwargs) -> None:
398398
super().__init__(kv_cache_spec, block_pool, **kwargs)
399+
print("local attetntion manager is init")
399400
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
400401
self._null_block = block_pool.null_block
401402

@@ -414,7 +415,7 @@ def find_longest_cache_hit(
414415
"chunked local attention groups")
415416
max_num_blocks = max_length // kv_cache_spec.block_size
416417
if max_length > 0:
417-
local_attention_start_idx = ((max_length - 1) //
418+
local_attention_start_idx = (max_length //
418419
kv_cache_spec.attention_chunk_size *
419420
kv_cache_spec.attention_chunk_size)
420421
else:
@@ -450,12 +451,12 @@ def remove_skipped_blocks(self, request_id: str,
450451
# chunked attention window and skipped
451452
# during the attention computation.
452453

453-
# (N-1) // chunk_size * chunk_size
454+
# N // chunk_size * chunk_size
454455
# [chunk 0][chunk 1]local_attention_start_idx ... current
455456

456457
local_attention_start_idx = (
457-
num_computed_tokens -
458-
1) // self.attention_chunk_size * self.attention_chunk_size
458+
num_computed_tokens
459+
) // self.attention_chunk_size * self.attention_chunk_size
459460
# 1024-> 0, 1025-> 1024
460461
first_useful_block_idx = local_attention_start_idx // self.block_size
461462
# block size =128, 0 -> block 0, 1024 -> block 8, 372 -> block 2

0 commit comments

Comments
 (0)