Skip to content

[TPU] add kv cache update kernel #19928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Jun 21, 2025

Purpose

TPU is not good at scatter-update. Here consecutive new kv status will be updated together with the help of the kv cache update kernel.

Test Plan

Kernel test: pytest -s -v tests/v1/tpu/test_kv_cache_update_kernel.py
Accuracy test: pytest -s -v tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine

Test Result

Passed.

Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a specialized kernel for updating the KV cache on TPUs, addressing the inefficiency of scatter-update operations on this hardware. By leveraging a Pallas-based kernel, it enables more performant and consecutive writes to the KV cache, which is crucial for optimizing large language model inference on TPU platforms. The changes involve implementing the new kernel, integrating it into the existing attention backend, and updating the data preparation logic to support the kernel's specific input requirements.

Highlights

  • TPU KV Cache Optimization: Implemented a new Pallas-based kernel (_kv_cache_update_kernel) for efficient KV cache updates on TPUs. This kernel is specifically designed to handle consecutive memory writes, which are significantly more performant than scatter-update operations on TPU architecture, addressing a key performance bottleneck.
  • Integration with vLLM Attention Backend: Integrated the new KV cache update kernel into the vLLM Pallas attention backend. The previous kv_cache.index_copy_ method has been replaced with a custom XLA operation (torch.ops.xla.kv_cache_update_op) that invokes the new Pallas kernel.
  • Dynamic Slot Mapping for Kernel: Introduced a new _get_slot_mapping_metadata function to dynamically compute detailed slice mappings (source and destination indices, and lengths) for KV cache updates. This metadata is crucial for the new kernel to efficiently identify and transfer data segments.
  • CI and Testing: Added a dedicated kernel test (test_kv_cache_update_kernel.py) to the .buildkite CI pipeline. This ensures the correctness and accuracy of the new KV cache update operation on TPUs by comparing its output against a CPU-based reference implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added ci/build v1 tpu Related to Google TPUs labels Jun 21, 2025
@yaochengji
Copy link
Collaborator Author

The kernel implementation is based on the discussion with @bythew3i and @vanbasten23 .

We can observe some performance improvement:
Llama 8B: 7.85 reqs/s -> 8.04 reqs/s
Llama 70B: 6.41 reqs/s -> 6.71 reqs/s

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new kernel for updating the KV cache on TPUs, which aims to improve performance by reducing scatter updates. The changes include adding a test for the new kernel, incorporating the kernel into the vLLM workflow, and defining necessary functions and constants. The code appears well-structured and includes necessary documentation. However, some comments could be added to improve readability.

Comment on lines +46 to +48
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size,
kernel_block_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding comments to explain the purpose of each parameter in the kv_cache_update_op function call to improve readability.

    new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
        new_kv_xla,  # new key values to be written
        slot_mapping_xla,  # mapping of slots to blocks in KV cache
        kv_cache_xla,  # KV cache to be updated
        page_size,  # size of each page
        kernel_block_size  # block size used in the kernel
    )

Comment on lines 56 to +59
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
KV_CACHE_UPDATE_BLOCK_SIZE = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment explaining why KV_CACHE_UPDATE_BLOCK_SIZE is set to 8.

Suggested change
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
KV_CACHE_UPDATE_BLOCK_SIZE = 8
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
KV_CACHE_UPDATE_BLOCK_SIZE = 8 # Block size used for updating KV cache

Comment on lines +663 to +665
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a brief comment explaining the purpose of the _get_padded_num_kv_cache_update_slices function call and what the padded number of slices represents.

Suggested change
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs,
self.block_size)
slot_mapping_metadata = self._get_slot_mapping_metadata(
num_reqs, num_scheduled_tokens_per_req)
# Calculate the padded number of slices for KV cache update
padded_num_slices = _get_padded_num_kv_cache_update_slices(

Comment on lines +1702 to +1704
""" A fixed shape of slot_mapping_metadata tensor is required to avoid
recompilation.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider rephrasing the docstring to be more concise and clear. For example, you could say: "Calculates the padded number of KV cache update slices to avoid recompilation."

Suggested change
""" A fixed shape of slot_mapping_metadata tensor is required to avoid
recompilation.
"""
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
"""Calculates the padded number of KV cache update slices to avoid recompilation."""

Comment on lines +1705 to +1710
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
pagged_num_slices = min(padded_num_slices, num_tokens)
pagged_num_slices = (
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE -
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE
return pagged_num_slices
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a comment to explain the purpose of each step in the calculation of pagged_num_slices to improve readability.

Suggested change
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
pagged_num_slices = min(padded_num_slices, num_tokens)
pagged_num_slices = (
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE -
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE
return pagged_num_slices
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
pagged_num_slices = min(padded_num_slices, num_tokens)
pagged_num_slices = (
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - # Add block size - 1
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE # Divide and multiply by block size to get the padded value
return pagged_num_slices

@yaochengji yaochengji requested review from mgoin and vanbasten23 June 21, 2025 05:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant