Skip to content

Conversation

@Chenyaaang
Copy link
Collaborator

@Chenyaaang Chenyaaang commented Nov 20, 2025

Description

Fix numerical issue on hybrid kv cache allocation. When we enable hybrid kv cache, at each kv cache allocation round, the block_id is different between each kv cache group, which means different layers are writing to different block_ids, so we need to create individual attention metadata for each layer, instead of using the same attention metadata for every layer.

Tests

  • unit tests in tpu_worker, tpu_runner passed
  • The results w/ vs w/o hybrid kv cache are the same when I run offline_inference.py with Gemma model. python examples/offline_inference.py --model google/gemma-3-27b-it --tensor-parallel-size 8
  • CI: https://buildkite.com/tpu-commons/tpu-inference-ci/builds/5787 all tasks are green except for lora, which I believe is an upstream change, not related to my pr.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Copy link
Collaborator

@py4 py4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These PR doesn't have any tests. Please add the following tests:

  1. e2e Correctness test: output with and without hybrid allocation is the same
  2. e2e performance test: performance with hybrid allocator is higher than without hybrid allocator
  3. unit tests for the changed python files and the runner. We need to keep coverage above 70% and we need our PRs to come with enough tests

# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
(
input_ids,
input_positions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we returning input_positions here? Shouldn't it be in attn_metadata?

Copy link
Collaborator Author

@Chenyaaang Chenyaaang Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Shouldn't it be in attn_metadata"- Yes, it is in attn_metadata.

But if we use hybrid kv cache, attn_metadata is a dict instead of a single metadata obj, which means we need to get it by attn_metadata[any_layer_name].input_positions. Considering either pass in layer name or input_positions directly, I chose the later way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean in self.model_fn, what's the difference between input_positions inside attn_metadata and the input_positions that you are passing directly? It doesn't seem clean to me that now there are two different fields for input_positions

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no difference between those 2 input_positions, it's just the attention_metadata becomes a dict[layer_name, attn_metadata_for_that layer] instead of a single attn_metadata shared by every layer. So inside vllm_model_wrapper's step_fun, when we need to pass input_positions to the model, we used to get it from attn_metadata.input_positions, now we need to get it in this way: attn_metadata[layer0_name].input_positions, which requires us to know the layer name, so I chose to pass input_positions directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes the code messy due to having two redundant fields. It's won't be clear one should read from input_positions or attn_metadata.input_positions.
Also isn't input_positions same for all layers? So at each later we can do something like attn_metadata.values()[0].input_positions. If not, i think it's better to get using layer name.
Overall having two fields for the same thing doesn't look good i think. wdyt?

Copy link
Collaborator

@py4 py4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also work for JAX path? if no, can we also make JAX path work?

# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
(
input_ids,
input_positions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean in self.model_fn, what's the difference between input_positions inside attn_metadata and the input_positions that you are passing directly? It doesn't seem clean to me that now there are two different fields for input_positions

@Chenyaaang
Copy link
Collaborator Author

Chenyaaang commented Nov 21, 2025

Does this also work for JAX path? if no, can we also make JAX path work?

It should be backend agnostic, but to enable in Jax, we need to modify the individual jax model. Previously all jax models don't need hybrid kv cache, so it's not enabled. The numerical issue is also reported using vLLM model instead of flax nnx.

@Chenyaaang Chenyaaang closed this Nov 21, 2025
@Chenyaaang Chenyaaang reopened this Nov 21, 2025
@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
@kyuyeunk
Copy link
Collaborator

with this PR, Ion gpt-oss, 've verified that numeric issue has been solved & also a performance issue that stemmed from numeric issues has been resolved.

Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants