Skip to content

Commit fb4ae17

Browse files
committed
Track expert selection metrics
Use VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM=1 to enable this feature. Make sure that PROMETHEUS_MULTIPROC_DIR must be set to get proper metrics! Expect a 2% maximum e2e overhead when running this! Perf on the GPU side is negligible. Note that this PR does enable anything by default, so perf is untouched this way. Signed-off-by: Thibault Schueller <[email protected]> Signed-off-by: 'Thibault Schueller' <'[email protected]'> Signed-off-by: Thibault Schueller <[email protected]>
1 parent 79f2f1c commit fb4ae17

File tree

20 files changed

+277
-13
lines changed

20 files changed

+277
-13
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.model_executor.layers.fused_moe.utils import (
8+
collect_expert_usage_histogram)
9+
10+
11+
@pytest.mark.parametrize("topk_experts,expert_count,topk_ids_dtype",
12+
[(4, 32, torch.int32), (1, 1, torch.int64)])
13+
@pytest.mark.parametrize("token_count", [256, 7])
14+
def test_collect_expert_usage_histogram(topk_experts: int, expert_count: int,
15+
token_count: int,
16+
topk_ids_dtype: torch.dtype):
17+
device = torch.device('cuda')
18+
19+
# Make an uniform distribution of expert usage
20+
topk_ids = torch.stack([torch.arange(topk_experts, dtype=topk_ids_dtype)] *
21+
token_count)
22+
23+
topk_ids_gpu = topk_ids.to(device)
24+
25+
expert_usage_histogram_gpu = torch.zeros(expert_count,
26+
dtype=torch.int32,
27+
device=device)
28+
29+
collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu)
30+
31+
# Every expert is used the same amount, so expecting token_count for
32+
# each expert set in the topk_ids tensor.
33+
assert torch.equal(
34+
expert_usage_histogram_gpu[:topk_experts],
35+
torch.full([topk_experts],
36+
token_count,
37+
dtype=torch.int32,
38+
device=device))
39+
40+
# The rest of the experts weren't used, so they should be zero.
41+
assert expert_usage_histogram_gpu[topk_experts:].sum() == 0
42+
43+
44+
@pytest.mark.parametrize("topk_experts,expert_count", [(16, 32)])
45+
@pytest.mark.parametrize("token_count", [1])
46+
@pytest.mark.parametrize("seed", [0xDEADBEEF, 0xCAFEBABE])
47+
def test_collect_expert_usage_histogram_random(topk_experts: int,
48+
expert_count: int,
49+
token_count: int, seed: int):
50+
device = torch.device('cuda')
51+
52+
generator = torch.Generator()
53+
generator.manual_seed(seed)
54+
55+
# Make random distribution of expert usage
56+
topk_ids_cpu = torch.stack(
57+
[torch.randperm(topk_experts, generator=generator, dtype=torch.int32)
58+
] * token_count)
59+
60+
# Compute ground truth
61+
torch_histogram = torch.histogram(topk_ids_cpu.to(torch.float),
62+
bins=expert_count,
63+
range=(0, expert_count - 1))
64+
65+
# Use our function
66+
expert_usage_histogram_gpu = torch.zeros(expert_count,
67+
dtype=torch.int32,
68+
device=device)
69+
70+
topk_ids_gpu = topk_ids_cpu.to(device)
71+
72+
collect_expert_usage_histogram(topk_ids_gpu, expert_usage_histogram_gpu)
73+
74+
assert torch.equal(expert_usage_histogram_gpu,
75+
torch_histogram.hist.to(torch.int32).to(device))

vllm/config.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ def _verify_bnb_config(self) -> None:
981981

982982
self.enforce_eager = True
983983

984-
def _verify_with_expert_parallelism(self) -> None:
984+
def get_total_num_experts(self) -> int:
985985
num_expert_names = [
986986
"moe_num_experts", # Dbrx
987987
"num_experts", # Jamba
@@ -993,7 +993,10 @@ def _verify_with_expert_parallelism(self) -> None:
993993
num_experts = getattr(self.hf_text_config, name, 0)
994994
if num_experts > 0:
995995
break
996-
if num_experts < 1:
996+
return num_experts
997+
998+
def _verify_with_expert_parallelism(self) -> None:
999+
if self.get_total_num_experts() < 1:
9971000
raise ValueError(
9981001
"Number of experts in the model must be greater than 0 "
9991002
"when expert parallelism is enabled.")
@@ -1222,16 +1225,21 @@ def get_num_attention_heads(self,
12221225
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
12231226
return num_heads // parallel_config.tensor_parallel_size
12241227

1225-
def get_layers_start_end_indices(
1226-
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
1227-
from vllm.distributed.utils import get_pp_indices
1228+
def get_total_num_hidden_layers(self) -> int:
12281229
if (self.hf_text_config.model_type == "deepseek_mtp"
12291230
or self.hf_config.model_type == "mimo_mtp"):
12301231
total_num_hidden_layers = getattr(self.hf_text_config,
12311232
"num_nextn_predict_layers", 0)
12321233
else:
12331234
total_num_hidden_layers = getattr(self.hf_text_config,
12341235
"num_hidden_layers", 0)
1236+
return total_num_hidden_layers
1237+
1238+
def get_layers_start_end_indices(
1239+
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
1240+
from vllm.distributed.utils import get_pp_indices
1241+
total_num_hidden_layers = self.get_total_num_hidden_layers()
1242+
12351243
# the layout order is: DP x PP x TP
12361244
pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size
12371245
) % parallel_config.pipeline_parallel_size

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
132132
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
133133
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
134+
VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM: bool = False
134135

135136

136137
def get_default_cache_root():
@@ -905,6 +906,10 @@ def get_vllm_port() -> Optional[int]:
905906
# or bad hardware but it may add compute overhead.
906907
"VLLM_COMPUTE_NANS_IN_LOGITS":
907908
lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))),
909+
910+
# Collects expert routing histogram per layer
911+
"VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM":
912+
lambda: bool(int(os.getenv("VLLM_COLLECT_EXPERT_USAGE_HISTOGRAM", "0"))),
908913
}
909914

