Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions tests/runner/test_tpu_runner_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,16 @@ def setup_method(self):
self.runner.query_start_loc_cpu = np.zeros(10, dtype=np.int32)
self.runner.seq_lens_cpu = np.zeros(8, dtype=np.int32)
self.runner.logits_indices_cpu = np.zeros(8, dtype=np.int32)
self.runner.block_table_cpu = np.zeros((8, 8), dtype=np.int32)
self.runner.block_tables_cpu = [np.zeros((8, 8), dtype=np.int32)]
self.runner.arange_cpu = np.arange(64, dtype=np.int64)

# mock kv cache group
mock_kv_cache_config = MagicMock()
mock_kv_cache_group = MagicMock()
mock_kv_cache_config.kv_cache_groups = [mock_kv_cache_group]
self.runner.kv_cache_config = mock_kv_cache_config
self.runner.use_hybrid_kvcache = False

# Mock scheduler config for async scheduling
self.runner.scheduler_config = MagicMock()
self.runner.scheduler_config.async_scheduling = False # Default to False for most tests
Expand Down Expand Up @@ -102,8 +109,8 @@ def test_prepare_inputs_dp_basic_functionality(self,
result = self.runner._prepare_inputs_dp(scheduler_output)

# Basic assertions
assert len(result) == 7
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
assert len(result) == 8
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result

# Verify utility functions were called
mock_runner_utils.get_padded_token_len.assert_called()
Expand Down Expand Up @@ -380,7 +387,7 @@ def mock_get_padded_token_len(paddings_list, val):

# Execute the method
result = self.runner._prepare_inputs_dp(scheduler_output)
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
# 1. Verify input_ids content
expected_input_ids = np.zeros(16, dtype=np.int32)
expected_input_ids[:2] = [1006, 1007]
Expand Down Expand Up @@ -494,7 +501,7 @@ def mock_get_padded_token_len(paddings_list, val):

# Execute the method
result = self.runner._prepare_inputs_dp(scheduler_output)
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result

# 1. Verify input_ids
expected_input_ids = np.zeros(16, dtype=np.int32)
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def get_flax_model(
hidden_states_sharding, # aux hidden states
),
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
static_argnums=6, #6 is layer_name_to_kvcache_index
static_argnums=7, #7 is layer_name_to_kvcache_index
)
def run_model(graphdef, state, *args):
model = nnx.merge(graphdef, state)
Expand Down
5 changes: 3 additions & 2 deletions tpu_inference/models/vllm/vllm_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def step_fun(
input_ids: jax.Array,
attn_metadata: AttentionMetadata,
input_embeds: jax.Array,
input_positions: jax.Array,
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
lora_metadata,
intermediate_tensors: JaxIntermediateTensors = None,
Expand All @@ -187,8 +188,8 @@ def step_fun(
torch_view(params_and_buffers),
kwargs={
"input_ids": torch_view(input_ids),
"positions": torch_view(attn_metadata.input_positions),
"intermediate_tensors": intermediate_tensors,
"positions": torch_view(input_positions),
"intermediate_tensors": None,
"inputs_embeds": None,
},
tie_weights=False,
Expand Down
4 changes: 4 additions & 0 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,7 @@ def use_sync_weight_loader(cls) -> bool:
Returns if the current platform needs to sync weight loader.
"""
return True

@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
48 changes: 33 additions & 15 deletions tpu_inference/runner/compilation_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -135,12 +135,6 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None

# Keep existing pattern for complex array operations
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
block_tables = block_tables.reshape(-1)
block_tables = device_array(self.runner.mesh,
block_tables,
sharding=dp_sharding)

seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
jnp.int32, dp_sharding)
query_start_loc = self._create_dummy_tensor(
Expand All @@ -152,40 +146,64 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
request_distribution,
sharding=dp_sharding)

attention_metadata = AttentionMetadata(
input_positions=positions,
block_tables=block_tables,
seq_lens=seq_lens,
query_start_loc=query_start_loc,
request_distribution=request_distribution,
)
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
uniform_attention_metadata: AttentionMetadata = None
for kv_cache_gid, kv_cache_group in enumerate(
self.runner.kv_cache_config.kv_cache_groups):
block_tables = self.runner.block_tables_cpu[
kv_cache_gid][:self.runner.max_num_reqs]
block_tables = block_tables.reshape(-1)
block_tables = device_array(self.runner.mesh,
block_tables,
sharding=dp_sharding)

attention_metadata_gid = AttentionMetadata(
input_positions=positions,
block_tables=block_tables,
seq_lens=seq_lens,
query_start_loc=query_start_loc,
request_distribution=request_distribution,
)
if not self.runner.use_hybrid_kvcache:
# all layers share the same attention metadata
uniform_attention_metadata = attention_metadata_gid
else:
for layer_name in kv_cache_group.layer_names:
attention_metadata_per_layer[
layer_name] = attention_metadata_gid

def model_fn_wrapper(
state,
kv_caches,
input_ids,
attention_metadata,
positions,
inputs_embeds,
layer_name_to_kvcache_index,
lora_metadata,
):
kv_caches, hidden_states, _ = self.runner.model_fn(
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
layer_name_to_kvcache_index, lora_metadata)
positions, layer_name_to_kvcache_index, lora_metadata)
self.runner.kv_caches = kv_caches
return hidden_states

with self.runner.maybe_select_dummy_loras(
self.runner.lora_config, np.array([num_tokens],
dtype=np.int32)):
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
if self.runner.use_hybrid_kvcache:
attention_metadata = attention_metadata_per_layer
else:
attention_metadata = uniform_attention_metadata
self._run_compilation(
name,
model_fn_wrapper,
self.runner.state,
self.runner.kv_caches,
input_ids,
attention_metadata,
positions,
inputs_embeds,
tuple(self.runner.layer_name_to_kvcache_index.items()),
lora_metadata,
Expand Down
7 changes: 7 additions & 0 deletions tpu_inference/runner/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import jax
import jax.numpy as jnp
import numpy as np
import vllm.envs as envs
from jax.sharding import NamedSharding, PartitionSpec
from torchax.ops.mappings import t2j_dtype
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionType
from vllm.config import get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MLAAttentionSpec,
SlidingWindowSpec)
Expand Down Expand Up @@ -174,6 +176,11 @@ def maybe_reinitialize_input_batch(self,
)
self.runner.input_batch = new_input_batch
self.runner.persistent_batch_manager.input_batch = new_input_batch
self.runner.block_tables_cpu = [
np.zeros((self.runner.max_num_reqs,
cdiv(self.runner.max_model_len, block_size)),
dtype=np.int32) for block_size in block_sizes
]

def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.maybe_reinitialize_input_batch(kv_cache_config)
Expand Down
Loading