@@ -386,7 +386,7 @@ def get_num_common_prefix_blocks(self, request_id: str,
386
386
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
387
387
So it's not correct to count ref_cnt like FullAttentionManager. Return
388
388
0 here for correctness. Need to support cascade attention + sliding
389
- window in the future
389
+ window in the future.
390
390
"""
391
391
return 0
392
392
@@ -414,22 +414,23 @@ def find_longest_cache_hit(
414
414
"chunked local attention groups" )
415
415
max_num_blocks = max_length // kv_cache_spec .block_size
416
416
if max_length > 0 :
417
- local_attention_start_idx = (
418
- ( max_length - 1 ) // kv_cache_spec .attention_chunk_size
419
- * kv_cache_spec .attention_chunk_size )
417
+ local_attention_start_idx = (( max_length - 1 ) //
418
+ kv_cache_spec .attention_chunk_size *
419
+ kv_cache_spec .attention_chunk_size )
420
420
else :
421
421
local_attention_start_idx = 0
422
422
# [ block 0, ..., block x(x_start<=first_attention_token),
423
423
# block x+1, .., block N (N_end <=max_len), ...]
424
- local_attention_start_block_idx = (
425
- local_attention_start_idx // kv_cache_spec .block_size )
424
+ local_attention_start_block_idx = (local_attention_start_idx //
425
+ kv_cache_spec .block_size )
426
426
computed_blocks : tuple [list [KVCacheBlock ], ...] = tuple (
427
427
[block_pool .null_block ] * local_attention_start_block_idx
428
428
for _ in range (len (kv_cache_group_ids )))
429
- # for local chunked attention, we marked blocks out of window as computed
430
- # with null blocks, and blocks inside window based on cache lookup result
431
- # [null] [null] ... [null] [hit block 1 (1st block contain last window)]
432
- # [hit block 2] ... [hit block x][
429
+ # we marked blocks out of window as computed
430
+ # with null blocks, and blocks inside window
431
+ # based on cache lookup result
432
+ # [null] [null] ... [null] [hit block 1 (1st block contain last window)]
433
+ # [hit block 2] ... [hit block x]
433
434
for i in range (local_attention_start_block_idx , max_num_blocks ):
434
435
block_hash = block_hashes [i ]
435
436
if cached_block := block_pool .get_cached_block (
0 commit comments