910915
# --8<-- [end:env-vars-definition]

vllm/forward_context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class ForwardContext:
9595
# set dynamically for each forward pass
9696
dp_metadata: Optional[DPMetadata] = None
9797
skip_cuda_graphs: bool = False
98+
# Set when recording usage histogram
99+
expert_usage_histogram: Optional[torch.Tensor] = None
98100

99101

100102
_forward_context: Optional[ForwardContext] = None
@@ -116,6 +118,7 @@ def set_forward_context(
116118
num_tokens: Optional[int] = None,
117119
num_tokens_across_dp: Optional[torch.Tensor] = None,
118120
skip_cuda_graphs: bool = False,
121+
expert_usage_histogram: Optional[torch.Tensor] = None,
119122
):
120123
"""A context manager that stores the current forward context,
121124
can be attention metadata, etc.
@@ -132,6 +135,9 @@ def set_forward_context(
132135
attn_metadata, num_tokens or 0,
133136
num_tokens_across_dp)
134137

138+
if expert_usage_histogram is not None:
139+
expert_usage_histogram.zero_()
140+
135141
global _forward_context
136142
prev_context = _forward_context
137143
_forward_context = ForwardContext(
@@ -141,6 +147,7 @@ def set_forward_context(
141147
attn_metadata=attn_metadata,
142148
dp_metadata=dp_metadata,
143149
skip_cuda_graphs=skip_cuda_graphs,
150+
expert_usage_histogram=expert_usage_histogram,
144151
)
145152

146153
try:

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@
2525
from vllm.model_executor.custom_op import CustomOp
2626
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
2727
is_rocm_aiter_moe_enabled)
28+
from vllm.model_executor.layers.fused_moe.utils import (
29+
collect_expert_usage_histogram)
2830
from vllm.model_executor.layers.quantization.base_config import (
2931
QuantizationConfig, QuantizeMethodBase)
32+
from vllm.model_executor.models.utils import extract_layer_index
3033
from vllm.model_executor.utils import set_weight_attrs
3134
from vllm.platforms import current_platform
3235
from vllm.platforms.interface import CpuArchEnum
@@ -415,6 +418,7 @@ def apply(
415418
router_logits: torch.Tensor,
416419
top_k: int,
417420
renormalize: bool,
421+
layer_index: int,
418422
use_grouped_topk: bool = False,
419423
topk_group: Optional[int] = None,
420424
num_expert_group: Optional[int] = None,
@@ -554,6 +558,7 @@ def apply(
554558
router_logits: torch.Tensor,
555559
top_k: int,
556560
renormalize: bool,
561+
layer_index: int,
557562
use_grouped_topk: bool = False,
558563
topk_group: Optional[int] = None,
559564
num_expert_group: Optional[int] = None,
@@ -571,6 +576,7 @@ def apply(
571576
router_logits=router_logits,
572577
top_k=top_k,
573578
renormalize=renormalize,
579+
layer_index=layer_index,
574580
use_grouped_topk=use_grouped_topk,
575581
topk_group=topk_group,
576582
num_expert_group=num_expert_group,
@@ -590,6 +596,7 @@ def forward_cuda(
590596
top_k: int,
591597
router_logits: torch.Tensor,
592598
renormalize: bool,
599+
layer_index: int,
593600
topk_group: Optional[int] = None,
594601
num_expert_group: Optional[int] = None,
595602
global_num_experts: int = -1,
@@ -607,6 +614,7 @@ def forward_cuda(
607614
use_grouped_topk=use_grouped_topk,
608615
top_k=top_k,
609616
renormalize=renormalize,
617+
layer_index=layer_index,
610618
topk_group=topk_group,
611619
num_expert_group=num_expert_group,
612620
custom_routing_function=custom_routing_function,
@@ -646,6 +654,7 @@ def forward_cpu(
646654
top_k: int,
647655
router_logits: torch.Tensor,
648656
renormalize: bool,
657+
layer_index: int,
649658
topk_group: Optional[int] = None,
650659
num_expert_group: Optional[int] = None,
651660
global_num_experts: int = -1,
@@ -680,6 +689,7 @@ def forward_hpu(
680689
top_k: int,
681690
router_logits: torch.Tensor,
682691
renormalize: bool,
692+
layer_index: int,
683693
topk_group: Optional[int] = None,
684694
num_expert_group: Optional[int] = None,
685695
global_num_experts: int = -1,
@@ -713,6 +723,7 @@ def forward_tpu(
713723
top_k: int,
714724
router_logits: torch.Tensor,
715725
renormalize: bool,
726+
layer_index: int,
716727
topk_group: Optional[int] = None,
717728
num_expert_group: Optional[int] = None,
718729
global_num_experts: int = -1,
@@ -861,6 +872,8 @@ def __init__(
861872
compilation_config.static_forward_context[prefix] = self
862873
self.layer_name = prefix
863874

875+
self.layer_index = extract_layer_index(prefix)
876+
864877
# Determine expert maps
865878
if self.use_ep:
866879
self.local_num_experts, self.expert_map = determine_expert_map(
@@ -1282,6 +1295,7 @@ def select_experts(hidden_states: torch.Tensor,
12821295
top_k: int,
12831296
use_grouped_topk: bool,
12841297
renormalize: bool,
1298+
layer_index: int,
12851299
topk_group: Optional[int] = None,
12861300
num_expert_group: Optional[int] = None,
12871301
custom_routing_function: Optional[Callable] = None,
@@ -1322,6 +1336,12 @@ def select_experts(hidden_states: torch.Tensor,
13221336
if indices_type is not None:
13231337
topk_ids = topk_ids.to(dtype=indices_type)
13241338

1339+
expert_usage_histogram = get_forward_context().expert_usage_histogram
1340+
1341+
if expert_usage_histogram is not None:
1342+
collect_expert_usage_histogram(topk_ids,
1343+
expert_usage_histogram[layer_index])
1344+
13251345
return topk_weights, topk_ids
13261346

13271347
def must_reduce_shared_expert_outputs(self) -> bool:
@@ -1354,10 +1374,12 @@ def maybe_all_reduce_tensor_model_parallel(
13541374
def forward(self, hidden_states: torch.Tensor,
13551375
router_logits: torch.Tensor):
13561376
if self.use_direct_call:
1357-
return self.forward_impl(hidden_states, router_logits)
1377+
return self.forward_impl(hidden_states, router_logits,
1378+
self.layer_index)
13581379
else:
13591380
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
1360-
self.layer_name)
1381+
self.layer_name,
1382+
self.layer_index)
13611383

13621384
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
13631385
full_router_logits: torch.Tensor):
@@ -1396,6 +1418,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
13961418
router_logits=staged_router_logits,
13971419
top_k=self.top_k,
13981420
renormalize=self.renormalize,
1421+
layer_index=self.layer_index,
13991422
use_grouped_topk=self.use_grouped_topk,
14001423
global_num_experts=self.global_num_experts,
14011424
expert_map=self.expert_map,
@@ -1432,7 +1455,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
14321455
return full_final_hidden_states
14331456

14341457
def forward_impl(self, hidden_states: torch.Tensor,
1435-
router_logits: torch.Tensor):
1458+
router_logits: torch.Tensor, layer_index: int):
14361459
assert self.quant_method is not None
14371460
if (self.moe_parallel_config.use_pplx_kernels
14381461
or self.moe_parallel_config.use_deepep_ll_kernels):
@@ -1452,6 +1475,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
14521475
router_logits=router_logits,
14531476
top_k=self.top_k,
14541477
renormalize=self.renormalize,
1478+
layer_index=layer_index,
14551479
use_grouped_topk=self.use_grouped_topk,
14561480
global_num_experts=self.global_num_experts,
14571481
expert_map=self.expert_map,
@@ -1514,16 +1538,16 @@ def extra_repr(self) -> str:
15141538

15151539

15161540
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
1517-
layer_name: str) -> torch.Tensor:
1541+
layer_name: str, layer_index: int) -> torch.Tensor:
15181542
forward_context: ForwardContext = get_forward_context()
15191543
self = forward_context.no_compile_layers[layer_name]
15201544
assert self.quant_method is not None
15211545

1522-
return self.forward_impl(hidden_states, router_logits)
1546+
return self.forward_impl(hidden_states, router_logits, layer_index)
15231547

15241548

15251549
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
1526-
layer_name: str) -> torch.Tensor:
1550+
layer_name: str, layer_index: int) -> torch.Tensor:
15271551
return torch.empty_like(hidden_states)
15281552

15291553

0 commit comments

Comments
 (0)