@@ -74,22 +74,24 @@ def run_one_case(block_is_cached, tail_token, expect_length):
74
74
75
75
run_one_case ([True ], 0 , 1 )
76
76
run_one_case ([True ], 1 , 1 )
77
- run_one_case ([True , False ], 0 , 1 )
77
+ run_one_case ([True , False ], 0 , 2 )
78
78
run_one_case ([True , False ], 1 , 2 )
79
79
run_one_case ([True , True ], 0 , 2 )
80
80
run_one_case ([True , True ], 1 , 2 )
81
81
run_one_case ([True , True , False ], 0 , 2 )
82
82
run_one_case ([True , True , False ], 1 , 2 )
83
83
run_one_case ([True , True , True ], 0 , 3 )
84
84
run_one_case ([True , True , True ], 1 , 3 )
85
- run_one_case ([True , True , True , False ], 0 , 3 )
85
+ run_one_case ([True , True , True , False ], 0 , 4 )
86
86
run_one_case ([True , True , True , False ], 1 , 4 )
87
+ run_one_case ([random .choice ([True , False ])] * 8 + [True ], 1 , 9 )
88
+ run_one_case ([random .choice ([True , False ])] * 8 + [False ], 1 , 8 )
87
89
run_one_case ([random .choice ([True , False ])] * 8 + [True , True ], 1 , 10 )
88
- run_one_case ([random .choice ([True , False ])] * 8 + [True , False ], 0 , 9 )
90
+ run_one_case ([random .choice ([True , False ])] * 8 + [True , False ], 0 , 10 )
89
91
run_one_case ([random .choice ([True , False ])] * 8 + [True , False ], 1 , 10 )
90
- run_one_case ([random .choice ([True , False ])] * 8 + [False , True ], 0 , 8 )
92
+ run_one_case ([random .choice ([True , False ])] * 8 + [False , True ], 0 , 10 )
91
93
run_one_case ([random .choice ([True , False ])] * 8 + [False , True ], 1 , 10 )
92
- run_one_case ([random .choice ([True , False ])] * 8 + [False , False ], 0 , 8 )
94
+ run_one_case ([random .choice ([True , False ])] * 8 + [False , False ], 0 , 10 )
93
95
run_one_case ([random .choice ([True , False ])] * 8 + [False , False ], 1 , 10 )
94
96
95
97
@@ -198,23 +200,18 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
198
200
manager .remove_skipped_blocks ("test" , 0 )
199
201
assert_block_id (block_table , original_block_ids )
200
202
201
- # 4 tokens are computed. no token is out of the local attention window.
203
+ # For 4th token (0-indexed), token 0-3 is out of the local attention window.
202
204
manager .remove_skipped_blocks ("test" , 4 )
203
- assert_block_id (block_table , original_block_ids )
204
-
205
- # 5 tokens are computed. token 0 is out of the local attention window.
206
- # no block can be removed.
207
- manager .remove_skipped_blocks ("test" , 5 )
208
- assert_block_id (block_table , [null_block_id ])
205
+ assert_block_id (block_table , [null_block_id ] * 2 )
209
206
210
- # 6 tokens are computed. token 4 - 5 are in local attention window,
207
+ # For 6th token (0-indexed), token 4 - 6 are in local attention window,
211
208
# token 0 - 3 are out, 2 blocks can be removed.
212
209
manager .remove_skipped_blocks ("test" , 6 )
213
210
assert_block_id (block_table , [null_block_id ] * 2 + original_block_ids [2 :])
214
- # 11 tokens are computed. token 8 - 11 are in local attention window ,
215
- # token 0-7 are out, 4 block can be removed.
216
- manager .remove_skipped_blocks ("test" , 11 )
217
- assert_block_id (block_table , [null_block_id ] * 4 + original_block_ids [ 4 :] )
211
+ # For 12th token (0-indexed) ,
212
+ # token 0-11 are out, 6 block can be removed.
213
+ manager .remove_skipped_blocks ("test" , 12 )
214
+ assert_block_id (block_table , [null_block_id ] * 6 )
218
215
219
216
220
217
def test_sliding_window_remove_skipped_blocks ():
0 commit comments