-
Notifications
You must be signed in to change notification settings - Fork 39
Fix numerical issue on hybrid kv cache allocation #1139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These PR doesn't have any tests. Please add the following tests:
- e2e Correctness test: output with and without hybrid allocation is the same
- e2e performance test: performance with hybrid allocator is higher than without hybrid allocator
- unit tests for the changed python files and the runner. We need to keep coverage above 70% and we need our PRs to come with enough tests
| # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b | ||
| ( | ||
| input_ids, | ||
| input_positions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we returning input_positions here? Shouldn't it be in attn_metadata?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Shouldn't it be in attn_metadata"- Yes, it is in attn_metadata.
But if we use hybrid kv cache, attn_metadata is a dict instead of a single metadata obj, which means we need to get it by attn_metadata[any_layer_name].input_positions. Considering either pass in layer name or input_positions directly, I chose the later way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean in self.model_fn, what's the difference between input_positions inside attn_metadata and the input_positions that you are passing directly? It doesn't seem clean to me that now there are two different fields for input_positions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no difference between those 2 input_positions, it's just the attention_metadata becomes a dict[layer_name, attn_metadata_for_that layer] instead of a single attn_metadata shared by every layer. So inside vllm_model_wrapper's step_fun, when we need to pass input_positions to the model, we used to get it from attn_metadata.input_positions, now we need to get it in this way: attn_metadata[layer0_name].input_positions, which requires us to know the layer name, so I chose to pass input_positions directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it makes the code messy due to having two redundant fields. It's won't be clear one should read from input_positions or attn_metadata.input_positions.
Also isn't input_positions same for all layers? So at each later we can do something like attn_metadata.values()[0].input_positions. If not, i think it's better to get using layer name.
Overall having two fields for the same thing doesn't look good i think. wdyt?
py4
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also work for JAX path? if no, can we also make JAX path work?
| # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b | ||
| ( | ||
| input_ids, | ||
| input_positions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean in self.model_fn, what's the difference between input_positions inside attn_metadata and the input_positions that you are passing directly? It doesn't seem clean to me that now there are two different fields for input_positions
It should be backend agnostic, but to enable in Jax, we need to modify the individual jax model. Previously all jax models don't need hybrid kv cache, so it's not enabled. The numerical issue is also reported using vLLM model instead of flax nnx. |
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
Signed-off-by: Chenyaaang <[email protected]>
8f5b161 to
a1d07b7
Compare
Signed-off-by: Chenyaaang <[email protected]>
|
with this PR, Ion gpt-oss, 've verified that numeric issue has been solved & also a performance issue that stemmed from numeric issues has been resolved. |
Signed-off-by: Chenyaaang <[email protected]>
Signed-off-by: Chenyaaang <[email protected]>
Description
Fix numerical issue on hybrid kv cache allocation. When we enable hybrid kv cache, at each kv cache allocation round, the block_id is different between each kv cache group, which means different layers are writing to different block_ids, so we need to create individual attention metadata for each layer, instead of using the same attention metadata for every layer.
Tests
python examples/offline_inference.py --model google/gemma-3-27b-it --tensor-parallel-size 8Checklist
Before submitting this PR, please make sure: