Skip to content

[V1] Perf optimization for layers reusing shared KV cache #19719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions tests/v1/e2e/test_kv_sharing_skip_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import gc
from collections.abc import Iterable
from typing import Optional, Union

import pytest
import torch
from torch import nn
from transformers import Qwen2Config

from vllm import LLM, SamplingParams
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP,
Qwen2Model)
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import (AutoWeightsLoader,
extract_layer_index,
maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from ...utils import fork_new_process_for_each_test

START_KV_SHARING_LAYER = 10


class Qwen2DecoderLayerWithKVSharing(nn.Module):

def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
attn_prefix = f"{prefix}.self_attn"
layer_idx = extract_layer_index(prefix)
kv_sharing_target_layer_name = None

if layer_idx >= START_KV_SHARING_LAYER:
# re-use KV cache from first 5 layers
target_layer_idx = layer_idx % 5
kv_sharing_target_layer_name = f"{attn_prefix}.attn".replace(
str(layer_idx), str(target_layer_idx))
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=attn_prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
)

self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual


class Qwen2ModelWithKVSharing(Qwen2Model):

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None

decode_indices = get_forward_context().decode_indices
if decode_indices is None:
decode_indices = torch.arange(positions.size(0),
device=positions.device)

# Forward with full inputs up to the first layer that shares KV cache
for layer in self.layers[self.start_layer:START_KV_SHARING_LAYER]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)

if decode_indices is not None:
decode_hidden_states = hidden_states[decode_indices]
decode_positions = positions[decode_indices]
decode_residual = (residual[decode_indices]
if residual is not None else None)
else:
decode_hidden_states = hidden_states
decode_positions = positions
decode_residual = residual

# Optimization: forward with partial inputs only for last N layers
for layer in self.layers[START_KV_SHARING_LAYER:self.end_layer]:
decode_hidden_states, decode_residual = layer(
decode_positions,
decode_hidden_states,
decode_residual,
)

# Merge results back
if decode_hidden_states is not None:
hidden_states[decode_indices] = decode_hidden_states
if residual is not None:
residual[decode_indices] = decode_residual

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class TestQwen2ForCausalLM(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = Qwen2ModelWithKVSharing(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
decoder_layer_type=Qwen2DecoderLayerWithKVSharing)
self.lm_head = self.model.embed_tokens
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)


# TODO: make it work with torch.compile
@fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True])
def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
prompt = "What is the capital of France?"
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
sampling_params = SamplingParams(temperature=0.0, max_tokens=40)
single_prompt = [prompt]

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
enforce_eager=enforce_eager)
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text

del llm
gc.collect()
torch.cuda.empty_cache()

m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1")

llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
enforce_eager=enforce_eager)
responses = llm.generate(single_prompt, sampling_params)
output = responses[0].outputs[0].text
assert output == ref_output
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_V1_KV_SHARING_SKIP_PREFILL: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to add it as a cli arg.



def get_default_cache_root():
Expand Down Expand Up @@ -879,6 +880,8 @@ def get_vllm_port() -> Optional[int]:
# processes via zmq.
"VLLM_MQ_MAX_CHUNK_BYTES_MB":
lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")),
"VLLM_V1_KV_SHARING_SKIP_PREFILL":
lambda: os.environ.get("VLLM_V1_KV_SHARING_SKIP_PREFILL", "0") == "1",
}

# --8<-- [end:env-vars-definition]
Expand Down
3 changes: 3 additions & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ForwardContext:
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
decode_indices: Optional[torch.Tensor] = None


_forward_context: Optional[ForwardContext] = None
Expand All @@ -116,6 +117,7 @@ def set_forward_context(
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
decode_indices: Optional[torch.Tensor] = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Expand All @@ -141,6 +143,7 @@ def set_forward_context(
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
decode_indices=decode_indices,
)

try:
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
**attn_kwargs,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -170,7 +171,8 @@ def __init__(
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
} if dual_chunk_attention_config else {},
**attn_kwargs)

def forward(
self,
Expand Down
21 changes: 17 additions & 4 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import numpy as np
import torch

Expand Down Expand Up @@ -119,11 +121,22 @@ def reorder_batch(self, input_batch: InputBatch,

return True

def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
decode_only_common_attn_metadata: Optional[
CommonAttentionMetadata] = None,
):
if decode_only_common_attn_metadata is not None:
raise NotImplementedError(
"CPU backend does not support decode-only attention yet.")
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
query_start_loc_np = (common_attn_metadata.query_start_loc_np
if common_attn_metadata.query_start_loc_np
is not None else self.runner.query_start_loc_np)

runner = self.runner
block_table = self.block_table
Expand All @@ -135,8 +148,8 @@ def build(self, common_prefix_len: int,
) if num_prompt_req < num_reqs else 0
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
num_prefill_tokens = query_start_loc_np[num_prompt_req].item()
num_decode_tokens = query_start_loc_np[num_reqs].item(
) - num_prefill_tokens
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
block_table_tensor = block_table.get_device_tensor()
Expand Down
Loading