diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index a2a5c2a02cb..90cad506ab1 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" run_and_track_test 15 "test_spmd_model_weight_loading.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" +run_and_track_test 16 "test_kv_cache_update_kernel.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py new file mode 100644 index 00000000000..e8045d1a2d8 --- /dev/null +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest +import torch +import torch_xla + +import vllm.v1.attention.backends.pallas # noqa: F401 +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a test for TPU only") +def test_kv_cache_update_kernel(): + page_num = 1000 + page_size = 32 + combined_kv_head_num = 16 + head_dim = 128 + kernel_block_size = 16 + padded_num_tokens = 128 + kv_cache_cpu = torch.zeros( + (page_num * page_size, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) + new_kv_cpu = torch.randn( + (padded_num_tokens, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + new_kv_xla = new_kv_cpu.to(torch_xla.device()) + slice_lens = np.array([7, 32, 32, 1, 1, 1, 9], dtype=np.int32) + kv_cache_start_indices = np.array([57, 64, 96, 104, 213, 345, 488], + dtype=np.int32) + new_kv_cache_indices = np.array([0, 7, 39, 71, 72, 73, 74], dtype=np.int32) + slot_mapping = np.stack( + [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) + slot_mapping = np.pad( + slot_mapping, [[0, kernel_block_size - slot_mapping.shape[0]], [0, 0]], + constant_values=0) + slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu") + slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) + torch_xla.sync() + + torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) + new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( + new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, + kernel_block_size) + kv_cache_xla.copy_(new_kv_cache_xla) + torch_xla.sync() + + for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, + slice_lens): + kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] + + assert torch.allclose(kv_cache_xla.cpu(), + kv_cache_cpu, + atol=1e-4, + rtol=1e-4) diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py new file mode 100644 index 00000000000..a916eca3003 --- /dev/null +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def _kv_cache_update_kernel( + # Prefetch + slices_ref, # [num_slices, 3] + # Input + new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim] + kv_cache_hbm_ref, + # Output + _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + # Scratch + scratch, # [block_size, page_size, num_combined_kv_heads, head_dim] + sem, +): + async_copies = [] + block_idx = pl.program_id(0) + block_size = scratch.shape[0] + + # Copy from new_kv_hbm_ref to scratch + for i in range(block_size): + offset_i = i + block_idx * block_size + new_kv_start = slices_ref[offset_i, 1] + length = slices_ref[offset_i, 2] + async_copy = pltpu.make_async_copy( + new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], + scratch.at[i, pl.ds(0, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + + for async_copy in async_copies: + async_copy.wait() + + # Copy from scratch to kv_cache_hbm_ref + async_copies.clear() + for i in range(block_size): + offset_i = i + block_idx * block_size + kv_cache_start = slices_ref[offset_i, 0] + length = slices_ref[offset_i, 2] + async_copy = pltpu.make_async_copy( + scratch.at[i, pl.ds(0, length), ...], + kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + for async_copy in async_copies: + async_copy.wait() + + +@functools.partial( + jax.jit, + static_argnames=["page_size", "block_size"], +) +def kv_cache_update( + new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] + slices: jax. + Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len) + kv_cache: jax. + Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + *, + page_size: int = 32, + block_size: int = 8, +): + assert slices.shape[0] % block_size == 0 + _, num_combined_kv_heads, head_dim = new_kv.shape + + in_specs = [ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + + out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] + out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] + + scalar_prefetches = [slices] + scratch = pltpu.VMEM( + (block_size, page_size, num_combined_kv_heads, head_dim), + new_kv.dtype, + ) + + scratch_shapes = [ + scratch, + pltpu.SemaphoreType.DMA, + ] + + kernel = pl.pallas_call( + _kv_cache_update_kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=(slices.shape[0] // block_size, ), + scratch_shapes=scratch_shapes, + ), + out_shape=out_shape, + input_output_aliases={len(scalar_prefetches) + 1: 0}, + ) + + return kernel(*scalar_prefetches, new_kv, kv_cache)[0] diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 1069578cfd2..56fd9204758 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,8 +5,12 @@ from typing import Any, Optional import torch -# Required to register custom ops. +import torch_xla.core.xla_builder as xb import torch_xla.experimental.custom_kernel # noqa: F401 +# Required to register custom ops. +from torch.library import impl +from torch_xla._internal.jax_workarounds import requires_jax +from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -108,6 +112,7 @@ class PallasMetadata: context_lens: torch.Tensor query_start_loc: torch.Tensor num_seqs: torch.Tensor + kv_cache_update_block_size: int class PallasAttentionBackendImpl(AttentionImpl): @@ -213,7 +218,10 @@ def forward( # Write input keys and values to the KV cache. # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping - write_to_kv_cache(key, value, kv_cache, slot_mapping) + kv_cache_update_block_size = \ + attn_metadata.kv_cache_update_block_size + write_to_kv_cache(key, value, kv_cache, slot_mapping, + kv_cache_update_block_size) output = torch.ops.xla.ragged_paged_attention( query, @@ -245,6 +253,7 @@ def write_to_kv_cache( value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache_update_block_size: int, ) -> None: """ Write the key and values to the KV cache. @@ -252,9 +261,9 @@ def write_to_kv_cache( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] - + kv_cache_update_block_size: int """ - _, _, num_combined_kv_heads, head_size = kv_cache.shape + _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, @@ -263,4 +272,40 @@ def write_to_kv_cache( torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) kv_cache = kv_cache.flatten(0, 1) - kv_cache.index_copy_(0, slot_mapping, kv) + new_kv_cache = torch.ops.xla.kv_cache_update_op( + kv, slot_mapping, kv_cache, page_size, kv_cache_update_block_size) + # NOTE: the in-place copy will be optimized away by XLA compiler. + kv_cache.copy_(new_kv_cache) + + +@requires_jax +def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + block_size: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { + "page_size": page_size, + "block_size": block_size + }) + return new_kv_cache + + +XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " + "int page_size, int block_size) -> Tensor", ) + + +@impl(XLA_LIB, "kv_cache_update_op", "XLA") +def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + block_size: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + page_size, block_size) + return new_kv_cache + + +@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") +def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + block_size: int) -> torch.Tensor: + return kv_cache diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 774caa1a3d9..f31b95a76c4 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -53,12 +53,11 @@ logger = init_logger(__name__) -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 +# Block size used for kv cache updating kernel +KV_CACHE_UPDATE_BLOCK_SIZE = 8 ######################################################### @@ -515,6 +514,69 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec + def _get_slot_mapping_metadata(self, num_reqs, + num_scheduled_tokens_per_req): + """ + Computes metadata for mapping slots to blocks in the key-value (KV) + cache for a batch of requests. + + This function determines, for each request in the batch, how the + scheduled tokens are distributed across memory blocks, and generates + metadata needed to map slices of tokens to their corresponding positions + in the KV cache. + + Args: + num_reqs (int): Number of requests in the current batch. + num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens + to be scheduled for each request. + + Returns: + np.ndarray: A 2D array of shape (total_block_len, 3), where each row + contains: + - kv_cache_start_index (int): The starting index in the KV cache + for the corresponding slice. + - new_kv_start_index (int): The starting index in the new KV + cache for the corresponding slice. + - slice_len (int): The length of the slice. + """ + slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] + slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ + num_scheduled_tokens_per_req + local_block_start_idx = slices_start // self.block_size + local_block_end_idx = (slices_end - 1) // self.block_size + no_repeat_req_indices = self.arange_np[:num_reqs] + global_block_start_idx = ( + no_repeat_req_indices * self.max_num_blocks_per_req + + local_block_start_idx) + block_lens = local_block_end_idx - local_block_start_idx + 1 + global_block_start_idx = np.repeat(global_block_start_idx, block_lens) + slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) + global_block_indices = global_block_start_idx + slice_arange + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() + total_block_len = np.sum(block_lens) + slot_mapping_slices = np.repeat(np.array([[0, self.block_size]], + dtype=np.int32), + total_block_len, + axis=0) + cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) + np.cumsum(block_lens, out=cu_block_lens[1:]) + for req_idx in range(num_reqs): + slot_mapping_slices[cu_block_lens[req_idx]][ + 0] = slices_start[req_idx] % self.block_size + slot_mapping_slices[ + cu_block_lens[req_idx + 1] - + 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1 + slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] + cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) + np.cumsum(slice_lens, out=cu_slices_lens[1:]) + kv_cache_start_indices = slot_mapping_slices[:, 0] + \ + (block_numbers * self.block_size) + new_kv_start_indices = cu_slices_lens[:-1] + slot_mapping_metadata = np.stack( + [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1) + return slot_mapping_metadata + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -567,26 +629,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.input_batch.block_table[0]. - slot_mapping_np[:total_num_scheduled_tokens]) - # Prepare the attention metadata. self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, @@ -609,12 +651,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.input_batch.block_table[0].slot_mapping_cpu[ - total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = ( - self.input_batch.block_table[0]. - slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( - self.device)) block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) @@ -623,6 +659,18 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device) seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) + slot_mapping_metadata = self._get_slot_mapping_metadata( + num_reqs, num_scheduled_tokens_per_req) + padded_num_slices = _get_padded_num_kv_cache_update_slices( + padded_total_num_scheduled_tokens, self.max_num_reqs, + self.block_size) + slot_mapping_metadata = np.pad( + slot_mapping_metadata, + [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], + constant_values=0) + slot_mapping_metadata = torch.tensor(slot_mapping_metadata, + device=self.device) + if self.lora_config is not None: # We need to respect padding when activating LoRA adapters padded_num_scheduled_tokens_per_req = np.copy( @@ -635,13 +683,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_metadata, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + kv_cache_update_block_size=KV_CACHE_UPDATE_BLOCK_SIZE, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -1033,8 +1082,10 @@ def _dummy_run(self, num_tokens: int) -> None: actual_num_reqs = min(num_tokens, self.max_num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) - slot_mapping = torch.zeros(num_tokens, - dtype=torch.int64).to(self.device) + padded_num_slices = _get_padded_num_kv_cache_update_slices( + num_tokens, self.max_num_reqs, self.block_size) + slot_mapping = torch.zeros((padded_num_slices, 3), + dtype=torch.int32).to(self.device) block_tables = torch.zeros( (self.max_num_reqs, self.block_table_cpu.shape[1]), dtype=torch.int32).to(self.device) @@ -1053,6 +1104,7 @@ def _dummy_run(self, num_tokens: int) -> None: context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + kv_cache_update_block_size=KV_CACHE_UPDATE_BLOCK_SIZE, ) if self.is_multimodal_model: @@ -1646,6 +1698,18 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] +def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, + page_size: int) -> int: + """Calculates the padded number of KV cache update slices to avoid + recompilation.""" + padded_num_slices = 2 * max_num_reqs + num_tokens // page_size + pagged_num_slices = min(padded_num_slices, num_tokens) + pagged_num_slices = ( + pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - + 1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE + return pagged_num_slices + + def replace_set_lora(model): def _tpu_set_lora(