-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[TPU] add kv cache update kernel #19928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Chengji Yao <[email protected]>
Signed-off-by: Chengji Yao <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a specialized kernel for updating the KV cache on TPUs, addressing the inefficiency of scatter-update operations on this hardware. By leveraging a Pallas-based kernel, it enables more performant and consecutive writes to the KV cache, which is crucial for optimizing large language model inference on TPU platforms. The changes involve implementing the new kernel, integrating it into the existing attention backend, and updating the data preparation logic to support the kernel's specific input requirements.
Highlights
- TPU KV Cache Optimization: Implemented a new Pallas-based kernel (
_kv_cache_update_kernel
) for efficient KV cache updates on TPUs. This kernel is specifically designed to handle consecutive memory writes, which are significantly more performant than scatter-update operations on TPU architecture, addressing a key performance bottleneck. - Integration with vLLM Attention Backend: Integrated the new KV cache update kernel into the vLLM Pallas attention backend. The previous
kv_cache.index_copy_
method has been replaced with a custom XLA operation (torch.ops.xla.kv_cache_update_op
) that invokes the new Pallas kernel. - Dynamic Slot Mapping for Kernel: Introduced a new
_get_slot_mapping_metadata
function to dynamically compute detailed slice mappings (source and destination indices, and lengths) for KV cache updates. This metadata is crucial for the new kernel to efficiently identify and transfer data segments. - CI and Testing: Added a dedicated kernel test (
test_kv_cache_update_kernel.py
) to the.buildkite
CI pipeline. This ensures the correctness and accuracy of the new KV cache update operation on TPUs by comparing its output against a CPU-based reference implementation.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
The kernel implementation is based on the discussion with @bythew3i and @vanbasten23 . We can observe some performance improvement: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request introduces a new kernel for updating the KV cache on TPUs, which aims to improve performance by reducing scatter updates. The changes include adding a test for the new kernel, incorporating the kernel into the vLLM workflow, and defining necessary functions and constants. The code appears well-structured and includes necessary documentation. However, some comments could be added to improve readability.
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( | ||
new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, | ||
kernel_block_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding comments to explain the purpose of each parameter in the kv_cache_update_op
function call to improve readability.
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, # new key values to be written
slot_mapping_xla, # mapping of slots to blocks in KV cache
kv_cache_xla, # KV cache to be updated
page_size, # size of each page
kernel_block_size # block size used in the kernel
)
INVALID_TOKEN_ID = -1 | ||
# Smallest output size | ||
MIN_NUM_SEQS = 8 | ||
KV_CACHE_UPDATE_BLOCK_SIZE = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a comment explaining why KV_CACHE_UPDATE_BLOCK_SIZE
is set to 8.
INVALID_TOKEN_ID = -1 | |
# Smallest output size | |
MIN_NUM_SEQS = 8 | |
KV_CACHE_UPDATE_BLOCK_SIZE = 8 | |
INVALID_TOKEN_ID = -1 | |
# Smallest output size | |
MIN_NUM_SEQS = 8 | |
KV_CACHE_UPDATE_BLOCK_SIZE = 8 # Block size used for updating KV cache |
padded_num_slices = _get_padded_num_kv_cache_update_slices( | ||
padded_total_num_scheduled_tokens, self.max_num_reqs, | ||
self.block_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a brief comment explaining the purpose of the _get_padded_num_kv_cache_update_slices
function call and what the padded number of slices represents.
padded_num_slices = _get_padded_num_kv_cache_update_slices( | |
padded_total_num_scheduled_tokens, self.max_num_reqs, | |
self.block_size) | |
slot_mapping_metadata = self._get_slot_mapping_metadata( | |
num_reqs, num_scheduled_tokens_per_req) | |
# Calculate the padded number of slices for KV cache update | |
padded_num_slices = _get_padded_num_kv_cache_update_slices( |
""" A fixed shape of slot_mapping_metadata tensor is required to avoid | ||
recompilation. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider rephrasing the docstring to be more concise and clear. For example, you could say: "Calculates the padded number of KV cache update slices to avoid recompilation."
""" A fixed shape of slot_mapping_metadata tensor is required to avoid | |
recompilation. | |
""" | |
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, | |
page_size: int) -> int: | |
"""Calculates the padded number of KV cache update slices to avoid recompilation.""" |
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size | ||
pagged_num_slices = min(padded_num_slices, num_tokens) | ||
pagged_num_slices = ( | ||
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - | ||
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE | ||
return pagged_num_slices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a comment to explain the purpose of each step in the calculation of pagged_num_slices
to improve readability.
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size | |
pagged_num_slices = min(padded_num_slices, num_tokens) | |
pagged_num_slices = ( | |
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - | |
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE | |
return pagged_num_slices | |
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size | |
pagged_num_slices = min(padded_num_slices, num_tokens) | |
pagged_num_slices = ( | |
pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - # Add block size - 1 | |
1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE # Divide and multiply by block size to get the padded value | |
return pagged_num_slices |
Purpose
TPU is not good at scatter-update. Here consecutive new kv status will be updated together with the help of the kv cache update kernel.
Test Plan
Kernel test: pytest -s -v tests/v1/tpu/test_kv_cache_update_kernel.py
Accuracy test: pytest -s -v tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
Test Result
Passed.