Skip to content

Use FusedMoEQuantConfig everywhere #19921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 40 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3d288bf
turn try_get_optimal_moe_config into an op so it can be torch.compiled
bnellnm Jun 16, 2025
385e0c5
lint
bnellnm Jun 16, 2025
c98ffbe
torch.compile tests
bnellnm Jun 16, 2025
c1c362a
add tests
bnellnm Jun 16, 2025
776ad95
add compiler + cudagraph tests
bnellnm Jun 16, 2025
961b5e8
tests
bnellnm Jun 16, 2025
bd9bd37
reduce number of compile/cudagraph tests
bnellnm Jun 16, 2025
23f26c9
lint
bnellnm Jun 16, 2025
5d564f6
fix lint
bnellnm Jun 16, 2025
06b4583
fix lint
bnellnm Jun 17, 2025
463ccaa
replace import that lint removed
bnellnm Jun 17, 2025
4ab6af7
fixes
bnellnm Jun 17, 2025
695203d
lint
bnellnm Jun 17, 2025
287a204
opify at a higher level
bnellnm Jun 18, 2025
1c9fd39
de-opify deepgemm kernels
bnellnm Jun 18, 2025
79a1962
remove cruft
bnellnm Jun 18, 2025
07d3aae
MoE refactoring
bnellnm Jun 12, 2025
847ec16
make FusedMoEModularKernel a Leaf
bnellnm Jun 13, 2025
5859222
make FusedMoEModularKernel a Leaf
bnellnm Jun 13, 2025
10137bb
fix format
bnellnm Jun 13, 2025
c73d6ba
config stuff + add more tests
bnellnm Jun 14, 2025
230a1fe
fixes
bnellnm Jun 14, 2025
782c3a0
wip test
bnellnm Jun 16, 2025
1bae03b
fix mergea
bnellnm Jun 16, 2025
7a95679
disable buggy fp8 tests
bnellnm Jun 17, 2025
5e22409
fixes
bnellnm Jun 17, 2025
12e42ea
more lint
bnellnm Jun 17, 2025
0b2f817
more lint
bnellnm Jun 17, 2025
4fdeb70
merge
bnellnm Jun 18, 2025
6b4e406
fix merge
bnellnm Jun 18, 2025
f1572d1
fix deep gemm test
bnellnm Jun 18, 2025
4c35a6c
add supports_expert_map method + cleanup select_gemm_impl methods
bnellnm Jun 19, 2025
69f878b
lint
bnellnm Jun 19, 2025
df3a90e
revert random linter changes
bnellnm Jun 19, 2025
b9046e7
fix comments + lint
bnellnm Jun 20, 2025
875a9c4
remove some logging
bnellnm Jun 20, 2025
ebb9e13
remove unused method
bnellnm Jun 20, 2025
e79b40a
try to fix lint
bnellnm Jun 20, 2025
b5d7cba
add some asserts to make lint happy
bnellnm Jun 20, 2025
1b57e9d
Use FusedMoEQuantConfig everywhere
bnellnm Jun 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions tests/kernels/moe/deepep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def make_deepep_ht_a2a(pg: ProcessGroup,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts,
quant_dtype=q_dtype,
block_shape=block_shape)
ht_args.num_local_experts)


def make_deepep_ll_a2a(pg: ProcessGroup,
Expand Down Expand Up @@ -168,8 +166,6 @@ def make_deepep_ll_a2a(pg: ProcessGroup,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
block_shape=block_shape,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)

Expand Down
214 changes: 179 additions & 35 deletions tests/kernels/moe/test_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,36 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from typing import Optional

import pytest
import torch
import triton.language as tl

from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, triton_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform

NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192


@dataclass
class BatchedMMConfig:
dtype: torch.dtype
in_dtype: torch.dtype
quant_dtype: Optional[torch.dtype]
out_dtype: torch.dtype
Comment on lines +32 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding docstrings to explain the purpose of each field in the BatchedMMConfig dataclass.

num_experts: int
max_tokens_per_expert: int
K: int
Expand All @@ -32,84 +50,210 @@ def make_tensors(config: BatchedMMConfig):
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
dtype=config.in_dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
dtype=config.out_dtype)

num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)


def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:

num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)

for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)

return C
return BatchedMMTensors(A, B, C, num_expert_tokens)


@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
current_platform.seed_everything(7)

config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn

test_output = tensors.C
ref_output = test_output.clone()
if block_shape is not None and not use_fp8_w8a8:
pytest.skip("Don't test blocking for non-quantized types.")

if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None

#print(f"TYPES {dtype}, {act_dtype}, {quant_dtype}")

num_expert_tokens = torch.randint(low=0,
high=max_tokens_per_expert,
size=(num_experts, ),
device="cuda",
dtype=torch.int32)

A, A_q, A_scale = make_quantized_test_activations(
num_experts,
max_tokens_per_expert,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant)

B, B_q, B_scale, _, _, _ = make_test_weights(
num_experts,
N // 2,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
)

out_shape = (num_experts, max_tokens_per_expert, N)
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")

compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]

assert A_q.dtype == B_q.dtype

invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
A_q,
B_q,
test_output,
tensors.num_expert_tokens,
num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
A_scale,
B_scale,
None,
# Quantization schemes
False,
use_fp8_w8a8,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
},
Comment on lines +154 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The block size is hardcoded here. Consider adding it as a parameter to the test function.

block_shape=block_shape,
)

ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
ref_output = native_batched_masked_quant_matmul(
A,
B,
ref_output,
num_expert_tokens,
None,
None,
None,
)

q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape)

rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]

torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize("m", [1, 32, 45, 64, 222])
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
current_platform.seed_everything(7)

use_fp8_w8a8 = dtype == torch.float8_e4m3fn

if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
pytest.skip("Skip quantization test for non-quantized type")

if per_act_token_quant and block_shape is not None or topk > e:
pytest.skip("Skip illegal quantization test")

a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None

_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype)

torch.set_printoptions(profile="full")

with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, quant_dtype, per_act_token_quant,
block_shape)
baseline_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, quant_dtype, per_act_token_quant,
block_shape)

torch.testing.assert_close(triton_output,
baseline_output,
atol=2e-2,
rtol=2e-2)

torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)
Loading