Skip to content

[TPU] add kv cache update kernel #19928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
# After all tests have been attempted, exit with the overall status.
if [ "$overall_script_exit_code" -ne 0 ]; then
Expand Down
59 changes: 59 additions & 0 deletions tests/v1/tpu/test_kv_cache_update_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import numpy as np
import pytest
import torch
import torch_xla

import vllm.v1.attention.backends.pallas # noqa: F401
from vllm.platforms import current_platform


@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a test for TPU only")
def test_kv_cache_update_kernel():
page_num = 1000
page_size = 32
combined_kv_head_num = 16
head_dim = 128
kernel_block_size = 16
padded_num_tokens = 128
kv_cache_cpu = torch.zeros(
(page_num * page_size, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
new_kv_cpu = torch.randn(
(padded_num_tokens, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, 32, 32, 1, 1, 1, 9], dtype=np.int32)
kv_cache_start_indices = np.array([57, 64, 96, 104, 213, 345, 488],
dtype=np.int32)
new_kv_cache_indices = np.array([0, 7, 39, 71, 72, 73, 74], dtype=np.int32)
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
slot_mapping = np.pad(
slot_mapping, [[0, kernel_block_size - slot_mapping.shape[0]], [0, 0]],
constant_values=0)
slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu")
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
torch_xla.sync()

torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
Copy link
Collaborator

@vanbasten23 vanbasten23 Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to do this torch.ops.xla.dynamo_set_buffer_donor_?

new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
kernel_block_size)
Comment on lines +46 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding comments to explain the purpose of each parameter in the kv_cache_update_op function call to improve readability.

    new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
        new_kv_xla,  # new key values to be written
        slot_mapping_xla,  # mapping of slots to blocks in KV cache
        kv_cache_xla,  # KV cache to be updated
        page_size,  # size of each page
        kernel_block_size  # block size used in the kernel
    )

kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()

for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
slice_lens):
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]

assert torch.allclose(kv_cache_xla.cpu(),
kv_cache_cpu,
atol=1e-4,
rtol=1e-4)
109 changes: 109 additions & 0 deletions vllm/attention/ops/pallas_kv_cache_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu


def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [num_slices, 3]
# Input
new_kv_hbm_ref, # [tokens, kv_head_num, head_dim]
kv_cache_hbm_ref,
# Output
_, # [total_num_pages * page_size, kv_head_num, head_dim]
# Scratch
scratch, # [block_size, page_size, kv_head_num, head_dim]
sem,
):
async_copies = []
block_idx = pl.program_id(0)
block_size = scratch.shape[0]

# Copy from new_kv_hbm_ref to scratch
for i in range(block_size):
offset_i = i + block_idx * block_size
new_kv_start = slices_ref[offset_i, 1]
length = slices_ref[offset_i, 2]
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)

for async_copy in async_copies:
async_copy.wait()

# Copy from scratch to kv_cache_hbm_ref
async_copies.clear()
for i in range(block_size):
offset_i = i + block_idx * block_size
kv_cache_start = slices_ref[offset_i, 0]
length = slices_ref[offset_i, 2]
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()


@functools.partial(
jax.jit,
static_argnames=["page_size", "block_size"],
)
def kv_cache_update(
new_kv: jax.Array, # [total_num_token, kv_head_num, head_dim]
slices: jax.
Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache: jax.
Array, # [total_num_pages * page_size, kv_head_num, head_dim]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use num_combined_kv_heads to be consistent

*,
page_size: int = 32,
block_size: int = 8,
):
assert slices.shape[0] % block_size == 0
_, kv_head_num, head_dim = new_kv.shape

in_specs = [
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]

out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]

scalar_prefetches = [slices]
scratch = pltpu.VMEM(
(block_size, page_size, kv_head_num, head_dim),
new_kv.dtype,
)

scratch_shapes = [
scratch,
pltpu.SemaphoreType.DMA,
]

kernel = pl.pallas_call(
_kv_cache_update_kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(slices.shape[0] // block_size, ),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this maps kv_cache_hbm_ref to the output so that you don't need to specify the output in "_kv_cache_update_kernel"?

)

return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
55 changes: 50 additions & 5 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from typing import Any, Optional

import torch
# Required to register custom ops.
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
# Required to register custom ops.
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
Expand Down Expand Up @@ -108,6 +112,7 @@ class PallasMetadata:
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
kv_cache_update_block_size: int


class PallasAttentionBackendImpl(AttentionImpl):
Expand Down Expand Up @@ -213,7 +218,10 @@ def forward(
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, kv_cache, slot_mapping)
kv_cache_update_block_size = \
attn_metadata.kv_cache_update_block_size
write_to_kv_cache(key, value, kv_cache, slot_mapping,
kv_cache_update_block_size)

output = torch.ops.xla.ragged_paged_attention(
query,
Expand Down Expand Up @@ -245,16 +253,17 @@ def write_to_kv_cache(
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_update_block_size: int,
) -> None:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
kv_cache_update_block_size: int
"""
_, _, num_combined_kv_heads, head_size = kv_cache.shape
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
Expand All @@ -263,4 +272,40 @@ def write_to_kv_cache(
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)

kv_cache = kv_cache.flatten(0, 1)
kv_cache.index_copy_(0, slot_mapping, kv)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv, slot_mapping, kv_cache, page_size, kv_cache_update_block_size)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)


@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
block_size: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), {
"page_size": page_size,
"block_size": block_size
})
return new_kv_cache


XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, "
"int page_size, int block_size) -> Tensor", )


@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
block_size: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
page_size, block_size)
return new_kv_cache


@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor, page_size: int,
block_size: int) -> torch.Tensor:
return kv_cache
Loading