25
25
from vllm .model_executor .custom_op import CustomOp
26
26
from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
27
27
is_rocm_aiter_moe_enabled )
28
+ from vllm .model_executor .layers .fused_moe .utils import (
29
+ collect_expert_usage_histogram )
28
30
from vllm .model_executor .layers .quantization .base_config import (
29
31
QuantizationConfig , QuantizeMethodBase )
32
+ from vllm .model_executor .models .utils import extract_layer_index
30
33
from vllm .model_executor .utils import set_weight_attrs
31
34
from vllm .platforms import current_platform
32
35
from vllm .platforms .interface import CpuArchEnum
@@ -415,6 +418,7 @@ def apply(
415
418
router_logits : torch .Tensor ,
416
419
top_k : int ,
417
420
renormalize : bool ,
421
+ layer_index : int ,
418
422
use_grouped_topk : bool = False ,
419
423
topk_group : Optional [int ] = None ,
420
424
num_expert_group : Optional [int ] = None ,
@@ -554,6 +558,7 @@ def apply(
554
558
router_logits : torch .Tensor ,
555
559
top_k : int ,
556
560
renormalize : bool ,
561
+ layer_index : int ,
557
562
use_grouped_topk : bool = False ,
558
563
topk_group : Optional [int ] = None ,
559
564
num_expert_group : Optional [int ] = None ,
@@ -571,6 +576,7 @@ def apply(
571
576
router_logits = router_logits ,
572
577
top_k = top_k ,
573
578
renormalize = renormalize ,
579
+ layer_index = layer_index ,
574
580
use_grouped_topk = use_grouped_topk ,
575
581
topk_group = topk_group ,
576
582
num_expert_group = num_expert_group ,
@@ -590,6 +596,7 @@ def forward_cuda(
590
596
top_k : int ,
591
597
router_logits : torch .Tensor ,
592
598
renormalize : bool ,
599
+ layer_index : int ,
593
600
topk_group : Optional [int ] = None ,
594
601
num_expert_group : Optional [int ] = None ,
595
602
global_num_experts : int = - 1 ,
@@ -607,6 +614,7 @@ def forward_cuda(
607
614
use_grouped_topk = use_grouped_topk ,
608
615
top_k = top_k ,
609
616
renormalize = renormalize ,
617
+ layer_index = layer_index ,
610
618
topk_group = topk_group ,
611
619
num_expert_group = num_expert_group ,
612
620
custom_routing_function = custom_routing_function ,
@@ -646,6 +654,7 @@ def forward_cpu(
646
654
top_k : int ,
647
655
router_logits : torch .Tensor ,
648
656
renormalize : bool ,
657
+ layer_index : int ,
649
658
topk_group : Optional [int ] = None ,
650
659
num_expert_group : Optional [int ] = None ,
651
660
global_num_experts : int = - 1 ,
@@ -680,6 +689,7 @@ def forward_hpu(
680
689
top_k : int ,
681
690
router_logits : torch .Tensor ,
682
691
renormalize : bool ,
692
+ layer_index : int ,
683
693
topk_group : Optional [int ] = None ,
684
694
num_expert_group : Optional [int ] = None ,
685
695
global_num_experts : int = - 1 ,
@@ -713,6 +723,7 @@ def forward_tpu(
713
723
top_k : int ,
714
724
router_logits : torch .Tensor ,
715
725
renormalize : bool ,
726
+ layer_index : int ,
716
727
topk_group : Optional [int ] = None ,
717
728
num_expert_group : Optional [int ] = None ,
718
729
global_num_experts : int = - 1 ,
@@ -861,6 +872,8 @@ def __init__(
861
872
compilation_config .static_forward_context [prefix ] = self
862
873
self .layer_name = prefix
863
874
875
+ self .layer_index = extract_layer_index (prefix )
876
+
864
877
# Determine expert maps
865
878
if self .use_ep :
866
879
self .local_num_experts , self .expert_map = determine_expert_map (
@@ -1282,6 +1295,7 @@ def select_experts(hidden_states: torch.Tensor,
1282
1295
top_k : int ,
1283
1296
use_grouped_topk : bool ,
1284
1297
renormalize : bool ,
1298
+ layer_index : int ,
1285
1299
topk_group : Optional [int ] = None ,
1286
1300
num_expert_group : Optional [int ] = None ,
1287
1301
custom_routing_function : Optional [Callable ] = None ,
@@ -1322,6 +1336,12 @@ def select_experts(hidden_states: torch.Tensor,
1322
1336
if indices_type is not None :
1323
1337
topk_ids = topk_ids .to (dtype = indices_type )
1324
1338
1339
+ expert_usage_histogram = get_forward_context ().expert_usage_histogram
1340
+
1341
+ if expert_usage_histogram is not None :
1342
+ collect_expert_usage_histogram (topk_ids ,
1343
+ expert_usage_histogram [layer_index ])
1344
+
1325
1345
return topk_weights , topk_ids
1326
1346
1327
1347
def must_reduce_shared_expert_outputs (self ) -> bool :
@@ -1354,10 +1374,12 @@ def maybe_all_reduce_tensor_model_parallel(
1354
1374
def forward (self , hidden_states : torch .Tensor ,
1355
1375
router_logits : torch .Tensor ):
1356
1376
if self .use_direct_call :
1357
- return self .forward_impl (hidden_states , router_logits )
1377
+ return self .forward_impl (hidden_states , router_logits ,
1378
+ self .layer_index )
1358
1379
else :
1359
1380
return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1360
- self .layer_name )
1381
+ self .layer_name ,
1382
+ self .layer_index )
1361
1383
1362
1384
def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
1363
1385
full_router_logits : torch .Tensor ):
@@ -1396,6 +1418,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
1396
1418
router_logits = staged_router_logits ,
1397
1419
top_k = self .top_k ,
1398
1420
renormalize = self .renormalize ,
1421
+ layer_index = self .layer_index ,
1399
1422
use_grouped_topk = self .use_grouped_topk ,
1400
1423
global_num_experts = self .global_num_experts ,
1401
1424
expert_map = self .expert_map ,
@@ -1432,7 +1455,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
1432
1455
return full_final_hidden_states
1433
1456
1434
1457
def forward_impl (self , hidden_states : torch .Tensor ,
1435
- router_logits : torch .Tensor ):
1458
+ router_logits : torch .Tensor , layer_index : int ):
1436
1459
assert self .quant_method is not None
1437
1460
if (self .moe_parallel_config .use_pplx_kernels
1438
1461
or self .moe_parallel_config .use_deepep_ll_kernels ):
@@ -1452,6 +1475,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
1452
1475
router_logits = router_logits ,
1453
1476
top_k = self .top_k ,
1454
1477
renormalize = self .renormalize ,
1478
+ layer_index = layer_index ,
1455
1479
use_grouped_topk = self .use_grouped_topk ,
1456
1480
global_num_experts = self .global_num_experts ,
1457
1481
expert_map = self .expert_map ,
@@ -1514,16 +1538,16 @@ def extra_repr(self) -> str:
1514
1538
1515
1539
1516
1540
def moe_forward (hidden_states : torch .Tensor , router_logits : torch .Tensor ,
1517
- layer_name : str ) -> torch .Tensor :
1541
+ layer_name : str , layer_index : int ) -> torch .Tensor :
1518
1542
forward_context : ForwardContext = get_forward_context ()
1519
1543
self = forward_context .no_compile_layers [layer_name ]
1520
1544
assert self .quant_method is not None
1521
1545
1522
- return self .forward_impl (hidden_states , router_logits )
1546
+ return self .forward_impl (hidden_states , router_logits , layer_index )
1523
1547
1524
1548
1525
1549
def moe_forward_fake (hidden_states : torch .Tensor , router_logits : torch .Tensor ,
1526
- layer_name : str ) -> torch .Tensor :
1550
+ layer_name : str , layer_index : int ) -> torch .Tensor :
1527
1551
return torch .empty_like (hidden_states )
1528
1552
1529
1553
0 commit comments