Skip to content

[Tests] V1 EAGLE Tests for Acceptance Rate #19104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
2 changes: 1 addition & 1 deletion benchmarks/kernels/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import itertools

import torch
import triton
from weight_shapes import WEIGHT_SHAPES

from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
from vllm.triton_utils import triton


@triton.testing.perf_report(
Expand Down
149 changes: 145 additions & 4 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,35 @@
import random
from typing import Any

import numpy as np
import pytest

from vllm import LLM, SamplingParams
from vllm.v1.metrics.reader import Metric


def get_spec_acceptance_metrics(metrics: list[Metric], k: int):
num_drafts = 0
num_accepted = 0
acceptance_counts = [0] * k
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accepted += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
for pos in range(len(metric.values)):
acceptance_counts[pos] += metric.values[pos]
acceptance_rate_per_pos = [
count / num_drafts for count in acceptance_counts
]
mean_acceptance_length = 1 + (num_accepted / num_drafts)
return {
"num_drafts": num_drafts,
"num_accepted": num_accepted,
"acceptance_rate_per_pos": acceptance_rate_per_pos,
"mean_acceptance_length": mean_acceptance_length,
}


@pytest.fixture
Expand Down Expand Up @@ -42,6 +68,34 @@ def test_prompts():
return prompts


@pytest.fixture
def test_ngram_acceptance_rate_prompts():
prompts = []
words = ["test", "temp", "hello", "where"]
for i in range(len(words)):
word = words[i]
prompt = f"Please repeat the word '{word}' 50 times.\n"
prompt += "Here is an example of how it should look like: " + " ".join(
[word] * 10) + "...\n"
prompt += "Give no other output than the word at least "
prompt += "fifty times in a row in lowercase "
prompt += "with spaces between each word and without quotes."
prompts.append([{"role": "user", "content": prompt}])
return prompts


@pytest.fixture
def test_draft_acceptance_rate_prompts():
prompts = [
"Please write a short story about a cat that loves to chase mice.",
"What is the capital of France?",
"Explain the theory of relativity in simple terms.",
"Describe the process of photosynthesis in plants.",
"What are the main ingredients in a traditional pizza?",
]
return [[{"role": "user", "content": prompt}] for prompt in prompts]


@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
Expand Down Expand Up @@ -98,9 +152,9 @@ def test_ngram_correctness(
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Heuristic: expect at least 65% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
assert matches > int(0.65 * (matches + misses))
del spec_llm


Expand Down Expand Up @@ -147,7 +201,94 @@ def test_eagle_correctness(
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 66% of the prompts to match exactly
# Heuristic: expect at least 65% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
assert matches > int(0.65 * len(ref_outputs))
del spec_llm


def test_ngram_acceptance_rate(
monkeypatch: pytest.MonkeyPatch,
test_ngram_acceptance_rate_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Test the acceptance rate of speculative decoding using ngram method.
The acceptance rate should be very high on the sample prompts,
as they are designed for 100% matches with the ngram method.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
disable_log_stats=False,
)
sampling_config.max_tokens = 50
spec_llm.chat(test_ngram_acceptance_rate_prompts, sampling_config)

metrics = get_spec_acceptance_metrics(spec_llm.get_metrics(), k=3)

# Expect nearly all (90%) of drafted tokens to be accepted
mean_acceptance_rate = np.mean(metrics["acceptance_rate_per_pos"])
assert mean_acceptance_rate > 0.90

# Expect the average acceptance length to be greater than 3
assert metrics["mean_acceptance_length"] > 3

del spec_llm


@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_acceptance_rate(
monkeypatch: pytest.MonkeyPatch,
test_draft_acceptance_rate_prompts: list[dict[str, Any]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
):
'''
Test the acceptance rate of speculative decoding using EAGLE methods.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=2048,
disable_log_stats=False,
)
sampling_config.max_tokens = 50
spec_llm.chat(test_draft_acceptance_rate_prompts, sampling_config)

metrics = get_spec_acceptance_metrics(spec_llm.get_metrics(), k=3)

# Expect many of drafted tokens to be accepted
if use_eagle3:
# EAGLE3 is more accurate, so we expect a higher acceptance rate
assert metrics["acceptance_rate_per_pos"][0] > 0.75
assert metrics["acceptance_rate_per_pos"][2] > 0.4
assert metrics["mean_acceptance_length"] > 2.75
else:
assert metrics["acceptance_rate_per_pos"][0] > 0.6
assert metrics["acceptance_rate_per_pos"][2] > 0.2
assert metrics["mean_acceptance_length"] > 2

del spec_llm
64 changes: 40 additions & 24 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.v1.spec_decode.eagle import EagleProposer

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -112,21 +113,26 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices)


@pytest.mark.parametrize(
"method,proposer_helper,draft_model_dir,target_attribute_path", [
("eagle", lambda k: _create_proposer("eagle", k), eagle_dir,
('lm_head', )),
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
('model', 'embed_tokens')),
])
@pytest.mark.parametrize("method,proposer_helper", [
("eagle", lambda k: _create_proposer("eagle", k)),
("eagle3", lambda k: _create_proposer("eagle3", k)),
])
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path):

# Setup model mock
proposer_helper, pp_size, use_distinct_embed_tokens):
# Setup draft model mock
mock_model = mock.MagicMock()
if use_distinct_embed_tokens:
# Some models can have a different hidden size than the target model,
# so we test that their embed_tokens doesn't get overwritten
mock_model.model.embed_tokens.weight.shape = (131072, 2048)
else:
mock_model.model.embed_tokens.weight.shape = (131072, 4096)

mock_get_model.return_value = mock_model

# Setup mocks for attention layers
Expand All @@ -144,22 +150,24 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,

# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 2 if method == "eagle" else 1
mock_pp_group.world_size = pp_size
mock_get_pp_group.return_value = mock_pp_group

# Setup target model with the appropriate attributes
target_model = mock.MagicMock()
# Setup the target model mock with a custom class so that
# isinstance() checks match the expected type.
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock

# Create the necessary attributes on the target model
current_obj = target_model
for i, attr in enumerate(target_attribute_path):
if i == len(target_attribute_path) - 1:
# Set the last attribute in the path to a MagicMock
setattr(current_obj, attr, mock.MagicMock())
else:
# Create intermediate objects if needed
setattr(current_obj, attr, mock.MagicMock())
current_obj = getattr(current_obj, attr)
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)

from vllm.model_executor.models import SupportsMultiModal
assert not isinstance(target_model, SupportsMultiModal)

if method == "eagle":
target_model.lm_head = mock.MagicMock()

# Create proposer using the helper function
proposer = proposer_helper(k=8)
Expand All @@ -170,10 +178,18 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
# Verify common interactions
mock_get_model.assert_called_once()

# Verify the specific attribute sharing based on the method
# Verify that EAGLE models gain the lm head from the target model
if method == "eagle":
assert proposer.model.lm_head == target_model.lm_head

# Verify that the embed tokens are set correctly
# If pp_size is > 1, the embed tokens should be distinct
if pp_size > 1 or use_distinct_embed_tokens:
assert proposer.model.model.embed_tokens != \
target_model.model.embed_tokens
else:
# When pp_size is 1 and the draft and target models have
# embed_tokens of the same shape, they should be shared.
assert proposer.model.model.embed_tokens == \
target_model.model.embed_tokens

Expand Down
14 changes: 6 additions & 8 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,11 @@ def __init__(
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size

# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
LlamaDecoderLayer(
Expand Down Expand Up @@ -163,4 +161,4 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
return loader.load_weights(model_weights.items())
loader.load_weights(model_weights.items())
24 changes: 14 additions & 10 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
Expand Down Expand Up @@ -94,13 +93,11 @@ def __init__(
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size

# if PP disabled then draft will share embed with target
if get_pp_group().world_size > 1:
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
LlamaDecoderLayer(
Expand Down Expand Up @@ -239,6 +236,7 @@ def combine_hidden_states(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
Expand All @@ -247,12 +245,18 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight

skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=["draft_id_to_target_id"] \
if not includes_draft_id_mapping else None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)

Expand Down
1 change: 1 addition & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ def __init__(self, device: Optional[torch.types.Device] = None):
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
gc.collect()
return current_platform.get_current_memory_usage(self.device)

def __enter__(self):
Expand Down
Loading