Skip to content

Commit ee1531b

Browse files
authored
[Bugfix][2/n] Fix speculative decoding CI - Fix test_ngram_e2e_greedy_correctness (#19644)
1 parent e13945f commit ee1531b

File tree

5 files changed

+50
-3
lines changed

5 files changed

+50
-3
lines changed

tests/spec_decode/e2e/test_integration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414
@pytest.mark.parametrize(
1515
"common_llm_kwargs",
1616
[{
17+
"model_name": "JackFram/llama-68m",
1718
1819
# Verify equality when cuda graphs allowed.
1920
"enforce_eager": False,
20-
"model_name": "JackFram/llama-68m",
21+
22+
# The original model is float32, keep it for numerical stability.
23+
"dtype": "float32",
2124
}])
2225
@pytest.mark.parametrize(
2326
"per_test_common_llm_kwargs",
@@ -59,6 +62,9 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
5962
6063
# Skip cuda graph recording for fast test.
6164
"enforce_eager": True,
65+
66+
# The original model is float32, keep it for numerical stability.
67+
"dtype": "float32",
6268
}])
6369
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
6470
@pytest.mark.parametrize(
@@ -117,6 +123,9 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
117123
118124
# Skip cuda graph recording for fast test.
119125
"enforce_eager": True,
126+
127+
# The original model is float32, keep it for numerical stability.
128+
"dtype": "float32",
120129
}])
121130
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
122131
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

tests/spec_decode/e2e/test_logprobs.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
"model_name": "JackFram/llama-160m",
1818
1919
# Skip cuda graph recording for fast test.
20-
"enforce_eager": True
20+
"enforce_eager": True,
21+
22+
# The original model is float32, keep it for numerical stability.
23+
"dtype": "float32",
2124
}])
2225
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
2326
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -75,6 +78,9 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
7578
7679
# Skip cuda graph recording for fast test.
7780
"enforce_eager": True,
81+
82+
# The original model is float32, keep it for numerical stability.
83+
"dtype": "float32",
7884
}])
7985
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
8086
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -128,6 +134,9 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
128134
129135
# Skip cuda graph recording for fast test.
130136
"enforce_eager": True,
137+
138+
# The original model is float32, keep it for numerical stability.
139+
"dtype": "float32",
131140
}])
132141
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
133142
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -182,6 +191,9 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
182191
183192
# Skip cuda graph recording for fast test.
184193
"enforce_eager": True,
194+
195+
# The original model is float32, keep it for numerical stability.
196+
"dtype": "float32",
185197
}])
186198
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
187199
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -256,8 +268,12 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
256268
"common_llm_kwargs",
257269
[{
258270
"model_name": "JackFram/llama-160m",
271+
259272
# Skip cuda graph recording for fast test.
260273
"enforce_eager": True,
274+
275+
# The original model is float32, keep it for numerical stability.
276+
"dtype": "float32",
261277
}])
262278
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
263279
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,9 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
494494
495495
# Skip cuda graph recording for fast test.
496496
"enforce_eager": True,
497+
498+
# Precision
499+
"dtype": PRECISION,
497500
}])
498501
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
499502
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

tests/spec_decode/e2e/test_ngram_correctness.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
4141
# Print spec metrics.
4242
"disable_log_stats": False,
43+
44+
# The original model is float32, keep it for numerical stability.
45+
"dtype": "float32",
4346
}])
4447
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
4548
{
@@ -97,6 +100,9 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
97100
98101
# Print spec metrics.
99102
"disable_log_stats": False,
103+
104+
# The original model is float32, keep it for numerical stability.
105+
"dtype": "float32",
100106
}])
101107
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
102108
{
@@ -160,6 +166,9 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
160166
161167
# Skip cuda graph recording for fast test.
162168
"enforce_eager": True,
169+
170+
# The original model is float32, keep it for numerical stability.
171+
"dtype": "float32",
163172
}])
164173
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
165174
{
@@ -221,6 +230,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
221230
222231
# Skip cuda graph recording for fast test.
223232
"enforce_eager": True,
233+
234+
# The original model is float32, keep it for numerical stability.
235+
"dtype": "float32",
224236
}])
225237
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
226238
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -281,6 +293,9 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
281293
282294
# Skip cuda graph recording for fast test.
283295
"enforce_eager": True,
296+
297+
# The original model is float32, keep it for numerical stability.
298+
"dtype": "float32",
284299
}])
285300
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
286301
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -337,6 +352,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
337352
338353
# Skip cuda graph recording for fast test.
339354
"enforce_eager": True,
355+
356+
# The original model is float32, keep it for numerical stability.
357+
"dtype": "float32",
340358
}])
341359
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
342360
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

vllm/model_executor/models/eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class EAGLE(nn.Module):
7474
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
7575
super().__init__()
7676
config = vllm_config.model_config.hf_config
77+
self.dtype = vllm_config.model_config.dtype
7778
self.config = config
7879

7980
architectures = getattr(self.config.model, "architectures", [])
@@ -250,7 +251,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
250251
lm_head_weight = torch.zeros(
251252
self.lm_head.org_vocab_size,
252253
self.lm_head.embedding_dim,
253-
dtype=self.config.torch_dtype,
254+
dtype=self.dtype,
254255
)
255256

256257
weight_loader = getattr(self.lm_head.weight, "weight_loader",

0 commit comments

Comments
 (0)