Skip to content

Commit 8f5b161

Browse files
committed
fix hybrid kv cache
Signed-off-by: Chenyaaang <[email protected]>
1 parent c0c8192 commit 8f5b161

File tree

6 files changed

+145
-72
lines changed

6 files changed

+145
-72
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def get_flax_model(
217217
hidden_states_sharding, # aux hidden states
218218
),
219219
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220-
static_argnums=6, #6 is layer_name_to_kvcache_index
220+
static_argnums=7, #7 is layer_name_to_kvcache_index
221221
)
222222
def run_model(graphdef, state, *args):
223223
model = nnx.merge(graphdef, state)

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def step_fun(
160160
kv_caches: List[jax.Array],
161161
input_ids: jax.Array,
162162
attn_metadata: AttentionMetadata,
163+
input_positions: jax.Array,
163164
input_embeds: jax.Array,
164165
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
165166
lora_metadata,
@@ -187,8 +188,8 @@ def step_fun(
187188
torch_view(params_and_buffers),
188189
kwargs={
189190
"input_ids": torch_view(input_ids),
190-
"positions": torch_view(attn_metadata.input_positions),
191-
"intermediate_tensors": intermediate_tensors,
191+
"positions": torch_view(input_positions),
192+
"intermediate_tensors": None,
192193
"inputs_embeds": None,
193194
},
194195
tie_weights=False,

tpu_inference/platforms/tpu_platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,7 @@ def use_sync_weight_loader(cls) -> bool:
266266
Returns if the current platform needs to sync weight loader.
267267
"""
268268
return True
269+
270+
@classmethod
271+
def support_hybrid_kv_cache(cls) -> bool:
272+
return True

tpu_inference/runner/compilation_manager.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import time
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
44

55
import jax
66
import jax.numpy as jnp
@@ -135,12 +135,6 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
135135
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
136136

137137
# Keep existing pattern for complex array operations
138-
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139-
block_tables = block_tables.reshape(-1)
140-
block_tables = device_array(self.runner.mesh,
141-
block_tables,
142-
sharding=dp_sharding)
143-
144138
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
145139
jnp.int32, dp_sharding)
146140
query_start_loc = self._create_dummy_tensor(
@@ -152,40 +146,64 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
152146
request_distribution,
153147
sharding=dp_sharding)
154148

155-
attention_metadata = AttentionMetadata(
156-
input_positions=positions,
157-
block_tables=block_tables,
158-
seq_lens=seq_lens,
159-
query_start_loc=query_start_loc,
160-
request_distribution=request_distribution,
161-
)
149+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
150+
uniform_attention_metadata: AttentionMetadata = None
151+
for kv_cache_gid, kv_cache_group in enumerate(
152+
self.runner.kv_cache_config.kv_cache_groups):
153+
block_tables = self.runner.block_tables_cpu[
154+
kv_cache_gid][:self.runner.max_num_reqs]
155+
block_tables = block_tables.reshape(-1)
156+
block_tables = device_array(self.runner.mesh,
157+
block_tables,
158+
sharding=dp_sharding)
159+
160+
attention_metadata_gid = AttentionMetadata(
161+
input_positions=positions,
162+
block_tables=block_tables,
163+
seq_lens=seq_lens,
164+
query_start_loc=query_start_loc,
165+
request_distribution=request_distribution,
166+
)
167+
if not self.runner.use_hybrid_kvcache:
168+
# all layers share the same attention metadata
169+
uniform_attention_metadata = attention_metadata_gid
170+
else:
171+
for layer_name in kv_cache_group.layer_names:
172+
attention_metadata_per_layer[
173+
layer_name] = attention_metadata_gid
162174

163175
def model_fn_wrapper(
164176
state,
165177
kv_caches,
166178
input_ids,
167179
attention_metadata,
180+
positions,
168181
inputs_embeds,
169182
layer_name_to_kvcache_index,
170183
lora_metadata,
171184
):
172185
kv_caches, hidden_states, _ = self.runner.model_fn(
173-
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
174-
layer_name_to_kvcache_index, lora_metadata)
186+
state, kv_caches, input_ids, attention_metadata, positions,
187+
inputs_embeds, layer_name_to_kvcache_index, lora_metadata)
175188
self.runner.kv_caches = kv_caches
176189
return hidden_states
177190

178191
with self.runner.maybe_select_dummy_loras(
179192
self.runner.lora_config, np.array([num_tokens],
180193
dtype=np.int32)):
181194
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
195+
if self.runner.use_hybrid_kvcache:
196+
attention_metadata = attention_metadata_per_layer
197+
else:
198+
attention_metadata = uniform_attention_metadata
182199
self._run_compilation(
183200
name,
184201
model_fn_wrapper,
185202
self.runner.state,
186203
self.runner.kv_caches,
187204
input_ids,
188205
attention_metadata,
206+
positions,
189207
inputs_embeds,
190208
tuple(self.runner.layer_name_to_kvcache_index.items()),
191209
lora_metadata,

tpu_inference/runner/kv_cache_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import jax
55
import jax.numpy as jnp
6+
import numpy as np
67
import vllm.envs as envs
78
from jax.sharding import NamedSharding, PartitionSpec
89
from torchax.ops.mappings import t2j_dtype
910
from vllm.attention import Attention
1011
from vllm.attention.backends.abstract import AttentionType
1112
from vllm.config import get_layers_from_vllm_config
13+
from vllm.utils.math_utils import cdiv
1214
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1315
KVCacheSpec, MLAAttentionSpec,
1416
SlidingWindowSpec)
@@ -174,6 +176,11 @@ def maybe_reinitialize_input_batch(self,
174176
)
175177
self.runner.input_batch = new_input_batch
176178
self.runner.persistent_batch_manager.input_batch = new_input_batch
179+
self.runner.block_tables_cpu = [
180+
np.zeros((self.runner.max_num_reqs,
181+
cdiv(self.runner.max_model_len, block_size)),
182+
dtype=np.int32) for block_size in block_sizes
183+
]
177184

178185
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
179186
self.maybe_reinitialize_input_batch(kv_cache_config)

tpu_inference/runner/tpu_runner.py

Lines changed: 96 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,11 @@ def _init_inputs(self) -> None:
438438

439439
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
440440
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
441-
self.block_table_cpu = np.zeros(
442-
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
441+
self.block_tables_cpu = [
442+
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
443+
dtype=np.int32)
444+
]
445+
443446
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
444447
dtype=np.int32)
445448
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -535,6 +538,7 @@ def get_kv_cache_spec(self):
535538

536539
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
537540
self.kv_cache_config = kv_cache_config
541+
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
538542
self.kv_caches = []
539543
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
540544
if has_kv_transfer_group():
@@ -701,6 +705,7 @@ def _execute_model(
701705
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
702706
(
703707
input_ids,
708+
input_positions,
704709
attn_metadata,
705710
_,
706711
logits_indices,
@@ -747,6 +752,7 @@ def _execute_model(
747752
self.kv_caches,
748753
input_ids,
749754
attn_metadata,
755+
input_positions,
750756
inputs_embeds,
751757
tuple(self.layer_name_to_kvcache_index.items()),
752758
lora_metadata,
@@ -1303,16 +1309,6 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13031309
mrope_positions = self.mrope_positions_cpu[:, :
13041310
padded_total_num_scheduled_tokens]
13051311

1306-
block_tables = self.block_table_cpu[:self.max_num_reqs]
1307-
for dp_rank in range(dp_size):
1308-
req_offset = dp_rank * max_num_reqs_per_dp_rank
1309-
_num_reqs = num_req_per_dp_rank[dp_rank]
1310-
1311-
block_tables[
1312-
req_offset:req_offset + _num_reqs, :self.
1313-
max_num_blocks_per_req] = self.input_batch.block_table[
1314-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1315-
13161312
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
13171313
dp_size]
13181314
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
@@ -1354,20 +1350,55 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13541350
if self.uses_mrope:
13551351
positions = mrope_positions
13561352

1357-
# Convert block_tables to 1D on cpu.
1358-
block_tables = block_tables.reshape(-1)
1359-
13601353
query_start_loc_cpu = query_start_loc
13611354
logits_indices_cpu = logits_indices
13621355
seq_lens_cpu = seq_lens
13631356

1364-
(input_ids, positions, block_tables, query_start_loc, seq_lens,
1365-
logits_indices, request_distribution) = device_array(
1357+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
1358+
request_distribution) = device_array(
13661359
self.mesh,
1367-
(input_ids, positions, block_tables, query_start_loc, seq_lens,
1368-
logits_indices, request_distribution),
1360+
(input_ids, positions, query_start_loc, seq_lens, logits_indices,
1361+
request_distribution),
13691362
sharding=data_parallel_attn_sharding,
13701363
)
1364+
1365+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1366+
uniform_attention_metadata: AttentionMetadata = None
1367+
for kv_cache_gid, kv_cache_group in enumerate(
1368+
self.kv_cache_config.kv_cache_groups):
1369+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1370+
max_num_reqs]
1371+
for dp_rank in range(dp_size):
1372+
req_offset = dp_rank * max_num_reqs_per_dp_rank
1373+
_num_reqs = num_req_per_dp_rank[dp_rank]
1374+
1375+
block_tables[
1376+
req_offset:req_offset + _num_reqs, :self.
1377+
max_num_blocks_per_req] = self.input_batch.block_table[
1378+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1379+
# Convert block_tables to 1D on cpu.
1380+
block_tables = block_tables.reshape(-1)
1381+
block_tables = device_array(self.mesh, (block_tables))
1382+
1383+
attention_metadata_gid = AttentionMetadata(
1384+
input_positions=positions,
1385+
block_tables=block_tables,
1386+
seq_lens=seq_lens,
1387+
query_start_loc=query_start_loc,
1388+
request_distribution=request_distribution,
1389+
)
1390+
1391+
# This is for making these cpu buffers hidden during tracing
1392+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1393+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1394+
1395+
if not self.use_hybrid_kvcache:
1396+
uniform_attention_metadata = attention_metadata_gid
1397+
else:
1398+
for layer_name in kv_cache_group.layer_names:
1399+
attention_metadata_per_layer[
1400+
layer_name] = attention_metadata_gid
1401+
13711402
# Async scheduling: substitute placeholder tokens for DP
13721403
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
13731404
# Collect all token indices that need substitution across all DP ranks
@@ -1396,20 +1427,13 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13961427
padded_total_num_scheduled_tokens,
13971428
)
13981429

1399-
attention_metadata = AttentionMetadata(
1400-
input_positions=positions,
1401-
block_tables=block_tables,
1402-
seq_lens=seq_lens,
1403-
query_start_loc=query_start_loc,
1404-
request_distribution=request_distribution,
1405-
)
1406-
1407-
# This is for making these cpu buffers hidden during tracing
1408-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1409-
attention_metadata.seq_lens_cpu = seq_lens_cpu
1410-
1430+
if self.use_hybrid_kvcache:
1431+
attention_metadata = attention_metadata_per_layer
1432+
else:
1433+
attention_metadata = uniform_attention_metadata
14111434
return (
14121435
input_ids,
1436+
positions,
14131437
attention_metadata,
14141438
sampling_metadata,
14151439
logits_indices,
@@ -1516,9 +1540,6 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15161540
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
15171541
mrope_positions = self.mrope_positions_cpu[:, :
15181542
padded_total_num_scheduled_tokens]
1519-
block_tables = self.block_table_cpu[:self.max_num_reqs]
1520-
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
1521-
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
15221543

15231544
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
15241545
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
@@ -1547,16 +1568,44 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15471568
self.mesh, self.input_batch, padded_num_reqs)
15481569
if self.uses_mrope:
15491570
positions = mrope_positions
1550-
1551-
# Convert block_tables to 1D on cpu.
1552-
block_tables = block_tables.reshape(-1)
1553-
15541571
query_start_loc_cpu = query_start_loc
15551572
seq_lens_cpu = seq_lens
1556-
(input_ids, positions, block_tables, query_start_loc, seq_lens,
1573+
1574+
(input_ids, positions, query_start_loc, seq_lens,
15571575
logits_indices, request_distribution) = device_array(
1558-
self.mesh, (input_ids, positions, block_tables, query_start_loc,
1559-
seq_lens, logits_indices, request_distribution))
1576+
self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1577+
logits_indices, request_distribution))
1578+
1579+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1580+
uniform_attention_metadata: AttentionMetadata = None
1581+
for kv_cache_gid, kv_cache_group in enumerate(
1582+
self.kv_cache_config.kv_cache_groups):
1583+
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1584+
max_num_reqs]
1585+
block_tables[:num_reqs] = (
1586+
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1587+
[:num_reqs])
1588+
# Convert block_tables to 1D on cpu.
1589+
block_tables = block_tables.reshape(-1)
1590+
block_tables = device_array(self.mesh, (block_tables))
1591+
1592+
attention_metadata_gid = AttentionMetadata(
1593+
input_positions=positions,
1594+
block_tables=block_tables,
1595+
seq_lens=seq_lens,
1596+
query_start_loc=query_start_loc,
1597+
request_distribution=request_distribution)
1598+
# This is for making these cpu buffers hidden during tracing
1599+
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1600+
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1601+
1602+
if not self.use_hybrid_kvcache:
1603+
# all layers share the same attention metadata
1604+
uniform_attention_metadata = attention_metadata_gid
1605+
else:
1606+
for layer_name in kv_cache_group.layer_names:
1607+
attention_metadata_per_layer[
1608+
layer_name] = attention_metadata_gid
15601609

15611610
if self.scheduler_config.async_scheduling and len(
15621611
token_in_tpu_cur_input_indices) > 0:
@@ -1569,19 +1618,13 @@ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
15691618
self.lora_utils.set_active_loras(
15701619
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
15711620
padded_total_num_scheduled_tokens)
1572-
1573-
attention_metadata = AttentionMetadata(
1574-
input_positions=positions,
1575-
block_tables=block_tables,
1576-
seq_lens=seq_lens,
1577-
query_start_loc=query_start_loc,
1578-
request_distribution=request_distribution)
1579-
1580-
# This is for making these cpu buffers hidden during tracing
1581-
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
1582-
attention_metadata.seq_lens_cpu = seq_lens_cpu
15831621
logits_indices_selector = None
1584-
return (input_ids, attention_metadata, sampling_metadata,
1622+
1623+
if self.use_hybrid_kvcache:
1624+
attention_metadata = attention_metadata_per_layer
1625+
else:
1626+
attention_metadata = uniform_attention_metadata
1627+
return (input_ids, positions, attention_metadata, sampling_metadata,
15851628
logits_indices, spec_decode_metadata, logits_indices_selector,
15861629
padded_num_reqs)
15871630

0 commit comments

Comments
 (0)