diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 4d674d1d..6624af5e 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -91,18 +91,24 @@ def deallocate(self, seq: Sequence): seq.num_cached_tokens = 0 seq.block_table.clear() + def need_append(self, seq: Sequence) -> bool: + return len(seq) % self.block_size == 1 + def can_append(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + # When a new token makes seq length exceed block_size (len(seq) % block_size == 1) + # need to allocate a new block. + return len(self.free_block_ids) >= 1 and len(seq) % self.block_size == 1 + + def append(self, seq: Sequence): + block_table = seq.block_table + block_id = self.free_block_ids[0] + self._allocate_block(block_id) + block_table.append(block_id) - def may_append(self, seq: Sequence): + def check_and_update_hash(self, seq: Sequence): block_table = seq.block_table last_block = self.blocks[block_table[-1]] - if len(seq) % self.block_size == 1: - assert last_block.hash != -1 - block_id = self.free_block_ids[0] - self._allocate_block(block_id) - block_table.append(block_id) - elif len(seq) % self.block_size == 0: + if len(seq) % self.block_size == 0: assert last_block.hash == -1 token_ids = seq.block(seq.num_blocks-1) prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 @@ -110,4 +116,5 @@ def may_append(self, seq: Sequence): last_block.update(h, token_ids) self.hash_to_block_id[h] = last_block.block_id else: - assert last_block.hash == -1 + assert last_block.hash == -1 + diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index 5bc19fe0..a504b3cc 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -43,16 +43,18 @@ def schedule(self) -> tuple[list[Sequence], bool]: # decode while self.running and num_seqs < self.max_num_seqs: seq = self.running.popleft() - while not self.block_manager.can_append(seq): - if self.running: - self.preempt(self.running.pop()) + self.block_manager.check_and_update_hash(seq) + if self.block_manager.need_append(seq): + while not self.block_manager.can_append(seq): + if self.running: + self.preempt(self.running.pop()) + else: + self.preempt(seq) + break else: - self.preempt(seq) - break - else: - num_seqs += 1 - self.block_manager.may_append(seq) - scheduled_seqs.append(seq) + self.block_manager.append(seq) + num_seqs += 1 + scheduled_seqs.append(seq) assert scheduled_seqs self.running.extendleft(reversed(scheduled_seqs)) return scheduled_seqs, False