Skip to content

[V1] [Spec decode] Llama4 type eagle support in v1 #18369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
44 changes: 29 additions & 15 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from vllm import LLM, SamplingParams

TP8_REQUIRED_MODELS = [
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
]


@pytest.fixture
def test_prompts():
Expand Down Expand Up @@ -53,14 +57,6 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"


def eagle3_model_name():
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"


def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
Expand Down Expand Up @@ -105,13 +101,20 @@ def test_ngram_correctness(
del spec_llm


@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
@pytest.mark.parametrize(
"method_model_and_draft_model",
[("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"),
("eagle", "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
"ronaldbxu/EAGLE-Llama-4-Maverick-17B-128E-Instruct"),
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B")],
ids=["llama3_eagle", "llama4_eagle", "llama3_eagle3"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
method_model_and_draft_model: tuple[str, str, str],
):
'''
Compare the outputs of a original LLM and a speculative LLM
Expand All @@ -120,17 +123,28 @@ def test_eagle_correctness(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name, max_model_len=2048)
model_name = method_model_and_draft_model[1]

tp = 1

if model_name in TP8_REQUIRED_MODELS:
tp = 8

ref_llm = LLM(model=model_name,
tensor_parallel_size=tp,
max_model_len=2048)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm

spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
method = method_model_and_draft_model[0]
spec_model_name = method_model_and_draft_model[2]

spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
Expand Down
35 changes: 24 additions & 11 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.v1.spec_decode.eagle import EagleProposer

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
llama3_model_dir = "meta-llama/Llama-3.1-8B-Instruct"
llama3_eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
llama3_eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"

llama4_model_dir = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
llama4_eagle_dir = "ronaldbxu/EAGLE-Llama-4-Maverick-17B-128E-Instruct"

def _create_proposer(method: str, k: int) -> EagleProposer:

def _create_proposer(method: str, model_dir: str, draft_model_dir: str,
k: int) -> EagleProposer:
model_config = ModelConfig(model=model_dir,
task="generate",
max_model_len=100,
Expand All @@ -27,9 +31,6 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
seed=None,
trust_remote_code=False)

# Choose model directory based on method
draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir

speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
Expand Down Expand Up @@ -115,8 +116,14 @@ def test_prepare_inputs():


@pytest.mark.parametrize("method,proposer_helper", [
("eagle", lambda k: _create_proposer("eagle", k)),
("eagle3", lambda k: _create_proposer("eagle3", k)),
("eagle",
lambda k: _create_proposer("eagle", llama3_model_dir, llama3_eagle_dir, k)
),
("eagle",
lambda k: _create_proposer("eagle", llama4_model_dir, llama4_eagle_dir, k)
),
("eagle3", lambda k: _create_proposer("eagle3", llama3_model_dir,
llama3_eagle3_dir, k)),
])
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
Expand Down Expand Up @@ -196,7 +203,12 @@ class _TargetModelStub(LlamaForCausalLM):


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens):
@pytest.mark.parametrize("model_and_draft_model",
[(llama3_model_dir, llama3_eagle_dir),
(llama4_model_dir, llama4_eagle_dir)])
def test_propose(num_speculative_tokens, model_and_draft_model):
model_dir = model_and_draft_model[0]
draft_model_dir = model_and_draft_model[1]
# Use GPU device
device = torch.device('cuda')

Expand All @@ -208,7 +220,8 @@ def test_propose(num_speculative_tokens):
vocab_size = 100

# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
proposer = _create_proposer("eagle", model_dir, draft_model_dir,
num_speculative_tokens)
# Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size

Expand Down
Loading